Skip to content

Commit 5ffb5f9

Browse files
SabYicflagtree-botsgjzfzzfGalaxy1458
authored
[Build] Add flagtree config class for triton_v3.5.x (#350)
* modified: python/setup_tools/setup_helper.py modified: python/setup_tools/utils/__init__.py modified: python/setup_tools/utils/aipu.py modified: python/setup_tools/utils/tools.py modified: setup.py * Apply code-format changes --------- Co-authored-by: flagtree-bot <flagtree_ai@163.com> Co-authored-by: Jinjie Liu <jjliu@baai.ac.cn> Co-authored-by: Galaxy1458 <55453380+Galaxy1458@users.noreply.github.com>
1 parent 7447688 commit 5ffb5f9

File tree

5 files changed

+80
-39
lines changed

5 files changed

+80
-39
lines changed

python/setup_tools/setup_helper.py

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,11 @@
99
import importlib.util
1010
import importlib.metadata
1111
from typing import List, Tuple
12+
from .utils.tools import flagtree_configs as configs
1213

13-
extend_backends = []
14-
default_backends = ["nvidia", "amd"]
15-
plugin_backends = ["cambricon", "ascend", "aipu", "tsingmicro"]
16-
ext_sourcedir = "triton/_C/"
17-
flagtree_backend = os.getenv("FLAGTREE_BACKEND", "").lower()
18-
flagtree_plugin = os.getenv("FLAGTREE_PLUGIN", "").lower()
19-
offline_build = os.getenv("FLAGTREE_PLUGIN", "OFF")
20-
device_mapping = {"xpu": "xpu", "mthreads": "musa", "ascend": "ascend"}
21-
activated_module = utils.activate(flagtree_backend)
2214
downloader = utils.tools.DownloadManager()
15+
configs = configs
16+
flagtree_backend = configs.flagtree_backend
2317

2418
set_llvm_env = lambda path: set_env(
2519
{
@@ -33,26 +27,26 @@
3327

3428
def install_extension(*args, **kargs):
3529
try:
36-
activated_module.install_extension(*args, **kargs)
30+
configs.activated_module.install_extension(*args, **kargs)
3731
except Exception:
3832
pass
3933

4034

4135
def get_backend_cmake_args(*args, **kargs):
4236
try:
43-
return activated_module.get_backend_cmake_args(*args, **kargs)
37+
return configs.activated_module.get_backend_cmake_args(*args, **kargs)
4438
except Exception:
4539
return []
4640

4741

4842
def get_device_name():
49-
return device_mapping[flagtree_backend]
43+
return configs.device_alias_map[flagtree_backend]
5044

5145

5246
def get_extra_packages():
5347
packages = []
5448
try:
55-
packages = activated_module.get_extra_install_packages()
49+
packages = configs.activated_module.get_extra_install_packages()
5650
except Exception:
5751
packages = []
5852
return packages
@@ -61,7 +55,7 @@ def get_extra_packages():
6155
def get_package_data_tools():
6256
package_data = ["compile.h", "compile.c"]
6357
try:
64-
package_data += activated_module.get_package_data_tools()
58+
package_data += configs.activated_module.get_package_data_tools()
6559
except Exception:
6660
package_data
6761
return package_data
@@ -88,15 +82,15 @@ def download_flagtree_third_party(name, condition, required=False, hock=None):
8882
downloader.download(module=submodule, required=required)
8983
if callable(hock):
9084
hock(third_party_base_dir=utils.flagtree_submodule_dir, backend=submodule,
91-
default_backends=default_backends)
85+
default_backends=configs.default_backends)
9286

9387
else:
9488
print(f"\033[1;33m[Note] Skip downloading {name} since USE_{name.upper()} is set to OFF\033[0m")
9589

9690

9791
def post_install():
9892
try:
99-
activated_module.post_install()
93+
configs.activated_module.post_install()
10094
except Exception:
10195
pass
10296

@@ -323,14 +317,14 @@ def skip_package_dir(package):
323317
if 'backends' in package or 'profiler' in package:
324318
return True
325319
try:
326-
return activated_module.skip_package_dir(package)
320+
return configs.activated_module.skip_package_dir(package)
327321
except Exception:
328322
return False
329323

330324
@staticmethod
331325
def get_package_dir(packages):
332326
package_dict = {}
333-
if flagtree_backend and flagtree_backend not in plugin_backends:
327+
if flagtree_backend and flagtree_backend not in configs.plugin_backends:
334328
connection = []
335329
backend_triton_path = f"../third_party/{flagtree_backend}/python/"
336330
for package in packages:
@@ -340,7 +334,7 @@ def get_package_dir(packages):
340334
connection.append(pair)
341335
package_dict.update(connection)
342336
try:
343-
package_dict.update(activated_module.get_package_dir())
337+
package_dict.update(configs.activated_module.get_package_dir())
344338
except Exception:
345339
pass
346340
return package_dict
@@ -350,8 +344,8 @@ def handle_flagtree_backend():
350344
global ext_sourcedir
351345
if flagtree_backend:
352346
print(f"\033[1;32m[INFO] FlagtreeBackend is {flagtree_backend}\033[0m")
353-
extend_backends.append(flagtree_backend)
354-
if "editable_wheel" in sys.argv and flagtree_backend not in plugin_backends:
347+
configs.extend_backends.append(flagtree_backend)
348+
if "editable_wheel" in sys.argv and flagtree_backend not in configs.plugin_backends:
355349
ext_sourcedir = os.path.abspath(f"../third_party/{flagtree_backend}/python/{ext_sourcedir}") + "/"
356350

357351

@@ -410,7 +404,7 @@ def uninstall_triton():
410404
)
411405

412406
cache.store(
413-
file="iluvatarTritonPlugin.so", condition=("iluvatar" == flagtree_backend) and (not flagtree_plugin), url=
407+
file="iluvatarTritonPlugin.so", condition=("iluvatar" == flagtree_backend) and (not configs.flagtree_plugin), url=
414408
"https://baai-cp-web.ks3-cn-beijing.ksyuncs.com/trans/iluvatarTritonPlugin-cpython3.10-glibc2.30-glibcxx3.4.28-cxxabi1.3.12-ubuntu-x86_64_v0.3.0.tar.gz",
415409
copy_dst_path=f"third_party/{flagtree_backend}", md5_digest="015b9af8")
416410

@@ -449,7 +443,7 @@ def uninstall_triton():
449443
)
450444

