Inference Dequantization Pipeline
Key insight: rotate the input, not the weight. Pre-rotating the activation avoids materializing the full weight matrix.
The Key Insight
Naively, dequantization would reconstruct the full weight matrix and compute . Instead, we pre-rotate the activation:
The rotation is applied to x once per group per layer โ a (B, d) matrix multiply vs the (M, d) inverse rotation on the weight side.
Pipeline Overview
Forward Pass Algorithm
output = zeros(B, M)
for each group g in [0, n_groups):
x_g = x[:, g*d : (g+1)*d] # (B, d)
x_rot = x_g @ Pi_g.T # (B, d) rotate input
idx_g = unpack_4bit(packed[..., g]) # (M, d) unpack
W_g = codebook[idx_g] # (M, d) lookup
out_g = x_rot @ W_g.T # (B, M) matmul
out_g = out_g * (norms_g / sqrt(d)) # (B, M) rescale
output += out_gKernel Fusion
Steps 2โ5 (unpack โ lookup โ matmul โ rescale) are fused into a single GPU kernel to avoid intermediate tensor materialization.
Naive Pipeline (4 kernel launches)
Fused Kernel (1 kernel launch)
Execution Paths
NVIDIA cuda.tile_experimental API. Shared-memory codebook, FP16/BF16 tensor cores, tile-based prefetching.
Portable alternative. Autotuned block sizes per problem shape, software pipelining, TF32 tensor cores.
Explicit operations: unpack โ codebook[indices] โ matmul โ rescale. Materializes dequantized weight slice.
Residual Pass Handling
When a layer has residual quantization, the forward method runs _forward_pass twice with different packed data and sums the results:
output = _forward_pass(x, pass1_data) output += _forward_pass(x, pass2_data) # if residual output += bias # if present
Memory Profile
The pipeline never materializes the full MรN weight matrix. Peak additional memory:
| Component | Size | Notes |
|---|---|---|
| x_rot | B ร d ร 4B | Per group, reused |
| W slice (PyTorch only) | M ร d ร 4B | Per group, reused |
| Output acc. | B ร M ร 4B | Persistent |
| Rotation matrix | d ร d ร 4B | Cached |
With fused kernels, the dequantized weight slice only exists in registers/shared memory within the kernel โ never written to global memory.