Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: ruff
on:
push:
branches:
- main
pull_request:
branches:
- main
jobs:
check-ruff-formatting:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.8
cache: 'pip'
- run: pip install ruff==0.11.8
- run: ruff format --diff
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@
# History files
.history.*


# python
build/
dist/
__pycache__
*.egg-info

# vscode
.vscode/
15 changes: 8 additions & 7 deletions i6_native_ops/fast_viterbi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@

try:
# Package is installed, so ops are already compiled
__version__ = get_distribution('i6_native_ops').version
__version__ = get_distribution("i6_native_ops").version
import i6_native_ops.fast_viterbi.fast_viterbi_core as core
except Exception as e:
except Exception:
# otherwise try to build locally
from torch.utils.cpp_extension import load

base_path = os.path.dirname(__file__)
core = load(
name="fast_viterbi_core",
Expand All @@ -18,14 +19,15 @@
os.path.join(base_path, "core.cu"),
],
extra_include_paths=[os.path.join(base_path, "..", "common")],
)
)


def align_viterbi(
log_probs: torch.FloatTensor,
fsa: Tuple[int, torch.IntTensor, torch.FloatTensor, torch.IntTensor],
seq_lens: torch.IntTensor
seq_lens: torch.IntTensor,
) -> Tuple[torch.IntTensor, torch.FloatTensor]:
""" Find best path with Viterbi algorithm.
"""Find best path with Viterbi algorithm.
:param log_probs: log probabilities of emission model as a (B, T, F)
:param fsa: weighted finite state automaton as a tuple consisting of:
* number of states
Expand All @@ -40,8 +42,7 @@ def align_viterbi(
log_probs = log_probs.transpose(0, 1).contiguous()
num_states, edge_tensor, weight_tensor, start_end_states = fsa
alignment, scores = core.fast_viterbi(
log_probs, edge_tensor, weight_tensor,
start_end_states, seq_lens, num_states
log_probs, edge_tensor, weight_tensor, start_end_states, seq_lens, num_states
)
alignment_batch_major = alignment.transpose(0, 1).contiguous()
return alignment_batch_major, scores
2 changes: 1 addition & 1 deletion i6_native_ops/monotonic_rnnt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# Package is installed, so ops are already compiled
__version__ = get_distribution("i6_native_ops").version
import i6_native_ops.monotonic_rnnt.monotonic_rnnt_core as core
except Exception as e:
except Exception:
# otherwise try to build locally
from torch.utils.cpp_extension import load

Expand Down
116 changes: 71 additions & 45 deletions i6_native_ops/warp_rnnt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,44 @@
import os
import torch
from typing import Optional, AnyStr, Literal
from typing import Literal
from pkg_resources import get_distribution

try:
# Package is installed, so ops are already compiled
__version__ = get_distribution('i6_native_ops').version
__version__ = get_distribution("i6_native_ops").version
import i6_native_ops.warp_rnnt.warp_rnnt_core as core
except Exception as e:
except Exception:
# otherwise try to build locally
from torch.utils.cpp_extension import load

base_path = os.path.dirname(__file__)
core = load(
name="warp_rnnt_core",
sources=[
f"{base_path}/core.cu",
f"{base_path}/core_gather.cu",
f"{base_path}/core_compact.cu",
f"{base_path}/binding.cpp"
]
)
f"{base_path}/core.cu",
f"{base_path}/core_gather.cu",
f"{base_path}/core_compact.cu",
f"{base_path}/binding.cpp",
],
)


class RNNTLoss(torch.autograd.Function):

@staticmethod
def forward(ctx, log_probs, labels, frames_lengths, labels_lengths, blank=0, fastemit_lambda=0.0):
def forward(
ctx,
log_probs,
labels,
frames_lengths,
labels_lengths,
blank=0,
fastemit_lambda=0.0,
):
costs, ctx.grads = core.rnnt_loss(
xs=log_probs, ys=labels,
xn=frames_lengths, yn=labels_lengths,
xs=log_probs,
ys=labels,
xn=frames_lengths,
yn=labels_lengths,
blank=blank,
fastemit_lambda=fastemit_lambda,
)
Expand All @@ -39,20 +49,32 @@ def backward(ctx, grads_output):
grads_output = grads_output.view(-1, 1, 1, 1).to(ctx.grads)
return ctx.grads.mul_(grads_output), None, None, None, None, None, None

class RNNTLossCompact(torch.autograd.Function):

class RNNTLossCompact(torch.autograd.Function):
@staticmethod
def forward(ctx, log_probs, labels, frames_lengths, labels_lengths, blank=0, fastemit_lambda=0.0, enable_grad: bool = True):

def forward(
ctx,
log_probs,
labels,
frames_lengths,
labels_lengths,
blank=0,
fastemit_lambda=0.0,
enable_grad: bool = True,
):
costs, grads, loc = core.rnnt_loss_compact(
xs=log_probs, ys=labels,
xn=frames_lengths, yn=labels_lengths,
xs=log_probs,
ys=labels,
xn=frames_lengths,
yn=labels_lengths,
blank=blank,
fastemit_lambda=fastemit_lambda,
required_grad=enable_grad
required_grad=enable_grad,
)
if enable_grad:
cumlen = torch.cumsum(frames_lengths * (labels_lengths+1), dim=0, dtype=torch.int32)
cumlen = torch.cumsum(
frames_lengths * (labels_lengths + 1), dim=0, dtype=torch.int32
)
ctx.V = log_probs.size(-1)
ctx.blank = blank
ctx.save_for_backward(grads, loc, cumlen)
Expand All @@ -62,25 +84,24 @@ def forward(ctx, log_probs, labels, frames_lengths, labels_lengths, blank=0, fas
def backward(ctx, grads_output):
grads, loc, cumlen = ctx.saved_tensors
grads_input = core.rnnt_loss_compact_backward(
grads_output.contiguous(),
grads, cumlen,
loc, ctx.V, ctx.blank
grads_output.contiguous(), grads, cumlen, loc, ctx.V, ctx.blank
)

return grads_input, None, None, None, None, None, None


def rnnt_loss(log_probs: torch.FloatTensor,
labels: torch.IntTensor,
frames_lengths: torch.IntTensor,
labels_lengths: torch.IntTensor,
average_frames: bool = False,
reduction: Literal['sum', 'mean', 'none'] = 'none',
blank: int = 0,
gather: bool = False,
fastemit_lambda: float = 0.0,
compact: bool = False) -> torch.Tensor:

def rnnt_loss(
log_probs: torch.FloatTensor,
labels: torch.IntTensor,
frames_lengths: torch.IntTensor,
labels_lengths: torch.IntTensor,
average_frames: bool = False,
reduction: Literal["sum", "mean", "none"] = "none",
blank: int = 0,
gather: bool = False,
fastemit_lambda: float = 0.0,
compact: bool = False,
) -> torch.Tensor:
"""The CUDA-Warp RNN-Transducer loss.

Args:
Expand Down Expand Up @@ -124,26 +145,31 @@ def rnnt_loss(log_probs: torch.FloatTensor,

if compact:
costs = RNNTLossCompact.apply(
log_probs.float(),
labels, frames_lengths,
labels_lengths, blank,
fastemit_lambda,
(log_probs.requires_grad and torch.is_grad_enabled())
log_probs.float(),
labels,
frames_lengths,
labels_lengths,
blank,
fastemit_lambda,
(log_probs.requires_grad and torch.is_grad_enabled()),
)
else:
if gather:

N, T, U, V = log_probs.size()

index = torch.full([N, T, U, 2], blank, device=labels.device, dtype=torch.long)
index = torch.full(
[N, T, U, 2], blank, device=labels.device, dtype=torch.long
)

index[:, :, :U-1, 1] = labels.unsqueeze(dim=1)
index[:, :, : U - 1, 1] = labels.unsqueeze(dim=1)

log_probs = log_probs.gather(dim=3, index=index)

blank = -1

costs = RNNTLoss.apply(log_probs, labels, frames_lengths, labels_lengths, blank, fastemit_lambda)
costs = RNNTLoss.apply(
log_probs, labels, frames_lengths, labels_lengths, blank, fastemit_lambda
)

if average_frames:
costs = costs / frames_lengths.to(log_probs)
Expand All @@ -156,5 +182,5 @@ def rnnt_loss(log_probs: torch.FloatTensor,
return costs.mean()
else:
raise ValueError(
f"Unknown reduction method: {reduction}, expected to be one of ['mean', 'sum', 'none']")

f"Unknown reduction method: {reduction}, expected to be one of ['mean', 'sum', 'none']"
)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
torch
Empty file added test/__init__.py
Empty file.
42 changes: 0 additions & 42 deletions test/fast_viterbi.py

This file was deleted.

51 changes: 51 additions & 0 deletions test/test_fast_viterbi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import unittest
import torch

from i6_native_ops.fast_viterbi import align_viterbi


class TestFastViterbi(unittest.TestCase):
def test_best_sequence(self):
log_probs = (
torch.tensor(
[
[0.9, 0.1],
[0.9, 0.1],
[0.4, 0.6],
[0.1, 0.9],
[0.1, 0.9],
],
device="cuda",
dtype=torch.float32,
)
.unsqueeze(0)
.log()
)
edges = (
torch.tensor(
[
# from, to, emission_idx, sequence_idx
[0, 0, 0, 0], # loop from 0 to 0, emit label 0
[0, 1, 0, 0], # forward from 0 to 1, emit 0
[1, 1, 1, 0],
], # loop from 1 to 1, emit 1
device="cuda",
dtype=torch.int32,
)
.transpose(0, 1)
.contiguous()
)
weights = torch.tensor([1, 1, 1], device="cuda", dtype=torch.float32)
start_end_states = torch.tensor([[0], [1]], dtype=torch.int32, device="cuda")
seq_lens = torch.tensor([5], dtype=torch.int32, device="cuda")

fsa = (2, edges, weights, start_end_states)

output, scores = align_viterbi(log_probs, fsa, seq_lens)
best_sequence = list(output[0, :])

self.assertEqual(best_sequence, [0, 0, 1, 1, 1])


if __name__ == "__main__":
unittest.main()
File renamed without changes.
Loading