451445
cache.store(
452-
file="mthreadsTritonPlugin.so", condition=("mthreads" == flagtree_backend) and (not flagtree_plugin), url=
446+
file="mthreadsTritonPlugin.so", condition=("mthreads" == flagtree_backend) and (not configs.flagtree_plugin), url=
453447
"https://baai-cp-web.ks3-cn-beijing.ksyuncs.com/trans/mthreadsTritonPlugin-cpython3.10-glibc2.35-glibcxx3.4.30-cxxabi1.3.13-ubuntu-x86_64_v0.3.0.tar.gz",
454448
copy_dst_path=f"third_party/{flagtree_backend}", md5_digest="2a9ca0b8")
455449

python/setup_tools/utils/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,16 @@
22
import importlib.util
33
import os
44
from . import tools, default, aipu
5-
from .tools import flagtree_submodule_dir, OfflineBuildManager
5+
from .tools import flagtree_configs, OfflineBuildManager
66

77
flagtree_submodules = {
88
"triton_shared":
99
tools.Module(name="triton_shared", url="https://github.com/microsoft/triton-shared.git",
1010
commit_id="5842469a16b261e45a2c67fbfc308057622b03ee",
11-
dst_path=os.path.join(flagtree_submodule_dir, "triton_shared")),
11+
dst_path=os.path.join(flagtree_configs.flagtree_submodule_dir, "triton_shared")),
1212
"flir":
1313
tools.Module(name="flir", url="https://github.com/FlagTree/flir.git",
14-
dst_path=os.path.join(flagtree_submodule_dir, "flir")),
14+
dst_path=os.path.join(flagtree_configs.flagtree_submodule_dir, "flir")),
1515
}
1616

1717

python/setup_tools/utils/aipu.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
def precompile_hock(*args, **kargs):
22
default_backends = kargs["default_backends"]
3-
default_backends.append('flir')
3+
default_backends_list = [*default_backends, "flir"]
4+
kargs["default_backends"] = tuple(default_backends_list)
5+
default_backends = tuple(default_backends_list)
6+
return default_backends

python/setup_tools/utils/tools.py

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,54 @@
1010

1111
from python.build_helpers import get_base_dir
1212
import platform
13+
from typing import Mapping
14+
from types import MappingProxyType
15+
import importlib.util
16+
from dataclasses import field
1317

14-
flagtree_root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
15-
flagtree_submodule_dir = os.path.join(flagtree_root_dir, "third_party")
16-
flagtree_backend = os.environ.get("FLAGTREE_BACKEND")
17-
use_cuda_toolkit = ["aipu"]
18+
19+
def _get_flagtree_root() -> str:
20+
return str(Path(__file__).resolve().parents[3])
21+
22+
23+
@dataclass
24+
class FlagtreeConfigs:
25+
default_backends: tuple = ("nvidia", "amd")
26+
plugin_backends: tuple = ("cambricon", "ascend", "aipu", "tsingmicro", "enflame")
27+
use_cuda_toolkit_backends: tuple = ('aipu', )
28+
language_extra_backends: tuple = ('xpu', 'mthreads', "cambricon")
29+
ext_sourcedir: str = "triton/_C/"
30+
flagtree_root_dir: str = field(default_factory=_get_flagtree_root)
31+
flagtree_backend: str = field(default_factory=lambda: os.environ.get("FLAGTREE_BACKEND"))
32+
flagtree_plugin: str = field(default_factory=lambda: os.environ.get("FLAGTREE_PLUGIN"))
33+
extend_backends: list = field(default_factory=list)
34+
activated_module: any = None
35+
flagtree_submodule_dir: str = ''
36+
device_alias_map: Mapping[str, str] = field(default_factory=lambda: MappingProxyType({
37+
"xpu": "xpu",
38+
"mthreads": "musa",
39+
"ascend": "ascend",
40+
"cambricon": "mlu",
41+
}))
42+
43+
def __post_init__(self):
44+
self.flagtree_submodule_dir = os.path.join(self.flagtree_root_dir, "third_party")
45+
self.activated_module = self._activate_device_module()
46+
47+
def _activate_device_module(self, suffix=".py"):
48+
backend = self.flagtree_backend or "default"
49+
module_path = Path(os.path.dirname(__file__)) / backend
50+
module_path = str(module_path) + suffix
51+
spec = importlib.util.spec_from_file_location("module", module_path)
52+
module = importlib.util.module_from_spec(spec)
53+
try:
54+
spec.loader.exec_module(module)
55+
except (AttributeError, FileNotFoundError, ImportError, ModuleNotFoundError):
56+
pass
57+
return module
58+
59+
60+
flagtree_configs = FlagtreeConfigs()
1861

