-
-
Notifications
You must be signed in to change notification settings - Fork 64
Add support for multi-arch and multi-platform cuda toolchains #422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
charleysaintNV
wants to merge
11
commits into
bazel-contrib:main
Choose a base branch
from
charleysaintNV:csaint/multi-arch
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
6f3d716
add multi-platform and multi-version support to rules_cuda
charleysaintNV 235a94a
fix alias targets that have a path
charleysaintNV d6a7478
Enable multi-version builds
lgulich 5e605df
Merge branch 'lgulich/enable-multi-version-builds' into 'csaint/multi…
charleysaintNV 7bc9c61
add error messages for bad platform/version selection
charleysaintNV 3e1f206
move versions and chaange flag names
charleysaintNV 53f739e
remove todos
charleysaintNV 4233145
remove list of versions
charleysaintNV 9fbd1ab
fix issues with dummy files
charleysaintNV d9cfe1c
Fix typos
lgulich 7eb64fe
Fix platform fallback for components missing on some platforms
lgulich File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| #error libdevice.10.bc of cuda toolkit does not exist |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,160 @@ | ||
| """Module extension for creating platform-specific aliases for external dependencies. | ||
|
|
||
| This extension creates repositories with alias targets that select between x86_64 and ARM64 | ||
| repositories based on the build platform. It also selects between versions of the component. | ||
| """ | ||
|
|
||
| load("//cuda/private:platforms.bzl", "SUPPORTED_PLATFORMS") | ||
| load("//cuda/private:templates/registry.bzl", "REGISTRY") | ||
|
|
||
| # Use REGISTRY as the source of truth for component targets | ||
| TARGET_MAPPING = REGISTRY | ||
|
|
||
| def _platform_alias_repo_impl(ctx): | ||
| """Implementation of the platform_alias_repo repository rule. | ||
|
|
||
| Args: | ||
| ctx: Repository context with attributes x86_repo, arm64_repo, and targets. | ||
| """ | ||
|
|
||
| # Generate BUILD.bazel content with platform-specific aliases | ||
| build_content = ["# Generated by platform_alias_repo rule", ""] | ||
|
|
||
| # Add load statement for alias | ||
| build_content.append('load("@rules_cuda//cuda:defs.bzl", "unsupported_cuda_version", "unsupported_cuda_platform")') | ||
| build_content.append("") | ||
|
|
||
| build_content.append("[") | ||
| build_content.append(" config_setting(") | ||
| build_content.append(' name = "version_is_{}".format(version.replace(".", "_")),') | ||
| build_content.append(' flag_values = {"@rules_cuda//cuda:version": "{}".format(version)},') | ||
| build_content.append(" )") | ||
| build_content.append(" for version in {}".format(ctx.attr.versions)) | ||
| build_content.append("]") | ||
| build_content.append("") | ||
|
|
||
| build_content.append('unsupported_cuda_version(name = "unsupported_cuda_version", component = "{}", available_versions = {})'.format(ctx.attr.component_name, ctx.attr.versions)) | ||
| build_content.append("") | ||
|
|
||
| # Build a target for the name of the repo (only if at least one platform is available). | ||
| platform_type = "exec" if ctx.attr.component_name in ["nvcc", "nvvm"] else "target" | ||
|
|
||
| # Check which platforms are available (have at least one version). | ||
| platforms_available = [] | ||
| if len(ctx.attr.linux_x86_64_repos) > 0: | ||
| platforms_available.append("linux-x86_64") | ||
| if len(ctx.attr.linux_sbsa_repos) > 0: | ||
| platforms_available.append("linux-sbsa") | ||
| if len(ctx.attr.linux_aarch64_repos) > 0: | ||
| platforms_available.append("linux-aarch64") | ||
|
|
||
| # Always create unsupported_cuda_platform target - it's used as the default case | ||
| # in select() when no platform condition matches. | ||
| build_content.append('unsupported_cuda_platform(name = "unsupported_cuda_platform", component = "{}", available_platforms = {})'.format(ctx.attr.component_name, platforms_available)) | ||
| build_content.append("") | ||
|
|
||
| # Only generate target aliases if this component is in TARGET_MAPPING. | ||
| if ctx.attr.component_name not in TARGET_MAPPING: | ||
| # Write the BUILD.bazel file with just the main alias. | ||
| ctx.file("BUILD.bazel", "\n".join(build_content)) | ||
| return | ||
|
|
||
| for target in TARGET_MAPPING[ctx.attr.component_name]: | ||
| # Create alias for each target with platform selection. | ||
| # Always add conditions for ALL platforms so that builds on any platform | ||
| # have a matching select condition. Platforms where the component doesn't | ||
| # exist will use a dummy target. | ||
| target_name = target if target.find("/") == -1 else target.split("/")[-1] | ||
|
|
||
| # Determine appropriate dummy target based on the target name. | ||
| dummy_target = "@rules_cuda//cuda/dummy:dummy" | ||
| if target_name == "cicc": | ||
| dummy_target = "@rules_cuda//cuda/dummy:cicc" | ||
| elif target_name == "libdevice.10.bc": | ||
| dummy_target = "@rules_cuda//cuda/dummy:libdevice.10.bc" | ||
|
|
||
| build_content.append("alias(") | ||
| build_content.append(' name = "{}",'.format(target_name)) | ||
| build_content.append(" actual = select({") | ||
| # Add conditions for ALL platforms, using dummy for unavailable ones. | ||
| for platform in SUPPORTED_PLATFORMS: | ||
| platform_suffix = platform.replace("-", "_") | ||
| build_content.append(' "@rules_cuda//cuda:{}_platform_is_{}":'.format(platform_type, platform_suffix)) | ||
| if platform in platforms_available: | ||
| build_content.append(' ":{}_{}",'.format(platform_suffix, target_name)) | ||
| else: | ||
| # Platform doesn't have this component, use dummy target. | ||
| build_content.append(' "{}",'.format(dummy_target)) | ||
| build_content.append(' "//conditions:default": ":unsupported_cuda_platform",') | ||
| build_content.append(" }),") | ||
| build_content.append(' visibility = ["//visibility:public"],') | ||
| build_content.append(")") | ||
| build_content.append("") | ||
|
|
||
| # Generate platform-specific aliases for ALL platforms. | ||
| # Platforms where the component exists get version-based selection. | ||
| # Platforms where it doesn't exist get dummy targets for all versions. | ||
| # This ensures builds on any platform have matching select conditions. | ||
|
|
||
| platform_repos_map = { | ||
| "linux-x86_64": ctx.attr.linux_x86_64_repos, | ||
| "linux-sbsa": ctx.attr.linux_sbsa_repos, | ||
| "linux-aarch64": ctx.attr.linux_aarch64_repos, | ||
| } | ||
|
|
||
| for platform in SUPPORTED_PLATFORMS: | ||
| platform_suffix = platform.replace("-", "_") | ||
| repos_dict = platform_repos_map[platform] | ||
| platform_available = platform in platforms_available | ||
|
|
||
| build_content.append("alias(") | ||
| build_content.append(' name = "{}_{}",'.format(platform_suffix, target_name)) | ||
| build_content.append(" actual = select({") | ||
|
|
||
| for version in ctx.attr.versions: | ||
| build_content.append(' ":version_is_{}": '.format(version.replace(".", "_"))) | ||
| if platform_available and version in repos_dict: | ||
| repo_name = repos_dict[version] | ||
| build_content.append(' "@{}//{}",'.format(repo_name, target if target.find(":") != -1 else ":" + target)) | ||
| else: | ||
| # Platform doesn't have this component for this version, use dummy. | ||
| build_content.append(' "{}",'.format(dummy_target)) | ||
| build_content.append(' "//conditions:default": ":unsupported_cuda_version",') | ||
|
|
||
| build_content.append(" }),") | ||
| build_content.append(' visibility = ["//visibility:public"],') | ||
| build_content.append(")") | ||
| build_content.append("") | ||
|
|
||
| # Write the BUILD.bazel file | ||
| ctx.file("BUILD.bazel", "\n".join(build_content)) | ||
|
|
||
| platform_alias_repo = repository_rule( | ||
| implementation = _platform_alias_repo_impl, | ||
| attrs = { | ||
| "repo_name": attr.string( | ||
| mandatory = True, | ||
| doc = "Original name of the repository", | ||
| ), | ||
| "component_name": attr.string( | ||
| mandatory = True, | ||
| doc = "Name of the component", | ||
| ), | ||
| "linux_x86_64_repos": attr.string_dict( | ||
| default = {}, | ||
| doc = "Dictionary mapping versions to x86_64 repository names", | ||
| ), | ||
| "linux_aarch64_repos": attr.string_dict( | ||
| default = {}, | ||
| doc = "Dictionary mapping versions to ARM64/Jetpack repository names", | ||
| ), | ||
| "linux_sbsa_repos": attr.string_dict( | ||
| default = {}, | ||
| doc = "Dictionary mapping versions to SBSA repository names", | ||
| ), | ||
| "versions": attr.string_list( | ||
| mandatory = True, | ||
| doc = "List of versions to create aliases for", | ||
| ), | ||
| }, | ||
| ) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not quite true if CTK delete some component in the future. I think a union across all CTK versions will be a little bit more robust.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could take some work, right now the version in cuda_toolkit isn't going to necessarily be correct since it's pointing to @cuda which can point to any number of versioned cuda repos, but I don't know if that gets used anywhere in the rules so I'll try removing it and see what falls out...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic is pretty deeply embedded in the repository rules where I can't use the value of a flag. I might need to go back and add the ability to register multiple toolkits to get everything to work as expected...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets leave it for future improvement, just point it out :)