๐Ÿ”“

Inference Dequantization Pipeline

Key insight: rotate the input, not the weight. Pre-rotating the activation avoids materializing the full weight matrix.

The Key Insight

Naively, dequantization would reconstruct the full weight matrix and compute . Instead, we pre-rotate the activation:

The rotation is applied to x once per group per layer โ€” a (B, d) matrix multiply vs the (M, d) inverse rotation on the weight side.

Pipeline Overview

๐Ÿ“ฅ
Input x
(B, N) activation
๐Ÿ”„
Rotate x
x ยท ฮ แต€ (cheap: Bร—d)
๐Ÿ“–
Unpack + Lookup
uint8 โ†’ codebook[idx]
โœ–๏ธ
Matmul
x_rot @ W_q.T
โš–๏ธ
Rescale
ร— ฮฑ / โˆšd

Forward Pass Algorithm

output = zeros(B, M)

for each group g in [0, n_groups):
    x_g   = x[:, g*d : (g+1)*d]           # (B, d)
    x_rot = x_g @ Pi_g.T                  # (B, d)  rotate input
    idx_g = unpack_4bit(packed[..., g])    # (M, d)  unpack
    W_g   = codebook[idx_g]               # (M, d)  lookup
    out_g = x_rot @ W_g.T                 # (B, M)  matmul
    out_g = out_g * (norms_g / sqrt(d))   # (B, M)  rescale
    output += out_g

Kernel Fusion

Steps 2โ€“5 (unpack โ†’ lookup โ†’ matmul โ†’ rescale) are fused into a single GPU kernel to avoid intermediate tensor materialization.

Naive Pipeline (4 kernel launches)

๐Ÿ“ฆ Unpack uint8 โ†’ int64Global Memory
โ†“ write + read
๐Ÿ“– Codebook lookup โ†’ float32Global Memory
โ†“ write + read
โœ–๏ธ Matrix multiplyGlobal Memory
โ†“ write + read
โš–๏ธ RescaleGlobal Memory

Fused Kernel (1 kernel launch)

๐Ÿ“ฆ Load packed uint8Registers
โ†“ in-register
๐Ÿ”“ Unpack nibbles (bitwise)Registers
โ†“ in-register
๐Ÿ“– Codebook (64B in L1)Shared Mem
โ†“ in-register
โœ–๏ธ Tensor Core MMA + RescaleRegisters
โ†“ in-register
๐Ÿ’พ Store final result1ร— Global Write

Execution Paths

CuTileCUDA 13.1+, Ampere+ (sm80/sm89/sm100+)

NVIDIA cuda.tile_experimental API. Shared-memory codebook, FP16/BF16 tensor cores, tile-based prefetching.

TritonTriton โ‰ฅ 3.0

Portable alternative. Autotuned block sizes per problem shape, software pipelining, TF32 tensor cores.

PyTorch (fallback)No special dependencies

Explicit operations: unpack โ†’ codebook[indices] โ†’ matmul โ†’ rescale. Materializes dequantized weight slice.

Residual Pass Handling

When a layer has residual quantization, the forward method runs _forward_pass twice with different packed data and sums the results:

output  = _forward_pass(x, pass1_data)
output += _forward_pass(x, pass2_data)  # if residual
output += bias                           # if present

CPU Offload for Pass 2

When GPU VRAM is limited, pass 2 (residual) data can be offloaded to CPU while pass 1 stays on GPU. This halves the on-GPU quantized weight footprint with ~10% latency overhead from pipelined Host-to-Device copies.

Architecture

GPU (always resident)

  • โ€ข Pass 1 data (indices, norms, codebook)
  • โ€ข SharedScratchPool (2 double-buffered slots)
  • โ€ข Embedding (bf16 or INT4/INT8)
  • โ€ข Activations / KV cache

CPU (pinned memory)

  • โ€ข Pass 2 data (indices, norms, codebook)
  • โ€ข Async H2D via copy_stream per layer

Prefetch Chain (per layer)

1
Fenceโ€” Record CUDA event on default stream, make copy_stream wait
2
Async H2Dโ€” Copy next layer's pass2 data to alternate scratch slot via copy_stream
3
Pass 1 computeโ€” Runs on default stream (overlaps with H2D copy)
4
Waitโ€” Default stream waits for this layer's pass2 copy (started by previous layer)
5
Pass 2 computeโ€” Uses the scratch slot now populated with pass2 data
# Enable at quantization time
config = TurboQuantConfig(..., cpu_offload_pass2=True)
model  = quantize_model(model, config)

# Or override at load time
model = load_quantized(model_name, path, cpu_offload_pass2=True)

Memory Profile

The pipeline never materializes the full Mร—N weight matrix. Peak additional memory:

ComponentSizeNotes
x_rotB ร— d ร— 4BPer group, reused
W slice (PyTorch only)M ร— d ร— 4BPer group, reused
Output acc.B ร— M ร— 4BPersistent
Rotation matrixd ร— d ร— 4BCached

With fused kernels, the dequantized weight slice only exists in registers/shared memory within the kernel โ€” never written to global memory.

Implementation

module.py โ†’ TurboQuantLinear._forward_pass(), TurboQuantLinear.forward()