1962

2063
@dataclass
@@ -41,7 +84,8 @@ def dir_rollback(deep, base_path):
4184

4285

4386
def is_skip_cuda_toolkits():
44-
return flagtree_backend and (flagtree_backend not in use_cuda_toolkit)
87+
return flagtree_configs.flagtree_backend and (flagtree_configs.flagtree_backend
88+
not in flagtree_configs.use_cuda_toolkit_backends)
4589

4690

4791
def remove_triton_in_modules(model):
@@ -216,7 +260,7 @@ def is_offline_build(self) -> bool:
216260
return os.getenv("TRITON_OFFLINE_BUILD", "OFF") == "ON" or os.getenv("FLAGTREE_OFFLINE_BUILD_DIR")
217261

218262
def copy_to_flagtree_project(self, kargs):
219-
dst_path = os.path.join(flagtree_root_dir,
263+
dst_path = os.path.join(_get_flagtree_root(),
220264
kargs['dst_path']) if 'dst_path' in kargs and kargs['dst_path'] else None
221265
src_path = self.src
222266
if not dst_path:
@@ -265,7 +309,7 @@ def handle_triton_origin_toolkits(self):
265309
shutil.copytree(src_path, toolkit_cache_path, dirs_exist_ok=True)
266310
else:
267311
raise RuntimeError(
268-
f"\n\n \033[31m[ERROR]:\033[0m The {flagtree_backend} offline build dependency \033[93m{src_path}\033[0m does not exist.\n"
312+
f"\n\n \033[31m[ERROR]:\033[0m The {flagtree_configs.flagtree_backend} offline build dependency \033[93m{src_path}\033[0m does not exist.\n"
269313
)
270314

271315
def validate_offline_build_dir(self, path, required=False):
@@ -280,7 +324,7 @@ def validate_offline_build_deps(self, path, kargs, required=False):
280324
url = kargs.get('url', None)
281325
if (not path or not os.path.exists(path)) and required:
282326
raise RuntimeError(
283-
f"\n\n \033[31m[ERROR]:\033[0m The {flagtree_backend} offline build dependency \033[93m{path}\033[0m does not exist.\n"
327+
f"\n\n \033[31m[ERROR]:\033[0m The {flagtree_configs.flagtree_backend} offline build dependency \033[93m{path}\033[0m does not exist.\n"
284328
f" And you can download the dependency package from the \n \033[93m{url}\033[0m \n"
285329
f" then extract it to the \033[93m{self.offline_build_dir}\033[0m directory you specified !\033[0m\n\n")
286330

@@ -301,7 +345,7 @@ def single_build(self, *args, **kargs):
301345
self.copy_to_flagtree_project(kargs)
302346
self.handle_flagtree_hock(kargs)
303347
if is_skip_cuda_toolkits():
304-
print(f"[INFO] Skipping CUDA toolkits for {flagtree_backend} backend in offline build.")
348+
print(f"[INFO] Skipping CUDA toolkits for {flagtree_configs.flagtree_backend} backend in offline build.")
305349
else:
306350
self.handle_triton_origin_toolkits()
307351
return True

setup.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -620,13 +620,13 @@ def download_and_copy_dependencies():
620620
if helper.flagtree_backend:
621621
if helper.flagtree_backend in ("aipu", "tsingmicro"):
622622
backends = [
623-
*BackendInstaller.copy(helper.default_backends + helper.extend_backends),
623+
*BackendInstaller.copy(helper.configs.default_backends + tuple(helper.configs.extend_backends)),
624624
*BackendInstaller.copy_externals(),
625625
]
626626
else:
627-
backends = [*BackendInstaller.copy(helper.extend_backends), *BackendInstaller.copy_externals()]
627+
backends = [*BackendInstaller.copy(helper.configs.extend_backends), *BackendInstaller.copy_externals()]
628628
else:
629-
print(helper.default_backends)
629+
print(helper.configs.default_backends)
630630
backends = [*BackendInstaller.copy(["nvidia", "amd"]), *BackendInstaller.copy_externals()]
631631

632632
#backends = [*BackendInstaller.copy(["nvidia", "amd"]), *BackendInstaller.copy_externals()]

0 commit comments

Comments
 (0)