Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions python/triton/compiler/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# ideally we wouldn't need any runtime component
from ..runtime import JITFunction
from .._utils import find_paths_if, get_iterable_path, set_iterable_path
from .hint_manager import hint_trigger

from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)

Expand Down Expand Up @@ -1243,10 +1244,7 @@ def visit_Call(self, node):
args = list(itertools.chain.from_iterable(x if isinstance(x, list) else [x] for x in args))

# 4. Get current line number and hints
line_num = node.lineno
function_def = self.jit_fn.parse()
line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {})
flagtree_hints = line_flagtree_hints.get(line_num)
flagtree_hints = hint_trigger("get_node_hints", self, node)

# 5. Handle JIT function calls
if isinstance(fn, JITFunction):
Expand All @@ -1261,12 +1259,7 @@ def visit_Call(self, node):
extra_kwargs['_generator'] = self
try:
# Special handling for tl.load with hints
if fn.__name__ == "load" and flagtree_hints is not None:
print(f"[FLAGTREE] tl.load at line {line_num} has annotation {flagtree_hints}")
if 'flagtree_hints' not in kws:
kws['flagtree_hints'] = ""
if flagtree_hints not in kws['flagtree_hints']:
kws['flagtree_hints'] = flagtree_hints
hint_trigger("inject_kwargs_with_hints", fn, flagtree_hints, node.lineno, kws)

ret = fn(*args, **extra_kwargs, **kws)
# builtin functions return plain tuples for readability
Expand Down
151 changes: 151 additions & 0 deletions python/triton/compiler/hint_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import os
import sys
import importlib


class BaseHintHandler:
# dynamicly find method
def trigger(self, hook_name, *args, **kwargs):
if hasattr(self, hook_name):
method = getattr(self, hook_name)
if callable(method):
try:
return method(*args, **kwargs)

except TypeError as e:
import inspect

try:
sig = inspect.signature(method)
expected = str(sig)
except Exception:
expected = "(unknown)"

actual_args = f"{len(args)} positional"
actual_kwargs = f"keys={list(kwargs.keys())}" if kwargs else "no keywords"

print(f"\n[Hint Trigger Mismatch] {self.__class__.__name__}.{hook_name}")
print(f" > Expect : {expected}")
print(f" > Actual : {actual_args}, {actual_kwargs}")
print(f" > Reason : {e}\n")

raise e
return None


class HintManager:

def __init__(self, backend_name):
self.backend_name = backend_name
# load Handler with backend name
self.handler = self._load_handler(backend_name)

def _load_handler(self, backend):
if backend == 'npu':
try:
module = importlib.import_module("third_party.ascend.backend.ascend_hint_handler")
return module.AscendHintHandler()
except ImportError as e:
print(f"[FlagTree] Warning: Failed to load Ascend Hint Handler: {e}", file=sys.stderr)
return BaseHintHandler()
elif backend == 'aipu':
try:
module = importlib.import_module("third_party.aipu.backend.aipu_hint_handler")
return module.AipuHintHandler()
except ImportError as e:
print(f"[FlagTree] Warning: Failed to load aipu Hint Handler: {e}", file=sys.stderr)
return BaseHintHandler()
else:
return BaseHintHandler()


# supported backend with matched version
SUPPORTED_CONFIG = {
"cuda": {"3.5"},
"npu": {"3.2"},
"aipu": {"3.3"},
}

# mapping name
BACKEND_ALIASES = {
"ascend": "npu",
"huawei": "npu",
"nv": "cuda",
}


def normalize_backend_name(name: str) -> str:
if not name:
return ""
name = name.lower()
return BACKEND_ALIASES.get(name, name)


def hint_get_flagtree_backend() -> str:
detected_backend = ""

import torch
import triton

# Priority 1: Triton Driver
try:
from triton.runtime import driver
if hasattr(driver, 'active') and hasattr(driver.active, 'get_active_torch_device'):
device = driver.active.get_active_torch_device()
if isinstance(device, torch.device):
detected_backend = device.type
# unimplemented support
elif isinstance(device, str):
detected_backend = device
except ImportError:
pass

# Priority 2: Torch Global State
if not detected_backend:
candidates = list(SUPPORTED_CONFIG.keys())
# cuda priority least
candidates.sort(key=lambda x: 1 if x == "cuda" else 0)

# 3. parse according to benefit
for candidate in candidates:
module_name = candidate
module = getattr(torch, module_name, None)
if module and hasattr(module, "is_available") and module.is_available():
detected_backend = candidate
break

# Priority 3: Environment Variable (need to remove!!!)
if not detected_backend:
detected_backend = os.environ.get("FLAGTREE_BACKEND", "")

