Skip to content

Commit 6b204a7

Browse files
committed
update hintmanager, wrap additional code into hintmanager, back no-hint-related handler func into spec, update import, change jit implement into hintmanager, simplify trigger call
1 parent ded0530 commit 6b204a7

File tree

4 files changed

+86
-67
lines changed

4 files changed

+86
-67
lines changed

python/triton/compiler/code_generator.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from ..runtime import JITFunction
1616
from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)
1717
from types import ModuleType
18+
from .hintmanager import hint_trigger
1819

1920

2021
def mangle_ty(ty):
@@ -247,11 +248,6 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n
247248
# special handling.
248249
self.visiting_arg_default_value = False
249250

250-
# adding unified hint manager init
251-
from .hint_manager import HintManager
252-
from .hint_manager import hint_get_flagtree_backend
253-
self.hint_manager = HintManager(hint_get_flagtree_backend())
254-
255251
builtin_namespace: Dict[str, Any] = {_.__name__: _ for _ in (len, list, range, float, int, isinstance, getattr)}
256252
builtin_namespace.update((
257253
('print', language.core.device_print),
@@ -522,7 +518,7 @@ def visit_Assign(self, node):
522518
self.set_value(name, value)
523519

524520
# switch into hintmanager
525-
self.hint_manager.handler.trigger("ext_CodeGenerator_visit_Assign_hint_anno", self, node, names, values)
521+
hint_trigger("ext_CodeGenerator_visit_Assign_hint_anno", self, node, names, values)
526522

527523
def visit_AugAssign(self, node):
528524
name = node.target.id
@@ -927,9 +923,10 @@ def visit_For(self, node):
927923
# flagtree backend specialization: add more ForOp attributes
928924
for_op_ext_attrs = (False, False, False, False)
929925

926+
# flagtree backend specialization
927+
from triton.runtime.driver import spec
930928
bind_sub_block = None
931-
ext_it_class_support = [language.range] # why?
932-
ext_it_class_support += self.hint_manager.handler.trigger("visit_For_ext_support")
929+
ext_it_class_support = spec("visit_For_ext_support")
933930
ext_it_class_support = [] if ext_it_class_support is None else ext_it_class_support
934931
if IteratorClass in [language.range] + ext_it_class_support:
935932
iterator = IteratorClass(*iter_args, **iter_kwargs)
@@ -942,7 +939,10 @@ def visit_For(self, node):
942939
num_stages = iterator.num_stages
943940
loop_unroll_factor = iterator.loop_unroll_factor
944941

945-
new_bind_sub_block = self.hint_manager.handler.trigger("set_bind_sub_block_when_parallel", IteratorClass, iterator, bind_sub_block)
942+
# flagtree backend specialization
943+
for_op_ext_attrs = spec("for_op_ext_attrs", iterator)
944+
# flagtree backend specialization
945+
new_bind_sub_block = spec("set_bind_sub_block_when_parallel", IteratorClass, iterator, bind_sub_block)
946946
if new_bind_sub_block is not None:
947947
bind_sub_block = new_bind_sub_block
948948
elif IteratorClass is range:
@@ -955,7 +955,7 @@ def visit_For(self, node):
955955
else:
956956
raise RuntimeError('Only `range` and `static_range` iterators are currently supported')
957957

958-
new_bind_sub_block = self.hint_manager.handler.trigger("check_override_bind_sub_block", self, node, bind_sub_block)
958+
new_bind_sub_block = hint_trigger("check_override_bind_sub_block", self, node, bind_sub_block)
959959
if new_bind_sub_block is not None:
960960
bind_sub_block = new_bind_sub_block
961961

@@ -1027,7 +1027,7 @@ def visit_For(self, node):
10271027
spec("for_op_set_ext_attrs", for_op, self.builder, for_op_ext_attrs)
10281028
# flagtree backend specialization
10291029
if bind_sub_block:
1030-
self.hint_manager.handler.trigger("forop_setattr_for_bind_sub_block", self, for_op, bind_sub_block)
1030+
hint_trigger("forop_setattr_for_bind_sub_block", self, for_op, bind_sub_block)
10311031

10321032
self.scf_stack.append(node)
10331033
self.builder.set_insertion_point_to_start(for_op.get_body(0))
@@ -1112,8 +1112,9 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs):
11121112
except Exception as e:
11131113
# Wrap the error in the callee with the location of the call.
11141114

1115-
1116-
if self.hint_manager.handler.trigger("need_repr_in_CodeGenerator_CompilationError"):
1115+
# flagtree backend specialization
1116+
from triton.runtime.driver import spec
1117+
if spec('need_repr_in_CodeGenerator_CompilationError'):
11171118
raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from e
11181119
raise CompilationError(self.jit_fn.src, self.cur_node, None) from e
11191120

@@ -1160,7 +1161,9 @@ def visit_Call(self, node):
11601161
# preserve the traceback of the original error, which may e.g.
11611162
# be in core.py.
11621163

1163-
if self.hint_manager.handler.trigger("need_repr_in_CodeGenerator_CompilationError"):
1164+
#flagtree backend specialization
1165+
from triton.runtime.driver import spec
1166+
if spec('need_repr_in_CodeGenerator_CompilationError'):
11641167
raise CompilationError(self.jit_fn.src, node, repr(e)) from e
11651168
raise CompilationError(self.jit_fn.src, node, None) from e
11661169

python/triton/compiler/hint_manager.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,16 @@ def hint_get_flagtree_backend() -> str:
152152
return canonical_backend
153153
else:
154154
# version and backend mismatch
155-
logging.warning(
155+
msg = (
156156
f"[Flagtree] Hint ignored: Detected backend '{canonical_backend}' but current Triton version "
157157
f"'{current_triton_version}' matches no supported versions {supported_versions}."
158158
)
159-
return ""
159+
print(msg, file=sys.stderr)
160+
return ""
161+
# lazy load after first call hint trigger
162+
_global_hint_manager = None
163+
164+
def hint_trigger(hook_name, *args, **kwargs):
165+
if _global_hint_manager is None:
166+
_global_hint_manager = HintManager(hint_get_flagtree_backend())
167+
return _global_hint_manager.handler.trigger(hook_name, *args, **kwargs)

python/triton/runtime/jit.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, overload, Dict, Any, Tuple
1212
from ..runtime.driver import driver
1313
from types import ModuleType
14+
from ..compiler.hintmanager import hint_trigger
1415

1516
TRITON_MODULE = __name__[:-len(".runtime.jit")]
1617

@@ -826,26 +827,16 @@ def get_flagtree_backend():
826827
# the user might want to monkey-patch self.src dynamically.
827828
# Our unit tests do this, for example.
828829
def parse(self):
829-
# remove flagtree backend specialization, because the implementation of 2 method is totally same
830-
line_flagtree_hints = {}
831-
code_str = self.src
832-
g = tokenize.generate_tokens(StringIO(code_str).readline)
833-
for tok_type, tok_text, start, end, _ in g:
834-
if tok_type == tokenize.COMMENT:
835-
comment = tok_text.replace(" ", "").strip()
836-
if comment.startswith('#@hint:'):
837-
flagtree_hints = comment[len('#@hint:'):].strip()
838-
# Record the line number of the comment
839-
line_num = start[0]
840-
line_flagtree_hints[line_num] = flagtree_hints
830+
# after removing flagtree backend specialization, hiding the implementation into hintmanager
831+
line_flagtree_hints = hint_trigger("maps_line_numbers_to_comment_hints", self)
841832

842833
tree = ast.parse(self.src)
843834
assert isinstance(tree, ast.Module)
844835
assert len(tree.body) == 1
845836
assert isinstance(tree.body[0], ast.FunctionDef)
846837

847838
# Attach the line number to comment mapping to the function definition node
848-
tree.body[0].line_flagtree_hints = line_flagtree_hints
839+
hint_trigger('attach_line_number_to_comment_mapping', tree, line_flagtree_hints)
849840

850841
return tree
851842

third_party/ascend/backend/ascend_hint_handler.py

Lines changed: 55 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,44 +6,36 @@
66

77
class AscendHintHandler(BaseHintHandler):
88

9-
def ext_CodeGenerator_visit_Assign_hint_anno(self, code_generator, node, names, values):
10-
import ast
11-
from triton.compiler.code_generator import _is_triton_value
12-
# flagtree: After normal processing, check if we need to add hint annotation
13-
if hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn'):
14-
line_num = node.lineno
15-
# TODO: reparse needed in case we need to deal with complex cases, will be redesigned later
16-
function_def = code_generator.jit_fn.parse()
17-
line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {})
18-
flagtree_hints = line_flagtree_hints.get(line_num)
19-
20-
# Check if this is a tl.load call with dot_pad_only_k hint
21-
if (flagtree_hints and 'dot_pad_only_k' in flagtree_hints and
22-
isinstance(node.value, ast.Call) and
23-
isinstance(node.value.func, ast.Attribute) and
24-
isinstance(node.value.func.value, ast.Name) and
25-
node.value.func.value.id == 'tl' and
26-
node.value.func.attr == 'load'):
27-
28-
# Add hint annotation to the loaded tensor(s)
29-
for name, value in zip(names, values):
30-
if _is_triton_value(value):
31-
# print(f"[FLAGTREE] Creating hint annotation for tensor: {flagtree_hints}")
32-
# Create hint annotation
33-
hint_val = code_generator.builder.get_unit_attr()
34-
code_generator.builder.create_annotation(value.handle, 'dot_pad_only_k', hint_val)
9+
@staticmethod
10+
def ext_CodeGenerator_visit_Assign_hint_anno(code_generator, node, names, values):
11+
import ast
12+
from triton.compiler.code_generator import _is_triton_value
13+
# flagtree: After normal processing, check if we need to add hint annotation
14+
if hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn'):
15+
line_num = node.lineno
16+
# TODO: reparse needed in case we need to deal with complex cases, will be redesigned later
17+
function_def = code_generator.jit_fn.parse()
18+
line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {})
19+
flagtree_hints = line_flagtree_hints.get(line_num)
3520

