Skip to content

Adding support for bias addition + rescaling with token weights to grouped_gemm#5280

Open
metastableB wants to merge 1 commit intopytorch:mainfrom
metastableB:export-D89699751
Open

Adding support for bias addition + rescaling with token weights to grouped_gemm#5280
metastableB wants to merge 1 commit intopytorch:mainfrom
metastableB:export-D89699751

Conversation

@metastableB
Copy link
Contributor

Summary:
Adds support for providing bias and token weights as optional arguments to fbgemm's triton grouped gemm. The changes were added to the _grouped_gemm protected kernel implementation, and exposed through a new public function grouped_gemm_bias_scale --- the original grouped_gemm signature remains untouched.

For internal testing use,

buck test -c fbcode.nvcc_arch=h100a -c fbcode.enable_gpu_sections=true fbcode//deeplearning/fbgemm/fbgemm_gpu/experimental/gemm/test:grouped_gemm_test -- test_grouped_gemm_bias_scale

For basic benchmarking on H100 nodes use,

 buck run -c fbcode.nvcc_arch=h100a -c fbcode.enable_gpu_sections=true fbcode//deeplearning/fbgemm/fbgemm_gpu/experimental/gemm/test:grouped_gemm_bias_scale_benchmark 2>/dev/null
Benchmark Results:
Config fused (ms) triton+torch (ms) torch (ms) Speedup vs torch Speedup vs triton+torch
 Small      0.009             0.017      0.027            3.03x                   1.91x
Medium      0.017             0.036      0.049            2.82x                   2.04x
 Large      0.048             0.091      0.142            2.97x                   1.91x

Differential Revision: D89699751

…ouped_gemm

Summary:
Adds support for  providing bias and  token weights as optional arguments to fbgemm's triton grouped gemm. The changes were added to the `_grouped_gemm` protected kernel implementation, and exposed through a new public function `grouped_gemm_bias_scale` --- the original `grouped_gemm` signature remains untouched.

For internal testing use,
```
buck test -c fbcode.nvcc_arch=h100a -c fbcode.enable_gpu_sections=true fbcode//deeplearning/fbgemm/fbgemm_gpu/experimental/gemm/test:grouped_gemm_test -- test_grouped_gemm_bias_scale
```

For basic benchmarking on H100 nodes use,
```
 buck run -c fbcode.nvcc_arch=h100a -c fbcode.enable_gpu_sections=true fbcode//deeplearning/fbgemm/fbgemm_gpu/experimental/gemm/test:grouped_gemm_bias_scale_benchmark 2>/dev/null
```

```
Benchmark Results:
Config fused (ms) triton+torch (ms) torch (ms) Speedup vs torch Speedup vs triton+torch
 Small      0.009             0.017      0.027            3.03x                   1.91x
Medium      0.017             0.036      0.049            2.82x                   2.04x
 Large      0.048             0.091      0.142            2.97x                   1.91x
```

Differential Revision: D89699751
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant