Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions i6_native_ops/monotonic_rnnt/include/gpu_rnnt_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ __global__ void compute_grad_kernel(Tp *grads, const Tp *const acts, const Tp *c
const int *const S_max, const int *const T_max, const int *const min_allowed_s,
const int *const max_allowed_s, const int *const V, const int blank_) {
int v = static_cast<int>(threadIdx.x);
int bts = static_cast<int>(blockIdx.x); // b, t, s packed
int64_t bts = static_cast<int64_t>(blockIdx.x); // b, t, s packed

int b = 0;
while (b < *B - 1 && denom_start_indices[b + 1] <= bts) {
Expand All @@ -259,7 +259,7 @@ __global__ void compute_grad_kernel(Tp *grads, const Tp *const acts, const Tp *c
const int *min_allowed_s_b = min_allowed_s + b * *T_max;
const int *max_allowed_s_b = max_allowed_s + b * *T_max;

int ts = bts - denom_start_indices[b];
int64_t ts = bts - denom_start_indices[b];
int t = ts / (S_b + 1);
int s = ts % (S_b + 1);

Expand Down