Skip to content

Commit 0cdaf60

Browse files
Add ruff formatting, fbw2 test and minor refactoring (#12)
1 parent d70308a commit 0cdaf60

File tree

12 files changed

+277
-96
lines changed

12 files changed

+277
-96
lines changed

.github/workflows/ruff.yml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
name: ruff
2+
on:
3+
push:
4+
branches:
5+
- main
6+
pull_request:
7+
branches:
8+
- main
9+
jobs:
10+
check-ruff-formatting:
11+
runs-on: ubuntu-latest
12+
steps:
13+
- uses: actions/checkout@v4
14+
- uses: actions/setup-python@v4
15+
with:
16+
python-version: 3.8
17+
cache: 'pip'
18+
- run: pip install ruff==0.11.8
19+
- run: ruff format --diff

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,11 @@
3434
# History files
3535
.history.*
3636

37-
3837
# python
3938
build/
4039
dist/
4140
__pycache__
4241
*.egg-info
42+
43+
# vscode
44+
.vscode/

i6_native_ops/fast_viterbi/__init__.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55

66
try:
77
# Package is installed, so ops are already compiled
8-
__version__ = get_distribution('i6_native_ops').version
8+
__version__ = get_distribution("i6_native_ops").version
99
import i6_native_ops.fast_viterbi.fast_viterbi_core as core
10-
except Exception as e:
10+
except Exception:
1111
# otherwise try to build locally
1212
from torch.utils.cpp_extension import load
13+
1314
base_path = os.path.dirname(__file__)
1415
core = load(
1516
name="fast_viterbi_core",
@@ -18,14 +19,15 @@
1819
os.path.join(base_path, "core.cu"),
1920
],
2021
extra_include_paths=[os.path.join(base_path, "..", "common")],
21-
)
22+
)
23+
2224

2325
def align_viterbi(
2426
log_probs: torch.FloatTensor,
2527
fsa: Tuple[int, torch.IntTensor, torch.FloatTensor, torch.IntTensor],
26-
seq_lens: torch.IntTensor
28+
seq_lens: torch.IntTensor,
2729
) -> Tuple[torch.IntTensor, torch.FloatTensor]:
28-
""" Find best path with Viterbi algorithm.
30+
"""Find best path with Viterbi algorithm.
2931
:param log_probs: log probabilities of emission model as a (B, T, F)
3032
:param fsa: weighted finite state automaton as a tuple consisting of:
3133
* number of states
@@ -40,8 +42,7 @@ def align_viterbi(
4042
log_probs = log_probs.transpose(0, 1).contiguous()
4143
num_states, edge_tensor, weight_tensor, start_end_states = fsa
4244
alignment, scores = core.fast_viterbi(
43-
log_probs, edge_tensor, weight_tensor,
44-
start_end_states, seq_lens, num_states
45+
log_probs, edge_tensor, weight_tensor, start_end_states, seq_lens, num_states
4546
)
4647
alignment_batch_major = alignment.transpose(0, 1).contiguous()
4748
return alignment_batch_major, scores

i6_native_ops/monotonic_rnnt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# Package is installed, so ops are already compiled
1212
__version__ = get_distribution("i6_native_ops").version
1313
import i6_native_ops.monotonic_rnnt.monotonic_rnnt_core as core
14-
except Exception as e:
14+
except Exception:
1515
# otherwise try to build locally
1616
from torch.utils.cpp_extension import load
1717

i6_native_ops/warp_rnnt/__init__.py

Lines changed: 71 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,44 @@
11
import os
22
import torch
3-
from typing import Optional, AnyStr, Literal
3+
from typing import Literal
44
from pkg_resources import get_distribution
55

66
try:
77
# Package is installed, so ops are already compiled
8-
__version__ = get_distribution('i6_native_ops').version
8+
__version__ = get_distribution("i6_native_ops").version
99
import i6_native_ops.warp_rnnt.warp_rnnt_core as core
10-
except Exception as e:
10+
except Exception:
1111
# otherwise try to build locally
1212
from torch.utils.cpp_extension import load
13+
1314
base_path = os.path.dirname(__file__)
1415
core = load(
1516
name="warp_rnnt_core",
1617
sources=[
17-
f"{base_path}/core.cu",
18-
f"{base_path}/core_gather.cu",
19-
f"{base_path}/core_compact.cu",
20-
f"{base_path}/binding.cpp"
21-
]
22-
)
18+
f"{base_path}/core.cu",
19+
f"{base_path}/core_gather.cu",
20+
f"{base_path}/core_compact.cu",
21+
f"{base_path}/binding.cpp",
22+
],
23+
)
2324

