From 6f3d71654c072b51d68a86ddd5cbdc9aa1584874 Mon Sep 17 00:00:00 2001 From: Charley Saint Date: Thu, 4 Dec 2025 01:26:14 +0000 Subject: [PATCH 01/10] add multi-platform and multi-version support to rules_cuda --- cuda/BUILD.bazel | 106 ++++++++++ cuda/extensions.bzl | 103 ++++++++-- cuda/platform_alias_extension.bzl | 289 ++++++++++++++++++++++++++++ cuda/private/redist_json_helper.bzl | 15 +- cuda/private/repositories.bzl | 14 +- cuda/versions/BUILD.bazel | 100 ++++++++++ 6 files changed, 590 insertions(+), 37 deletions(-) create mode 100644 cuda/platform_alias_extension.bzl create mode 100644 cuda/versions/BUILD.bazel diff --git a/cuda/BUILD.bazel b/cuda/BUILD.bazel index e49fa85c..47832197 100644 --- a/cuda/BUILD.bazel +++ b/cuda/BUILD.bazel @@ -38,6 +38,112 @@ config_setting( flag_values = {"@cuda//:valid_toolchain_found": "True"}, ) +string_flag( + name = "version", + build_setting_default = "13.0.2", + values = [ + "12.0.0", + "12.0.1", + "12.1.0", + "12.1.1", + "12.2.0", + "12.2.1", + "12.2.2", + "12.3.0", + "12.3.1", + "12.3.2", + "12.4.0", + "12.4.1", + "12.5.0", + "12.5.1", + "12.6.0", + "12.6.1", + "12.6.2", + "12.6.3", + "12.8.0", + "12.8.1", + "12.9.0", + "12.9.1", + "13.0.0", + "13.0.1", + "13.0.2", + ], +) + +string_flag( + name = "runtime_platform", + build_setting_default = "linux-x86_64", + values = [ + "linux-x86_64", + "linux-sbsa", + "linux-aarch64", + "linux-ppc64le", + "windows-x86_64", + ], +) + +config_setting( + name = "runtime_platform_is_linux_x86_64", + flag_values = {":runtime_platform": "linux-x86_64"}, +) + +config_setting( + name = "runtime_platform_is_linux_aarch64", + flag_values = {":runtime_platform": "linux-aarch64"}, +) + +config_setting( + name = "runtime_platform_is_linux_ppc64le", + flag_values = {":runtime_platform": "linux-ppc64le"}, +) + +config_setting( + name = "runtime_platform_is_linux_sbsa", + flag_values = {":runtime_platform": "linux-sbsa"}, +) + +config_setting( + name = "runtime_platform_is_windows_x86_64", + flag_values = {":runtime_platform": "windows-x86_64"}, +) + +string_flag( + name = "nvcc_platform", + build_setting_default = "linux-x86_64", + values = [ + "linux-x86_64", + "linux-sbsa", + "linux-aarch64", + "linux-ppc64le", + "windows-x86_64", + ], +) + +config_setting( + name = "nvcc_platform_is_linux_x86_64", + flag_values = {":nvcc_platform": "linux-x86_64"}, +) + +config_setting( + name = "nvcc_platform_is_linux_aarch64", + flag_values = {":nvcc_platform": "linux-aarch64"}, +) + +config_setting( + name = "nvcc_platform_is_linux_ppc64le", + flag_values = {":nvcc_platform": "linux-ppc64le"}, +) + +config_setting( + name = "nvcc_platform_is_linux_sbsa", + flag_values = {":nvcc_platform": "linux-sbsa"}, +) + +config_setting( + name = "nvcc_platform_is_windows_x86_64", + flag_values = {":nvcc_platform": "windows-x86_64"}, +) + # Command line flag to specify the list of CUDA architectures to compile for. # # Provides CudaArchsInfo of the list of archs to build. diff --git a/cuda/extensions.bzl b/cuda/extensions.bzl index 4eb5211d..ea0c9a9a 100644 --- a/cuda/extensions.bzl +++ b/cuda/extensions.bzl @@ -2,6 +2,7 @@ load("//cuda/private:redist_json_helper.bzl", "redist_json_helper") load("//cuda/private:repositories.bzl", "cuda_component", "cuda_toolkit") +load("//cuda:platform_alias_extension.bzl", "platform_alias_repo") cuda_component_tag = tag_class(attrs = { "name": attr.string(mandatory = True, doc = "Repo name for the deliverable cuda_component"), @@ -53,6 +54,9 @@ cuda_redist_json_tag = tag_class(attrs = { "URLs are tried in order until one succeeds, so you should list local mirrors first. " + "If all downloads fail, the rule will fail.", ), + "platforms": attr.string_list( + doc = "A list of platforms to generate components for.", + ), "version": attr.string( doc = "Generate a URL by using the specified version." + "This URL will be tried after all URLs specified in the `urls` attribute.", @@ -74,6 +78,40 @@ cuda_toolkit_tag = tag_class(attrs = { ), }) +platform_alias_tag = tag_class( + attrs = { + "name": attr.string( + mandatory = True, + doc = "Name of the alias repository to create", + ), + "component_name": attr.string( + mandatory = True, + doc = "Name of the component to create aliases for", + ), + "linux_x86_64_repo": attr.string( + mandatory = True, + doc = "Name of the repository to use for x86_64 platform", + ), + "linux_aarch64_repo": attr.string( + mandatory = True, + doc = "Name of the repository to use for ARM64/Jetpack platform", + ), + "linux_sbsa_repo": attr.string( + mandatory = True, + doc = "Name of the repository to use for SBSA platform", + ), + "versions": attr.string_list( + mandatory = True, + doc = "List of versions to create aliases for", + ), + }, + doc = """Defines a platform-specific alias repository. + + Each alias tag creates a repository with targets that select between + x86_64 and ARM64 repositories based on the build platform. + """, +) + def _find_modules(module_ctx): root = None our_module = None @@ -95,18 +133,20 @@ def _module_tag_to_dict(t): def _redist_json_impl(module_ctx, attr): url, json_object = redist_json_helper.get(module_ctx, attr) redist_ver = redist_json_helper.get_redist_version(module_ctx, attr, json_object) - component_specs = redist_json_helper.collect_specs(module_ctx, attr, json_object, url) - - mapping = {} - for spec in component_specs: - repo_name = redist_json_helper.get_repo_name(module_ctx, spec) - mapping[spec["component_name"]] = "@" + repo_name - - attr = {key: value for key, value in spec.items()} - attr["name"] = repo_name - cuda_component(**attr) - return redist_ver, mapping + platform_mapping = {} + for platform in attr.platforms: + component_specs = redist_json_helper.collect_specs(module_ctx, attr, platform, json_object, url) + mapping = {} + for spec in component_specs: + repo_name = redist_json_helper.get_repo_name(module_ctx, spec) + mapping[spec["component_name"]] = repo_name + + component_attr = {key: value for key, value in spec.items()} + component_attr["name"] = repo_name + "_" + platform.replace("-", "_") + "_" + redist_ver.replace(".", "_") + cuda_component(**component_attr) + platform_mapping[platform] = mapping + return redist_ver, platform_mapping def _impl(module_ctx): # Toolchain configuration is only allowed in the root module, or in rules_cuda. root, rules_cuda = _find_modules(module_ctx) @@ -117,22 +157,38 @@ def _impl(module_ctx): components = root.tags.component redist_jsons = root.tags.redist_json toolkits = root.tags.toolkit + platform_aliases = root.tags.platform_alias else: components = rules_cuda.tags.component redist_jsons = rules_cuda.tags.redist_json toolkits = rules_cuda.tags.toolkit - + platform_aliases = rules_cuda.tags.platform_alias for component in components: cuda_component(**_module_tag_to_dict(component)) - if len(redist_jsons) > 1: - fail("Using multiple cuda.redist_json is not supported yet.") - redist_version = None components_mapping = None + redist_versions = [] + redist_components_mapping = {} for redist_json in redist_jsons: - redist_version, components_mapping = _redist_json_impl(module_ctx, redist_json) - + components_mapping = {} + redist_version, platform_mapping = _redist_json_impl(module_ctx, redist_json) + redist_versions.append(redist_version) + for platform in platform_mapping.keys(): + for component_name, repo_name in platform_mapping[platform].items(): + redist_components_mapping[component_name] = repo_name + + for component_name in redist_components_mapping.keys(): + platform_alias_repo( + name = redist_components_mapping[component_name], + repo_name = redist_components_mapping[component_name], + component_name = component_name, + linux_x86_64_repo = "cuda_" + component_name + "_linux_x86_64", + linux_aarch64_repo = "cuda_" + component_name + "_linux_aarch64", + linux_sbsa_repo = "cuda_" + component_name + "_linux_sbsa", + versions = redist_versions, + ) + components_mapping[component_name] = "@" + redist_components_mapping[component_name] registrations = {} for toolkit in toolkits: if toolkit.name in registrations.keys(): @@ -152,11 +208,24 @@ def _impl(module_ctx): else: cuda_toolkit(**_module_tag_to_dict(toolkit)) + for alias_tag in platform_aliases: + # Create a repository for each alias tag + platform_alias_repo( + name = alias_tag.name, + repo_name = alias_tag.name, + component_name = alias_tag.component_name, + linux_x86_64_repo = alias_tag.linux_x86_64_repo, + linux_aarch64_repo = alias_tag.linux_aarch64_repo, + linux_sbsa_repo = alias_tag.linux_sbsa_repo, + versions = alias_tag.versions, + ) + toolchain = module_extension( implementation = _impl, tag_classes = { "component": cuda_component_tag, "redist_json": cuda_redist_json_tag, "toolkit": cuda_toolkit_tag, + "platform_alias": platform_alias_tag, }, ) diff --git a/cuda/platform_alias_extension.bzl b/cuda/platform_alias_extension.bzl new file mode 100644 index 00000000..b3c05d14 --- /dev/null +++ b/cuda/platform_alias_extension.bzl @@ -0,0 +1,289 @@ +"""Module extension for creating platform-specific aliases for external dependencies. + +This extension creates repositories with alias targets that select between x86_64 and ARM64 +repositories based on the build platform. +""" + +TARGET_MAPPING = { + "cccl": [ + "cccl_all_files", + "cccl_header_files", + "cccl_headers", + "cub", + "thrust", + "valid_toolchain_found", + ], + "cudart": [ + "cuda", + "cuda_lib", + "cuda_runtime", + "cuda_runtime_static", + "cuda_so", + "cudadevrt_a", + "cudadevrt_lib", + "cudart_all_files", + "cudart_header_files", + "cudart_headers", + "cudart_lib", + "cudart_so", + "no_cuda_runtime", + "valid_toolchain_found", + ], + "nvcc": [ + "compiler_deps", + "compiler_root", + "nvcc_all_files", + "nvcc_header_files", + "nvcc_headers", + "nvptxcompiler", + "nvptxcompiler_lib", + "nvptxcompiler_so", + "valid_toolchain_found", + "nvcc/bin/nvcc", + "nvcc/bin/nvlink", + "nvcc/bin/crt/link.stub", + "nvcc/bin/bin2c", + "nvcc/bin/fatbinary", + "nvvm/bin/cicc", + ], + "cublas": [ + "cublas", + "cublasLt_lib", + "cublasLt_so", + "cublas_all_files", + "cublas_header_files", + "cublas_headers", + "cublas_lib", + "cublas_so", + "valid_toolchain_found", + ], + "cufft": [ + "cufft", + "cufft_all_files", + "cufft_header_files", + "cufft_headers", + "cufft_lib", + "cufft_so", + "cufft_static", + "cufft_static_a", + "cufft_static_nocallback", + "cufft_static_nocallback_a", + "cufftw_lib", + "cufftw_so", + "cufftw_static", + "cufftw_static_a", + "valid_toolchain_found", + ], + "cusolver": [ + "cusolver", + "cusolver_all_files", + "cusolver_header_files", + "cusolver_headers", + "cusolver_lib", + "cusolver_so", + "valid_toolchain_found", + ], + "cusparse": [ + "cusparse", + "cusparse_all_files", + "cusparse_header_files", + "cusparse_headers", + "cusparse_lib", + "cusparse_so", + "valid_toolchain_found", + ], + "npp": [ + "npp_all_files", + "npp_header_files", + "npp_headers", + "nppc", + "nppc_lib", + "nppc_so", + "nppi", + "nppial", + "nppial_lib", + "nppial_so", + "nppicc", + "nppicc_lib", + "nppicc_so", + "nppidei", + "nppidei_lib", + "nppidei_so", + "nppif", + "nppif_lib", + "nppif_so", + "nppig", + "nppig_lib", + "nppig_so", + "nppim", + "nppim_lib", + "nppim_so", + "nppist", + "nppist_lib", + "nppist_so", + "nppisu", + "nppisu_lib", + "nppisu_so", + "nppitc", + "nppitc_lib", + "nppitc_so", + "npps", + "npps_lib", + "npps_so", + "valid_toolchain_found", + ], + "nvrtc": [ + "nvrtc", + "nvrtc_all_files", + "nvrtc_header_files", + "nvrtc_headers", + "nvrtc_lib", + "nvrtc_so", + "valid_toolchain_found", + ], + "nvjitlink": [ + "nvJitLink", + "nvJitLink_lib", + "nvJitLink_so", + "nvJitLink_static", + "nvJitLink_static_a", + "nvjitlink", + "nvjitlink_all_files", + "nvjitlink_header_files", + "nvjitlink_headers", + "nvjitlink_static", + "valid_toolchain_found", + ], + "nvjpeg": [ + "nvjpeg", + "nvjpeg_all_files", + "nvjpeg_header_files", + "nvjpeg_headers", + "nvjpeg_lib", + "nvjpeg_so", + "nvjpeg_static", + "nvjpeg_static_a", + "valid_toolchain_found", + ], + "crt": [ + "crt_all_files", + "crt_header_files", + "crt_headers", + "valid_toolchain_found", + ], + "nvvm": [ + "nvvm_all_files", + "nvvm_header_files", + "nvvm_headers", + "valid_toolchain_found", + "nvvm/nvvm/bin/cicc", + "nvvm/nvvm/libdevice/libdevice.10.bc", + ], + "culibos": [ + "culibos_all_files", + "culibos_license", + "culibos_a", + ], +} + + +def _platform_alias_repo_impl(ctx): + """Implementation of the platform_alias_repo repository rule. + + Args: + ctx: Repository context with attributes x86_repo, arm64_repo, and targets. + """ + # Generate BUILD.bazel content with platform-specific aliases + build_content = ['# Generated by platform_alias_repo rule', ''] + + # Add load statement for alias + build_content.append('load("@bazel_skylib//lib:selects.bzl", "selects")') + build_content.append('') + + # Build a target for the name of the repo + build_content.append('alias(') + build_content.append(' name = "{}",'.format(ctx.attr.repo_name)) + build_content.append(' actual = select({') + build_content.append(' "@rules_cuda//cuda:{}_platform_is_linux_x86_64":'.format("nvcc" if ctx.attr.component_name in ["nvcc", "nvvm"] else "runtime")) + build_content.append(' ":linux_x86_64_{}",'.format(ctx.attr.repo_name)) + build_content.append(' "@rules_cuda//cuda:{}_platform_is_linux_sbsa":'.format("nvcc" if ctx.attr.component_name in ["nvcc", "nvvm"] else "runtime")) + build_content.append(' ":linux_sbsa_{}",'.format(ctx.attr.repo_name)) + build_content.append(' "@rules_cuda//cuda:{}_platform_is_linux_aarch64":'.format("nvcc" if ctx.attr.component_name in ["nvcc", "nvvm"] else "runtime")) + build_content.append(' ":linux_aarch64_{}",'.format(ctx.attr.repo_name)) + build_content.append(' }),') + build_content.append(' visibility = ["//visibility:public"],') + build_content.append(')') + build_content.append('') + + build_content.append('alias(') + build_content.append(' name = "linux_x86_64_{}",'.format(ctx.attr.repo_name)) + build_content.append(' actual = select({') + for version in ctx.attr.versions: + build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' "@{}_{}//:{}",'.format(ctx.attr.linux_x86_64_repo, version.replace(".", "_"), ctx.attr.repo_name)) + build_content.append(' }),') + build_content.append(' visibility = ["//visibility:public"],') + build_content.append(')') + build_content.append('') + + for target in TARGET_MAPPING[ctx.attr.component_name]: + # Create alias for each target with platform selection + target_name = target if target.find("/") == -1 else target.split("/")[-1] + build_content.append('alias(') + build_content.append(' name = "{}",'.format(target_name)) + build_content.append(' actual = select({') + build_content.append(' "@rules_cuda//cuda:{}_platform_is_linux_x86_64":'.format("nvcc" if ctx.attr.component_name in ["nvcc", "nvvm"] else "runtime")) + build_content.append(' ":linux_x86_64_{}",'.format(target_name)) + build_content.append(' "@rules_cuda//cuda:{}_platform_is_linux_sbsa":'.format("nvcc" if ctx.attr.component_name in ["nvcc", "nvvm"] else "runtime")) + build_content.append(' ":linux_sbsa_{}",'.format(target_name)) + build_content.append(' "@rules_cuda//cuda:{}_platform_is_linux_aarch64":'.format("nvcc" if ctx.attr.component_name in ["nvcc", "nvvm"] else "runtime")) + build_content.append(' ":linux_aarch64_{}",'.format(target_name)) + build_content.append(' }),') + build_content.append(' visibility = ["//visibility:public"],') + build_content.append(')') + build_content.append('') + + build_content.append('alias(') + build_content.append(' name = "linux_x86_64_{}",'.format(target_name)) + build_content.append(' actual = select({') + for version in ctx.attr.versions: + build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' "@{}_{}//:{}",'.format(ctx.attr.linux_x86_64_repo, version.replace(".", "_"), target_name)) + build_content.append(' }),') + build_content.append(' visibility = ["//visibility:public"],') + build_content.append(')') + build_content.append('') + + # Write the BUILD.bazel file + ctx.file("BUILD.bazel", "\n".join(build_content)) + + +platform_alias_repo = repository_rule( + implementation = _platform_alias_repo_impl, + attrs = { + "repo_name": attr.string( + mandatory = True, + doc = "Original name of the repository", + ), + "component_name": attr.string( + mandatory = True, + doc = "Name of the component", + ), + "linux_x86_64_repo": attr.string( + mandatory = True, + doc = "Name of the repository to use for x86_64 platform", + ), + "linux_aarch64_repo": attr.string( + mandatory = True, + doc = "Name of the repository to use for ARM64/Jetpack platform", + ), + "linux_sbsa_repo": attr.string( + mandatory = True, + doc = "Name of the repository to use for SBSA platform", + ), + "versions": attr.string_list( + mandatory = True, + doc = "List of versions to create aliases for", + ), + }, +) diff --git a/cuda/private/redist_json_helper.bzl b/cuda/private/redist_json_helper.bzl index decdc48d..5e5dd813 100644 --- a/cuda/private/redist_json_helper.bzl +++ b/cuda/private/redist_json_helper.bzl @@ -58,7 +58,7 @@ def _get_redist_version(ctx, attr, redist): return redist_ver -def _collect_specs(ctx, attr, redist, the_url): +def _collect_specs(ctx, attr, platform, redist, the_url): """Convert redistrib_.json content to the specs for instantiating cuda_component repos. List of specs, aka, list of dicts with cuda_component attrs. @@ -66,19 +66,12 @@ def _collect_specs(ctx, attr, redist, the_url): Args: ctx: repository_ctx | module_ctx attr: cuda_component repo attr or cuda_redist_json_tag + platform: string, the platform for which we are collecting specs redist: json object, read from the redistrib_.json file. the_url: string, the very unique url from which we get the redistrib_.json file. """ specs = [] - os = None - if _is_linux(ctx): - os = "linux" - elif _is_windows(ctx): - os = "windows" - - arch = "x86_64" # TODO: support cross compiling - platform = "{os}-{arch}".format(os = os, arch = arch) all_components_on_platform = [k for k, v in FULL_COMPONENT_NAME.items() if v in redist and platform in redist[v]] components = attr.components if attr.components else all_components_on_platform @@ -111,10 +104,6 @@ def _get_repo_name(ctx, spec): """ repo_name = "cuda_" + spec["component_name"] - version = spec.get("version", None) - if version != None: - repo_name = repo_name + "_v" + version - return repo_name redist_json_helper = struct( diff --git a/cuda/private/repositories.bzl b/cuda/private/repositories.bzl index 04d0bdd6..f10de761 100644 --- a/cuda/private/repositories.bzl +++ b/cuda/private/repositories.bzl @@ -111,18 +111,18 @@ def _detect_deliverable_cuda_toolkit(repository_ctx): nvcc_repo = repository_ctx.attr.components_mapping["nvcc"] bin_ext = ".exe" if _is_windows(repository_ctx) else "" - nvcc = "{}//:nvcc/bin/nvcc{}".format(nvcc_repo, bin_ext) - nvlink = "{}//:nvcc/bin/nvlink{}".format(nvcc_repo, bin_ext) - link_stub = "{}//:nvcc/bin/crt/link.stub".format(nvcc_repo) - bin2c = "{}//:nvcc/bin/bin2c{}".format(nvcc_repo, bin_ext) - fatbinary = "{}//:nvcc/bin/fatbinary{}".format(nvcc_repo, bin_ext) + nvcc = "{}//:nvcc{}".format(nvcc_repo, bin_ext) + nvlink = "{}//:nvlink{}".format(nvcc_repo, bin_ext) + link_stub = "{}//:link.stub".format(nvcc_repo) + bin2c = "{}//:bin2c{}".format(nvcc_repo, bin_ext) + fatbinary = "{}//:fatbinary{}".format(nvcc_repo, bin_ext) cicc = None libdevice = None if int(cuda_version_major) >= 13: nvvm_repo = repository_ctx.attr.components_mapping["nvvm"] - cicc = "{}//:nvvm/nvvm/bin/cicc{}".format(nvvm_repo, bin_ext) # TODO: can we use @cuda//:cicc? - libdevice = "{}//:nvvm/nvvm/libdevice/libdevice.10.bc".format(nvvm_repo) # TODO: can we use @cuda//:libdevice? + cicc = "{}//:cicc{}".format(nvvm_repo, bin_ext) # TODO: can we use @cuda//:cicc? + libdevice = "{}//:libdevice.10.bc".format(nvvm_repo) # TODO: can we use @cuda//:libdevice? return struct( path = None, # scattered components diff --git a/cuda/versions/BUILD.bazel b/cuda/versions/BUILD.bazel new file mode 100644 index 00000000..76e048e4 --- /dev/null +++ b/cuda/versions/BUILD.bazel @@ -0,0 +1,100 @@ +config_setting( + name = "version_is_12_0_0", + flag_values = {"@rules_cuda//cuda:version": "12.0.0"}, +) +config_setting( + name = "version_is_12_0_1", + flag_values = {"@rules_cuda//cuda:version": "12.0.1"}, +) +config_setting( + name = "version_is_12_1_0", + flag_values = {"@rules_cuda//cuda:version": "12.1.0"}, +) +config_setting( + name = "version_is_12_1_1", + flag_values = {"@rules_cuda//cuda:version": "12.1.1"}, +) +config_setting( + name = "version_is_12_2_0", + flag_values = {"@rules_cuda//cuda:version": "12.2.0"}, +) +config_setting( + name = "version_is_12_2_1", + flag_values = {"@rules_cuda//cuda:version": "12.2.1"}, +) +config_setting( + name = "version_is_12_2_2", + flag_values = {"@rules_cuda//cuda:version": "12.2.2"}, +) +config_setting( + name = "version_is_12_3_0", + flag_values = {"@rules_cuda//cuda:version": "12.3.0"}, +) +config_setting( + name = "version_is_12_3_1", + flag_values = {"@rules_cuda//cuda:version": "12.3.1"}, +) +config_setting( + name = "version_is_12_3_2", + flag_values = {"@rules_cuda//cuda:version": "12.3.2"}, +) +config_setting( + name = "version_is_12_4_0", + flag_values = {"@rules_cuda//cuda:version": "12.4.0"}, +) +config_setting( + name = "version_is_12_4_1", + flag_values = {"@rules_cuda//cuda:version": "12.4.1"}, +) +config_setting( + name = "version_is_12_5_0", + flag_values = {"@rules_cuda//cuda:version": "12.5.0"}, +) +config_setting( + name = "version_is_12_5_1", + flag_values = {"@rules_cuda//cuda:version": "12.5.1"}, +) +config_setting( + name = "version_is_12_6_0", + flag_values = {"@rules_cuda//cuda:version": "12.6.0"}, +) +config_setting( + name = "version_is_12_6_1", + flag_values = {"@rules_cuda//cuda:version": "12.6.1"}, +) +config_setting( + name = "version_is_12_6_2", + flag_values = {"@rules_cuda//cuda:version": "12.6.2"}, +) +config_setting( + name = "version_is_12_6_3", + flag_values = {"@rules_cuda//cuda:version": "12.6.3"}, +) +config_setting( + name = "version_is_12_8_0", + flag_values = {"@rules_cuda//cuda:version": "12.8.0"}, +) +config_setting( + name = "version_is_12_8_1", + flag_values = {"@rules_cuda//cuda:version": "12.8.1"}, +) +config_setting( + name = "version_is_12_9_0", + flag_values = {"@rules_cuda//cuda:version": "12.9.0"}, +) +config_setting( + name = "version_is_12_9_1", + flag_values = {"@rules_cuda//cuda:version": "12.9.1"}, +) +config_setting( + name = "version_is_13_0_0", + flag_values = {"@rules_cuda//cuda:version": "13.0.0"}, +) +config_setting( + name = "version_is_13_0_1", + flag_values = {"@rules_cuda//cuda:version": "13.0.1"}, +) +config_setting( + name = "version_is_13_0_2", + flag_values = {"@rules_cuda//cuda:version": "13.0.2"}, +) From 235a94aa978a80887de4d5690693724fea6a8c47 Mon Sep 17 00:00:00 2001 From: Charley Saint Date: Thu, 4 Dec 2025 01:40:01 +0000 Subject: [PATCH 02/10] fix alias targets that have a path --- cuda/platform_alias_extension.bzl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cuda/platform_alias_extension.bzl b/cuda/platform_alias_extension.bzl index b3c05d14..79760423 100644 --- a/cuda/platform_alias_extension.bzl +++ b/cuda/platform_alias_extension.bzl @@ -248,7 +248,7 @@ def _platform_alias_repo_impl(ctx): build_content.append(' actual = select({') for version in ctx.attr.versions: build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) - build_content.append(' "@{}_{}//:{}",'.format(ctx.attr.linux_x86_64_repo, version.replace(".", "_"), target_name)) + build_content.append(' "@{}_{}//{}",'.format(ctx.attr.linux_x86_64_repo, version.replace(".", "_"), target if target.find(":") != -1 else ":" + target)) build_content.append(' }),') build_content.append(' visibility = ["//visibility:public"],') build_content.append(')') From d6a747849c9ac4a808d7951b2cd8f7aa3e1b3658 Mon Sep 17 00:00:00 2001 From: Lionel Gulich Date: Thu, 4 Dec 2025 11:49:37 +0100 Subject: [PATCH 03/10] Enable multi-version builds --- cuda/dummy/BUILD.bazel | 22 ++ cuda/extensions.bzl | 63 ++-- cuda/platform_alias_extension.bzl | 506 +++++++++++++++------------- cuda/private/redist_json_helper.bzl | 16 +- cuda/private/template_helper.bzl | 9 +- cuda/private/templates/BUILD.cufile | 9 + cuda/private/templates/BUILD.nvcc | 35 ++ cuda/private/templates/BUILD.nvvm | 23 ++ cuda/private/templates/registry.bzl | 8 +- 9 files changed, 420 insertions(+), 271 deletions(-) diff --git a/cuda/dummy/BUILD.bazel b/cuda/dummy/BUILD.bazel index c474f707..085309f8 100644 --- a/cuda/dummy/BUILD.bazel +++ b/cuda/dummy/BUILD.bazel @@ -25,3 +25,25 @@ cc_binary( srcs = ["dummy.cpp"], defines = ["TOOLNAME=fatbinary"], ) + +# Empty cc_library that provides CcInfo for components not available in this CUDA version. +cc_library( + name = "dummy", + srcs = [], + hdrs = [], +) + +# Dummy executable for cicc that doesn't exist in this version. +genrule( + name = "cicc_gen", + outs = ["cicc"], + cmd = "echo '#!/bin/bash' > $@ && echo 'echo Error: cicc not available in this CUDA version' >> $@ && chmod +x $@", + executable = True, +) + +# Dummy bitcode for libdevice that doesn't exist in this version. +genrule( + name = "libdevice_gen", + outs = ["libdevice.10.bc"], + cmd = "echo '; Placeholder bitcode' > $@", +) diff --git a/cuda/extensions.bzl b/cuda/extensions.bzl index ea0c9a9a..d6591f8a 100644 --- a/cuda/extensions.bzl +++ b/cuda/extensions.bzl @@ -147,6 +147,7 @@ def _redist_json_impl(module_ctx, attr): cuda_component(**component_attr) platform_mapping[platform] = mapping return redist_ver, platform_mapping + def _impl(module_ctx): # Toolchain configuration is only allowed in the root module, or in rules_cuda. root, rules_cuda = _find_modules(module_ctx) @@ -170,6 +171,9 @@ def _impl(module_ctx): components_mapping = None redist_versions = [] redist_components_mapping = {} + + # Track all versioned repositories for each component and platform. + versioned_repos = {} for redist_json in redist_jsons: components_mapping = {} redist_version, platform_mapping = _redist_json_impl(module_ctx, redist_json) @@ -178,16 +182,28 @@ def _impl(module_ctx): for component_name, repo_name in platform_mapping[platform].items(): redist_components_mapping[component_name] = repo_name + # Track the versioned repo name for this component/platform/version. + if component_name not in versioned_repos: + versioned_repos[component_name] = {} + if platform not in versioned_repos[component_name]: + versioned_repos[component_name][platform] = {} + versioned_repos[component_name][platform][redist_version] = repo_name + "_" + platform.replace("-", "_") + "_" + redist_version.replace(".", "_") + for component_name in redist_components_mapping.keys(): + # Build dictionaries mapping versions to repo names for each platform. + x86_64_repos = {ver: versioned_repos[component_name]["linux-x86_64"][ver] for ver in redist_versions if "linux-x86_64" in versioned_repos[component_name] and ver in versioned_repos[component_name]["linux-x86_64"]} + aarch64_repos = {ver: versioned_repos[component_name]["linux-aarch64"][ver] for ver in redist_versions if "linux-aarch64" in versioned_repos[component_name] and ver in versioned_repos[component_name]["linux-aarch64"]} + sbsa_repos = {ver: versioned_repos[component_name]["linux-sbsa"][ver] for ver in redist_versions if "linux-sbsa" in versioned_repos[component_name] and ver in versioned_repos[component_name]["linux-sbsa"]} + platform_alias_repo( - name = redist_components_mapping[component_name], - repo_name = redist_components_mapping[component_name], - component_name = component_name, - linux_x86_64_repo = "cuda_" + component_name + "_linux_x86_64", - linux_aarch64_repo = "cuda_" + component_name + "_linux_aarch64", - linux_sbsa_repo = "cuda_" + component_name + "_linux_sbsa", - versions = redist_versions, - ) + name = redist_components_mapping[component_name], + repo_name = redist_components_mapping[component_name], + component_name = component_name, + linux_x86_64_repos = x86_64_repos, + linux_aarch64_repos = aarch64_repos, + linux_sbsa_repos = sbsa_repos, + versions = redist_versions, + ) components_mapping[component_name] = "@" + redist_components_mapping[component_name] registrations = {} for toolkit in toolkits: @@ -204,21 +220,30 @@ def _impl(module_ctx): for _, toolkit in registrations.items(): if components_mapping != None: - cuda_toolkit(name = toolkit.name, components_mapping = components_mapping, version = redist_version) + # Always use the maximum version so the toolkit includes all components. + # Components that don't exist in older versions will fall back to dummy. + toolkit_version = redist_versions[0] + for ver in redist_versions: + ver_parts = [int(x) for x in ver.split(".")] + tv_parts = [int(x) for x in toolkit_version.split(".")] + if ver_parts > tv_parts: + toolkit_version = ver + + cuda_toolkit(name = toolkit.name, components_mapping = components_mapping, version = toolkit_version) else: cuda_toolkit(**_module_tag_to_dict(toolkit)) for alias_tag in platform_aliases: - # Create a repository for each alias tag - platform_alias_repo( - name = alias_tag.name, - repo_name = alias_tag.name, - component_name = alias_tag.component_name, - linux_x86_64_repo = alias_tag.linux_x86_64_repo, - linux_aarch64_repo = alias_tag.linux_aarch64_repo, - linux_sbsa_repo = alias_tag.linux_sbsa_repo, - versions = alias_tag.versions, - ) + # Create a repository for each alias tag + platform_alias_repo( + name = alias_tag.name, + repo_name = alias_tag.name, + component_name = alias_tag.component_name, + linux_x86_64_repo = alias_tag.linux_x86_64_repo, + linux_aarch64_repo = alias_tag.linux_aarch64_repo, + linux_sbsa_repo = alias_tag.linux_sbsa_repo, + versions = alias_tag.versions, + ) toolchain = module_extension( implementation = _impl, diff --git a/cuda/platform_alias_extension.bzl b/cuda/platform_alias_extension.bzl index 79760423..ba2b1a36 100644 --- a/cuda/platform_alias_extension.bzl +++ b/cuda/platform_alias_extension.bzl @@ -4,260 +4,266 @@ This extension creates repositories with alias targets that select between x86_6 repositories based on the build platform. """ -TARGET_MAPPING = { - "cccl": [ - "cccl_all_files", - "cccl_header_files", - "cccl_headers", - "cub", - "thrust", - "valid_toolchain_found", - ], - "cudart": [ - "cuda", - "cuda_lib", - "cuda_runtime", - "cuda_runtime_static", - "cuda_so", - "cudadevrt_a", - "cudadevrt_lib", - "cudart_all_files", - "cudart_header_files", - "cudart_headers", - "cudart_lib", - "cudart_so", - "no_cuda_runtime", - "valid_toolchain_found", - ], - "nvcc": [ - "compiler_deps", - "compiler_root", - "nvcc_all_files", - "nvcc_header_files", - "nvcc_headers", - "nvptxcompiler", - "nvptxcompiler_lib", - "nvptxcompiler_so", - "valid_toolchain_found", - "nvcc/bin/nvcc", - "nvcc/bin/nvlink", - "nvcc/bin/crt/link.stub", - "nvcc/bin/bin2c", - "nvcc/bin/fatbinary", - "nvvm/bin/cicc", - ], - "cublas": [ - "cublas", - "cublasLt_lib", - "cublasLt_so", - "cublas_all_files", - "cublas_header_files", - "cublas_headers", - "cublas_lib", - "cublas_so", - "valid_toolchain_found", - ], - "cufft": [ - "cufft", - "cufft_all_files", - "cufft_header_files", - "cufft_headers", - "cufft_lib", - "cufft_so", - "cufft_static", - "cufft_static_a", - "cufft_static_nocallback", - "cufft_static_nocallback_a", - "cufftw_lib", - "cufftw_so", - "cufftw_static", - "cufftw_static_a", - "valid_toolchain_found", - ], - "cusolver": [ - "cusolver", - "cusolver_all_files", - "cusolver_header_files", - "cusolver_headers", - "cusolver_lib", - "cusolver_so", - "valid_toolchain_found", - ], - "cusparse": [ - "cusparse", - "cusparse_all_files", - "cusparse_header_files", - "cusparse_headers", - "cusparse_lib", - "cusparse_so", - "valid_toolchain_found", - ], - "npp": [ - "npp_all_files", - "npp_header_files", - "npp_headers", - "nppc", - "nppc_lib", - "nppc_so", - "nppi", - "nppial", - "nppial_lib", - "nppial_so", - "nppicc", - "nppicc_lib", - "nppicc_so", - "nppidei", - "nppidei_lib", - "nppidei_so", - "nppif", - "nppif_lib", - "nppif_so", - "nppig", - "nppig_lib", - "nppig_so", - "nppim", - "nppim_lib", - "nppim_so", - "nppist", - "nppist_lib", - "nppist_so", - "nppisu", - "nppisu_lib", - "nppisu_so", - "nppitc", - "nppitc_lib", - "nppitc_so", - "npps", - "npps_lib", - "npps_so", - "valid_toolchain_found", - ], - "nvrtc": [ - "nvrtc", - "nvrtc_all_files", - "nvrtc_header_files", - "nvrtc_headers", - "nvrtc_lib", - "nvrtc_so", - "valid_toolchain_found", - ], - "nvjitlink": [ - "nvJitLink", - "nvJitLink_lib", - "nvJitLink_so", - "nvJitLink_static", - "nvJitLink_static_a", - "nvjitlink", - "nvjitlink_all_files", - "nvjitlink_header_files", - "nvjitlink_headers", - "nvjitlink_static", - "valid_toolchain_found", - ], - "nvjpeg": [ - "nvjpeg", - "nvjpeg_all_files", - "nvjpeg_header_files", - "nvjpeg_headers", - "nvjpeg_lib", - "nvjpeg_so", - "nvjpeg_static", - "nvjpeg_static_a", - "valid_toolchain_found", - ], - "crt": [ - "crt_all_files", - "crt_header_files", - "crt_headers", - "valid_toolchain_found", - ], - "nvvm": [ - "nvvm_all_files", - "nvvm_header_files", - "nvvm_headers", - "valid_toolchain_found", - "nvvm/nvvm/bin/cicc", - "nvvm/nvvm/libdevice/libdevice.10.bc", - ], - "culibos": [ - "culibos_all_files", - "culibos_license", - "culibos_a", - ], -} +load("//cuda/private:templates/registry.bzl", "REGISTRY") + +# Use REGISTRY as the source of truth for component targets +TARGET_MAPPING = REGISTRY def _platform_alias_repo_impl(ctx): """Implementation of the platform_alias_repo repository rule. - + Args: ctx: Repository context with attributes x86_repo, arm64_repo, and targets. """ + # Generate BUILD.bazel content with platform-specific aliases - build_content = ['# Generated by platform_alias_repo rule', ''] - + build_content = ["# Generated by platform_alias_repo rule", ""] + # Add load statement for alias build_content.append('load("@bazel_skylib//lib:selects.bzl", "selects")') - build_content.append('') - - # Build a target for the name of the repo - build_content.append('alias(') - build_content.append(' name = "{}",'.format(ctx.attr.repo_name)) - build_content.append(' actual = select({') - build_content.append(' "@rules_cuda//cuda:{}_platform_is_linux_x86_64":'.format("nvcc" if ctx.attr.component_name in ["nvcc", "nvvm"] else "runtime")) - build_content.append(' ":linux_x86_64_{}",'.format(ctx.attr.repo_name)) - build_content.append(' "@rules_cuda//cuda:{}_platform_is_linux_sbsa":'.format("nvcc" if ctx.attr.component_name in ["nvcc", "nvvm"] else "runtime")) - build_content.append(' ":linux_sbsa_{}",'.format(ctx.attr.repo_name)) - build_content.append(' "@rules_cuda//cuda:{}_platform_is_linux_aarch64":'.format("nvcc" if ctx.attr.component_name in ["nvcc", "nvvm"] else "runtime")) - build_content.append(' ":linux_aarch64_{}",'.format(ctx.attr.repo_name)) - build_content.append(' }),') - build_content.append(' visibility = ["//visibility:public"],') - build_content.append(')') - build_content.append('') - - build_content.append('alias(') - build_content.append(' name = "linux_x86_64_{}",'.format(ctx.attr.repo_name)) - build_content.append(' actual = select({') - for version in ctx.attr.versions: - build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) - build_content.append(' "@{}_{}//:{}",'.format(ctx.attr.linux_x86_64_repo, version.replace(".", "_"), ctx.attr.repo_name)) - build_content.append(' }),') - build_content.append(' visibility = ["//visibility:public"],') - build_content.append(')') - build_content.append('') + build_content.append("") - for target in TARGET_MAPPING[ctx.attr.component_name]: - # Create alias for each target with platform selection - target_name = target if target.find("/") == -1 else target.split("/")[-1] - build_content.append('alias(') - build_content.append(' name = "{}",'.format(target_name)) - build_content.append(' actual = select({') - build_content.append(' "@rules_cuda//cuda:{}_platform_is_linux_x86_64":'.format("nvcc" if ctx.attr.component_name in ["nvcc", "nvvm"] else "runtime")) - build_content.append(' ":linux_x86_64_{}",'.format(target_name)) - build_content.append(' "@rules_cuda//cuda:{}_platform_is_linux_sbsa":'.format("nvcc" if ctx.attr.component_name in ["nvcc", "nvvm"] else "runtime")) - build_content.append(' ":linux_sbsa_{}",'.format(target_name)) - build_content.append(' "@rules_cuda//cuda:{}_platform_is_linux_aarch64":'.format("nvcc" if ctx.attr.component_name in ["nvcc", "nvvm"] else "runtime")) - build_content.append(' ":linux_aarch64_{}",'.format(target_name)) - build_content.append(' }),') + # Check if using new dict-based API or old string-based API. + use_dict_api = len(ctx.attr.linux_x86_64_repos) > 0 + + # Build a target for the name of the repo (only if at least one platform is available). + platform_type = "nvcc" if ctx.attr.component_name in ["nvcc", "nvvm"] else "runtime" + + # Check which platforms are available. + platforms_available = [] + if use_dict_api: + if len(ctx.attr.linux_x86_64_repos) > 0: + platforms_available.append("x86_64") + if len(ctx.attr.linux_sbsa_repos) > 0: + platforms_available.append("sbsa") + if len(ctx.attr.linux_aarch64_repos) > 0: + platforms_available.append("aarch64") + else: + if ctx.attr.linux_x86_64_repo != "": + platforms_available.append("x86_64") + if ctx.attr.linux_sbsa_repo != "": + platforms_available.append("sbsa") + if ctx.attr.linux_aarch64_repo != "": + platforms_available.append("aarch64") + + if len(platforms_available) > 0: + build_content.append("alias(") + build_content.append(' name = "{}",'.format(ctx.attr.repo_name)) + build_content.append(" actual = select({") + if "x86_64" in platforms_available: + build_content.append(' "@rules_cuda//cuda:{}_platform_is_linux_x86_64":'.format(platform_type)) + build_content.append(' ":linux_x86_64_{}",'.format(ctx.attr.repo_name)) + if "sbsa" in platforms_available: + build_content.append(' "@rules_cuda//cuda:{}_platform_is_linux_sbsa":'.format(platform_type)) + build_content.append(' ":linux_sbsa_{}",'.format(ctx.attr.repo_name)) + if "aarch64" in platforms_available: + build_content.append(' "@rules_cuda//cuda:{}_platform_is_linux_aarch64":'.format(platform_type)) + build_content.append(' ":linux_aarch64_{}",'.format(ctx.attr.repo_name)) + build_content.append(" }),") + build_content.append(' visibility = ["//visibility:public"],') + build_content.append(")") + build_content.append("") + + # Generate aliases for each platform (only if versions exist for that platform). + # Check if we have any versions for x86_64. + has_x86_64 = False + if use_dict_api: + has_x86_64 = len(ctx.attr.linux_x86_64_repos) > 0 + else: + has_x86_64 = ctx.attr.linux_x86_64_repo != "" + + if has_x86_64: + build_content.append("alias(") + build_content.append(' name = "linux_x86_64_{}",'.format(ctx.attr.repo_name)) + build_content.append(" actual = select({") + if use_dict_api: + for version in ctx.attr.versions: + if version in ctx.attr.linux_x86_64_repos: + repo_name = ctx.attr.linux_x86_64_repos[version] + build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' "@{}//:{}",'.format(repo_name, ctx.attr.repo_name)) + + # Add default for versions where this component doesn't exist. + build_content.append(' "//conditions:default": "@rules_cuda//cuda/dummy:dummy",') + else: + # Old string-based API for backward compatibility. + for version in ctx.attr.versions: + build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' "@{}_{}//:{}",'.format(ctx.attr.linux_x86_64_repo, version.replace(".", "_"), ctx.attr.repo_name)) + build_content.append(" }),") + build_content.append(' visibility = ["//visibility:public"],') + build_content.append(")") + build_content.append("") + + # Check if we have any versions for sbsa. + has_sbsa = False + if use_dict_api: + has_sbsa = len(ctx.attr.linux_sbsa_repos) > 0 + else: + has_sbsa = ctx.attr.linux_sbsa_repo != "" + + if has_sbsa: + build_content.append("alias(") + build_content.append(' name = "linux_sbsa_{}",'.format(ctx.attr.repo_name)) + build_content.append(" actual = select({") + if use_dict_api: + for version in ctx.attr.versions: + if version in ctx.attr.linux_sbsa_repos: + repo_name = ctx.attr.linux_sbsa_repos[version] + build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' "@{}//:{}",'.format(repo_name, ctx.attr.repo_name)) + + # Add default for versions where this component doesn't exist. + build_content.append(' "//conditions:default": "@rules_cuda//cuda/dummy:dummy",') + else: + # Old string-based API for backward compatibility. + for version in ctx.attr.versions: + build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' "@{}_{}//:{}",'.format(ctx.attr.linux_sbsa_repo, version.replace(".", "_"), ctx.attr.repo_name)) + build_content.append(" }),") build_content.append(' visibility = ["//visibility:public"],') - build_content.append(')') - build_content.append('') - - build_content.append('alias(') - build_content.append(' name = "linux_x86_64_{}",'.format(target_name)) - build_content.append(' actual = select({') - for version in ctx.attr.versions: - build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) - build_content.append(' "@{}_{}//{}",'.format(ctx.attr.linux_x86_64_repo, version.replace(".", "_"), target if target.find(":") != -1 else ":" + target)) - build_content.append(' }),') + build_content.append(")") + build_content.append("") + + # Check if we have any versions for aarch64. + has_aarch64 = False + if use_dict_api: + has_aarch64 = len(ctx.attr.linux_aarch64_repos) > 0 + else: + has_aarch64 = ctx.attr.linux_aarch64_repo != "" + + if has_aarch64: + build_content.append("alias(") + build_content.append(' name = "linux_aarch64_{}",'.format(ctx.attr.repo_name)) + build_content.append(" actual = select({") + if use_dict_api: + for version in ctx.attr.versions: + if version in ctx.attr.linux_aarch64_repos: + repo_name = ctx.attr.linux_aarch64_repos[version] + build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' "@{}//:{}",'.format(repo_name, ctx.attr.repo_name)) + + # Add default for versions where this component doesn't exist. + build_content.append(' "//conditions:default": "@rules_cuda//cuda/dummy:dummy",') + else: + # Old string-based API for backward compatibility. + for version in ctx.attr.versions: + build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' "@{}_{}//:{}",'.format(ctx.attr.linux_aarch64_repo, version.replace(".", "_"), ctx.attr.repo_name)) + build_content.append(" }),") build_content.append(' visibility = ["//visibility:public"],') - build_content.append(')') - build_content.append('') + build_content.append(")") + build_content.append("") + + # Only generate target aliases if this component is in TARGET_MAPPING. + if ctx.attr.component_name not in TARGET_MAPPING: + # Write the BUILD.bazel file with just the main alias. + ctx.file("BUILD.bazel", "\n".join(build_content)) + return + + for target in TARGET_MAPPING[ctx.attr.component_name]: + # Create alias for each target with platform selection (only if platforms are available). + target_name = target if target.find("/") == -1 else target.split("/")[-1] + + if len(platforms_available) > 0: + build_content.append("alias(") + build_content.append(' name = "{}",'.format(target_name)) + build_content.append(" actual = select({") + if "x86_64" in platforms_available: + build_content.append(' "@rules_cuda//cuda:{}_platform_is_linux_x86_64":'.format(platform_type)) + build_content.append(' ":linux_x86_64_{}",'.format(target_name)) + if "sbsa" in platforms_available: + build_content.append(' "@rules_cuda//cuda:{}_platform_is_linux_sbsa":'.format(platform_type)) + build_content.append(' ":linux_sbsa_{}",'.format(target_name)) + if "aarch64" in platforms_available: + build_content.append(' "@rules_cuda//cuda:{}_platform_is_linux_aarch64":'.format(platform_type)) + build_content.append(' ":linux_aarch64_{}",'.format(target_name)) + build_content.append(" }),") + build_content.append(' visibility = ["//visibility:public"],') + build_content.append(")") + build_content.append("") + + # Generate platform-specific aliases for this target (only if versions exist). + if "x86_64" in platforms_available: + # Determine appropriate dummy target based on the target name. + dummy_target = "@rules_cuda//cuda/dummy:dummy" + if target_name == "cicc": + dummy_target = "@rules_cuda//cuda/dummy:cicc" + elif target_name == "libdevice.10.bc": + dummy_target = "@rules_cuda//cuda/dummy:libdevice.10.bc" + + build_content.append("alias(") + build_content.append(' name = "linux_x86_64_{}",'.format(target_name)) + build_content.append(" actual = select({") + if use_dict_api: + for version in ctx.attr.versions: + if version in ctx.attr.linux_x86_64_repos: + repo_name = ctx.attr.linux_x86_64_repos[version] + build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' "@{}//{}",'.format(repo_name, target if target.find(":") != -1 else ":" + target)) + + # Add default for versions where this component doesn't exist. + build_content.append(' "//conditions:default": "{}",'.format(dummy_target)) + else: + # Old string-based API for backward compatibility. + for version in ctx.attr.versions: + build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' "@{}_{}//{}",'.format(ctx.attr.linux_x86_64_repo, version.replace(".", "_"), target if target.find(":") != -1 else ":" + target)) + build_content.append(" }),") + build_content.append(' visibility = ["//visibility:public"],') + build_content.append(")") + build_content.append("") + + if "sbsa" in platforms_available: + build_content.append("alias(") + build_content.append(' name = "linux_sbsa_{}",'.format(target_name)) + build_content.append(" actual = select({") + if use_dict_api: + for version in ctx.attr.versions: + if version in ctx.attr.linux_sbsa_repos: + repo_name = ctx.attr.linux_sbsa_repos[version] + build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' "@{}//{}",'.format(repo_name, target if target.find(":") != -1 else ":" + target)) + + # Add default for versions where this component doesn't exist. + build_content.append(' "//conditions:default": "{}",'.format(dummy_target)) + else: + # Old string-based API for backward compatibility. + for version in ctx.attr.versions: + build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' "@{}_{}//{}",'.format(ctx.attr.linux_sbsa_repo, version.replace(".", "_"), target if target.find(":") != -1 else ":" + target)) + build_content.append(" }),") + build_content.append(' visibility = ["//visibility:public"],') + build_content.append(")") + build_content.append("") + + if "aarch64" in platforms_available: + build_content.append("alias(") + build_content.append(' name = "linux_aarch64_{}",'.format(target_name)) + build_content.append(" actual = select({") + if use_dict_api: + for version in ctx.attr.versions: + if version in ctx.attr.linux_aarch64_repos: + repo_name = ctx.attr.linux_aarch64_repos[version] + build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' "@{}//{}",'.format(repo_name, target if target.find(":") != -1 else ":" + target)) + + # Add default for versions where this component doesn't exist. + build_content.append(' "//conditions:default": "{}",'.format(dummy_target)) + else: + # Old string-based API for backward compatibility. + for version in ctx.attr.versions: + build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' "@{}_{}//{}",'.format(ctx.attr.linux_aarch64_repo, version.replace(".", "_"), target if target.find(":") != -1 else ":" + target)) + build_content.append(" }),") + build_content.append(' visibility = ["//visibility:public"],') + build_content.append(")") + build_content.append("") # Write the BUILD.bazel file ctx.file("BUILD.bazel", "\n".join(build_content)) - platform_alias_repo = repository_rule( implementation = _platform_alias_repo_impl, attrs = { @@ -269,17 +275,31 @@ platform_alias_repo = repository_rule( mandatory = True, doc = "Name of the component", ), + # New dict-based API (preferred). + "linux_x86_64_repos": attr.string_dict( + default = {}, + doc = "Dictionary mapping versions to x86_64 repository names", + ), + "linux_aarch64_repos": attr.string_dict( + default = {}, + doc = "Dictionary mapping versions to ARM64/Jetpack repository names", + ), + "linux_sbsa_repos": attr.string_dict( + default = {}, + doc = "Dictionary mapping versions to SBSA repository names", + ), + # Old string-based API (for backward compatibility). "linux_x86_64_repo": attr.string( - mandatory = True, - doc = "Name of the repository to use for x86_64 platform", + default = "", + doc = "Base name of the repository to use for x86_64 platform (deprecated, use linux_x86_64_repos)", ), "linux_aarch64_repo": attr.string( - mandatory = True, - doc = "Name of the repository to use for ARM64/Jetpack platform", + default = "", + doc = "Base name of the repository to use for ARM64/Jetpack platform (deprecated, use linux_aarch64_repos)", ), "linux_sbsa_repo": attr.string( - mandatory = True, - doc = "Name of the repository to use for SBSA platform", + default = "", + doc = "Base name of the repository to use for SBSA platform (deprecated, use linux_sbsa_repos)", ), "versions": attr.string_list( mandatory = True, diff --git a/cuda/private/redist_json_helper.bzl b/cuda/private/redist_json_helper.bzl index 5e5dd813..b9a7f998 100644 --- a/cuda/private/redist_json_helper.bzl +++ b/cuda/private/redist_json_helper.bzl @@ -27,9 +27,21 @@ def _get(ctx, attr): if len(urls) == 0: fail("`urls` or `version` must be specified.") + # Use the name of the repository or tag to avoid collisions when multiple + # redist_json are defined. + name = getattr(attr, "name", None) + if not name and hasattr(ctx, "name"): + name = ctx.name + + # In case name is still not found (unlikely), use a default + if not name: + name = "default" + + output_filename = "redist_{}.json".format(name) + for url in urls: ret = ctx.download( - output = "redist.json", + output = output_filename, integrity = attr.integrity, sha256 = attr.sha256, url = url, @@ -41,7 +53,7 @@ def _get(ctx, attr): if the_url == None: fail("Failed to retrieve the redist json file.") - return the_url, json.decode(ctx.read("redist.json")) + return the_url, json.decode(ctx.read(output_filename)) def _get_redist_version(ctx, attr, redist): """Get version string. diff --git a/cuda/private/template_helper.bzl b/cuda/private/template_helper.bzl index feaec41c..611cbabd 100644 --- a/cuda/private/template_helper.bzl +++ b/cuda/private/template_helper.bzl @@ -28,13 +28,16 @@ def _expand_lctk_cuda(repository_ctx, components): def _expand_dctk_cuda(repository_ctx, components): tpl_label = Label("//cuda/private:templates/BUILD.dctk_cuda") - all_files_srcs_line = [comp + "_all_files" for comp in components.keys()] + # Filter out components with empty REGISTRY entries + valid_components = [comp for comp in components.keys() if len(REGISTRY[comp]) > 0] + + all_files_srcs_line = [comp + "_all_files" for comp in valid_components] all_files_srcs_line = "srcs = " + repr(all_files_srcs_line) - license_srcs_line = [comp + "_license" for comp in components.keys()] + license_srcs_line = [comp + "_license" for comp in valid_components] license_srcs_line = "srcs = " + repr(license_srcs_line) - headers_deps_line = [comp + "_headers" for comp in components.keys()] + headers_deps_line = [comp + "_headers" for comp in valid_components] headers_deps_line = "deps = " + repr(headers_deps_line) substitutions = { diff --git a/cuda/private/templates/BUILD.cufile b/cuda/private/templates/BUILD.cufile index e69de29b..db49e382 100644 --- a/cuda/private/templates/BUILD.cufile +++ b/cuda/private/templates/BUILD.cufile @@ -0,0 +1,9 @@ +exports_files(glob(["**"])) + +cc_library( + name = "cufile", + srcs = glob(["cufile/lib/**/*.so*"]), + hdrs = glob(["cufile/include/**/*.h"]), + includes = ["cufile/include"], + visibility = ["//visibility:public"], +) \ No newline at end of file diff --git a/cuda/private/templates/BUILD.nvcc b/cuda/private/templates/BUILD.nvcc index 5f71863f..75bf0c4c 100644 --- a/cuda/private/templates/BUILD.nvcc +++ b/cuda/private/templates/BUILD.nvcc @@ -1,3 +1,38 @@ +# Export individual executables as simple target names +exports_files([ + "%{component_name}/bin/nvcc", + "%{component_name}/bin/nvlink", + "%{component_name}/bin/bin2c", + "%{component_name}/bin/fatbinary", + "%{component_name}/bin/crt/link.stub", +], visibility = ["//visibility:public"]) + +# Create aliases with simple names for toolchain compatibility +alias( + name = "nvcc", + actual = "%{component_name}/bin/nvcc", +) + +alias( + name = "nvlink", + actual = "%{component_name}/bin/nvlink", +) + +alias( + name = "bin2c", + actual = "%{component_name}/bin/bin2c", +) + +alias( + name = "fatbinary", + actual = "%{component_name}/bin/fatbinary", +) + +alias( + name = "link.stub", + actual = "%{component_name}/bin/crt/link.stub", +) + filegroup( name = "compiler_root", srcs = [":nvcc"], diff --git a/cuda/private/templates/BUILD.nvvm b/cuda/private/templates/BUILD.nvvm index e69de29b..d55f10b4 100644 --- a/cuda/private/templates/BUILD.nvvm +++ b/cuda/private/templates/BUILD.nvvm @@ -0,0 +1,23 @@ +# Export individual files as simple target names +exports_files([ + "%{component_name}/nvvm/bin/cicc", + "%{component_name}/nvvm/libdevice/libdevice.10.bc", +], visibility = ["//visibility:public"]) + +# Create aliases with simple names for toolchain compatibility +alias( + name = "cicc", + actual = "%{component_name}/nvvm/bin/cicc", +) + +alias( + name = "libdevice.10.bc", + actual = "%{component_name}/nvvm/libdevice/libdevice.10.bc", +) + +# Alias for generic libdevice name +alias( + name = "libdevice", + actual = ":libdevice.10.bc", +) + diff --git a/cuda/private/templates/registry.bzl b/cuda/private/templates/registry.bzl index f31ef825..7ba9c958 100644 --- a/cuda/private/templates/registry.bzl +++ b/cuda/private/templates/registry.bzl @@ -1,14 +1,14 @@ # map short component name to consumable targets REGISTRY = { "cudart": ["cudart_all_files", "cudart_license", "cudart_headers", "cuda", "cuda_runtime", "cuda_runtime_static"], - "nvcc": ["nvcc_all_files", "nvcc_license", "nvcc_headers", "compiler_root", "compiler_deps", "nvptxcompiler"], - "nvvm": ["nvvm_all_files", "nvvm_license", "nvvm_headers", "cicc", "libdevice"], + "nvcc": ["nvcc_all_files", "nvcc_license", "nvcc_headers", "compiler_root", "compiler_deps", "nvptxcompiler", "nvcc", "nvlink", "bin2c", "fatbinary", "link.stub"], + "nvvm": ["nvvm_all_files", "nvvm_license", "nvvm_headers", "cicc", "libdevice", "libdevice.10.bc"], "cccl": ["cccl_all_files", "cccl_license", "cccl_headers", "libcudacxx", "cub", "thrust"], "crt": ["crt_all_files", "crt_license", "crt_headers", "crt"], - "culibos": ["culibos_all_files", "culibos_license", "culibos_a"], # culibos_a is not for end users + "culibos": ["culibos_all_files", "culibos_license", "culibos_headers", "culibos_a"], # culibos_a is not for end users. "cublas": ["cublas_all_files", "cublas_license", "cublas_headers", "cublas"], "cufft": ["cufft_all_files", "cufft_license", "cufft_headers", "cufft", "cufft_static"], - "cufile": [], + "cufile": ["cufile", "cufile_all_files", "cufile_header_files", "cufile_headers", "cufile_license"], "cupti": ["cupti_all_files", "cupti_license", "cupti_headers", "cupti", "nvperf_host", "nvperf_target"], "curand": ["curand_all_files", "curand_license", "curand_headers", "curand"], "cusolver": ["cusolver_all_files", "cusolver_license", "cusolver_headers", "cusolver"], From 7bc9c61f527a1b8596dab3edf3b80441e9e8fce4 Mon Sep 17 00:00:00 2001 From: Charley Saint Date: Thu, 4 Dec 2025 19:53:18 +0000 Subject: [PATCH 04/10] add error messages for bad platform/version selection --- cuda/defs.bzl | 3 + cuda/platform_alias_extension.bzl | 170 +++++++----------------------- cuda/private/errors.bzl | 27 +++++ 3 files changed, 69 insertions(+), 131 deletions(-) create mode 100644 cuda/private/errors.bzl diff --git a/cuda/defs.bzl b/cuda/defs.bzl index e3fdbab3..ad0178cc 100644 --- a/cuda/defs.bzl +++ b/cuda/defs.bzl @@ -3,6 +3,7 @@ Core rules for building CUDA projects. """ load("//cuda/private:defs.bzl", _requires_cuda = "requires_cuda") +load("//cuda/private:errors.bzl", _unsupported_cuda_version = "unsupported_cuda_version", _unsupported_cuda_platform = "unsupported_cuda_platform") load("//cuda/private:macros/cuda_binary.bzl", _cuda_binary = "cuda_binary") load("//cuda/private:macros/cuda_test.bzl", _cuda_test = "cuda_test") load("//cuda/private:os_helpers.bzl", _cc_import_versioned_sos = "cc_import_versioned_sos", _if_linux = "if_linux", _if_windows = "if_windows") @@ -47,3 +48,5 @@ if_windows = _if_windows cc_import_versioned_sos = _cc_import_versioned_sos requires_cuda = _requires_cuda +unsupported_cuda_version = _unsupported_cuda_version +unsupported_cuda_platform = _unsupported_cuda_platform \ No newline at end of file diff --git a/cuda/platform_alias_extension.bzl b/cuda/platform_alias_extension.bzl index ba2b1a36..a8145ca6 100644 --- a/cuda/platform_alias_extension.bzl +++ b/cuda/platform_alias_extension.bzl @@ -21,8 +21,9 @@ def _platform_alias_repo_impl(ctx): build_content = ["# Generated by platform_alias_repo rule", ""] # Add load statement for alias - build_content.append('load("@bazel_skylib//lib:selects.bzl", "selects")') + build_content.append('load("@rules_cuda//cuda:defs.bzl", "unsupported_cuda_version", "unsupported_cuda_platform")') build_content.append("") + build_content.append('unsupported_cuda_version(name = "unsupported_cuda_version", component = "{}", available_versions = {})'.format(ctx.attr.component_name, ctx.attr.versions)) # Check if using new dict-based API or old string-based API. use_dict_api = len(ctx.attr.linux_x86_64_repos) > 0 @@ -34,126 +35,21 @@ def _platform_alias_repo_impl(ctx): platforms_available = [] if use_dict_api: if len(ctx.attr.linux_x86_64_repos) > 0: - platforms_available.append("x86_64") + platforms_available.append("linux-x86_64") if len(ctx.attr.linux_sbsa_repos) > 0: - platforms_available.append("sbsa") + platforms_available.append("linux-sbsa") if len(ctx.attr.linux_aarch64_repos) > 0: - platforms_available.append("aarch64") + platforms_available.append("linux-aarch64") else: if ctx.attr.linux_x86_64_repo != "": - platforms_available.append("x86_64") + platforms_available.append("linux-x86_64") if ctx.attr.linux_sbsa_repo != "": - platforms_available.append("sbsa") + platforms_available.append("linux-sbsa") if ctx.attr.linux_aarch64_repo != "": - platforms_available.append("aarch64") + platforms_available.append("linux-aarch64") if len(platforms_available) > 0: - build_content.append("alias(") - build_content.append(' name = "{}",'.format(ctx.attr.repo_name)) - build_content.append(" actual = select({") - if "x86_64" in platforms_available: - build_content.append(' "@rules_cuda//cuda:{}_platform_is_linux_x86_64":'.format(platform_type)) - build_content.append(' ":linux_x86_64_{}",'.format(ctx.attr.repo_name)) - if "sbsa" in platforms_available: - build_content.append(' "@rules_cuda//cuda:{}_platform_is_linux_sbsa":'.format(platform_type)) - build_content.append(' ":linux_sbsa_{}",'.format(ctx.attr.repo_name)) - if "aarch64" in platforms_available: - build_content.append(' "@rules_cuda//cuda:{}_platform_is_linux_aarch64":'.format(platform_type)) - build_content.append(' ":linux_aarch64_{}",'.format(ctx.attr.repo_name)) - build_content.append(" }),") - build_content.append(' visibility = ["//visibility:public"],') - build_content.append(")") - build_content.append("") - - # Generate aliases for each platform (only if versions exist for that platform). - # Check if we have any versions for x86_64. - has_x86_64 = False - if use_dict_api: - has_x86_64 = len(ctx.attr.linux_x86_64_repos) > 0 - else: - has_x86_64 = ctx.attr.linux_x86_64_repo != "" - - if has_x86_64: - build_content.append("alias(") - build_content.append(' name = "linux_x86_64_{}",'.format(ctx.attr.repo_name)) - build_content.append(" actual = select({") - if use_dict_api: - for version in ctx.attr.versions: - if version in ctx.attr.linux_x86_64_repos: - repo_name = ctx.attr.linux_x86_64_repos[version] - build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) - build_content.append(' "@{}//:{}",'.format(repo_name, ctx.attr.repo_name)) - - # Add default for versions where this component doesn't exist. - build_content.append(' "//conditions:default": "@rules_cuda//cuda/dummy:dummy",') - else: - # Old string-based API for backward compatibility. - for version in ctx.attr.versions: - build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) - build_content.append(' "@{}_{}//:{}",'.format(ctx.attr.linux_x86_64_repo, version.replace(".", "_"), ctx.attr.repo_name)) - build_content.append(" }),") - build_content.append(' visibility = ["//visibility:public"],') - build_content.append(")") - build_content.append("") - - # Check if we have any versions for sbsa. - has_sbsa = False - if use_dict_api: - has_sbsa = len(ctx.attr.linux_sbsa_repos) > 0 - else: - has_sbsa = ctx.attr.linux_sbsa_repo != "" - - if has_sbsa: - build_content.append("alias(") - build_content.append(' name = "linux_sbsa_{}",'.format(ctx.attr.repo_name)) - build_content.append(" actual = select({") - if use_dict_api: - for version in ctx.attr.versions: - if version in ctx.attr.linux_sbsa_repos: - repo_name = ctx.attr.linux_sbsa_repos[version] - build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) - build_content.append(' "@{}//:{}",'.format(repo_name, ctx.attr.repo_name)) - - # Add default for versions where this component doesn't exist. - build_content.append(' "//conditions:default": "@rules_cuda//cuda/dummy:dummy",') - else: - # Old string-based API for backward compatibility. - for version in ctx.attr.versions: - build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) - build_content.append(' "@{}_{}//:{}",'.format(ctx.attr.linux_sbsa_repo, version.replace(".", "_"), ctx.attr.repo_name)) - build_content.append(" }),") - build_content.append(' visibility = ["//visibility:public"],') - build_content.append(")") - build_content.append("") - - # Check if we have any versions for aarch64. - has_aarch64 = False - if use_dict_api: - has_aarch64 = len(ctx.attr.linux_aarch64_repos) > 0 - else: - has_aarch64 = ctx.attr.linux_aarch64_repo != "" - - if has_aarch64: - build_content.append("alias(") - build_content.append(' name = "linux_aarch64_{}",'.format(ctx.attr.repo_name)) - build_content.append(" actual = select({") - if use_dict_api: - for version in ctx.attr.versions: - if version in ctx.attr.linux_aarch64_repos: - repo_name = ctx.attr.linux_aarch64_repos[version] - build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) - build_content.append(' "@{}//:{}",'.format(repo_name, ctx.attr.repo_name)) - - # Add default for versions where this component doesn't exist. - build_content.append(' "//conditions:default": "@rules_cuda//cuda/dummy:dummy",') - else: - # Old string-based API for backward compatibility. - for version in ctx.attr.versions: - build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) - build_content.append(' "@{}_{}//:{}",'.format(ctx.attr.linux_aarch64_repo, version.replace(".", "_"), ctx.attr.repo_name)) - build_content.append(" }),") - build_content.append(' visibility = ["//visibility:public"],') - build_content.append(")") + build_content.append('unsupported_cuda_platform(name = "unsupported_cuda_platform", component = "{}", available_platforms = {})'.format(ctx.attr.component_name, platforms_available)) build_content.append("") # Only generate target aliases if this component is in TARGET_MAPPING. @@ -170,29 +66,31 @@ def _platform_alias_repo_impl(ctx): build_content.append("alias(") build_content.append(' name = "{}",'.format(target_name)) build_content.append(" actual = select({") - if "x86_64" in platforms_available: + if "linux-x86_64" in platforms_available: build_content.append(' "@rules_cuda//cuda:{}_platform_is_linux_x86_64":'.format(platform_type)) build_content.append(' ":linux_x86_64_{}",'.format(target_name)) - if "sbsa" in platforms_available: + if "linux-sbsa" in platforms_available: build_content.append(' "@rules_cuda//cuda:{}_platform_is_linux_sbsa":'.format(platform_type)) build_content.append(' ":linux_sbsa_{}",'.format(target_name)) - if "aarch64" in platforms_available: + if "linux-aarch64" in platforms_available: build_content.append(' "@rules_cuda//cuda:{}_platform_is_linux_aarch64":'.format(platform_type)) build_content.append(' ":linux_aarch64_{}",'.format(target_name)) + build_content.append(' "//conditions:default": ":unsupported_cuda_platform",') build_content.append(" }),") build_content.append(' visibility = ["//visibility:public"],') build_content.append(")") build_content.append("") - # Generate platform-specific aliases for this target (only if versions exist). - if "x86_64" in platforms_available: - # Determine appropriate dummy target based on the target name. - dummy_target = "@rules_cuda//cuda/dummy:dummy" - if target_name == "cicc": - dummy_target = "@rules_cuda//cuda/dummy:cicc" - elif target_name == "libdevice.10.bc": - dummy_target = "@rules_cuda//cuda/dummy:libdevice.10.bc" + # Determine appropriate dummy target based on the target name. + dummy_target = "@rules_cuda//cuda/dummy:dummy" + if target_name == "cicc": + dummy_target = "@rules_cuda//cuda/dummy:cicc" + elif target_name == "libdevice.10.bc": + dummy_target = "@rules_cuda//cuda/dummy:libdevice.10.bc" + + # Generate platform-specific aliases for this target (only if versions exist). + if "linux-x86_64" in platforms_available: build_content.append("alias(") build_content.append(' name = "linux_x86_64_{}",'.format(target_name)) build_content.append(" actual = select({") @@ -202,20 +100,23 @@ def _platform_alias_repo_impl(ctx): repo_name = ctx.attr.linux_x86_64_repos[version] build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) build_content.append(' "@{}//{}",'.format(repo_name, target if target.find(":") != -1 else ":" + target)) - + else: + build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' "{}",'.format(dummy_target)) # Add default for versions where this component doesn't exist. - build_content.append(' "//conditions:default": "{}",'.format(dummy_target)) + build_content.append(' "//conditions:default": ":unsupported_cuda_version",') else: # Old string-based API for backward compatibility. for version in ctx.attr.versions: build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) build_content.append(' "@{}_{}//{}",'.format(ctx.attr.linux_x86_64_repo, version.replace(".", "_"), target if target.find(":") != -1 else ":" + target)) + build_content.append(' "//conditions:default": ":unsupported_cuda_version",') build_content.append(" }),") build_content.append(' visibility = ["//visibility:public"],') build_content.append(")") build_content.append("") - if "sbsa" in platforms_available: + if "linux-sbsa" in platforms_available: build_content.append("alias(") build_content.append(' name = "linux_sbsa_{}",'.format(target_name)) build_content.append(" actual = select({") @@ -225,20 +126,24 @@ def _platform_alias_repo_impl(ctx): repo_name = ctx.attr.linux_sbsa_repos[version] build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) build_content.append(' "@{}//{}",'.format(repo_name, target if target.find(":") != -1 else ":" + target)) + else: + build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' "{}",'.format(dummy_target)) # Add default for versions where this component doesn't exist. - build_content.append(' "//conditions:default": "{}",'.format(dummy_target)) + build_content.append(' "//conditions:default": ":unsupported_cuda_version",') else: # Old string-based API for backward compatibility. for version in ctx.attr.versions: build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) build_content.append(' "@{}_{}//{}",'.format(ctx.attr.linux_sbsa_repo, version.replace(".", "_"), target if target.find(":") != -1 else ":" + target)) + build_content.append(' "//conditions:default": ":unsupported_cuda_version",') build_content.append(" }),") build_content.append(' visibility = ["//visibility:public"],') build_content.append(")") build_content.append("") - if "aarch64" in platforms_available: + if "linux-aarch64" in platforms_available: build_content.append("alias(") build_content.append(' name = "linux_aarch64_{}",'.format(target_name)) build_content.append(" actual = select({") @@ -248,14 +153,17 @@ def _platform_alias_repo_impl(ctx): repo_name = ctx.attr.linux_aarch64_repos[version] build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) build_content.append(' "@{}//{}",'.format(repo_name, target if target.find(":") != -1 else ":" + target)) - + else: + build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' "{}",'.format(dummy_target)) # Add default for versions where this component doesn't exist. - build_content.append(' "//conditions:default": "{}",'.format(dummy_target)) + build_content.append(' "//conditions:default": ":unsupported_cuda_version",') else: # Old string-based API for backward compatibility. for version in ctx.attr.versions: build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) build_content.append(' "@{}_{}//{}",'.format(ctx.attr.linux_aarch64_repo, version.replace(".", "_"), target if target.find(":") != -1 else ":" + target)) + build_content.append(' "//conditions:default": ":unsupported_cuda_version",') build_content.append(" }),") build_content.append(' visibility = ["//visibility:public"],') build_content.append(")") diff --git a/cuda/private/errors.bzl b/cuda/private/errors.bzl new file mode 100644 index 00000000..8d358721 --- /dev/null +++ b/cuda/private/errors.bzl @@ -0,0 +1,27 @@ +def _unsupported_cuda_version_impl(ctx): + fail("CUDA component '{}' is not available for the selected CUDA version. Available versions: {}".format( + ctx.attr.component, + ", ".join(ctx.attr.available_versions) + )) + +unsupported_cuda_version = rule( + implementation = _unsupported_cuda_version_impl, + attrs = { + "component": attr.string(mandatory = True), + "available_versions": attr.string_list(mandatory = True), + }, +) + +def _unsupported_cuda_platform_impl(ctx): + fail("CUDA component '{}' is not available for the selected platform. Available platforms: {}".format( + ctx.attr.component, + ", ".join(ctx.attr.available_platforms) + )) + +unsupported_cuda_platform = rule( + implementation = _unsupported_cuda_platform_impl, + attrs = { + "component": attr.string(mandatory = True), + "available_platforms": attr.string_list(mandatory = True), + }, +) From 3e1f2061d928e0c24cf40ee8794cbadb584e4018 Mon Sep 17 00:00:00 2001 From: Charley Saint Date: Fri, 5 Dec 2025 04:15:03 +0000 Subject: [PATCH 05/10] move versions and chaange flag names --- cuda/BUILD.bazel | 126 ++++++++---------------------- cuda/platform_alias_extension.bzl | 20 ++--- cuda/private/cuda_versions.bzl | 27 +++++++ cuda/versions/BUILD.bazel | 100 ------------------------ 4 files changed, 71 insertions(+), 202 deletions(-) create mode 100644 cuda/private/cuda_versions.bzl delete mode 100644 cuda/versions/BUILD.bazel diff --git a/cuda/BUILD.bazel b/cuda/BUILD.bazel index 47832197..45512582 100644 --- a/cuda/BUILD.bazel +++ b/cuda/BUILD.bazel @@ -4,10 +4,19 @@ load( "bool_flag", "string_flag", ) +load("//cuda/private:cuda_versions.bzl", "CUDA_VERSIONS") load("//cuda/private:rules/flags.bzl", "cuda_archs_flag", "repeatable_string_flag") package(default_visibility = ["//visibility:public"]) +CUDA_PLATFORMS = [ + "linux-x86_64", + "linux-sbsa", + "linux-aarch64", + "linux-ppc64le", + "windows-x86_64", +] + bzl_library( name = "bzl_srcs", srcs = glob(["*.bzl"]), @@ -40,109 +49,42 @@ config_setting( string_flag( name = "version", - build_setting_default = "13.0.2", - values = [ - "12.0.0", - "12.0.1", - "12.1.0", - "12.1.1", - "12.2.0", - "12.2.1", - "12.2.2", - "12.3.0", - "12.3.1", - "12.3.2", - "12.4.0", - "12.4.1", - "12.5.0", - "12.5.1", - "12.6.0", - "12.6.1", - "12.6.2", - "12.6.3", - "12.8.0", - "12.8.1", - "12.9.0", - "12.9.1", - "13.0.0", - "13.0.1", - "13.0.2", - ], + build_setting_default = "13.0.0", ) -string_flag( - name = "runtime_platform", - build_setting_default = "linux-x86_64", - values = [ - "linux-x86_64", - "linux-sbsa", - "linux-aarch64", - "linux-ppc64le", - "windows-x86_64", - ], -) +[ + config_setting( + name = "version_is_{}".format(version.replace(".", "_")), + flag_values = {"@rules_cuda//cuda:version": "{}".format(version)}, + ) for version in CUDA_VERSIONS +] -config_setting( - name = "runtime_platform_is_linux_x86_64", - flag_values = {":runtime_platform": "linux-x86_64"}, -) - -config_setting( - name = "runtime_platform_is_linux_aarch64", - flag_values = {":runtime_platform": "linux-aarch64"}, -) - -config_setting( - name = "runtime_platform_is_linux_ppc64le", - flag_values = {":runtime_platform": "linux-ppc64le"}, -) - -config_setting( - name = "runtime_platform_is_linux_sbsa", - flag_values = {":runtime_platform": "linux-sbsa"}, -) - -config_setting( - name = "runtime_platform_is_windows_x86_64", - flag_values = {":runtime_platform": "windows-x86_64"}, -) string_flag( - name = "nvcc_platform", + name = "target_platform", build_setting_default = "linux-x86_64", - values = [ - "linux-x86_64", - "linux-sbsa", - "linux-aarch64", - "linux-ppc64le", - "windows-x86_64", - ], -) - -config_setting( - name = "nvcc_platform_is_linux_x86_64", - flag_values = {":nvcc_platform": "linux-x86_64"}, -) - -config_setting( - name = "nvcc_platform_is_linux_aarch64", - flag_values = {":nvcc_platform": "linux-aarch64"}, + values = CUDA_PLATFORMS, ) -config_setting( - name = "nvcc_platform_is_linux_ppc64le", - flag_values = {":nvcc_platform": "linux-ppc64le"}, -) +[ + config_setting( + name = "target_platform_is_{}".format(platform.replace("-", "_")), + flag_values = {":target_platform": platform}, + ) for platform in CUDA_PLATFORMS +] -config_setting( - name = "nvcc_platform_is_linux_sbsa", - flag_values = {":nvcc_platform": "linux-sbsa"}, +string_flag( + name = "exec_platform", + build_setting_default = "linux-x86_64", + values = CUDA_PLATFORMS, ) -config_setting( - name = "nvcc_platform_is_windows_x86_64", - flag_values = {":nvcc_platform": "windows-x86_64"}, -) +[ + config_setting( + name = "exec_platform_is_{}".format(platform.replace("-", "_")), + flag_values = {":exec_platform": platform}, + ) for platform in CUDA_PLATFORMS +] # Command line flag to specify the list of CUDA architectures to compile for. # diff --git a/cuda/platform_alias_extension.bzl b/cuda/platform_alias_extension.bzl index a8145ca6..3128cf90 100644 --- a/cuda/platform_alias_extension.bzl +++ b/cuda/platform_alias_extension.bzl @@ -29,7 +29,7 @@ def _platform_alias_repo_impl(ctx): use_dict_api = len(ctx.attr.linux_x86_64_repos) > 0 # Build a target for the name of the repo (only if at least one platform is available). - platform_type = "nvcc" if ctx.attr.component_name in ["nvcc", "nvvm"] else "runtime" + platform_type = "exec" if ctx.attr.component_name in ["nvcc", "nvvm"] else "target" # Check which platforms are available. platforms_available = [] @@ -98,17 +98,17 @@ def _platform_alias_repo_impl(ctx): for version in ctx.attr.versions: if version in ctx.attr.linux_x86_64_repos: repo_name = ctx.attr.linux_x86_64_repos[version] - build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' "@rules_cuda//cuda:version_is_{}": '.format(version.replace(".", "_"))) build_content.append(' "@{}//{}",'.format(repo_name, target if target.find(":") != -1 else ":" + target)) else: - build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' "@rules_cuda//cuda:version_is_{}": '.format(version.replace(".", "_"))) build_content.append(' "{}",'.format(dummy_target)) # Add default for versions where this component doesn't exist. build_content.append(' "//conditions:default": ":unsupported_cuda_version",') else: # Old string-based API for backward compatibility. for version in ctx.attr.versions: - build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' "@rules_cuda//cuda:version_is_{}": '.format(version.replace(".", "_"))) build_content.append(' "@{}_{}//{}",'.format(ctx.attr.linux_x86_64_repo, version.replace(".", "_"), target if target.find(":") != -1 else ":" + target)) build_content.append(' "//conditions:default": ":unsupported_cuda_version",') build_content.append(" }),") @@ -124,10 +124,10 @@ def _platform_alias_repo_impl(ctx): for version in ctx.attr.versions: if version in ctx.attr.linux_sbsa_repos: repo_name = ctx.attr.linux_sbsa_repos[version] - build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' "@rules_cuda//cuda:version_is_{}": '.format(version.replace(".", "_"))) build_content.append(' "@{}//{}",'.format(repo_name, target if target.find(":") != -1 else ":" + target)) else: - build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' "@rules_cuda//cuda:version_is_{}": '.format(version.replace(".", "_"))) build_content.append(' "{}",'.format(dummy_target)) # Add default for versions where this component doesn't exist. @@ -135,7 +135,7 @@ def _platform_alias_repo_impl(ctx): else: # Old string-based API for backward compatibility. for version in ctx.attr.versions: - build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' "@rules_cuda//cuda:version_is_{}": '.format(version.replace(".", "_"))) build_content.append(' "@{}_{}//{}",'.format(ctx.attr.linux_sbsa_repo, version.replace(".", "_"), target if target.find(":") != -1 else ":" + target)) build_content.append(' "//conditions:default": ":unsupported_cuda_version",') build_content.append(" }),") @@ -151,17 +151,17 @@ def _platform_alias_repo_impl(ctx): for version in ctx.attr.versions: if version in ctx.attr.linux_aarch64_repos: repo_name = ctx.attr.linux_aarch64_repos[version] - build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' "@rules_cuda//cuda:version_is_{}": '.format(version.replace(".", "_"))) build_content.append(' "@{}//{}",'.format(repo_name, target if target.find(":") != -1 else ":" + target)) else: - build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' "@rules_cuda//cuda:version_is_{}": '.format(version.replace(".", "_"))) build_content.append(' "{}",'.format(dummy_target)) # Add default for versions where this component doesn't exist. build_content.append(' "//conditions:default": ":unsupported_cuda_version",') else: # Old string-based API for backward compatibility. for version in ctx.attr.versions: - build_content.append(' "@rules_cuda//cuda/versions:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' "@rules_cuda//cuda:version_is_{}": '.format(version.replace(".", "_"))) build_content.append(' "@{}_{}//{}",'.format(ctx.attr.linux_aarch64_repo, version.replace(".", "_"), target if target.find(":") != -1 else ":" + target)) build_content.append(' "//conditions:default": ":unsupported_cuda_version",') build_content.append(" }),") diff --git a/cuda/private/cuda_versions.bzl b/cuda/private/cuda_versions.bzl new file mode 100644 index 00000000..f837a4e9 --- /dev/null +++ b/cuda/private/cuda_versions.bzl @@ -0,0 +1,27 @@ +CUDA_VERSIONS = [ + "12.0.0", + "12.0.1", + "12.1.0", + "12.1.1", + "12.2.0", + "12.2.1", + "12.2.2", + "12.3.0", + "12.3.1", + "12.3.2", + "12.4.0", + "12.4.1", + "12.5.0", + "12.5.1", + "12.6.0", + "12.6.1", + "12.6.2", + "12.6.3", + "12.8.0", + "12.8.1", + "12.9.0", + "12.9.1", + "13.0.0", + "13.0.1", + "13.0.2", +] diff --git a/cuda/versions/BUILD.bazel b/cuda/versions/BUILD.bazel deleted file mode 100644 index 76e048e4..00000000 --- a/cuda/versions/BUILD.bazel +++ /dev/null @@ -1,100 +0,0 @@ -config_setting( - name = "version_is_12_0_0", - flag_values = {"@rules_cuda//cuda:version": "12.0.0"}, -) -config_setting( - name = "version_is_12_0_1", - flag_values = {"@rules_cuda//cuda:version": "12.0.1"}, -) -config_setting( - name = "version_is_12_1_0", - flag_values = {"@rules_cuda//cuda:version": "12.1.0"}, -) -config_setting( - name = "version_is_12_1_1", - flag_values = {"@rules_cuda//cuda:version": "12.1.1"}, -) -config_setting( - name = "version_is_12_2_0", - flag_values = {"@rules_cuda//cuda:version": "12.2.0"}, -) -config_setting( - name = "version_is_12_2_1", - flag_values = {"@rules_cuda//cuda:version": "12.2.1"}, -) -config_setting( - name = "version_is_12_2_2", - flag_values = {"@rules_cuda//cuda:version": "12.2.2"}, -) -config_setting( - name = "version_is_12_3_0", - flag_values = {"@rules_cuda//cuda:version": "12.3.0"}, -) -config_setting( - name = "version_is_12_3_1", - flag_values = {"@rules_cuda//cuda:version": "12.3.1"}, -) -config_setting( - name = "version_is_12_3_2", - flag_values = {"@rules_cuda//cuda:version": "12.3.2"}, -) -config_setting( - name = "version_is_12_4_0", - flag_values = {"@rules_cuda//cuda:version": "12.4.0"}, -) -config_setting( - name = "version_is_12_4_1", - flag_values = {"@rules_cuda//cuda:version": "12.4.1"}, -) -config_setting( - name = "version_is_12_5_0", - flag_values = {"@rules_cuda//cuda:version": "12.5.0"}, -) -config_setting( - name = "version_is_12_5_1", - flag_values = {"@rules_cuda//cuda:version": "12.5.1"}, -) -config_setting( - name = "version_is_12_6_0", - flag_values = {"@rules_cuda//cuda:version": "12.6.0"}, -) -config_setting( - name = "version_is_12_6_1", - flag_values = {"@rules_cuda//cuda:version": "12.6.1"}, -) -config_setting( - name = "version_is_12_6_2", - flag_values = {"@rules_cuda//cuda:version": "12.6.2"}, -) -config_setting( - name = "version_is_12_6_3", - flag_values = {"@rules_cuda//cuda:version": "12.6.3"}, -) -config_setting( - name = "version_is_12_8_0", - flag_values = {"@rules_cuda//cuda:version": "12.8.0"}, -) -config_setting( - name = "version_is_12_8_1", - flag_values = {"@rules_cuda//cuda:version": "12.8.1"}, -) -config_setting( - name = "version_is_12_9_0", - flag_values = {"@rules_cuda//cuda:version": "12.9.0"}, -) -config_setting( - name = "version_is_12_9_1", - flag_values = {"@rules_cuda//cuda:version": "12.9.1"}, -) -config_setting( - name = "version_is_13_0_0", - flag_values = {"@rules_cuda//cuda:version": "13.0.0"}, -) -config_setting( - name = "version_is_13_0_1", - flag_values = {"@rules_cuda//cuda:version": "13.0.1"}, -) -config_setting( - name = "version_is_13_0_2", - flag_values = {"@rules_cuda//cuda:version": "13.0.2"}, -) From 53f739e857e2dc1f86de98d80cba4dd6fcbcf772 Mon Sep 17 00:00:00 2001 From: Charley Saint Date: Fri, 5 Dec 2025 04:18:56 +0000 Subject: [PATCH 06/10] remove todos --- cuda/private/repositories.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cuda/private/repositories.bzl b/cuda/private/repositories.bzl index f10de761..affa29f6 100644 --- a/cuda/private/repositories.bzl +++ b/cuda/private/repositories.bzl @@ -121,8 +121,8 @@ def _detect_deliverable_cuda_toolkit(repository_ctx): libdevice = None if int(cuda_version_major) >= 13: nvvm_repo = repository_ctx.attr.components_mapping["nvvm"] - cicc = "{}//:cicc{}".format(nvvm_repo, bin_ext) # TODO: can we use @cuda//:cicc? - libdevice = "{}//:libdevice.10.bc".format(nvvm_repo) # TODO: can we use @cuda//:libdevice? + cicc = "{}//:cicc{}".format(nvvm_repo, bin_ext) + libdevice = "{}//:libdevice.10.bc".format(nvvm_repo) return struct( path = None, # scattered components From 4233145fc5e2ff78d871fc3e7264c58f6be5942e Mon Sep 17 00:00:00 2001 From: Charley Saint Date: Fri, 5 Dec 2025 19:07:03 +0000 Subject: [PATCH 07/10] remove list of versions --- cuda/BUILD.bazel | 9 -------- cuda/platform_alias_extension.bzl | 35 ++++++++++++++++++++----------- cuda/private/cuda_versions.bzl | 27 ------------------------ 3 files changed, 23 insertions(+), 48 deletions(-) delete mode 100644 cuda/private/cuda_versions.bzl diff --git a/cuda/BUILD.bazel b/cuda/BUILD.bazel index 45512582..c4517e10 100644 --- a/cuda/BUILD.bazel +++ b/cuda/BUILD.bazel @@ -4,7 +4,6 @@ load( "bool_flag", "string_flag", ) -load("//cuda/private:cuda_versions.bzl", "CUDA_VERSIONS") load("//cuda/private:rules/flags.bzl", "cuda_archs_flag", "repeatable_string_flag") package(default_visibility = ["//visibility:public"]) @@ -52,14 +51,6 @@ string_flag( build_setting_default = "13.0.0", ) -[ - config_setting( - name = "version_is_{}".format(version.replace(".", "_")), - flag_values = {"@rules_cuda//cuda:version": "{}".format(version)}, - ) for version in CUDA_VERSIONS -] - - string_flag( name = "target_platform", build_setting_default = "linux-x86_64", diff --git a/cuda/platform_alias_extension.bzl b/cuda/platform_alias_extension.bzl index 3128cf90..6d7a8221 100644 --- a/cuda/platform_alias_extension.bzl +++ b/cuda/platform_alias_extension.bzl @@ -1,7 +1,7 @@ """Module extension for creating platform-specific aliases for external dependencies. This extension creates repositories with alias targets that select between x86_64 and ARM64 -repositories based on the build platform. +repositories based on the build platform. It also selects between versions of the component. """ load("//cuda/private:templates/registry.bzl", "REGISTRY") @@ -9,7 +9,6 @@ load("//cuda/private:templates/registry.bzl", "REGISTRY") # Use REGISTRY as the source of truth for component targets TARGET_MAPPING = REGISTRY - def _platform_alias_repo_impl(ctx): """Implementation of the platform_alias_repo repository rule. @@ -23,7 +22,18 @@ def _platform_alias_repo_impl(ctx): # Add load statement for alias build_content.append('load("@rules_cuda//cuda:defs.bzl", "unsupported_cuda_version", "unsupported_cuda_platform")') build_content.append("") + + build_content.append("[") + build_content.append(" config_setting(") + build_content.append(' name = "version_is_{}".format(version.replace(".", "_")),') + build_content.append(' flag_values = {"@rules_cuda//cuda:version": "{}".format(version)},') + build_content.append(" )") + build_content.append(" for version in {}".format(ctx.attr.versions)) + build_content.append("]") + build_content.append("") + build_content.append('unsupported_cuda_version(name = "unsupported_cuda_version", component = "{}", available_versions = {})'.format(ctx.attr.component_name, ctx.attr.versions)) + build_content.append("") # Check if using new dict-based API or old string-based API. use_dict_api = len(ctx.attr.linux_x86_64_repos) > 0 @@ -81,7 +91,6 @@ def _platform_alias_repo_impl(ctx): build_content.append(")") build_content.append("") - # Determine appropriate dummy target based on the target name. dummy_target = "@rules_cuda//cuda/dummy:dummy" if target_name == "cicc": @@ -98,17 +107,18 @@ def _platform_alias_repo_impl(ctx): for version in ctx.attr.versions: if version in ctx.attr.linux_x86_64_repos: repo_name = ctx.attr.linux_x86_64_repos[version] - build_content.append(' "@rules_cuda//cuda:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' ":version_is_{}": '.format(version.replace(".", "_"))) build_content.append(' "@{}//{}",'.format(repo_name, target if target.find(":") != -1 else ":" + target)) else: - build_content.append(' "@rules_cuda//cuda:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' ":version_is_{}": '.format(version.replace(".", "_"))) build_content.append(' "{}",'.format(dummy_target)) + # Add default for versions where this component doesn't exist. build_content.append(' "//conditions:default": ":unsupported_cuda_version",') else: # Old string-based API for backward compatibility. for version in ctx.attr.versions: - build_content.append(' "@rules_cuda//cuda:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' ":version_is_{}": '.format(version.replace(".", "_"))) build_content.append(' "@{}_{}//{}",'.format(ctx.attr.linux_x86_64_repo, version.replace(".", "_"), target if target.find(":") != -1 else ":" + target)) build_content.append(' "//conditions:default": ":unsupported_cuda_version",') build_content.append(" }),") @@ -124,10 +134,10 @@ def _platform_alias_repo_impl(ctx): for version in ctx.attr.versions: if version in ctx.attr.linux_sbsa_repos: repo_name = ctx.attr.linux_sbsa_repos[version] - build_content.append(' "@rules_cuda//cuda:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' ":version_is_{}": '.format(version.replace(".", "_"))) build_content.append(' "@{}//{}",'.format(repo_name, target if target.find(":") != -1 else ":" + target)) else: - build_content.append(' "@rules_cuda//cuda:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' ":version_is_{}": '.format(version.replace(".", "_"))) build_content.append(' "{}",'.format(dummy_target)) # Add default for versions where this component doesn't exist. @@ -135,7 +145,7 @@ def _platform_alias_repo_impl(ctx): else: # Old string-based API for backward compatibility. for version in ctx.attr.versions: - build_content.append(' "@rules_cuda//cuda:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' ":version_is_{}": '.format(version.replace(".", "_"))) build_content.append(' "@{}_{}//{}",'.format(ctx.attr.linux_sbsa_repo, version.replace(".", "_"), target if target.find(":") != -1 else ":" + target)) build_content.append(' "//conditions:default": ":unsupported_cuda_version",') build_content.append(" }),") @@ -151,17 +161,18 @@ def _platform_alias_repo_impl(ctx): for version in ctx.attr.versions: if version in ctx.attr.linux_aarch64_repos: repo_name = ctx.attr.linux_aarch64_repos[version] - build_content.append(' "@rules_cuda//cuda:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' ":version_is_{}": '.format(version.replace(".", "_"))) build_content.append(' "@{}//{}",'.format(repo_name, target if target.find(":") != -1 else ":" + target)) else: - build_content.append(' "@rules_cuda//cuda:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' ":version_is_{}": '.format(version.replace(".", "_"))) build_content.append(' "{}",'.format(dummy_target)) + # Add default for versions where this component doesn't exist. build_content.append(' "//conditions:default": ":unsupported_cuda_version",') else: # Old string-based API for backward compatibility. for version in ctx.attr.versions: - build_content.append(' "@rules_cuda//cuda:version_is_{}": '.format(version.replace(".", "_"))) + build_content.append(' ":version_is_{}": '.format(version.replace(".", "_"))) build_content.append(' "@{}_{}//{}",'.format(ctx.attr.linux_aarch64_repo, version.replace(".", "_"), target if target.find(":") != -1 else ":" + target)) build_content.append(' "//conditions:default": ":unsupported_cuda_version",') build_content.append(" }),") diff --git a/cuda/private/cuda_versions.bzl b/cuda/private/cuda_versions.bzl deleted file mode 100644 index f837a4e9..00000000 --- a/cuda/private/cuda_versions.bzl +++ /dev/null @@ -1,27 +0,0 @@ -CUDA_VERSIONS = [ - "12.0.0", - "12.0.1", - "12.1.0", - "12.1.1", - "12.2.0", - "12.2.1", - "12.2.2", - "12.3.0", - "12.3.1", - "12.3.2", - "12.4.0", - "12.4.1", - "12.5.0", - "12.5.1", - "12.6.0", - "12.6.1", - "12.6.2", - "12.6.3", - "12.8.0", - "12.8.1", - "12.9.0", - "12.9.1", - "13.0.0", - "13.0.1", - "13.0.2", -] From 9fbd1ab75235811ab556a5a8c874634ca7e6b917 Mon Sep 17 00:00:00 2001 From: Charley Saint Date: Mon, 8 Dec 2025 22:28:26 +0000 Subject: [PATCH 08/10] fix issues with dummy files --- cuda/dummy/BUILD.bazel | 23 +++++++---------------- cuda/dummy/libdevice.10.bc | 1 + cuda/private/templates/BUILD.cufile | 22 ++++++++++++++++------ 3 files changed, 24 insertions(+), 22 deletions(-) create mode 100644 cuda/dummy/libdevice.10.bc diff --git a/cuda/dummy/BUILD.bazel b/cuda/dummy/BUILD.bazel index 085309f8..00bc8b9a 100644 --- a/cuda/dummy/BUILD.bazel +++ b/cuda/dummy/BUILD.bazel @@ -6,13 +6,19 @@ cc_binary( defines = ["TOOLNAME=nvcc"], ) +cc_binary( + name = "cicc", + srcs = ["dummy.cpp"], + defines = ["TOOLNAME=cicc"], +) + cc_binary( name = "nvlink", srcs = ["dummy.cpp"], defines = ["TOOLNAME=nvlink"], ) -exports_files(["link.stub"]) +exports_files(["link.stub", "libdevice.10.bc"]) cc_binary( name = "bin2c", @@ -32,18 +38,3 @@ cc_library( srcs = [], hdrs = [], ) - -# Dummy executable for cicc that doesn't exist in this version. -genrule( - name = "cicc_gen", - outs = ["cicc"], - cmd = "echo '#!/bin/bash' > $@ && echo 'echo Error: cicc not available in this CUDA version' >> $@ && chmod +x $@", - executable = True, -) - -# Dummy bitcode for libdevice that doesn't exist in this version. -genrule( - name = "libdevice_gen", - outs = ["libdevice.10.bc"], - cmd = "echo '; Placeholder bitcode' > $@", -) diff --git a/cuda/dummy/libdevice.10.bc b/cuda/dummy/libdevice.10.bc new file mode 100644 index 00000000..cb1b2e62 --- /dev/null +++ b/cuda/dummy/libdevice.10.bc @@ -0,0 +1 @@ +#error link.stub of cuda toolkit does not exist diff --git a/cuda/private/templates/BUILD.cufile b/cuda/private/templates/BUILD.cufile index db49e382..488b85dd 100644 --- a/cuda/private/templates/BUILD.cufile +++ b/cuda/private/templates/BUILD.cufile @@ -1,9 +1,19 @@ -exports_files(glob(["**"])) +cc_import_versioned_sos( + name = "cufile_so", + shared_library = "%{component_name}/%{libpath}/libcufile.so", +) + +cc_import_versioned_sos( + name = "cufile_rdma_so", + shared_library = "%{component_name}/%{libpath}/libcufile_rdma.so", +) cc_library( - name = "cufile", - srcs = glob(["cufile/lib/**/*.so*"]), - hdrs = glob(["cufile/include/**/*.h"]), - includes = ["cufile/include"], - visibility = ["//visibility:public"], + name = "curand", + deps = [ + ":%{component_name}_headers", + ":cufile_so", + ":cufile_rdma_so", + ], + target_compatible_with = ["@platforms//os:linux"], ) \ No newline at end of file From d9cfe1c0fd729f361657e89e6afb97af74996a38 Mon Sep 17 00:00:00 2001 From: Lionel Gulich Date: Thu, 15 Jan 2026 19:38:32 +0100 Subject: [PATCH 09/10] Fix typos --- cuda/dummy/libdevice.10.bc | 2 +- cuda/private/templates/BUILD.cufile | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cuda/dummy/libdevice.10.bc b/cuda/dummy/libdevice.10.bc index cb1b2e62..ae4819a4 100644 --- a/cuda/dummy/libdevice.10.bc +++ b/cuda/dummy/libdevice.10.bc @@ -1 +1 @@ -#error link.stub of cuda toolkit does not exist +#error libdevice.10.bc of cuda toolkit does not exist diff --git a/cuda/private/templates/BUILD.cufile b/cuda/private/templates/BUILD.cufile index 488b85dd..7a64f29b 100644 --- a/cuda/private/templates/BUILD.cufile +++ b/cuda/private/templates/BUILD.cufile @@ -9,11 +9,11 @@ cc_import_versioned_sos( ) cc_library( - name = "curand", + name = "cufile", deps = [ ":%{component_name}_headers", ":cufile_so", ":cufile_rdma_so", ], target_compatible_with = ["@platforms//os:linux"], -) \ No newline at end of file +) From 7eb64fedebc3c1f97c952dd13eff1528b96575c8 Mon Sep 17 00:00:00 2001 From: Lionel Gulich Date: Tue, 27 Jan 2026 15:23:44 +0100 Subject: [PATCH 10/10] Fix platform fallback for components missing on some platforms When a CUDA component (like cuda_crt) doesn't exist for a platform (like linux-aarch64), builds would fail because the select() had no matching condition for that platform. Now platform aliases are generated for ALL platforms, with dummy targets used for platforms where the component doesn't exist. This ensures builds on any platform have matching select conditions. Also consolidates platform definitions into a single source of truth in cuda/private/platforms.bzl and removes unused backward-compatibility code. Fixes JP6 (linux-aarch64, CUDA 12) build failure where cuda_crt (a CUDA 13+ only component) caused select() to fail. --- cuda/BUILD.bazel | 17 +-- cuda/defs.bzl | 2 +- cuda/extensions.bzl | 49 -------- cuda/platform_alias_extension.bzl | 188 ++++++++++-------------------- cuda/private/platforms.bzl | 9 ++ cuda/private/template_helper.bzl | 2 +- 6 files changed, 76 insertions(+), 191 deletions(-) create mode 100644 cuda/private/platforms.bzl diff --git a/cuda/BUILD.bazel b/cuda/BUILD.bazel index c4517e10..b0b2cf8a 100644 --- a/cuda/BUILD.bazel +++ b/cuda/BUILD.bazel @@ -4,18 +4,11 @@ load( "bool_flag", "string_flag", ) +load("//cuda/private:platforms.bzl", "SUPPORTED_PLATFORMS") load("//cuda/private:rules/flags.bzl", "cuda_archs_flag", "repeatable_string_flag") package(default_visibility = ["//visibility:public"]) -CUDA_PLATFORMS = [ - "linux-x86_64", - "linux-sbsa", - "linux-aarch64", - "linux-ppc64le", - "windows-x86_64", -] - bzl_library( name = "bzl_srcs", srcs = glob(["*.bzl"]), @@ -54,27 +47,27 @@ string_flag( string_flag( name = "target_platform", build_setting_default = "linux-x86_64", - values = CUDA_PLATFORMS, + values = SUPPORTED_PLATFORMS, ) [ config_setting( name = "target_platform_is_{}".format(platform.replace("-", "_")), flag_values = {":target_platform": platform}, - ) for platform in CUDA_PLATFORMS + ) for platform in SUPPORTED_PLATFORMS ] string_flag( name = "exec_platform", build_setting_default = "linux-x86_64", - values = CUDA_PLATFORMS, + values = SUPPORTED_PLATFORMS, ) [ config_setting( name = "exec_platform_is_{}".format(platform.replace("-", "_")), flag_values = {":exec_platform": platform}, - ) for platform in CUDA_PLATFORMS + ) for platform in SUPPORTED_PLATFORMS ] # Command line flag to specify the list of CUDA architectures to compile for. diff --git a/cuda/defs.bzl b/cuda/defs.bzl index ad0178cc..c2d3ea1a 100644 --- a/cuda/defs.bzl +++ b/cuda/defs.bzl @@ -49,4 +49,4 @@ cc_import_versioned_sos = _cc_import_versioned_sos requires_cuda = _requires_cuda unsupported_cuda_version = _unsupported_cuda_version -unsupported_cuda_platform = _unsupported_cuda_platform \ No newline at end of file +unsupported_cuda_platform = _unsupported_cuda_platform diff --git a/cuda/extensions.bzl b/cuda/extensions.bzl index d6591f8a..86fde7fc 100644 --- a/cuda/extensions.bzl +++ b/cuda/extensions.bzl @@ -78,40 +78,6 @@ cuda_toolkit_tag = tag_class(attrs = { ), }) -platform_alias_tag = tag_class( - attrs = { - "name": attr.string( - mandatory = True, - doc = "Name of the alias repository to create", - ), - "component_name": attr.string( - mandatory = True, - doc = "Name of the component to create aliases for", - ), - "linux_x86_64_repo": attr.string( - mandatory = True, - doc = "Name of the repository to use for x86_64 platform", - ), - "linux_aarch64_repo": attr.string( - mandatory = True, - doc = "Name of the repository to use for ARM64/Jetpack platform", - ), - "linux_sbsa_repo": attr.string( - mandatory = True, - doc = "Name of the repository to use for SBSA platform", - ), - "versions": attr.string_list( - mandatory = True, - doc = "List of versions to create aliases for", - ), - }, - doc = """Defines a platform-specific alias repository. - - Each alias tag creates a repository with targets that select between - x86_64 and ARM64 repositories based on the build platform. - """, -) - def _find_modules(module_ctx): root = None our_module = None @@ -158,12 +124,10 @@ def _impl(module_ctx): components = root.tags.component redist_jsons = root.tags.redist_json toolkits = root.tags.toolkit - platform_aliases = root.tags.platform_alias else: components = rules_cuda.tags.component redist_jsons = rules_cuda.tags.redist_json toolkits = rules_cuda.tags.toolkit - platform_aliases = rules_cuda.tags.platform_alias for component in components: cuda_component(**_module_tag_to_dict(component)) @@ -233,24 +197,11 @@ def _impl(module_ctx): else: cuda_toolkit(**_module_tag_to_dict(toolkit)) - for alias_tag in platform_aliases: - # Create a repository for each alias tag - platform_alias_repo( - name = alias_tag.name, - repo_name = alias_tag.name, - component_name = alias_tag.component_name, - linux_x86_64_repo = alias_tag.linux_x86_64_repo, - linux_aarch64_repo = alias_tag.linux_aarch64_repo, - linux_sbsa_repo = alias_tag.linux_sbsa_repo, - versions = alias_tag.versions, - ) - toolchain = module_extension( implementation = _impl, tag_classes = { "component": cuda_component_tag, "redist_json": cuda_redist_json_tag, "toolkit": cuda_toolkit_tag, - "platform_alias": platform_alias_tag, }, ) diff --git a/cuda/platform_alias_extension.bzl b/cuda/platform_alias_extension.bzl index 6d7a8221..bd1742b3 100644 --- a/cuda/platform_alias_extension.bzl +++ b/cuda/platform_alias_extension.bzl @@ -4,6 +4,7 @@ This extension creates repositories with alias targets that select between x86_6 repositories based on the build platform. It also selects between versions of the component. """ +load("//cuda/private:platforms.bzl", "SUPPORTED_PLATFORMS") load("//cuda/private:templates/registry.bzl", "REGISTRY") # Use REGISTRY as the source of truth for component targets @@ -35,32 +36,22 @@ def _platform_alias_repo_impl(ctx): build_content.append('unsupported_cuda_version(name = "unsupported_cuda_version", component = "{}", available_versions = {})'.format(ctx.attr.component_name, ctx.attr.versions)) build_content.append("") - # Check if using new dict-based API or old string-based API. - use_dict_api = len(ctx.attr.linux_x86_64_repos) > 0 - # Build a target for the name of the repo (only if at least one platform is available). platform_type = "exec" if ctx.attr.component_name in ["nvcc", "nvvm"] else "target" - # Check which platforms are available. + # Check which platforms are available (have at least one version). platforms_available = [] - if use_dict_api: - if len(ctx.attr.linux_x86_64_repos) > 0: - platforms_available.append("linux-x86_64") - if len(ctx.attr.linux_sbsa_repos) > 0: - platforms_available.append("linux-sbsa") - if len(ctx.attr.linux_aarch64_repos) > 0: - platforms_available.append("linux-aarch64") - else: - if ctx.attr.linux_x86_64_repo != "": - platforms_available.append("linux-x86_64") - if ctx.attr.linux_sbsa_repo != "": - platforms_available.append("linux-sbsa") - if ctx.attr.linux_aarch64_repo != "": - platforms_available.append("linux-aarch64") - - if len(platforms_available) > 0: - build_content.append('unsupported_cuda_platform(name = "unsupported_cuda_platform", component = "{}", available_platforms = {})'.format(ctx.attr.component_name, platforms_available)) - build_content.append("") + if len(ctx.attr.linux_x86_64_repos) > 0: + platforms_available.append("linux-x86_64") + if len(ctx.attr.linux_sbsa_repos) > 0: + platforms_available.append("linux-sbsa") + if len(ctx.attr.linux_aarch64_repos) > 0: + platforms_available.append("linux-aarch64") + + # Always create unsupported_cuda_platform target - it's used as the default case + # in select() when no platform condition matches. + build_content.append('unsupported_cuda_platform(name = "unsupported_cuda_platform", component = "{}", available_platforms = {})'.format(ctx.attr.component_name, platforms_available)) + build_content.append("") # Only generate target aliases if this component is in TARGET_MAPPING. if ctx.attr.component_name not in TARGET_MAPPING: @@ -69,28 +60,12 @@ def _platform_alias_repo_impl(ctx): return for target in TARGET_MAPPING[ctx.attr.component_name]: - # Create alias for each target with platform selection (only if platforms are available). + # Create alias for each target with platform selection. + # Always add conditions for ALL platforms so that builds on any platform + # have a matching select condition. Platforms where the component doesn't + # exist will use a dummy target. target_name = target if target.find("/") == -1 else target.split("/")[-1] - if len(platforms_available) > 0: - build_content.append("alias(") - build_content.append(' name = "{}",'.format(target_name)) - build_content.append(" actual = select({") - if "linux-x86_64" in platforms_available: - build_content.append(' "@rules_cuda//cuda:{}_platform_is_linux_x86_64":'.format(platform_type)) - build_content.append(' ":linux_x86_64_{}",'.format(target_name)) - if "linux-sbsa" in platforms_available: - build_content.append(' "@rules_cuda//cuda:{}_platform_is_linux_sbsa":'.format(platform_type)) - build_content.append(' ":linux_sbsa_{}",'.format(target_name)) - if "linux-aarch64" in platforms_available: - build_content.append(' "@rules_cuda//cuda:{}_platform_is_linux_aarch64":'.format(platform_type)) - build_content.append(' ":linux_aarch64_{}",'.format(target_name)) - build_content.append(' "//conditions:default": ":unsupported_cuda_platform",') - build_content.append(" }),") - build_content.append(' visibility = ["//visibility:public"],') - build_content.append(")") - build_content.append("") - # Determine appropriate dummy target based on the target name. dummy_target = "@rules_cuda//cuda/dummy:dummy" if target_name == "cicc": @@ -98,83 +73,54 @@ def _platform_alias_repo_impl(ctx): elif target_name == "libdevice.10.bc": dummy_target = "@rules_cuda//cuda/dummy:libdevice.10.bc" - # Generate platform-specific aliases for this target (only if versions exist). - if "linux-x86_64" in platforms_available: - build_content.append("alias(") - build_content.append(' name = "linux_x86_64_{}",'.format(target_name)) - build_content.append(" actual = select({") - if use_dict_api: - for version in ctx.attr.versions: - if version in ctx.attr.linux_x86_64_repos: - repo_name = ctx.attr.linux_x86_64_repos[version] - build_content.append(' ":version_is_{}": '.format(version.replace(".", "_"))) - build_content.append(' "@{}//{}",'.format(repo_name, target if target.find(":") != -1 else ":" + target)) - else: - build_content.append(' ":version_is_{}": '.format(version.replace(".", "_"))) - build_content.append(' "{}",'.format(dummy_target)) - - # Add default for versions where this component doesn't exist. - build_content.append(' "//conditions:default": ":unsupported_cuda_version",') + build_content.append("alias(") + build_content.append(' name = "{}",'.format(target_name)) + build_content.append(" actual = select({") + # Add conditions for ALL platforms, using dummy for unavailable ones. + for platform in SUPPORTED_PLATFORMS: + platform_suffix = platform.replace("-", "_") + build_content.append(' "@rules_cuda//cuda:{}_platform_is_{}":'.format(platform_type, platform_suffix)) + if platform in platforms_available: + build_content.append(' ":{}_{}",'.format(platform_suffix, target_name)) else: - # Old string-based API for backward compatibility. - for version in ctx.attr.versions: - build_content.append(' ":version_is_{}": '.format(version.replace(".", "_"))) - build_content.append(' "@{}_{}//{}",'.format(ctx.attr.linux_x86_64_repo, version.replace(".", "_"), target if target.find(":") != -1 else ":" + target)) - build_content.append(' "//conditions:default": ":unsupported_cuda_version",') - build_content.append(" }),") - build_content.append(' visibility = ["//visibility:public"],') - build_content.append(")") - build_content.append("") + # Platform doesn't have this component, use dummy target. + build_content.append(' "{}",'.format(dummy_target)) + build_content.append(' "//conditions:default": ":unsupported_cuda_platform",') + build_content.append(" }),") + build_content.append(' visibility = ["//visibility:public"],') + build_content.append(")") + build_content.append("") - if "linux-sbsa" in platforms_available: - build_content.append("alias(") - build_content.append(' name = "linux_sbsa_{}",'.format(target_name)) - build_content.append(" actual = select({") - if use_dict_api: - for version in ctx.attr.versions: - if version in ctx.attr.linux_sbsa_repos: - repo_name = ctx.attr.linux_sbsa_repos[version] - build_content.append(' ":version_is_{}": '.format(version.replace(".", "_"))) - build_content.append(' "@{}//{}",'.format(repo_name, target if target.find(":") != -1 else ":" + target)) - else: - build_content.append(' ":version_is_{}": '.format(version.replace(".", "_"))) - build_content.append(' "{}",'.format(dummy_target)) - - # Add default for versions where this component doesn't exist. - build_content.append(' "//conditions:default": ":unsupported_cuda_version",') - else: - # Old string-based API for backward compatibility. - for version in ctx.attr.versions: - build_content.append(' ":version_is_{}": '.format(version.replace(".", "_"))) - build_content.append(' "@{}_{}//{}",'.format(ctx.attr.linux_sbsa_repo, version.replace(".", "_"), target if target.find(":") != -1 else ":" + target)) - build_content.append(' "//conditions:default": ":unsupported_cuda_version",') - build_content.append(" }),") - build_content.append(' visibility = ["//visibility:public"],') - build_content.append(")") - build_content.append("") + # Generate platform-specific aliases for ALL platforms. + # Platforms where the component exists get version-based selection. + # Platforms where it doesn't exist get dummy targets for all versions. + # This ensures builds on any platform have matching select conditions. + + platform_repos_map = { + "linux-x86_64": ctx.attr.linux_x86_64_repos, + "linux-sbsa": ctx.attr.linux_sbsa_repos, + "linux-aarch64": ctx.attr.linux_aarch64_repos, + } + + for platform in SUPPORTED_PLATFORMS: + platform_suffix = platform.replace("-", "_") + repos_dict = platform_repos_map[platform] + platform_available = platform in platforms_available - if "linux-aarch64" in platforms_available: build_content.append("alias(") - build_content.append(' name = "linux_aarch64_{}",'.format(target_name)) + build_content.append(' name = "{}_{}",'.format(platform_suffix, target_name)) build_content.append(" actual = select({") - if use_dict_api: - for version in ctx.attr.versions: - if version in ctx.attr.linux_aarch64_repos: - repo_name = ctx.attr.linux_aarch64_repos[version] - build_content.append(' ":version_is_{}": '.format(version.replace(".", "_"))) - build_content.append(' "@{}//{}",'.format(repo_name, target if target.find(":") != -1 else ":" + target)) - else: - build_content.append(' ":version_is_{}": '.format(version.replace(".", "_"))) - build_content.append(' "{}",'.format(dummy_target)) - - # Add default for versions where this component doesn't exist. - build_content.append(' "//conditions:default": ":unsupported_cuda_version",') - else: - # Old string-based API for backward compatibility. - for version in ctx.attr.versions: - build_content.append(' ":version_is_{}": '.format(version.replace(".", "_"))) - build_content.append(' "@{}_{}//{}",'.format(ctx.attr.linux_aarch64_repo, version.replace(".", "_"), target if target.find(":") != -1 else ":" + target)) - build_content.append(' "//conditions:default": ":unsupported_cuda_version",') + + for version in ctx.attr.versions: + build_content.append(' ":version_is_{}": '.format(version.replace(".", "_"))) + if platform_available and version in repos_dict: + repo_name = repos_dict[version] + build_content.append(' "@{}//{}",'.format(repo_name, target if target.find(":") != -1 else ":" + target)) + else: + # Platform doesn't have this component for this version, use dummy. + build_content.append(' "{}",'.format(dummy_target)) + build_content.append(' "//conditions:default": ":unsupported_cuda_version",') + build_content.append(" }),") build_content.append(' visibility = ["//visibility:public"],') build_content.append(")") @@ -194,7 +140,6 @@ platform_alias_repo = repository_rule( mandatory = True, doc = "Name of the component", ), - # New dict-based API (preferred). "linux_x86_64_repos": attr.string_dict( default = {}, doc = "Dictionary mapping versions to x86_64 repository names", @@ -207,19 +152,6 @@ platform_alias_repo = repository_rule( default = {}, doc = "Dictionary mapping versions to SBSA repository names", ), - # Old string-based API (for backward compatibility). - "linux_x86_64_repo": attr.string( - default = "", - doc = "Base name of the repository to use for x86_64 platform (deprecated, use linux_x86_64_repos)", - ), - "linux_aarch64_repo": attr.string( - default = "", - doc = "Base name of the repository to use for ARM64/Jetpack platform (deprecated, use linux_aarch64_repos)", - ), - "linux_sbsa_repo": attr.string( - default = "", - doc = "Base name of the repository to use for SBSA platform (deprecated, use linux_sbsa_repos)", - ), "versions": attr.string_list( mandatory = True, doc = "List of versions to create aliases for", diff --git a/cuda/private/platforms.bzl b/cuda/private/platforms.bzl new file mode 100644 index 00000000..c4816db3 --- /dev/null +++ b/cuda/private/platforms.bzl @@ -0,0 +1,9 @@ +"""Platform constants for CUDA toolkit support.""" + +# Platforms supported by the CUDA redistribution and alias system. +# These are the platforms for which we generate component aliases. +SUPPORTED_PLATFORMS = [ + "linux-x86_64", + "linux-sbsa", + "linux-aarch64", +] diff --git a/cuda/private/template_helper.bzl b/cuda/private/template_helper.bzl index 611cbabd..e5e1c2ff 100644 --- a/cuda/private/template_helper.bzl +++ b/cuda/private/template_helper.bzl @@ -30,7 +30,7 @@ def _expand_dctk_cuda(repository_ctx, components): # Filter out components with empty REGISTRY entries valid_components = [comp for comp in components.keys() if len(REGISTRY[comp]) > 0] - + all_files_srcs_line = [comp + "_all_files" for comp in valid_components] all_files_srcs_line = "srcs = " + repr(all_files_srcs_line)