1515from ..runtime import JITFunction
1616from .errors import (CompilationError , CompileTimeAssertionFailure , UnsupportedLanguageConstruct )
1717from types import ModuleType
18+ from .hintmanager import hint_trigger
1819
1920
2021def mangle_ty (ty ):
@@ -247,11 +248,6 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n
247248 # special handling.
248249 self .visiting_arg_default_value = False
249250
250- # adding unified hint manager init
251- from .hint_manager import HintManager
252- from .hint_manager import hint_get_flagtree_backend
253- self .hint_manager = HintManager (hint_get_flagtree_backend ())
254-
255251 builtin_namespace : Dict [str , Any ] = {_ .__name__ : _ for _ in (len , list , range , float , int , isinstance , getattr )}
256252 builtin_namespace .update ((
257253 ('print' , language .core .device_print ),
@@ -522,7 +518,7 @@ def visit_Assign(self, node):
522518 self .set_value (name , value )
523519
524520 # switch into hintmanager
525- self . hint_manager . handler . trigger ("ext_CodeGenerator_visit_Assign_hint_anno" , self , node , names , values )
521+ hint_trigger ("ext_CodeGenerator_visit_Assign_hint_anno" , self , node , names , values )
526522
527523 def visit_AugAssign (self , node ):
528524 name = node .target .id
@@ -927,9 +923,10 @@ def visit_For(self, node):
927923 # flagtree backend specialization: add more ForOp attributes
928924 for_op_ext_attrs = (False , False , False , False )
929925
926+ # flagtree backend specialization
927+ from triton .runtime .driver import spec
930928 bind_sub_block = None
931- ext_it_class_support = [language .range ] # why?
932- ext_it_class_support += self .hint_manager .handler .trigger ("visit_For_ext_support" )
929+ ext_it_class_support = spec ("visit_For_ext_support" )
933930 ext_it_class_support = [] if ext_it_class_support is None else ext_it_class_support
934931 if IteratorClass in [language .range ] + ext_it_class_support :
935932 iterator = IteratorClass (* iter_args , ** iter_kwargs )
@@ -942,7 +939,10 @@ def visit_For(self, node):
942939 num_stages = iterator .num_stages
943940 loop_unroll_factor = iterator .loop_unroll_factor
944941
945- new_bind_sub_block = self .hint_manager .handler .trigger ("set_bind_sub_block_when_parallel" , IteratorClass , iterator , bind_sub_block )
942+ # flagtree backend specialization
943+ for_op_ext_attrs = spec ("for_op_ext_attrs" , iterator )
944+ # flagtree backend specialization
945+ new_bind_sub_block = spec ("set_bind_sub_block_when_parallel" , IteratorClass , iterator , bind_sub_block )
946946 if new_bind_sub_block is not None :
947947 bind_sub_block = new_bind_sub_block
948948 elif IteratorClass is range :
@@ -955,7 +955,7 @@ def visit_For(self, node):
955955 else :
956956 raise RuntimeError ('Only `range` and `static_range` iterators are currently supported' )
957957
958- new_bind_sub_block = self . hint_manager . handler . trigger ("check_override_bind_sub_block" , self , node , bind_sub_block )
958+ new_bind_sub_block = hint_trigger ("check_override_bind_sub_block" , self , node , bind_sub_block )
959959 if new_bind_sub_block is not None :
960960 bind_sub_block = new_bind_sub_block
961961
@@ -1027,7 +1027,7 @@ def visit_For(self, node):
10271027 spec ("for_op_set_ext_attrs" , for_op , self .builder , for_op_ext_attrs )
10281028 # flagtree backend specialization
10291029 if bind_sub_block :
1030- self . hint_manager . handler . trigger ("forop_setattr_for_bind_sub_block" , self , for_op , bind_sub_block )
1030+ hint_trigger ("forop_setattr_for_bind_sub_block" , self , for_op , bind_sub_block )
10311031
10321032 self .scf_stack .append (node )
10331033 self .builder .set_insertion_point_to_start (for_op .get_body (0 ))
@@ -1112,8 +1112,9 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs):
11121112 except Exception as e :
11131113 # Wrap the error in the callee with the location of the call.
11141114
1115-
1116- if self .hint_manager .handler .trigger ("need_repr_in_CodeGenerator_CompilationError" ):
1115+ # flagtree backend specialization
1116+ from triton .runtime .driver import spec
1117+ if spec ('need_repr_in_CodeGenerator_CompilationError' ):
11171118 raise CompilationError (self .jit_fn .src , self .cur_node , repr (e )) from e
11181119 raise CompilationError (self .jit_fn .src , self .cur_node , None ) from e
11191120
@@ -1160,7 +1161,9 @@ def visit_Call(self, node):
11601161 # preserve the traceback of the original error, which may e.g.
11611162 # be in core.py.
11621163
1163- if self .hint_manager .handler .trigger ("need_repr_in_CodeGenerator_CompilationError" ):
1164+ #flagtree backend specialization
1165+ from triton .runtime .driver import spec
1166+ if spec ('need_repr_in_CodeGenerator_CompilationError' ):
11641167 raise CompilationError (self .jit_fn .src , node , repr (e )) from e
11651168 raise CompilationError (self .jit_fn .src , node , None ) from e
11661169
0 commit comments