11import os
22import torch
3- from typing import Optional , AnyStr , Literal
3+ from typing import Literal
44from pkg_resources import get_distribution
55
66try :
77 # Package is installed, so ops are already compiled
8- __version__ = get_distribution (' i6_native_ops' ).version
8+ __version__ = get_distribution (" i6_native_ops" ).version
99 import i6_native_ops .warp_rnnt .warp_rnnt_core as core
10- except Exception as e :
10+ except Exception :
1111 # otherwise try to build locally
1212 from torch .utils .cpp_extension import load
13+
1314 base_path = os .path .dirname (__file__ )
1415 core = load (
1516 name = "warp_rnnt_core" ,
1617 sources = [
17- f"{ base_path } /core.cu" ,
18- f"{ base_path } /core_gather.cu" ,
19- f"{ base_path } /core_compact.cu" ,
20- f"{ base_path } /binding.cpp"
21- ]
22- )
18+ f"{ base_path } /core.cu" ,
19+ f"{ base_path } /core_gather.cu" ,
20+ f"{ base_path } /core_compact.cu" ,
21+ f"{ base_path } /binding.cpp" ,
22+ ],
23+ )
2324
2425
2526class RNNTLoss (torch .autograd .Function ):
26-
2727 @staticmethod
28- def forward (ctx , log_probs , labels , frames_lengths , labels_lengths , blank = 0 , fastemit_lambda = 0.0 ):
28+ def forward (
29+ ctx ,
30+ log_probs ,
31+ labels ,
32+ frames_lengths ,
33+ labels_lengths ,
34+ blank = 0 ,
35+ fastemit_lambda = 0.0 ,
36+ ):
2937 costs , ctx .grads = core .rnnt_loss (
30- xs = log_probs , ys = labels ,
31- xn = frames_lengths , yn = labels_lengths ,
38+ xs = log_probs ,
39+ ys = labels ,
40+ xn = frames_lengths ,
41+ yn = labels_lengths ,
3242 blank = blank ,
3343 fastemit_lambda = fastemit_lambda ,
3444 )
@@ -39,20 +49,32 @@ def backward(ctx, grads_output):
3949 grads_output = grads_output .view (- 1 , 1 , 1 , 1 ).to (ctx .grads )
4050 return ctx .grads .mul_ (grads_output ), None , None , None , None , None , None
4151
42- class RNNTLossCompact (torch .autograd .Function ):
4352
53+ class RNNTLossCompact (torch .autograd .Function ):
4454 @staticmethod
45- def forward (ctx , log_probs , labels , frames_lengths , labels_lengths , blank = 0 , fastemit_lambda = 0.0 , enable_grad : bool = True ):
46-
55+ def forward (
56+ ctx ,
57+ log_probs ,
58+ labels ,
59+ frames_lengths ,
60+ labels_lengths ,
61+ blank = 0 ,
62+ fastemit_lambda = 0.0 ,
63+ enable_grad : bool = True ,
64+ ):
4765 costs , grads , loc = core .rnnt_loss_compact (
48- xs = log_probs , ys = labels ,
49- xn = frames_lengths , yn = labels_lengths ,
66+ xs = log_probs ,
67+ ys = labels ,
68+ xn = frames_lengths ,
69+ yn = labels_lengths ,
5070 blank = blank ,
5171 fastemit_lambda = fastemit_lambda ,
52- required_grad = enable_grad
72+ required_grad = enable_grad ,
5373 )
5474 if enable_grad :
55- cumlen = torch .cumsum (frames_lengths * (labels_lengths + 1 ), dim = 0 , dtype = torch .int32 )
75+ cumlen = torch .cumsum (
76+ frames_lengths * (labels_lengths + 1 ), dim = 0 , dtype = torch .int32
77+ )
5678 ctx .V = log_probs .size (- 1 )
5779 ctx .blank = blank
5880 ctx .save_for_backward (grads , loc , cumlen )
@@ -62,25 +84,24 @@ def forward(ctx, log_probs, labels, frames_lengths, labels_lengths, blank=0, fas
6284 def backward (ctx , grads_output ):
6385 grads , loc , cumlen = ctx .saved_tensors
6486 grads_input = core .rnnt_loss_compact_backward (
65- grads_output .contiguous (),
66- grads , cumlen ,
67- loc , ctx .V , ctx .blank
87+ grads_output .contiguous (), grads , cumlen , loc , ctx .V , ctx .blank
6888 )
6989
7090 return grads_input , None , None , None , None , None , None
7191
7292
73- def rnnt_loss (log_probs : torch .FloatTensor ,
74- labels : torch .IntTensor ,
75- frames_lengths : torch .IntTensor ,
76- labels_lengths : torch .IntTensor ,
77- average_frames : bool = False ,
78- reduction : Literal ['sum' , 'mean' , 'none' ] = 'none' ,
79- blank : int = 0 ,
80- gather : bool = False ,
81- fastemit_lambda : float = 0.0 ,
82- compact : bool = False ) -> torch .Tensor :
83-
93+ def rnnt_loss (
94+ log_probs : torch .FloatTensor ,
95+ labels : torch .IntTensor ,
96+ frames_lengths : torch .IntTensor ,
97+ labels_lengths : torch .IntTensor ,
98+ average_frames : bool = False ,
99+ reduction : Literal ["sum" , "mean" , "none" ] = "none" ,
100+ blank : int = 0 ,
101+ gather : bool = False ,
102+ fastemit_lambda : float = 0.0 ,
103+ compact : bool = False ,
104+ ) -> torch .Tensor :
84105 """The CUDA-Warp RNN-Transducer loss.
85106
86107 Args:
@@ -124,26 +145,31 @@ def rnnt_loss(log_probs: torch.FloatTensor,
124145
125146 if compact :
126147 costs = RNNTLossCompact .apply (
127- log_probs .float (),
128- labels , frames_lengths ,
129- labels_lengths , blank ,
130- fastemit_lambda ,
131- (log_probs .requires_grad and torch .is_grad_enabled ())
148+ log_probs .float (),
149+ labels ,
150+ frames_lengths ,
151+ labels_lengths ,
152+ blank ,
153+ fastemit_lambda ,
154+ (log_probs .requires_grad and torch .is_grad_enabled ()),
132155 )
133156 else :
134157 if gather :
135-
136158 N , T , U , V = log_probs .size ()
137159
138- index = torch .full ([N , T , U , 2 ], blank , device = labels .device , dtype = torch .long )
160+ index = torch .full (
161+ [N , T , U , 2 ], blank , device = labels .device , dtype = torch .long
162+ )
139163
140- index [:, :, :U - 1 , 1 ] = labels .unsqueeze (dim = 1 )
164+ index [:, :, : U - 1 , 1 ] = labels .unsqueeze (dim = 1 )
141165
142166 log_probs = log_probs .gather (dim = 3 , index = index )
143167
144168 blank = - 1
145169
146- costs = RNNTLoss .apply (log_probs , labels , frames_lengths , labels_lengths , blank , fastemit_lambda )
170+ costs = RNNTLoss .apply (
171+ log_probs , labels , frames_lengths , labels_lengths , blank , fastemit_lambda
172+ )
147173
148174 if average_frames :
149175 costs = costs / frames_lengths .to (log_probs )
@@ -156,5 +182,5 @@ def rnnt_loss(log_probs: torch.FloatTensor,
156182 return costs .mean ()
157183 else :
158184 raise ValueError (
159- f"Unknown reduction method: { reduction } , expected to be one of ['mean', 'sum', 'none']" )
160-
185+ f"Unknown reduction method: { reduction } , expected to be one of ['mean', 'sum', 'none']"
186+ )
0 commit comments