2425

2526
class RNNTLoss(torch.autograd.Function):
26-
2727
@staticmethod
28-
def forward(ctx, log_probs, labels, frames_lengths, labels_lengths, blank=0, fastemit_lambda=0.0):
28+
def forward(
29+
ctx,
30+
log_probs,
31+
labels,
32+
frames_lengths,
33+
labels_lengths,
34+
blank=0,
35+
fastemit_lambda=0.0,
36+
):
2937
costs, ctx.grads = core.rnnt_loss(
30-
xs=log_probs, ys=labels,
31-
xn=frames_lengths, yn=labels_lengths,
38+
xs=log_probs,
39+
ys=labels,
40+
xn=frames_lengths,
41+
yn=labels_lengths,
3242
blank=blank,
3343
fastemit_lambda=fastemit_lambda,
3444
)
@@ -39,20 +49,32 @@ def backward(ctx, grads_output):
3949
grads_output = grads_output.view(-1, 1, 1, 1).to(ctx.grads)
4050
return ctx.grads.mul_(grads_output), None, None, None, None, None, None
4151

42-
class RNNTLossCompact(torch.autograd.Function):
4352

53+
class RNNTLossCompact(torch.autograd.Function):
4454
@staticmethod
45-
def forward(ctx, log_probs, labels, frames_lengths, labels_lengths, blank=0, fastemit_lambda=0.0, enable_grad: bool = True):
46-
55+
def forward(
56+
ctx,
57+
log_probs,
58+
labels,
59+
frames_lengths,
60+
labels_lengths,
61+
blank=0,
62+
fastemit_lambda=0.0,
63+
enable_grad: bool = True,
64+
):
4765
costs, grads, loc = core.rnnt_loss_compact(
48-
xs=log_probs, ys=labels,
49-
xn=frames_lengths, yn=labels_lengths,
66+
xs=log_probs,
67+
ys=labels,
68+
xn=frames_lengths,
69+
yn=labels_lengths,
5070
blank=blank,
5171
fastemit_lambda=fastemit_lambda,
52-
required_grad=enable_grad
72+
required_grad=enable_grad,
5373
)
5474
if enable_grad:
55-
cumlen = torch.cumsum(frames_lengths * (labels_lengths+1), dim=0, dtype=torch.int32)
75+
cumlen = torch.cumsum(
76+
frames_lengths * (labels_lengths + 1), dim=0, dtype=torch.int32
77+
)
5678
ctx.V = log_probs.size(-1)
5779
ctx.blank = blank
5880
ctx.save_for_backward(grads, loc, cumlen)
@@ -62,25 +84,24 @@ def forward(ctx, log_probs, labels, frames_lengths, labels_lengths, blank=0, fas
6284
def backward(ctx, grads_output):
6385
grads, loc, cumlen = ctx.saved_tensors
6486
grads_input = core.rnnt_loss_compact_backward(
65-
grads_output.contiguous(),
66-
grads, cumlen,
67-
loc, ctx.V, ctx.blank
87+
grads_output.contiguous(), grads, cumlen, loc, ctx.V, ctx.blank
6888
)
6989

7090
return grads_input, None, None, None, None, None, None
7191

7292

73-
def rnnt_loss(log_probs: torch.FloatTensor,
74-
labels: torch.IntTensor,
75-
frames_lengths: torch.IntTensor,
76-
labels_lengths: torch.IntTensor,
77-
average_frames: bool = False,
78-
reduction: Literal['sum', 'mean', 'none'] = 'none',
79-
blank: int = 0,
80-
gather: bool = False,
81-
fastemit_lambda: float = 0.0,
82-
compact: bool = False) -> torch.Tensor:
83-
93+
def rnnt_loss(
94+
log_probs: torch.FloatTensor,
95+
labels: torch.IntTensor,
96+
frames_lengths: torch.IntTensor,
97+
labels_lengths: torch.IntTensor,
98+
average_frames: bool = False,
99+
reduction: Literal["sum", "mean", "none"] = "none",
100+
blank: int = 0,
101+
gather: bool = False,
102+
fastemit_lambda: float = 0.0,
103+
compact: bool = False,
104+
) -> torch.Tensor:
84105
"""The CUDA-Warp RNN-Transducer loss.
85106
86107
Args:
@@ -124,26 +145,31 @@ def rnnt_loss(log_probs: torch.FloatTensor,
124145

125146
if compact:
126147
costs = RNNTLossCompact.apply(
127-
log_probs.float(),
128-
labels, frames_lengths,
129-
labels_lengths, blank,
130-
fastemit_lambda,
131-
(log_probs.requires_grad and torch.is_grad_enabled())
148+
log_probs.float(),
149+
labels,
150+
frames_lengths,
151+
labels_lengths,
152+
blank,
153+
fastemit_lambda,
154+
(log_probs.requires_grad and torch.is_grad_enabled()),
132155
)
133156
else:
134157
if gather:
135-
136158
N, T, U, V = log_probs.size()
137159

138-
index = torch.full([N, T, U, 2], blank, device=labels.device, dtype=torch.long)
160+
index = torch.full(
161+
[N, T, U, 2], blank, device=labels.device, dtype=torch.long
162+
)
139163

140-
index[:, :, :U-1, 1] = labels.unsqueeze(dim=1)
164+
index[:, :, : U - 1, 1] = labels.unsqueeze(dim=1)
141165

142166
log_probs = log_probs.gather(dim=3, index=index)
143167

144168
blank = -1
145169

146-
costs = RNNTLoss.apply(log_probs, labels, frames_lengths, labels_lengths, blank, fastemit_lambda)
170+
costs = RNNTLoss.apply(
171+
log_probs, labels, frames_lengths, labels_lengths, blank, fastemit_lambda
172+
)
147173

148174
if average_frames:
149175
costs = costs / frames_lengths.to(log_probs)
@@ -156,5 +182,5 @@ def rnnt_loss(log_probs: torch.FloatTensor,
156182
return costs.mean()
157183
else:
158184
raise ValueError(
159-
f"Unknown reduction method: {reduction}, expected to be one of ['mean', 'sum', 'none']")
160-
185+
f"Unknown reduction method: {reduction}, expected to be one of ['mean', 'sum', 'none']"
186+
)

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
torch

test/__init__.py

Whitespace-only changes.

test/fast_viterbi.py

Lines changed: 0 additions & 42 deletions
This file was deleted.

test/test_fast_viterbi.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import unittest
2+
import torch
3+
4+
from i6_native_ops.fast_viterbi import align_viterbi
5+
6+
7+
class TestFastViterbi(unittest.TestCase):
8+
def test_best_sequence(self):
9+
log_probs = (
10+
torch.tensor(
11+
[
12+
[0.9, 0.1],
13+
[0.9, 0.1],
14+
[0.4, 0.6],
15+
[0.1, 0.9],
16+
[0.1, 0.9],
17+
],
18+
device="cuda",
19+
dtype=torch.float32,
20+
)
21+
.unsqueeze(0)
22+
.log()
23+
)
24+
edges = (
25+
torch.tensor(
26+
[
27+
# from, to, emission_idx, sequence_idx
28+
[0, 0, 0, 0], # loop from 0 to 0, emit label 0
29+
[0, 1, 0, 0], # forward from 0 to 1, emit 0
30+
[1, 1, 1, 0],
31+
], # loop from 1 to 1, emit 1
32+
device="cuda",
33+
dtype=torch.int32,
34+
)
35+
.transpose(0, 1)
36+
.contiguous()
37+
)
38+
weights = torch.tensor([1, 1, 1], device="cuda", dtype=torch.float32)
39+
start_end_states = torch.tensor([[0], [1]], dtype=torch.int32, device="cuda")
40+
seq_lens = torch.tensor([5], dtype=torch.int32, device="cuda")
41+
42+
fsa = (2, edges, weights, start_end_states)
43+
44+
output, scores = align_viterbi(log_probs, fsa, seq_lens)
45+
best_sequence = list(output[0, :])
46+
47+
self.assertEqual(best_sequence, [0, 0, 1, 1, 1])
48+
49+
50+
if __name__ == "__main__":
51+
unittest.main()
File renamed without changes.

0 commit comments

Comments
 (0)