diff --git a/dlib/cuda/cuda_dlib.cu b/dlib/cuda/cuda_dlib.cu index 82522929ef..6275071ac3 100644 --- a/dlib/cuda/cuda_dlib.cu +++ b/dlib/cuda/cuda_dlib.cu @@ -108,65 +108,69 @@ namespace dlib } } - // ----------------------------------------------------------------------------------- - // ----------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------- - __global__ void _cuda_inverse_norms(float* invnorms, const float* data, size_t nr, size_t nc, const float eps) + __global__ void _cuda_inverse_norms_accumulate( + float* invnorms, + const float* data, + size_t nr, + size_t nc + ) { - // initialize invnorms before we begin. - for (auto i : grid_stride_range_y(0, nr)) - for (auto j : grid_stride_range(0, 1)) - invnorms[i] = eps; - __syncthreads(); - for (auto i : grid_stride_range_y(0, nr)) { - auto p = data + i*nc; + auto p = data + i * nc; float temp = 0; for (auto j : grid_stride_range(0, nc)) - temp += p[j]*p[j]; + temp += p[j] * p[j]; - // and store the sum into invnorms[i] warp_reduce_atomic_add(invnorms[i], temp); } - __syncthreads(); + } + __global__ void _cuda_inverse_norms_invert( + float* invnorms, + size_t nr + ) + { for (auto i : grid_stride_range_y(0, nr)) - for (auto j : grid_stride_range(0, 1)) - invnorms[i] = 1.0/std::sqrt(invnorms[i]); + { + if (threadIdx.x == 0) + invnorms[i] = 1.0f / std::sqrt(invnorms[i]); + } } - void inverse_norms ( + void inverse_norms( resizable_tensor& invnorms, const tensor& data, const double eps ) { - invnorms.set_size(data.num_samples()); - launch_kernel(_cuda_inverse_norms, max_jobs(data.size()/data.num_samples(), data.num_samples()), - invnorms.device(), data.device(), data.num_samples(), data.size()/data.num_samples(), eps); + const auto nr = data.num_samples(); + const auto nc = data.size() / data.num_samples(); + + invnorms.set_size(nr); + invnorms = eps; + + launch_kernel(_cuda_inverse_norms_accumulate, max_jobs(nc, nr), + invnorms.device(), data.device(), nr, nc); + + launch_kernel(_cuda_inverse_norms_invert, max_jobs(1, nr), + invnorms.device(), nr); } // ---------------------------------------------------------------------------------------- __global__ void _cuda_dot_prods(float* out, const float* lhs, const float* rhs, size_t nr, size_t nc) { - // initialize out before we begin. - for (auto i : grid_stride_range_y(0, nr)) - for (auto j : grid_stride_range(0, 1)) - out[i] = 0; - __syncthreads(); - for (auto i : grid_stride_range_y(0, nr)) { - auto l = lhs + i*nc; - auto r = rhs + i*nc; + auto l = lhs + i * nc; + auto r = rhs + i * nc; float temp = 0; for (auto j : grid_stride_range(0, nc)) - temp += l[j]*r[j]; + temp += l[j] * r[j]; - // and store the sum into out[i] warp_reduce_atomic_add(out[i], temp); } } @@ -175,53 +179,61 @@ namespace dlib { for (auto i : grid_stride_range_y(0, nr)) { - auto l = lhs + i*nc; - auto r = rhs + i*nc; + auto l = lhs + i * nc; + auto r = rhs + i * nc; float temp = 0; for (auto j : grid_stride_range(0, nc)) - temp += l[j]*r[j]; + temp += l[j] * r[j]; - // and store the sum into out[i] warp_reduce_atomic_add(out[i], temp); } } - void dot_prods ( + void dot_prods( resizable_tensor& out, const tensor& lhs, const tensor& rhs ) { - DLIB_CASSERT(have_same_dimensions(lhs,rhs)); + DLIB_CASSERT(have_same_dimensions(lhs, rhs)); out.set_size(lhs.num_samples()); if (out.size() == 0) return; const auto nr = lhs.num_samples(); - const auto nc = lhs.size()/lhs.num_samples(); + const auto nc = lhs.size() / lhs.num_samples(); - launch_kernel(_cuda_dot_prods, max_jobs(nc,nr), out.device_write_only(), lhs.device(), rhs.device(), nr, nc); + out = 0; + launch_kernel(_cuda_dot_prods, max_jobs(nc, nr), + out.device(), lhs.device(), rhs.device(), nr, nc); } - void dot_prods ( + void dot_prods( bool add_to, tensor& out, const tensor& lhs, const tensor& rhs ) { - DLIB_CASSERT(have_same_dimensions(lhs,rhs)); + DLIB_CASSERT(have_same_dimensions(lhs, rhs)); DLIB_CASSERT(out.k() == 1 && out.nr() == 1 && out.nc() == 1); - DLIB_CASSERT(out.size() == lhs.num_samples()); + DLIB_CASSERT(out.num_samples() == lhs.num_samples()); const auto nr = lhs.num_samples(); - const auto nc = lhs.size()/lhs.num_samples(); + const auto nc = lhs.size() / lhs.num_samples(); if (add_to) - launch_kernel(_cuda_dot_prods_add_to, max_jobs(nc,nr), out.device(), lhs.device(), rhs.device(), nr, nc); + { + launch_kernel(_cuda_dot_prods_add_to, max_jobs(nc, nr), + out.device(), lhs.device(), rhs.device(), nr, nc); + } else - launch_kernel(_cuda_dot_prods, max_jobs(nc,nr), out.device_write_only(), lhs.device(), rhs.device(), nr, nc); + { + out = 0; + launch_kernel(_cuda_dot_prods, max_jobs(nc, nr), + out.device(), lhs.device(), rhs.device(), nr, nc); + } } // ---------------------------------------------------------------------------------------- @@ -465,27 +477,21 @@ namespace dlib { for (auto i : grid_stride_range(0, n)) { - auto k = (i/bs)%ks; - d[i] = s1[i]*s2[k]; + auto k = (i / bs) % ks; + d[i] = s1[i] * s2[k]; } } __global__ void _cuda_multiply_conv2(float* d, const float* s1, size_t n, const float* s2, size_t bs, size_t ks) { - // zero initialize d before we begin. - for (auto i : grid_stride_range_y(0, ks)) - for (auto j : grid_stride_range(0, 1)) - d[i] = 0; - __syncthreads(); - // loop over all the image planes for (auto i : grid_stride_range_y(0, n)) { // sum all the elements in the i-th image plane float temp = 0; - for (auto j : grid_stride_range(i*bs, (i+1)*bs)) - temp += s1[j]*s2[j]; - auto k = i%ks; + for (auto j : grid_stride_range(i* bs, (i + 1)* bs)) + temp += s1[j] * s2[j]; + auto k = i % ks; // and store the sum into d[k] warp_reduce_atomic_add(d[k], temp); } @@ -495,8 +501,8 @@ namespace dlib { for (auto i : grid_stride_range(0, n)) { - auto k = (i/bs)%ks; - d[i] += s1[i]*s2[k]; + auto k = (i / bs) % ks; + d[i] += s1[i] * s2[k]; } } @@ -507,53 +513,56 @@ namespace dlib { // sum all the elements in the i-th image plane float temp = 0; - for (auto j : grid_stride_range(i*bs, (i+1)*bs)) - temp += s1[j]*s2[j]; - auto k = i%ks; + for (auto j : grid_stride_range(i* bs, (i + 1)* bs)) + temp += s1[j] * s2[j]; + auto k = i % ks; // and store the sum into d[k] warp_reduce_atomic_add(d[k], temp); } } - - void multiply_conv ( + void multiply_conv( bool add_to, tensor& dest, const tensor& src1, const tensor& src2 ) { - if (have_same_dimensions(dest,src1)) + if (have_same_dimensions(dest, src1)) { DLIB_CASSERT(src2.num_samples() == 1 && src2.nr() == 1 && src2.nc() == 1 && src2.k() == src1.k()); - if (dest.size() == 0) - return; - if (add_to) - launch_kernel(_cuda_multiply_conv_add_to,max_jobs(dest.size()), - dest.device(), src1.device(), src1.size(), src2.device(), src1.nr()*src1.nc(), src1.k()); + { + launch_kernel(_cuda_multiply_conv_add_to, max_jobs(dest.size()), + dest.device(), src1.device(), src1.size(), src2.device(), src1.nr() * src1.nc(), src1.k()); + } else - launch_kernel(_cuda_multiply_conv,max_jobs(dest.size()), - dest.device(), src1.device(), src1.size(), src2.device(), src1.nr()*src1.nc(), src1.k()); + { + launch_kernel(_cuda_multiply_conv, max_jobs(dest.size()), + dest.device(), src1.device(), src1.size(), src2.device(), src1.nr() * src1.nc(), src1.k()); + } } else { - DLIB_CASSERT(have_same_dimensions(src1,src2)); + DLIB_CASSERT(src1.num_samples() == src2.num_samples() && src1.k() == src2.k() && + src1.nr() == src2.nr() && src1.nc() == src2.nc()); DLIB_CASSERT(dest.num_samples() == 1 && dest.nr() == 1 && dest.nc() == 1 && dest.k() == src1.k()); - if (dest.size() == 0) - return; + const auto bs = src1.nr() * src1.nc(); + const auto n = src1.num_samples() * src1.k(); - const auto bs = src1.nr()*src1.nc(); - const auto n = src1.num_samples()*src1.k(); if (add_to) - launch_kernel(_cuda_multiply_conv2_add_to, max_jobs(bs,n), + { + launch_kernel(_cuda_multiply_conv2_add_to, max_jobs(bs, n), dest.device(), src1.device(), n, src2.device(), bs, src1.k()); + } else - launch_kernel(_cuda_multiply_conv2, max_jobs(bs,n), + { + dest = 0; + launch_kernel(_cuda_multiply_conv2, max_jobs(bs, n), dest.device(), src1.device(), n, src2.device(), bs, src1.k()); + } } - } // ------------------------------------------------------------------------------------ @@ -2210,26 +2219,21 @@ namespace dlib // ---------------------------------------------------------------------------------------- - __global__ void _cuda_layer_normalize( - float* out, - const float* s, + __global__ void _cuda_layer_normalize_accumulate( float* m, float* v, - const float* g, - const float* b, - float eps, + const float* s, size_t ns, size_t k, size_t num ) { - // compute means and sum of squares for (auto n : grid_stride_range_y(0, ns)) { const auto ps = s + n * k * num; float means = 0; float invstds = 0; - for (auto i : grid_stride_range(0, k * num)) + for (auto i : grid_stride_range(0, k* num)) { means += ps[i]; invstds += ps[i] * ps[i]; @@ -2237,23 +2241,39 @@ namespace dlib warp_reduce_atomic_add(m[n], means / (k * num)); warp_reduce_atomic_add(v[n], invstds / (k * num)); } - __syncthreads(); + } - // compute variances + __global__ void _cuda_layer_normalize_invert( + float* m, + float* v, + float eps, + size_t ns + ) + { for (auto n : grid_stride_range_y(0, ns)) { - for (auto i : grid_stride_range(0, 1)) - { + if (threadIdx.x == 0) v[n] = 1.0f / std::sqrt(v[n] - m[n] * m[n] + eps); - } } - __syncthreads(); + } + __global__ void _cuda_layer_normalize_apply( + float* out, + const float* s, + const float* m, + const float* v, + const float* g, + const float* b, + size_t ns, + size_t k, + size_t num + ) + { for (auto n : grid_stride_range_y(0, ns)) { const auto ps = s + n * k * num; const auto pout = out + n * k * num; - for (auto i : grid_stride_range(0, k * num)) + for (auto i : grid_stride_range(0, k* num)) { pout[i] = (ps[i] - m[n]) * v[n]; pout[i] = pout[i] * g[i / num] + b[i / num]; @@ -2261,7 +2281,7 @@ namespace dlib } } - void layer_normalize ( + void layer_normalize( const double eps, resizable_tensor& dest, resizable_tensor& means, @@ -2288,47 +2308,56 @@ namespace dlib "\neps: " << eps ); + const long ns = src.num_samples(); + const long ks = src.k(); + dest.copy_size(src); - means.set_size(src.num_samples()); - invstds.set_size(src.num_samples()); + means.set_size(ns); + invstds.set_size(ns); means = 0; invstds = 0; - launch_kernel(_cuda_layer_normalize, max_jobs(src.k() * num, src.num_samples()), dest.device(), src.device(), - means.device(), invstds.device(), gamma.device(), beta.device(), eps, src.num_samples(), src.k(), num); + + launch_kernel(_cuda_layer_normalize_accumulate, max_jobs(ks * num, ns), + means.device(), invstds.device(), src.device(), ns, ks, num); + + launch_kernel(_cuda_layer_normalize_invert, max_jobs(1, ns), + means.device(), invstds.device(), eps, ns); + + launch_kernel(_cuda_layer_normalize_apply, max_jobs(ks * num, ns), + dest.device(), src.device(), means.device(), invstds.device(), + gamma.device(), beta.device(), ns, ks, num); } // ---------------------------------------------------------------------------------------- - __global__ void _cuda_layer_normalize_gradient( - float* out, - float* gg, + __global__ void _cuda_layer_normalize_gradient_accumulate( float* bg, + float* gg, + float* dv, const float* s, const float* gi, const float* m, const float* v, const float* g, - float* dm, - float* dv, - float eps, size_t ns, size_t ks, - size_t num) + size_t num + ) { - for (auto nk : grid_stride_range_y(0, ns * ks)) + for (auto nk : grid_stride_range_y(0, ns* ks)) { const auto n = nk / ks; const auto k = nk % ks; const auto ps = s + (n * ks + k) * num; const auto pgi = gi + (n * ks + k) * num; - const float invstd_pow = -0.5 * std::pow(v[n], 3.0f); + const float invstd_pow = -0.5f * std::pow(v[n], 3.0f); float temp_bg = 0; float temp_gg = 0; float temp_dv = 0; for (auto i : grid_stride_range(0, num)) { const float x_hat = (ps[i] - m[n]) * v[n]; - const float dx = pgi[i] * g[i / num]; + const float dx = pgi[i] * g[k]; temp_bg += pgi[i]; temp_gg += pgi[i] * x_hat; temp_dv += dx * (ps[i] - m[n]) * invstd_pow; @@ -2337,29 +2366,57 @@ namespace dlib warp_reduce_atomic_add(gg[k], temp_gg); warp_reduce_atomic_add(dv[n], temp_dv); } - __syncthreads(); + } + __global__ void _cuda_layer_normalize_gradient_compute_dm( + float* dm, + const float* dv, + const float* s, + const float* gi, + const float* m, + const float* v, + const float* g, + size_t ns, + size_t ks, + size_t num + ) + { const float invnum = 1.0f / (ks * num); for (auto n : grid_stride_range_y(0, ns)) { const auto ps = s + n * ks * num; const auto pgi = gi + n * ks * num; float temp_dm = 0; - for (auto i : grid_stride_range(0, ks * num)) + for (auto i : grid_stride_range(0, ks* num)) { const float dx = pgi[i] * g[i / num]; temp_dm += -dx * v[n] + dv[n] * -2 * (ps[i] - m[n]) * invnum; } warp_reduce_atomic_add(dm[n], temp_dm); } - __syncthreads(); + } + __global__ void _cuda_layer_normalize_gradient_apply( + float* out, + const float* dm, + const float* dv, + const float* s, + const float* gi, + const float* m, + const float* v, + const float* g, + size_t ns, + size_t ks, + size_t num + ) + { + const float invnum = 1.0f / (ks * num); for (auto n : grid_stride_range_y(0, ns)) { const auto ps = s + n * ks * num; const auto pgi = gi + n * ks * num; const auto pout = out + n * ks * num; - for (auto i : grid_stride_range(0, ks * num)) + for (auto i : grid_stride_range(0, ks* num)) { const float dx = pgi[i] * g[i / num]; pout[i] += dx * v[n] + dv[n] * 2 * (ps[i] - m[n]) * invnum + dm[n] * invnum; @@ -2367,7 +2424,7 @@ namespace dlib } } - void layer_normalize_gradient ( + void layer_normalize_gradient( const double eps, const tensor& gradient_input, const tensor& means, @@ -2393,26 +2450,37 @@ namespace dlib DLIB_CASSERT(have_same_dimensions(gradient_input, src_grad)); DLIB_CASSERT(eps > 0); + const long ns = src.num_samples(); + const long ks = src.k(); + beta_grad = 0; gamma_grad = 0; dvars.copy_size(invstds); dmeans.copy_size(means); dvars = 0; dmeans = 0; - launch_kernel(_cuda_layer_normalize_gradient, max_jobs(src.k() * num, src.num_samples()), - src_grad.device(), gamma_grad.device(), beta_grad.device(), src.device(), - gradient_input.device(), means.device(), invstds.device(), gamma.device(), - dmeans.device(), dvars.device(), eps, src.num_samples(), src.k(), num); + + launch_kernel(_cuda_layer_normalize_gradient_accumulate, max_jobs(ks * num, ns * ks), + beta_grad.device(), gamma_grad.device(), dvars.device(), + src.device(), gradient_input.device(), means.device(), invstds.device(), + gamma.device(), ns, ks, num); + + launch_kernel(_cuda_layer_normalize_gradient_compute_dm, max_jobs(ks * num, ns), + dmeans.device(), dvars.device(), + src.device(), gradient_input.device(), means.device(), invstds.device(), + gamma.device(), ns, ks, num); + + launch_kernel(_cuda_layer_normalize_gradient_apply, max_jobs(ks * num, ns), + src_grad.device(), dmeans.device(), dvars.device(), + src.device(), gradient_input.device(), means.device(), invstds.device(), + gamma.device(), ns, ks, num); } // ---------------------------------------------------------------------------------------- - __global__ void _cuda_rms_normalize( - float* dest, + __global__ void _cuda_rms_normalize_accumulate( float* scale, const float* src, - const float* gamma, - float eps, size_t ns, size_t ks, size_t num @@ -2422,28 +2490,42 @@ namespace dlib { const auto ps = src + n * ks * num; float sum_squares = 0.0f; - for (auto i : grid_stride_range(0, ks * num)) + for (auto i : grid_stride_range(0, ks* num)) { sum_squares += ps[i] * ps[i]; } warp_reduce_atomic_add(scale[n], sum_squares / (ks * num)); } - __syncthreads(); + } + __global__ void _cuda_rms_normalize_invert( + float* scale, + float eps, + size_t ns + ) + { for (auto n : grid_stride_range_y(0, ns)) { - for (auto i : grid_stride_range(0, 1)) - { + if (threadIdx.x == 0) scale[n] = 1.0f / std::sqrt(scale[n] + eps); - } } - __syncthreads(); + } + __global__ void _cuda_rms_normalize_apply( + float* dest, + const float* scale, + const float* src, + const float* gamma, + size_t ns, + size_t ks, + size_t num + ) + { for (auto n : grid_stride_range_y(0, ns)) { const auto ps = src + n * ks * num; const auto pd = dest + n * ks * num; - for (auto i : grid_stride_range(0, ks * num)) + for (auto i : grid_stride_range(0, ks* num)) { pd[i] = ps[i] * scale[n] * gamma[i / num]; } @@ -2457,7 +2539,7 @@ namespace dlib const tensor& src, const tensor& gamma ) - { + { DLIB_CASSERT( gamma.k() == src.k() && gamma.nr() == 1 && @@ -2478,26 +2560,31 @@ namespace dlib scale.set_size(ns); scale = 0; - launch_kernel(_cuda_rms_normalize, max_jobs(ks * num, ns), - dest.device(), scale.device(), src.device(), gamma.device(), eps, ns, ks, num); + launch_kernel(_cuda_rms_normalize_accumulate, max_jobs(ks * num, ns), + scale.device(), src.device(), ns, ks, num); + + launch_kernel(_cuda_rms_normalize_invert, max_jobs(1, ns), + scale.device(), eps, ns); + + launch_kernel(_cuda_rms_normalize_apply, max_jobs(ks * num, ns), + dest.device(), scale.device(), src.device(), gamma.device(), ns, ks, num); } // ---------------------------------------------------------------------------------------- - __global__ void _cuda_rms_normalize_gradient( - float* src_grad, + __global__ void _cuda_rms_normalize_gradient_accumulate( float* gamma_grad, float* dscale, const float* src, const float* gradient_input, const float* scale, const float* gamma, - size_t ns, - size_t ks, - size_t num + size_t ns, + size_t ks, + size_t num ) { - for (auto nk : grid_stride_range_y(0, ns * ks)) + for (auto nk : grid_stride_range_y(0, ns* ks)) { const auto n = nk / ks; const auto k = nk % ks; @@ -2509,22 +2596,34 @@ namespace dlib for (auto i : grid_stride_range(0, num)) { const float x_hat = ps[i] * scale[n]; - const float dx = pgi[i] * gamma[i / num]; + const float dx = pgi[i] * gamma[k]; temp_gg += pgi[i] * x_hat; temp_ds += dx * ps[i] * scale_pow; } warp_reduce_atomic_add(gamma_grad[k], temp_gg); warp_reduce_atomic_add(dscale[n], temp_ds); } - __syncthreads(); + } + __global__ void _cuda_rms_normalize_gradient_apply( + float* src_grad, + const float* dscale, + const float* src, + const float* gradient_input, + const float* scale, + const float* gamma, + size_t ns, + size_t ks, + size_t num + ) + { const float invnum = 1.0f / (ks * num); for (auto n : grid_stride_range_y(0, ns)) { const auto ps = src + n * ks * num; const auto pgi = gradient_input + n * ks * num; const auto psg = src_grad + n * ks * num; - for (auto i : grid_stride_range(0, ks * num)) + for (auto i : grid_stride_range(0, ks* num)) { const float dx = pgi[i] * gamma[i / num]; psg[i] += dx * scale[n] + dscale[n] * 2 * ps[i] * invnum; @@ -2541,7 +2640,7 @@ namespace dlib tensor& gamma_grad, resizable_tensor& dscale ) - { + { DLIB_CASSERT(src.num_samples() == scale.size()); DLIB_CASSERT(have_same_dimensions(gamma, gamma_grad)); DLIB_CASSERT(gamma.k() == src.k()); @@ -2558,9 +2657,13 @@ namespace dlib dscale.copy_size(scale); dscale = 0; - // Lancement du kernel CUDA - launch_kernel(_cuda_rms_normalize_gradient, max_jobs(ks * num, ns), - src_grad.device(), gamma_grad.device(), dscale.device(), + launch_kernel(_cuda_rms_normalize_gradient_accumulate, max_jobs(ks * num, ns * ks), + gamma_grad.device(), dscale.device(), + src.device(), gradient_input.device(), scale.device(), gamma.device(), + ns, ks, num); + + launch_kernel(_cuda_rms_normalize_gradient_apply, max_jobs(ks * num, ns), + src_grad.device(), dscale.device(), src.device(), gradient_input.device(), scale.device(), gamma.device(), ns, ks, num); } @@ -2736,12 +2839,23 @@ namespace dlib // ---------------------------------------------------------------------------------------- // CUDA Kernels for ACT operations - __global__ void _cuda_compute_act_halt_probabilities( - float* halt_probs, + + // Kernel 1: initialize logits with bias + __global__ void _cuda_act_init_logits( + float* logits, + float b_halt, + size_t total_positions + ) + { + for (auto pos : grid_stride_range(0, total_positions)) + logits[pos] = b_halt; + } + + // Kernel 2: compute dot product and accumulate into logits + __global__ void _cuda_act_accumulate_logits( float* logits, const float* input_data, const float* W_halt, - float b_halt, size_t batch_size, size_t seq_len, size_t d_model, @@ -2751,11 +2865,6 @@ namespace dlib { const long total_positions = batch_size * seq_len; - for (auto pos : grid_stride_range_y(0, total_positions)) - for (auto i : grid_stride_range(0, 1)) - logits[pos] = b_halt; - __syncthreads(); - for (auto pos : grid_stride_range_y(0, total_positions)) { const long n = pos / seq_len; @@ -2773,12 +2882,17 @@ namespace dlib warp_reduce_atomic_add(logits[pos], temp); } - __syncthreads(); + } + // Kernel 3: apply sigmoid to compute halt probabilities + __global__ void _cuda_act_apply_sigmoid( + float* halt_probs, + const float* logits, + size_t total_positions + ) + { for (auto pos : grid_stride_range(0, total_positions)) - { halt_probs[pos] = 1.0f / (1.0f + expf(-logits[pos])); - } } void compute_act_halt_probabilities( @@ -2798,18 +2912,36 @@ namespace dlib halt_probs.set_size(total_positions, 1, 1, 1); logits.set_size(total_positions, 1, 1, 1); - launch_kernel(_cuda_compute_act_halt_probabilities, + // Extract bias from halt_params (last element) + const float b_halt = halt_params.host()[feature_dim]; + + // Phase 1: initialize logits with bias + launch_kernel(_cuda_act_init_logits, + max_jobs(total_positions), + logits.device(), + b_halt, + total_positions); + + // Phase 2: accumulate dot product into logits + // Note: sequential kernel launch provides implicit synchronization + launch_kernel(_cuda_act_accumulate_logits, max_jobs(feature_dim, total_positions), - halt_probs.device(), logits.device(), input_data.device(), halt_params.device(), - halt_params.host()[feature_dim], batch_size, seq_len, d_model, num_channels, feature_dim); + + // Phase 3: apply sigmoid + // Note: sequential kernel launch provides implicit synchronization + launch_kernel(_cuda_act_apply_sigmoid, + max_jobs(total_positions), + halt_probs.device(), + logits.device(), + total_positions); } __global__ void _cuda_update_act_state( diff --git a/dlib/dnn/core.h b/dlib/dnn/core.h index dd4f7eb0fb..a945a41e43 100644 --- a/dlib/dnn/core.h +++ b/dlib/dnn/core.h @@ -3226,6 +3226,26 @@ namespace dlib data[i] = rnd.get_random_gaussian()*sigma; } + struct test_layer_params + { + /*! + WHAT THIS OBJECT REPRESENTS + This object allows specifying constraints on tensor dimensions + when testing layers with test_layer(). + + If a member is set to 0 (the default), the dimension is chosen randomly + during testing. If a member is strictly positive, the dimension is fixed + to that value for all iterations. + + This is useful for layers with intrinsic constraints (e.g., k() must be 1 + or nc() must equal a specific d_model). + !*/ + long num_samples = 0; + long k = 0; + long nr = 0; + long nc = 0; + }; + class test_layer_subnet { public: @@ -3233,21 +3253,15 @@ namespace dlib dlib::rand& rnd_ ) : rnd(rnd_) { - // Output and gradient_input have to have the same dimensions in each - // layer. - const long num_samples = rnd.get_random_32bit_number()%4+3; - const long k = rnd.get_random_32bit_number()%4+2; - const long nr = ((rnd.get_random_32bit_number()%4)/2)*2+2; - const long nc = ((rnd.get_random_32bit_number()%4)/2)*2+2; - - output.set_size(num_samples, k, nr, nc); - gradient_input.set_size(num_samples, k, nr, nc); - - // Use a non-zero initial gradient to make sure the layers add to it - // rather than assign and blow away the initial value. - fill_with_gassuan_random_numbers(gradient_input, rnd, 0.01); + init(test_layer_params()); + } - fill_with_gassuan_random_numbers(output, rnd); + test_layer_subnet( + dlib::rand& rnd_, + const test_layer_params& p + ) : rnd(rnd_) + { + init(p); } @@ -3288,6 +3302,24 @@ namespace dlib private: + void init(const test_layer_params& p) + { + // If a dimension is fixed in p, use it. Otherwise, generate random dimensions. + const long num_samples = p.num_samples != 0 ? p.num_samples : (rnd.get_random_32bit_number() % 4 + 3); + const long k = p.k != 0 ? p.k : (rnd.get_random_32bit_number() % 4 + 2); + const long nr = p.nr != 0 ? p.nr : (((rnd.get_random_32bit_number() % 4) / 2) * 2 + 2); + const long nc = p.nc != 0 ? p.nc : (((rnd.get_random_32bit_number() % 4) / 2) * 2 + 2); + + output.set_size(num_samples, k, nr, nc); + gradient_input.set_size(num_samples, k, nr, nc); + + // Use a non-zero initial gradient to make sure the layers add to it + // rather than assign and blow away the initial value. + fill_with_gassuan_random_numbers(gradient_input, rnd, 0.01); + + fill_with_gassuan_random_numbers(output, rnd); + } + // We lazily initialize sub-layers as needed when someone tries to call // subnet() void init_sub() const @@ -3326,7 +3358,8 @@ namespace dlib > layer_test_results impl_test_layer ( layer_details_type l, - const float base_eps + const float base_eps, + const timpl::test_layer_params& p ) { using namespace timpl; @@ -3336,7 +3369,8 @@ namespace dlib std::ostringstream sout; for (int iter = 0; iter < 10; ++iter) { - test_layer_subnet subnetwork(rnd); + // Pass the test_layer_params to the subnet constructor + test_layer_subnet subnetwork(rnd, p); resizable_tensor output, out2, out3; // Run setup() and forward() as well to make sure any calls to subnet() have // happened before we start assuming we know how many data elements there are @@ -3381,7 +3415,7 @@ namespace dlib // in in-place mode. if (impl::is_inplace_layer(l, subnetwork)) { - test_layer_subnet subnetwork2(rnd); + test_layer_subnet subnetwork2(rnd, p); layer_details_type ll(l); ll.setup(subnetwork2); resizable_tensor ip_out; @@ -3526,21 +3560,33 @@ namespace dlib template < typename layer_details_type - > - layer_test_results test_layer ( + > + layer_test_results test_layer( layer_details_type l ) + { + // Default behavior: use random dimensions (all zeros in params) + return test_layer(l, timpl::test_layer_params()); + } + + template < + typename layer_details_type + > + layer_test_results test_layer( + layer_details_type l, + const timpl::test_layer_params& params + ) { // Try a few different derivative step sizes to see if any work. for (float base_eps = 0.0001; base_eps < 0.1; base_eps *= 2) { - auto result = impl_test_layer(l, base_eps); + auto result = impl_test_layer(l, base_eps, params); if (result) return result; } // However, if none of the step sizes worked then try this one and probably result // in returning an error. - return impl_test_layer(l, 0.01); + return impl_test_layer(l, 0.01, params); } // ---------------------------------------------------------------------------------------- diff --git a/dlib/test/dnn.cpp b/dlib/test/dnn.cpp index a67bdc7236..4b979ce20d 100644 --- a/dlib/test/dnn.cpp +++ b/dlib/test/dnn.cpp @@ -2607,6 +2607,9 @@ namespace } { print_spinner(); + dlib::timpl::test_layer_params p; + p.k = 1; + p.nc = 128; adaptive_computation_time_<6> l; auto res = test_layer(l); DLIB_TEST_MSG(res, res); diff --git a/examples/slm_advanced_train_ex.cpp b/examples/slm_advanced_train_ex.cpp index 4ed6ffec22..d4826be851 100644 --- a/examples/slm_advanced_train_ex.cpp +++ b/examples/slm_advanced_train_ex.cpp @@ -1265,9 +1265,7 @@ int main(int argc, char** argv) while (total_bytes < target_size && next_token != start_of_text && next_token != end_of_text && !g_terminate_flag.load()) { // Predict next token - std::vector> in_tokens = { input_seq, input_seq }; - auto out_token = net(in_tokens); - next_token = static_cast(out_token[0]); + next_token = static_cast(net(input_seq)); token_buffer.push_back(next_token); token_count++;