# (Normalization and Validation)
canonical_backend = normalize_backend_name(detected_backend)

if not canonical_backend or canonical_backend not in SUPPORTED_CONFIG:
return ""

# verify name and version match
try:
current_triton_version = ".".join(triton.__version__.split(".")[:2])
supported_versions = SUPPORTED_CONFIG[canonical_backend]
if current_triton_version not in supported_versions:
msg = (f"[Flagtree] Hint ignored: Detected backend '{canonical_backend}' but current Triton version "
f"'{current_triton_version}' matches no supported versions {supported_versions}.")
print(msg, file=sys.stderr)
return ""
except Exception:
pass

return canonical_backend


# lazy load after first call hint trigger
_global_hint_manager = None


def hint_trigger(hook_name, *args, **kwargs):
global _global_hint_manager

if _global_hint_manager is None:
_global_hint_manager = HintManager(hint_get_flagtree_backend())
return _global_hint_manager.handler.trigger(hook_name, *args, **kwargs)
19 changes: 5 additions & 14 deletions python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
from ..runtime.driver import driver
from types import ModuleType
from .._utils import find_paths_if, get_iterable_path
import tokenize
from io import StringIO

TRITON_MODULE = __name__[:-len(".runtime.jit")]

Expand Down Expand Up @@ -705,26 +703,19 @@ def preload(self, specialization_data):
# the user might want to monkey-patch self.src dynamically.
# Our unit tests do this, for example.
def parse(self):
from ..compiler.hint_manager import hint_trigger
# Maps line numbers to comment hints
line_flagtree_hints = {}
code_str = self.src
g = tokenize.generate_tokens(StringIO(code_str).readline)
for tok_type, tok_text, start, end, _ in g:
if tok_type == tokenize.COMMENT:
comment = tok_text.replace(" ", "").strip()
if comment.startswith('#@hint:'):
flagtree_hints = comment[len('#@hint:'):].strip()
# Record the line number of the comment
line_num = start[0]
line_flagtree_hints[line_num] = flagtree_hints
line_flagtree_hints = hint_trigger("maps_line_numbers_to_comment_hints", self)
if line_flagtree_hints is None:
line_flagtree_hints = {}

tree = ast.parse(self.src)
assert isinstance(tree, ast.Module)
assert len(tree.body) == 1
assert isinstance(tree.body[0], ast.FunctionDef)

# Attach the line number to comment mapping to the function definition node
tree.body[0].line_flagtree_hints = line_flagtree_hints
hint_trigger("attach_line_number_to_comment_mapping", tree, line_flagtree_hints)
return tree

def __call__(self, *args, **kwargs):
Expand Down
53 changes: 53 additions & 0 deletions third_party/aipu/backend/aipu_hint_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# should store at third_party/aipu/backend/
from triton.compiler.hint_manager import BaseHintHandler
import triton.language as language
import ast
from triton.compiler.code_generator import _is_triton_value


class AipuHintHandler(BaseHintHandler):
# because aipu is diff from ascend in 2 aspects
# 1. not backend_spec, modify triton src violently
# 2. modify builder, semantic, core, and so on. pollute the src, which cant be involved in hint manager
# for this, we just move changes in codegen & jit into hintmanager.

@staticmethod
def get_node_hints(code_generator, node):
line_num = node.lineno
function_def = code_generator.jit_fn.parse()
line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {})
flagtree_hints = line_flagtree_hints.get(line_num)
return flagtree_hints

@staticmethod
def inject_kwargs_with_hints(fn, flagtree_hints, line_num, kws):
if fn.__name__ == "load" and flagtree_hints is not None:
print(f"[FLAGTREE] tl.load at line {line_num} has annotation {flagtree_hints}")
if 'flagtree_hints' not in kws:
kws['flagtree_hints'] = ""
if flagtree_hints not in kws['flagtree_hints']:
kws['flagtree_hints'] = flagtree_hints

@staticmethod
def maps_line_numbers_to_comment_hints(jit_fn):
import tokenize
from io import StringIO
# Maps line numbers to comment hints
line_flagtree_hints = {}
code_str = jit_fn.src
g = tokenize.generate_tokens(StringIO(code_str).readline)
for tok_type, tok_text, start, end, _ in g:
if tok_type == tokenize.COMMENT:
comment = tok_text.replace(" ", "").strip()
if comment.startswith('#@hint:'):
flagtree_hints = comment[len('#@hint:'):].strip()
# Record the line number of the comment
line_num = start[0]
line_flagtree_hints[line_num] = flagtree_hints

return line_flagtree_hints

@staticmethod
def attach_line_number_to_comment_mapping(tree, line_flagtree_hints):
if tree.body:
tree.body[0].line_flagtree_hints = line_flagtree_hints
Loading