Written: August 11, 2024
TLDR: A quick highlight of some of gotchas you might encounter when trying to benchmark pt2 code
Intro
During my day-to-day development of PyTorch I very often need to measure the performance of different units of code at different levels of coarseness. Often this entails taking a cuda kernel in isolation and measuring its performance relative to another.
I do this so much in fact that I have some helper utilities defined here: https://github.com/drisspg/transformer_nuggets/tree/main/transformer_nuggets/utils
The tools and data this generates is only valuable though if you can trust the numbers, and I was encountering something strange with this benchmark:
https://github.com/drisspg/transformer_nuggets/blob/main/benchmarks/fp8_sat_cast.py
Whats the Problem?
In FP8 workloads it is common to utilize a technique known as “delayed scaling” to more optimally compute statistics needed for numerical stability during training. I wanted to write a small triton kernel to manually fuse the the operations that we have torch.compile do in this repo: https://github.com/pytorch-labs/float8_experimental
Okay so I we have a triton kernel, the eager code, and the compiled code and we want to compare them. Lets loop over different tensor shapes, saturation settings, and dtypes to get a sense of the performance charactersistics.
Slapping torch.compile over code block
It is not uncomon that you will want to microbenchmark some code that has graph breaks:
compiled_pytorch_fn = torch.compile(eager_scaled_quant)
This would produce the following numbers with our microbenchmark: Besides a log about “cache_size_limit reached”
[2024-01-08 16:35:27,627] torch._dynamo.convert_frame: [WARNING] function: 'eager_scaled_quant' (/home/drisspg/meta/transformer_nuggets/transformer_nuggets/fp8/scaled_quant.py:70)
The numbers look pretty reasonable (truncated for brevity): https://gist.github.com/drisspg/2bea2b1122793dc58c9ce66440bba12a#file-bad_output-md
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [02:18<00:00, 4.31s/it]
numel high_precision_dtype low_precision_dtype saturated triton_time pytorch_time compiled_pytorch_time
-------- ---------------------- --------------------- ----------- ------------- -------------- -----------------------
2097152 torch.bfloat16 torch.float8_e4m3fn True 33.0514 55.897 59.4491
2097152 torch.bfloat16 torch.float8_e4m3fn False 32.5205 27.1579 22.5592
2097152 torch.bfloat16 torch.float8_e5m2 True 32.6485 55.9769 59.7592
2097152 torch.bfloat16 torch.float8_e5m2 False 32.5321 27.3613 22.6497
2097152 torch.float32 torch.float8_e4m3fn True 32.9345 53.8681 57.449And most importantly what you would clean is that it appears the “triton kernel” implementation is universally faster! Cool lets call this custom kernel from now and ship some wins!
Hold on a second!
That log was kinda weird, something about a cache… Well I think the code that I am attempting to compile should be ‘fullgraph’ compilable lets try setting that:
compiled_pytorch_fn = torch.compile(eager_scaled_quant, fullgraph=True)
Now we get an error!
File "/home/drisspg/meta/pytorch/torch/_dynamo/convert_frame.py", line 353, in _convert_frame_assert
unimplemented("cache_size_limit reached")
File "/home/drisspg/meta/pytorch/torch/_dynamo/exc.py", line 193, in unimplemented
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: cache_size_limit reachedCrawling (asking voz) through the dynamo config you will find this argument:
torch._dynamo.config.cache_size_limit
with comment:
# controls the maximum number of cache entries with a guard on same ID_MATCH'd
# object. It also controls the maximum size of cache entries if they don't have
# any ID_MATCH'd guards.Nothing about perf behavior in here but I know that for my example I iterating over 32 configs lets try setting this to something larger, like 1000 (go big or go home).
This gives now the following results:
numel high_precision_dtype low_precision_dtype saturated triton_time pytorch_time compiled_pytorch_time
-------- ---------------------- --------------------- ----------- ------------- -------------- -----------------------
2097152 torch.bfloat16 torch.float8_e4m3fn True 33.2983 55.8643 59.6987
2097152 torch.bfloat16 torch.float8_e4m3fn False 33.0075 26.9028 22.0985
2097152 torch.bfloat16 torch.float8_e5m2 True 33.1423 55.8323 59.4687
2097152 torch.bfloat16 torch.float8_e5m2 False 32.4296 26.9288 22.1912
2097152 torch.float32 torch.float8_e4m3fn True 32.9567 53.8027 58.6836
2097152 torch.float32 torch.float8_e4m3fn False 32.6021 25.575 22.4197
2097152 torch.float32 torch.float8_e5m2 True 33.1044 53.8684 22.212https://gist.github.com/drisspg/2bea2b1122793dc58c9ce66440bba12a#file-actual_numbers-md
Same experiment different conclusions
Re-running the same microbenchmark sweep now with fullgraph=True and increased dynamo cache size limit we come to the opposite conclusion compared to the naive benchmark. For all but a few cases the inductor generated code appears to be more performant. And we want the generated code of inductor and compare to the triton kernel.
I hope that this little note finds someone well, and will inspire you to not ignore logs when things seem slightly awry.