36-
def visit_For_ext_support(self):
37-
import triton.language as language
38-
return [language.parallel]
21+
# Check if this is a tl.load call with dot_pad_only_k hint
22+
if (flagtree_hints and 'dot_pad_only_k' in flagtree_hints and
23+
isinstance(node.value, ast.Call) and
24+
isinstance(node.value.func, ast.Attribute) and
25+
isinstance(node.value.func.value, ast.Name) and
26+
node.value.func.value.id == 'tl' and
27+
node.value.func.attr == 'load'):
3928

40-
def set_bind_sub_block_when_parallel(self, IteratorClass, iterator, bind_sub_block):
41-
import triton.language as language
42-
if (IteratorClass is language.parallel):
43-
return iterator.bind_sub_block
44-
return bind_sub_block
29+
# Add hint annotation to the loaded tensor(s)
30+
for name, value in zip(names, values):
31+
if _is_triton_value(value):
32+
# print(f"[FLAGTREE] Creating hint annotation for tensor: {flagtree_hints}")
33+
# Create hint annotation
34+
hint_val = code_generator.builder.get_unit_attr()
35+
code_generator.builder.create_annotation(value.handle, 'dot_pad_only_k', hint_val)
4536

46-
def check_override_bind_sub_block(self, code_generator, node, bind_sub_block):
37+
@staticmethod
38+
def check_override_bind_sub_block(code_generator, node, bind_sub_block):
4739
# flagtree: After normal processing, check if we need to override bind_sub_block
4840
if hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn'):
4941
line_num = node.lineno
@@ -58,8 +50,33 @@ def check_override_bind_sub_block(self, code_generator, node, bind_sub_block):
5850
# print(f"[FLAGTREE] Found bind_sub_block hint at line {line_num}")
5951
return bind_sub_block
6052

