๐Ÿ”“

Inference Dequantization Pipeline

Key insight: rotate the input, not the weight. Pre-rotating the activation avoids materializing the full weight matrix.

The Key Insight

Naively, dequantization would reconstruct the full weight matrix and compute . Instead, we pre-rotate the activation:

The rotation is applied to x once per group per layer โ€” a (B, d) matrix multiply vs the (M, d) inverse rotation on the weight side.

Pipeline Overview

๐Ÿ“ฅ
Input x
(B, N) activation
๐Ÿ”„
Rotate x
x ยท ฮ แต€ (cheap: Bร—d)
๐Ÿ“–
Unpack + Lookup
uint8 โ†’ codebook[idx]
โœ–๏ธ
Matmul
x_rot @ W_q.T
โš–๏ธ
Rescale
ร— ฮฑ / โˆšd

Forward Pass Algorithm

output = zeros(B, M)

for each group g in [0, n_groups):
    x_g   = x[:, g*d : (g+1)*d]           # (B, d)
    x_rot = x_g @ Pi_g.T                  # (B, d)  rotate input
    idx_g = unpack_4bit(packed[..., g])    # (M, d)  unpack
    W_g   = codebook[idx_g]               # (M, d)  lookup
    out_g = x_rot @ W_g.T                 # (B, M)  matmul
    out_g = out_g * (norms_g / sqrt(d))   # (B, M)  rescale
    output += out_g

Kernel Fusion

Steps 2โ€“5 (unpack โ†’ lookup โ†’ matmul โ†’ rescale) are fused into a single GPU kernel to avoid intermediate tensor materialization.

Naive Pipeline (4 kernel launches)

๐Ÿ“ฆ Unpack uint8 โ†’ int64Global Memory
โ†“ write + read
๐Ÿ“– Codebook lookup โ†’ float32Global Memory
โ†“ write + read
โœ–๏ธ Matrix multiplyGlobal Memory
โ†“ write + read
โš–๏ธ RescaleGlobal Memory

Fused Kernel (1 kernel launch)

๐Ÿ“ฆ Load packed uint8Registers
โ†“ in-register
๐Ÿ”“ Unpack nibbles (bitwise)Registers
โ†“ in-register
๐Ÿ“– Codebook (64B in L1)Shared Mem
โ†“ in-register
โœ–๏ธ Tensor Core MMA + RescaleRegisters
โ†“ in-register
๐Ÿ’พ Store final result1ร— Global Write

Execution Paths

CuTileCUDA 13.1+, Ampere+ (sm80/sm89/sm100+)

NVIDIA cuda.tile_experimental API. Shared-memory codebook, FP16/BF16 tensor cores, tile-based prefetching.

TritonTriton โ‰ฅ 3.0

Portable alternative. Autotuned block sizes per problem shape, software pipelining, TF32 tensor cores.

PyTorch (fallback)No special dependencies

Explicit operations: unpack โ†’ codebook[indices] โ†’ matmul โ†’ rescale. Materializes dequantized weight slice.

Residual Pass Handling

When a layer has residual quantization, the forward method runs _forward_pass twice with different packed data and sums the results:

output  = _forward_pass(x, pass1_data)
output += _forward_pass(x, pass2_data)  # if residual
output += bias                           # if present

Memory Profile

The pipeline never materializes the full Mร—N weight matrix. Peak additional memory:

ComponentSizeNotes
x_rotB ร— d ร— 4BPer group, reused
W slice (PyTorch only)M ร— d ร— 4BPer group, reused
Output acc.B ร— M ร— 4BPersistent
Rotation matrixd ร— d ร— 4BCached

With fused kernels, the dequantized weight slice only exists in registers/shared memory within the kernel โ€” never written to global memory.

Implementation

module.py โ†’ TurboQuantLinear._forward_pass(), TurboQuantLinear.forward()