Backporting FP8 to the RTX 3090 (No H100 Required)
NVIDIA’s FP8 story is usually told like this: “If you want to experiment with FP8 numerics, you need an H100 (or at least a very new GPU with FP8 support, like an RTX 4090).”
I disagree.
Call it: backporting FP8-style numerics experiments to the RTX 3090.
Not because Ampere magically does FP8 compute (it doesn’t), and not because this makes an RTX 3090 “faster” than Hopper (it won’t).
But because a lot of FP8 research and engineering is really about:
- how you store weights (bytes on the wire)
- when and where you expand them (decode)
- what scaling/quantization contract you enforce
You can explore a surprising amount of that on consumer Ampere, if you’re willing to treat FP8 as a storage format and map the math onto hardware that is available.
Quick note: if you see an acronym you don’t recognize, jump to the glossary.
Code: https://github.com/poad42/cuda-fp8-ampere
The plan
Ampere (sm_86) has extremely capable tensor cores, but it doesn’t have native FP8 tensor-core MMA. What it does have is a very fast path for INT8 tensor cores (IMMA / WMMA).
So the project becomes:
Keep weights stored as 1-byte FP8 bit patterns in VRAM, decode/scale/quantize on the fly, and use INT8 tensor cores for the matmul.
That’s the whole framing: democratize FP8 research by making the storage + numerics experimentable on hardware people actually have.
FP8-as-storage, in one paragraph
I am not trying to do “FP8 compute.” I’m trying to store weights in a compact FP8 format and only expand them when needed.
The VRAM part is simple: FP16/BF16 weights cost 2 bytes/weight, while FP8 weights cost 1 byte/weight. So for large weight matrices, storing FP8 can cut the resident weight footprint (and the bandwidth to stream it) by close to 2×.
In practice you also store scale factors (e.g. one FP16 scale per output channel), but that overhead is tiny compared to the full $N\times K$ weight matrix.
Conceptually:
- Store weights as FP8 bytes (E4M3) — literally
uint8bit patterns. - Decode FP8 → FP16 on the fly using a 256-entry LUT.
- Apply per-output-channel (per-column) scale.
- Quantize to INT8 so the tensor cores can consume it.
- Run IMMA (INT8×INT8→INT32 accumulate), then write FP16 output.
That’s the whole “FP8 without FP8 MMA” idea.
What’s actually new here (and what isn’t)
Three honesty bullets up front:
- This is not a claim that Ampere beats BF16/FP16 cuBLAS. In fact, for pure compute, cuBLAS is usually hard to beat.
- This is not full FP8 training. There’s no backward pass here.
- This project focuses on FP8(E4M3) storage. Extending to E5M2 is conceptually similar (another decode path), but I didn’t build it into this writeup.
So what is interesting?
Bit-level FP8 handling (LUT decode)
I store FP8 weights as raw uint8 bit patterns and decode them with a 256-entry LUT. Since there are only 256 possible FP8 bytes, decode is conceptually:
-
u8→fp16viaLUT[u8]
No __byte_perm tricks here — it’s mostly about making that decode cheap enough to hide behind the tensor-core pipe.
Scaling + quantization as a first-class contract
The weights aren’t “just FP8.” They’re FP8 bits + per-output-channel scale. The kernel makes that explicit: decode → apply scale → saturating quantize to int8 → IMMA.
Stochastic rounding (SR): important, but not implemented here
If you’re interested in FP8 training dynamics, stochastic rounding matters a lot. This project doesn’t implement SR (no backward pass), but if I were pushing this toward “training-like” experiments on older GPUs, SR would be near the top of the list.
Glossary (quick definitions)
- FP8(E4M3): an 8-bit float format. Great for storage, not great for high-accuracy math.
- MMA: matrix multiply-accumulate (the tensor core instruction family).
- IMMA / WMMA: NVIDIA’s tensor core path for int8 matrix multiply (instruction path / CUDA API).
- cuBLAS / cuBLASLt: NVIDIA’s GPU linear algebra libraries (GEMM).
- cp.async: an Ampere instruction to asynchronously copy from global memory to shared memory.
- l2pin: using “persisting L2” cache hints to keep hot tensors resident longer.
- Per-column scale: one scale factor per output channel; common in quantized inference.
- LUT decode: since there are only 256 FP8 bit patterns, decode can be a table lookup.
The pipeline, in one diagram
A (fp16/bf16) B (uint8 fp8-e4m3 bits) col_scales (u16 bits)
[M,K] row-major [N,K] (represents KxN col-major) [N]
| | |
| | (LUT in __constant__) |
| v |
| fp8 -> fp16 decode |
| | |
| +-----------(per-column)--------+
| scale
| |
| v
| fp16 -> int8 (sat)
| |
+--------------- int8 A --------+
(act quant)
|
v
WMMA/IMMA (int8) accumulate (int32)
|
v
D (fp16) written as [N,M]
(represents MxN col-major)
If you’ve never written CUDA kernels: that diagram is basically the whole story.
Baseline: PyTorch decode + matmul
Before writing any custom kernel, I wanted a baseline that matches the real workflow:
# weights stored as FP8 bytes
B_u8 = ... # [N,K] uint8
# decode fp8 -> fp16 every iteration
B_fp16 = LUT[B_u8] * scales[:, None]
# compute in fp16 using standard matmul
out = A @ B_fp16.T
That’s the easiest FP8-as-storage implementation: store FP8 bytes, decode on demand, then use cuBLAS.
Two additional baselines are useful:
- Decode + matmul + downcast output to FP8: what the pipeline looks like if you want to store the output/activation in FP8.
- Matmul-only with fp16 weights cached: not apples-to-apples (you’re no longer storing FP8), but it’s a useful upper bound.
Baseline numbers (PyTorch)
Measured on RTX 3090 Ti (sm_86), CUDA-visible, shape $M=N=K=4096$.
| Path | What it includes | Time / iter | Effective TOPS | Peak alloc |
|---|---|---|---|---|
| Fused extension | custom kernel (fp8imma_ext.imma_fp8_v4_act) | 2.914 ms | 47.17 | 120.1 MiB |
| Naive Torch | decode FP8→fp16 each iter + fp16 matmul | 2.267 ms | 60.63 | 248.1 MiB |
| Naive Torch (end-to-end) | decode + fp16 matmul + downcast output to FP8 | 2.322 ms | 59.18 | 248.1 MiB |
| Torch matmul only | fp16 weights cached (no decode) | 1.828 ms | 75.17 | 120.1 MiB |
Notes (important, and easy to misread):
- The “matmul only” baseline assumes fp16 weights are already resident. That defeats the FP8 VRAM savings.
- “Peak alloc” here is per-call peak allocated bytes; it does not include already-resident fp16 cached weights.
The naive decode+matmul being fast is not a paradox — cuBLAS is extremely optimized, and the decode step is embarrassingly parallel. My main motivation for the fused kernel is controlling memory traffic and keeping the pipeline “weight storage = FP8 bytes” end-to-end.
Fusing it into one kernel
Once you accept that IMMA wants int8 fragments, the kernel is a pipeline problem:
- Where does decode happen? (constant memory LUT vs texture vs global)
- Where does scaling happen? (apply scale in fp16, or bake it into an int8 conversion)
- Where does activation quant happen? (register path vs shared-memory staging)
- How do you feed tensor cores continuously? (avoid stalls from decode/scale/quant)
I ended up implementing variants as a way to test hypotheses.
Variants (experiments)
- v2: baseline fused path (FP8→INT8 JIT + IMMA). Keep it simple and measure.
- v2_i8lut: “what if I precompute a per-column FP8→INT8 table in shared memory?” (sounds clever; didn’t win).
- v3_act_f16: fused activation quantization, register path.
- v4_act_f16: cp.async staging for activations + shared-memory quantization, then IMMA.
- texscale: load per-column scales via TEX.
- l2pin: persisting-L2 hints for B/scales.
Kernel benchmark numbers
Measured via ./build/gpu_bench on RTX 3090 Ti (sm_86), driver 590.48.01, CUDA 13.1.
Shape: M=N=K=4096, --warmup 10 --iters 50.
| Benchmark | Time / iter | Throughput |
|---|---|---|
imma_fp8_jit_v2 | 2.714 ms | 50.63 TOPS |
imma_fp8_jit_v2_l2pin | 2.744 ms | 50.09 TOPS |
imma_fp8_jit_v4_act_f16 | 2.818 ms | 48.77 TOPS |
imma_fp8_jit_v4_act_f16_l2pin | 2.851 ms | 48.21 TOPS |
imma_fp8_jit_v4_act_f16_texscale | 2.824 ms | 48.66 TOPS |
imma_fp8_jit_v4_act_f16_texscale_l2pin | 2.854 ms | 48.16 TOPS |
imma_fp8_jit_v2_i8lut | 3.369 ms | 40.79 TOPS |
imma_fp8_jit_v3_act_f16 | 5.606 ms | 24.52 TOPS |
int8gemm (cuBLASLt baseline) | 0.018 ms | 118.06 TOPS |
Notes:
-
*_l2pincan vary with driver/GPU state and other workloads. - The
int8gemmcuBLASLt number is not FP8-as-storage; it’s a ceiling for int8 TC GEMM on this machine.
Why do this at all?
After seeing the tables, the fair question is:
If naive Torch is already fast, why bother?
Because “fast” depends on what you’re measuring.
On pure matmul throughput, a highly tuned fp16/bf16 GEMM can absolutely win. This project is about a different constraint: weight storage and weight movement.
If your weights are truly stored in FP8 (1 byte/weight), then compared to fp16/bf16 (2 bytes/weight) you’re targeting up to 2× less weight traffic and 2× lower resident weight footprint. That’s a real lever for memory-bound inference workloads — even if you pay some extra compute to decode/scale/quantize.
Practically, the “democratizing FP8 research” win is:
- you can keep the storage format honest (FP8 bytes in VRAM)
- you can experiment with scaling/quantization contracts
- you can measure the cost of decode/quant instead of hiding it in a pre-processing step
So I view this as a tool for exploration: FP8-as-storage end-to-end, on hardware that doesn’t officially “support FP8.”
Try it yourself
Repo: https://github.com/poad42/cuda-fp8-ampere
Build:
git submodule update --init --recursive
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build -j
Run tests:
cd build
ctest --output-on-failure
Run the kernel benches:
./build/gpu_bench --bench imma_fp8_jit_v2 --M 4096 --N 4096 --K 4096 --warmup 10 --iters 50
./build/gpu_bench --bench imma_fp8_jit_v4_act_f16 --M 4096 --N 4096 --K 4096 --warmup 10 --iters 50
./build/gpu_bench --bench imma_fp8_jit_v4_act_f16_texscale --M 4096 --N 4096 --K 4096 --warmup 10 --iters 50
Run the Torch baselines (including end-to-end downcast):
. .venv_torch_cuda312/bin/activate
python scripts/bench_torch_vs_fp8imma.py --M 4096 --N 4096 --K 4096 --kChunk 32 --report_mem --downcast_out_fp8
Next steps
If I had another weekend:
- Add a tiny numerical correctness harness (reference decode + GEMM with tolerances).
- Report a more honest memory metric: resident weights + peak workspace, not just per-call peak alloc.
- Try more realistic shapes (transformer-ish M, larger N, varying K) instead of only 4096³.
If you want to dig into the code, the repo contains:
- a CUDA kernel library (C++ API + C ABI)
- a benchmark harness (
gpu_bench) - a minimal PyTorch extension
- smoke tests (CTest + torch compile/import test)