@@ -142,7 +142,7 @@ class phase1launchFactory
142142 std::cout << " B TYpe: " << B->type << std::endl;
143143// // (1) create the semiring code and name
144144
145- // // (2) ensure the jitifier has "GB_semiring_[mysemiring.sr_code].h"
145+ // // (2) ensure the jitifier has "GB_semiring_[mysemiring.sr_code].h"
146146 jit::GBJitCache filecache = jit::GBJitCache::Instance () ;
147147 filecache.getFile (semiring_factory_) ;
148148
@@ -162,6 +162,11 @@ class phase1launchFactory
162162 dim3 grid (get_number_of_blocks (M));
163163 dim3 block (get_threads_per_block ());
164164
165+ // for (auto s:compiler_flags)
166+ // {
167+ // std::cout << "Compiler Flags: " << s << std::endl ;
168+ // }
169+
165170 jit::launcher ( hashable_name + " _" + M->type ->name + " _" + sr_code,
166171 string_to_be_jitted.str (),
167172 header_names,
@@ -199,6 +204,13 @@ class phase2launchFactory
199204 return (ntasks + threads_per_block - 1 ) / threads_per_block ;
200205 }
201206
207+ int get_number_of_phase1_blocks ( GrB_Matrix M){
208+ const int64_t mnz = GB_nnz (M) ;
209+ int number_of_sms = GB_Global_gpu_sm_get (0 );
210+ int nblks = ( GB_nnz (M) + chunk_size - 1 )/chunk_size;
211+ return GB_IMIN ( nblks, 128 * number_of_sms);
212+ }
213+
202214 bool jitGridBlockLaunch (// parameters to AxB_phase2:
203215 int64_t *blockBucket, int64_t *offset, GrB_Matrix M) {
204216
@@ -224,7 +236,7 @@ class phase2launchFactory
224236 .set_kernel_inst ( kernel_name, {})
225237 .configure (grid, block)
226238 // parameters to AxB_phase2:
227- .launch ( blockBucket, offset, get_number_of_blocks (M));
239+ .launch ( blockBucket, offset, get_number_of_phase1_blocks (M));
228240
229241 checkCudaErrors ( cudaDeviceSynchronize () );
230242 result= true ;
@@ -319,9 +331,9 @@ class phase3launchFactory
319331 // ----------------------------------------------------------------------
320332 // phase3: do the numerical work
321333 // ----------------------------------------------------------------------
334+
322335 C->jumbled = true ;
323- C->nzombies = bucketp[1 ]; // set pre-zombie counts
324- const int64_t Cnz = GB_nnz (C) ;
336+ const int64_t nz = end - start; // number of dots in this bucket
325337 const int64_t mnvec = M->nvec ;
326338
327339 int gridsz, blocksz, sz = 4 ;
@@ -332,10 +344,13 @@ class phase3launchFactory
332344 /* *
333345 * Configure geometry and kernel function name based on sparsity of C and number of vectors in M
334346 */
335- configure (Cnz, mnvec, final_kernel_name_ss, blocksz, gridsz, sz);
347+ configure ( nz, mnvec, final_kernel_name_ss, blocksz, gridsz, sz);
348+
349+ auto sr_code = std::to_string (semiring_factory_.sr_code );
336350
337351 std::string hashable_name = base_name + " _" + final_kernel_name_ss.str ();
338352 std::stringstream string_to_be_jitted ;
353+ std::vector<std::string> template_types = {C->type ->name , A->type ->name , B->type ->name };
339354
340355 jit::GBJitCache filecache = jit::GBJitCache::Instance () ;
341356 filecache.getFile (semiring_factory_) ;
@@ -347,17 +362,16 @@ class phase3launchFactory
347362 dim3 grid (gridsz);
348363 dim3 block (blocksz);
349364
350- C->nzombies = 0 ;
351365 GBURBLE (" (GPU phase3 launch st,end=%ld,%ld nblocks,blocksize= %d,%d )\n " ,start,end,gridsz,blocksz) ;
352- jit::launcher ( hashable_name,
366+ jit::launcher ( hashable_name + " _ " + M-> type -> name + " _ " + sr_code ,
353367 string_to_be_jitted.str (),
354368 header_names,
355369 compiler_flags,
356370 file_callback)
357- .set_kernel_inst (final_kernel_name_ss.str (),
358- { C->type ->name ,
359- A->type ->name ,
360- B->type ->name })
371+ .set_kernel_inst (final_kernel_name_ss.str (), template_types )
372+ // { C->type->name,
373+ // A->type->name,
374+ // B->type->name })
361375 .configure (grid, block) // if commented, use implicit 1D configure in launch
362376 .launch (
363377 start, // input/output:
@@ -386,6 +400,7 @@ class phase3launchFactory
386400 int number_of_sms = GB_Global_gpu_sm_get (0 ) ;
387401
388402 std::string Opname;
403+ // TODO: make sure this works with different geometry
389404
390405 printf (" LAUNCHING BUCKET CODE: %d\n " , (int )bucket_code_);
391406 switch (bucket_code_)
@@ -706,4 +721,4 @@ inline bool GB_cuda_reduce(GrB_Matrix A, void *output, GrB_Monoid op) {
706721//
707722//
708723#endif // C++11
709- #endif
724+ #endif
0 commit comments