61-
def forop_setattr_for_bind_sub_block(self, code_generator, for_op, bind_sub_block):
53+
@staticmethod
54+
def forop_setattr_for_bind_sub_block(code_generator, for_op, bind_sub_block):
6255
for_op.set_attr("bind_sub_block", code_generator.builder.get_bool_attr(bind_sub_block))
6356

64-
def need_repr_in_CodeGenerator_CompilationError(self):
65-
return True
57+
58+
@staticmethod
59+
def maps_line_numbers_to_comment_hints(jit_fn):
60+
import tokenize
61+
from io import StringIO
62+
# Maps line numbers to comment hints
63+
line_flagtree_hints = {}
64+
code_str = jit_fn.src
65+
g = tokenize.generate_tokens(StringIO(code_str).readline)
66+
for tok_type, tok_text, start, end, _ in g:
67+
if tok_type == tokenize.COMMENT:
68+
comment = tok_text.replace(" ", "").strip()
69+
if comment.startswith('#@hint:'):
70+
flagtree_hints = comment[len('#@hint:'):].strip()
71+
# Record the line number of the comment
72+
line_num = start[0]
73+
line_flagtree_hints[line_num] = flagtree_hints
74+
75+
# print(f"[FLAGTREE] Parsed hint at line {line_num}: {flagtree_hints}")
76+
77+
return line_flagtree_hints
78+
79+
@staticmethod
80+
def attach_line_number_to_comment_mapping(tree, line_flagtree_hints):
81+
# Attach the line number to comment mapping to the function definition node
82+
tree.body[0].line_flagtree_hints = line_flagtree_hints

0 commit comments

Comments
 (0)