Skip to content
Open
32 changes: 32 additions & 0 deletions cuda/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ load(
"bool_flag",
"string_flag",
)
load("//cuda/private:platforms.bzl", "SUPPORTED_PLATFORMS")
load("//cuda/private:rules/flags.bzl", "cuda_archs_flag", "repeatable_string_flag")

package(default_visibility = ["//visibility:public"])
Expand Down Expand Up @@ -38,6 +39,37 @@ config_setting(
flag_values = {"@cuda//:valid_toolchain_found": "True"},
)

string_flag(
name = "version",
build_setting_default = "13.0.0",
)

string_flag(
name = "target_platform",
build_setting_default = "linux-x86_64",
values = SUPPORTED_PLATFORMS,
)

[
config_setting(
name = "target_platform_is_{}".format(platform.replace("-", "_")),
flag_values = {":target_platform": platform},
) for platform in SUPPORTED_PLATFORMS
]

string_flag(
name = "exec_platform",
build_setting_default = "linux-x86_64",
values = SUPPORTED_PLATFORMS,
)

[
config_setting(
name = "exec_platform_is_{}".format(platform.replace("-", "_")),
flag_values = {":exec_platform": platform},
) for platform in SUPPORTED_PLATFORMS
]

# Command line flag to specify the list of CUDA architectures to compile for.
#
# Provides CudaArchsInfo of the list of archs to build.
Expand Down
3 changes: 3 additions & 0 deletions cuda/defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Core rules for building CUDA projects.
"""

load("//cuda/private:defs.bzl", _requires_cuda = "requires_cuda")
load("//cuda/private:errors.bzl", _unsupported_cuda_version = "unsupported_cuda_version", _unsupported_cuda_platform = "unsupported_cuda_platform")
load("//cuda/private:macros/cuda_binary.bzl", _cuda_binary = "cuda_binary")
load("//cuda/private:macros/cuda_test.bzl", _cuda_test = "cuda_test")
load("//cuda/private:os_helpers.bzl", _cc_import_versioned_sos = "cc_import_versioned_sos", _if_linux = "if_linux", _if_windows = "if_windows")
Expand Down Expand Up @@ -47,3 +48,5 @@ if_windows = _if_windows
cc_import_versioned_sos = _cc_import_versioned_sos

requires_cuda = _requires_cuda
unsupported_cuda_version = _unsupported_cuda_version
unsupported_cuda_platform = _unsupported_cuda_platform
15 changes: 14 additions & 1 deletion cuda/dummy/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,19 @@ cc_binary(
defines = ["TOOLNAME=nvcc"],
)

cc_binary(
name = "cicc",
srcs = ["dummy.cpp"],
defines = ["TOOLNAME=cicc"],
)

cc_binary(
name = "nvlink",
srcs = ["dummy.cpp"],
defines = ["TOOLNAME=nvlink"],
)

exports_files(["link.stub"])
exports_files(["link.stub", "libdevice.10.bc"])

cc_binary(
name = "bin2c",
Expand All @@ -25,3 +31,10 @@ cc_binary(
srcs = ["dummy.cpp"],
defines = ["TOOLNAME=fatbinary"],
)

# Empty cc_library that provides CcInfo for components not available in this CUDA version.
cc_library(
name = "dummy",
srcs = [],
hdrs = [],
)
1 change: 1 addition & 0 deletions cuda/dummy/libdevice.10.bc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#error libdevice.10.bc of cuda toolkit does not exist
77 changes: 61 additions & 16 deletions cuda/extensions.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

load("//cuda/private:redist_json_helper.bzl", "redist_json_helper")
load("//cuda/private:repositories.bzl", "cuda_component", "cuda_toolkit")
load("//cuda:platform_alias_extension.bzl", "platform_alias_repo")

cuda_component_tag = tag_class(attrs = {
"name": attr.string(mandatory = True, doc = "Repo name for the deliverable cuda_component"),
Expand Down Expand Up @@ -53,6 +54,9 @@ cuda_redist_json_tag = tag_class(attrs = {
"URLs are tried in order until one succeeds, so you should list local mirrors first. " +
"If all downloads fail, the rule will fail.",
),
"platforms": attr.string_list(
doc = "A list of platforms to generate components for.",
),
"version": attr.string(
doc = "Generate a URL by using the specified version." +
"This URL will be tried after all URLs specified in the `urls` attribute.",
Expand Down Expand Up @@ -95,17 +99,20 @@ def _module_tag_to_dict(t):
def _redist_json_impl(module_ctx, attr):
url, json_object = redist_json_helper.get(module_ctx, attr)
redist_ver = redist_json_helper.get_redist_version(module_ctx, attr, json_object)
component_specs = redist_json_helper.collect_specs(module_ctx, attr, json_object, url)

mapping = {}
for spec in component_specs:
repo_name = redist_json_helper.get_repo_name(module_ctx, spec)
mapping[spec["component_name"]] = "@" + repo_name
platform_mapping = {}
for platform in attr.platforms:
component_specs = redist_json_helper.collect_specs(module_ctx, attr, platform, json_object, url)
mapping = {}
for spec in component_specs:
repo_name = redist_json_helper.get_repo_name(module_ctx, spec)
mapping[spec["component_name"]] = repo_name

attr = {key: value for key, value in spec.items()}
attr["name"] = repo_name
cuda_component(**attr)
return redist_ver, mapping
component_attr = {key: value for key, value in spec.items()}
component_attr["name"] = repo_name + "_" + platform.replace("-", "_") + "_" + redist_ver.replace(".", "_")
cuda_component(**component_attr)
platform_mapping[platform] = mapping
return redist_ver, platform_mapping

def _impl(module_ctx):
# Toolchain configuration is only allowed in the root module, or in rules_cuda.
Expand All @@ -121,18 +128,47 @@ def _impl(module_ctx):
components = rules_cuda.tags.component
redist_jsons = rules_cuda.tags.redist_json
toolkits = rules_cuda.tags.toolkit

for component in components:
cuda_component(**_module_tag_to_dict(component))

if len(redist_jsons) > 1:
fail("Using multiple cuda.redist_json is not supported yet.")

redist_version = None
components_mapping = None
for redist_json in redist_jsons:
redist_version, components_mapping = _redist_json_impl(module_ctx, redist_json)
redist_versions = []
redist_components_mapping = {}

# Track all versioned repositories for each component and platform.
versioned_repos = {}
for redist_json in redist_jsons:
components_mapping = {}
redist_version, platform_mapping = _redist_json_impl(module_ctx, redist_json)
redist_versions.append(redist_version)
for platform in platform_mapping.keys():
for component_name, repo_name in platform_mapping[platform].items():
redist_components_mapping[component_name] = repo_name

# Track the versioned repo name for this component/platform/version.
if component_name not in versioned_repos:
versioned_repos[component_name] = {}
if platform not in versioned_repos[component_name]:
versioned_repos[component_name][platform] = {}
versioned_repos[component_name][platform][redist_version] = repo_name + "_" + platform.replace("-", "_") + "_" + redist_version.replace(".", "_")

for component_name in redist_components_mapping.keys():
# Build dictionaries mapping versions to repo names for each platform.
x86_64_repos = {ver: versioned_repos[component_name]["linux-x86_64"][ver] for ver in redist_versions if "linux-x86_64" in versioned_repos[component_name] and ver in versioned_repos[component_name]["linux-x86_64"]}
aarch64_repos = {ver: versioned_repos[component_name]["linux-aarch64"][ver] for ver in redist_versions if "linux-aarch64" in versioned_repos[component_name] and ver in versioned_repos[component_name]["linux-aarch64"]}
sbsa_repos = {ver: versioned_repos[component_name]["linux-sbsa"][ver] for ver in redist_versions if "linux-sbsa" in versioned_repos[component_name] and ver in versioned_repos[component_name]["linux-sbsa"]}

platform_alias_repo(
name = redist_components_mapping[component_name],
repo_name = redist_components_mapping[component_name],
component_name = component_name,
linux_x86_64_repos = x86_64_repos,
linux_aarch64_repos = aarch64_repos,
linux_sbsa_repos = sbsa_repos,
versions = redist_versions,
)
components_mapping[component_name] = "@" + redist_components_mapping[component_name]
registrations = {}
for toolkit in toolkits:
if toolkit.name in registrations.keys():
Expand All @@ -148,7 +184,16 @@ def _impl(module_ctx):

for _, toolkit in registrations.items():
if components_mapping != None:
cuda_toolkit(name = toolkit.name, components_mapping = components_mapping, version = redist_version)
# Always use the maximum version so the toolkit includes all components.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not quite true if CTK delete some component in the future. I think a union across all CTK versions will be a little bit more robust.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could take some work, right now the version in cuda_toolkit isn't going to necessarily be correct since it's pointing to @cuda which can point to any number of versioned cuda repos, but I don't know if that gets used anywhere in the rules so I'll try removing it and see what falls out...

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic is pretty deeply embedded in the repository rules where I can't use the value of a flag. I might need to go back and add the ability to register multiple toolkits to get everything to work as expected...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets leave it for future improvement, just point it out :)

# Components that don't exist in older versions will fall back to dummy.
toolkit_version = redist_versions[0]
for ver in redist_versions:
ver_parts = [int(x) for x in ver.split(".")]
tv_parts = [int(x) for x in toolkit_version.split(".")]
if ver_parts > tv_parts:
toolkit_version = ver

cuda_toolkit(name = toolkit.name, components_mapping = components_mapping, version = toolkit_version)
else:
cuda_toolkit(**_module_tag_to_dict(toolkit))

Expand Down
160 changes: 160 additions & 0 deletions cuda/platform_alias_extension.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
"""Module extension for creating platform-specific aliases for external dependencies.

This extension creates repositories with alias targets that select between x86_64 and ARM64
repositories based on the build platform. It also selects between versions of the component.
"""

load("//cuda/private:platforms.bzl", "SUPPORTED_PLATFORMS")
load("//cuda/private:templates/registry.bzl", "REGISTRY")

# Use REGISTRY as the source of truth for component targets
TARGET_MAPPING = REGISTRY

def _platform_alias_repo_impl(ctx):
"""Implementation of the platform_alias_repo repository rule.

Args:
ctx: Repository context with attributes x86_repo, arm64_repo, and targets.
"""

# Generate BUILD.bazel content with platform-specific aliases
build_content = ["# Generated by platform_alias_repo rule", ""]

# Add load statement for alias
build_content.append('load("@rules_cuda//cuda:defs.bzl", "unsupported_cuda_version", "unsupported_cuda_platform")')
build_content.append("")

build_content.append("[")
build_content.append(" config_setting(")
build_content.append(' name = "version_is_{}".format(version.replace(".", "_")),')
build_content.append(' flag_values = {"@rules_cuda//cuda:version": "{}".format(version)},')
build_content.append(" )")
build_content.append(" for version in {}".format(ctx.attr.versions))
build_content.append("]")
build_content.append("")

build_content.append('unsupported_cuda_version(name = "unsupported_cuda_version", component = "{}", available_versions = {})'.format(ctx.attr.component_name, ctx.attr.versions))
build_content.append("")

# Build a target for the name of the repo (only if at least one platform is available).
platform_type = "exec" if ctx.attr.component_name in ["nvcc", "nvvm"] else "target"

# Check which platforms are available (have at least one version).
platforms_available = []
if len(ctx.attr.linux_x86_64_repos) > 0:
platforms_available.append("linux-x86_64")
if len(ctx.attr.linux_sbsa_repos) > 0:
platforms_available.append("linux-sbsa")
if len(ctx.attr.linux_aarch64_repos) > 0:
platforms_available.append("linux-aarch64")

# Always create unsupported_cuda_platform target - it's used as the default case
# in select() when no platform condition matches.
build_content.append('unsupported_cuda_platform(name = "unsupported_cuda_platform", component = "{}", available_platforms = {})'.format(ctx.attr.component_name, platforms_available))
build_content.append("")

# Only generate target aliases if this component is in TARGET_MAPPING.
if ctx.attr.component_name not in TARGET_MAPPING:
# Write the BUILD.bazel file with just the main alias.
ctx.file("BUILD.bazel", "\n".join(build_content))
return

for target in TARGET_MAPPING[ctx.attr.component_name]:
# Create alias for each target with platform selection.
# Always add conditions for ALL platforms so that builds on any platform
# have a matching select condition. Platforms where the component doesn't
# exist will use a dummy target.
target_name = target if target.find("/") == -1 else target.split("/")[-1]

# Determine appropriate dummy target based on the target name.
dummy_target = "@rules_cuda//cuda/dummy:dummy"
if target_name == "cicc":
dummy_target = "@rules_cuda//cuda/dummy:cicc"
elif target_name == "libdevice.10.bc":
dummy_target = "@rules_cuda//cuda/dummy:libdevice.10.bc"

build_content.append("alias(")
build_content.append(' name = "{}",'.format(target_name))
build_content.append(" actual = select({")
# Add conditions for ALL platforms, using dummy for unavailable ones.
for platform in SUPPORTED_PLATFORMS:
platform_suffix = platform.replace("-", "_")
build_content.append(' "@rules_cuda//cuda:{}_platform_is_{}":'.format(platform_type, platform_suffix))
if platform in platforms_available:
build_content.append(' ":{}_{}",'.format(platform_suffix, target_name))
else:
# Platform doesn't have this component, use dummy target.
build_content.append(' "{}",'.format(dummy_target))
build_content.append(' "//conditions:default": ":unsupported_cuda_platform",')
build_content.append(" }),")
build_content.append(' visibility = ["//visibility:public"],')
build_content.append(")")
build_content.append("")

# Generate platform-specific aliases for ALL platforms.
# Platforms where the component exists get version-based selection.
# Platforms where it doesn't exist get dummy targets for all versions.
# This ensures builds on any platform have matching select conditions.

platform_repos_map = {
"linux-x86_64": ctx.attr.linux_x86_64_repos,
"linux-sbsa": ctx.attr.linux_sbsa_repos,
"linux-aarch64": ctx.attr.linux_aarch64_repos,
}

for platform in SUPPORTED_PLATFORMS:
platform_suffix = platform.replace("-", "_")
repos_dict = platform_repos_map[platform]
platform_available = platform in platforms_available

build_content.append("alias(")
build_content.append(' name = "{}_{}",'.format(platform_suffix, target_name))
build_content.append(" actual = select({")

for version in ctx.attr.versions:
build_content.append(' ":version_is_{}": '.format(version.replace(".", "_")))
if platform_available and version in repos_dict:
repo_name = repos_dict[version]
build_content.append(' "@{}//{}",'.format(repo_name, target if target.find(":") != -1 else ":" + target))
else:
# Platform doesn't have this component for this version, use dummy.
build_content.append(' "{}",'.format(dummy_target))
build_content.append(' "//conditions:default": ":unsupported_cuda_version",')

build_content.append(" }),")
build_content.append(' visibility = ["//visibility:public"],')
build_content.append(")")
build_content.append("")

# Write the BUILD.bazel file
ctx.file("BUILD.bazel", "\n".join(build_content))

platform_alias_repo = repository_rule(
implementation = _platform_alias_repo_impl,
attrs = {
"repo_name": attr.string(
mandatory = True,
doc = "Original name of the repository",
),
"component_name": attr.string(
mandatory = True,
doc = "Name of the component",
),
"linux_x86_64_repos": attr.string_dict(
default = {},
doc = "Dictionary mapping versions to x86_64 repository names",
),
"linux_aarch64_repos": attr.string_dict(
default = {},
doc = "Dictionary mapping versions to ARM64/Jetpack repository names",
),
"linux_sbsa_repos": attr.string_dict(
default = {},
doc = "Dictionary mapping versions to SBSA repository names",
),
"versions": attr.string_list(
mandatory = True,
doc = "List of versions to create aliases for",
),
},
)
Loading
Loading