@@ -9,8 +9,8 @@ class SiluAndMulBenchmark(Benchmark):
99
1010 # Workload: small
1111 def setup_small (self ):
12- self .x = torch .randn (1 , 128 , 512 , device = "cuda" , dtype = torch .float16 )
13- self .out = torch .empty (1 , 128 , 256 , device = "cuda" , dtype = torch .float16 )
12+ self .x = torch .randn (8 , 1024 , 2048 , device = self . device , dtype = torch .float16 )
13+ self .out = torch .empty (8 , 1024 , 1024 , device = self . device , dtype = torch .float16 )
1414
1515 def benchmark_small (self ):
1616 self .kernel .silu_and_mul (self .out , self .x )
@@ -21,8 +21,8 @@ def verify_small(self) -> torch.Tensor:
2121
2222 # Workload: medium
2323 def setup_medium (self ):
24- self .x = torch .randn (4 , 512 , 1024 , device = "cuda" , dtype = torch .float16 )
25- self .out = torch .empty (4 , 512 , 512 , device = "cuda" , dtype = torch .float16 )
24+ self .x = torch .randn (8 , 2048 , 4096 , device = self . device , dtype = torch .float16 )
25+ self .out = torch .empty (8 , 2048 , 2048 , device = self . device , dtype = torch .float16 )
2626
2727 def benchmark_medium (self ):
2828 self .kernel .silu_and_mul (self .out , self .x )
@@ -33,12 +33,53 @@ def verify_medium(self) -> torch.Tensor:
3333
3434 # Workload: large
3535 def setup_large (self ):
36- self .x = torch .randn (8 , 1024 , 2048 , device = "cuda" , dtype = torch .float16 )
37- self .out = torch .empty (8 , 1024 , 1024 , device = "cuda" , dtype = torch .float16 )
36+ self .x = torch .randn (8 , 4096 , 8192 , device = self . device , dtype = torch .float16 )
37+ self .out = torch .empty (8 , 4096 , 4096 , device = self . device , dtype = torch .float16 )
3838
3939 def benchmark_large (self ):
4040 self .kernel .silu_and_mul (self .out , self .x )
41+ self .kernel .silu_and_mul (self .out , self .x )
4142
4243 def verify_large (self ) -> torch .Tensor :
4344 d = self .x .shape [- 1 ] // 2
4445 return F .silu (self .x [..., :d ]) * self .x [..., d :]
46+
47+
48+ class GeluAndMulBenchmark (Benchmark ):
49+ seed : int = 42
50+
51+ # Workload: small
52+ def setup_small (self ):
53+ self .x = torch .randn (8 , 1024 , 2048 , device = self .device , dtype = torch .float16 )
54+ self .out = torch .empty (8 , 1024 , 1024 , device = self .device , dtype = torch .float16 )
55+
56+ def benchmark_small (self ):
57+ self .kernel .gelu_and_mul (self .out , self .x )
58+
59+ def verify_small (self ) -> torch .Tensor :
60+ d = self .x .shape [- 1 ] // 2
61+ return F .gelu (self .x [..., :d ]) * self .x [..., d :]
62+
63+ # Workload: medium
64+ def setup_medium (self ):
65+ self .x = torch .randn (8 , 2048 , 4096 , device = self .device , dtype = torch .float16 )
66+ self .out = torch .empty (8 , 2048 , 2048 , device = self .device , dtype = torch .float16 )
67+
68+ def benchmark_medium (self ):
69+ self .kernel .gelu_and_mul (self .out , self .x )
70+
71+ def verify_medium (self ) -> torch .Tensor :
72+ d = self .x .shape [- 1 ] // 2
73+ return F .gelu (self .x [..., :d ]) * self .x [..., d :]
74+
75+ # Workload: large
76+ def setup_large (self ):
77+ self .x = torch .randn (8 , 4096 , 8192 , device = self .device , dtype = torch .float16 )
78+ self .out = torch .empty (8 , 4096 , 4096 , device = self .device , dtype = torch .float16 )
79+
80+ def benchmark_large (self ):
81+ self .kernel .gelu_and_mul (self .out , self .x )
82+
83+ def verify_large (self ) -> torch .Tensor :
84+ d = self .x .shape [- 1 ] // 2
85+ return F .gelu (self .x [..., :d ]) * self .x [..., d :]
0 commit comments