Skip to content

LoweringException: NoValidChoicesError: No choices to select. #17197

@dmmosh

Description

@dmmosh

🐛 Describe the bug

When attempting to save the model with a CUDA backend, I get an error: LoweringException: NoValidChoicesError: No choices to select. Provided reason: No choices exist for backend. please consider adding ATEN into max_autotune_gemm_backends config (defined in torch/_inductor/config.py) to allow at least one choice.

inputs = (torch.randint(ord('a'),ord('z')+1,(model.block_size,)),)
    print(inputs)
    MODEL_NAME = 'libtorch_optimized_model'
    
    # Export the model using torch.export
    # Create the CUDA partitioner
    #print(CudaBackend.generate_method_name_compile_spec("forward"))

    partitioner = CudaPartitioner(compile_spec= [
        CudaBackend.generate_method_name_compile_spec('forward'),
        CompileSpec(key="triton_kernel_mode", value=b"OFF"),
        # CompileSpec("device", "cuda:0"),
        # CompileSpec("optimization_level", "5"),
        # CompileSpec("precision", "fp16"),
        # CompileSpec("use_tf32", "True"),
        # CompileSpec("max_workspace_size", str(1<<30)) # 1GB for autotuning
        ])# compilespecs 
    
    exported_program = torch.export.export(model, inputs)
    
    edge_program=to_edge_transform_and_lower(
        exported_program,
        partitioner=[partitioner],
        compile_config=edge_compile_config)

    
    executorch_program = edge_program.to_executorch()
    
    
    # Convert to executable program and save
    filename = MODEL_NAME+'.pte'
    #with open(filename, "wb") as file:
    #    file.write(executorch_program.buffer)
    #exec_program.get_etrecord().save(MODEL_NAME+"_etrecord.bin")

    print('optimized model saved to', filename)

error:

Traceback (most recent call last):
  File "/home/wetsock/coding/lightning-search/pytorch/words/convert.py", line 117, in <module>
    edge_program=to_edge_transform_and_lower(
        exported_program,
        partitioner=[partitioner],
        compile_config=edge_compile_config)
  File "/home/wetsock/coding/lightning-search/.venv/lib/python3.13/site-packages/executorch/exir/program/_program.py", line 115, in wrapper
    return func(*args, **kwargs)
  File "/home/wetsock/coding/lightning-search/.venv/lib/python3.13/site-
...
  File "/home/wetsock/coding/lightning-search/.venv/lib/python3.13/site-packages/torch/_inductor/select_algorithm.py", line 4329, in autotune_select_algorithm
    return cache(*args, **kwargs)
  File "/home/wetsock/coding/lightning-search/.venv/lib/python3.13/site-packages/torch/_inductor/select_algorithm.py", line 2808, in __call__
    raise self.create_no_valid_choices(name, "No choices exist for backend.")
torch._inductor.exc.InductorError: LoweringException: NoValidChoicesError: No choices to select. Provided reason: No choices exist for backend. please consider adding ATEN into max_autotune_gemm_backends config (defined in torch/_inductor/config.py) to allow at least one choice. 
  target: aten.mm.default
  args[0]: TensorBox(StorageBox(
    ConstantBuffer(name='_tensor_constant0', layout=FixedLayout('cuda:0', torch.float32, size=[1, 120], stride=[120, 1]))
  ))
  args[1]: TensorBox(
    ReinterpretView(
      StorageBox(
        ConstantBuffer(name='model_0_weight', layout=FixedLayout('cuda:0', torch.float32, size=[3000, 120], stride=[120, 1]))
      ),
      FixedLayout('cuda:0', torch.float32, size=[120, 3000], stride=[1, 120]),
      origins=OrderedSet([permute]),
      stack_traces = {,
      File "<eval_with_key>.42", line 5, in forward,
          linear = torch.ops.aten.linear.default(c_model_0_lifted_tensor_0, p_model_0_weight);  c_model_0_lifted_tensor_0 = p_model_0_weight = None,
      }
    )
  )

full code:

import torch
import faulthandler
from header import * # read access
from executorch.backends.cuda.cuda_backend import CudaBackend
from executorch.backends.cuda.cuda_partitioner import CudaPartitioner
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
from executorch.examples.models import MODEL_NAME_TO_MODEL
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
import torch._inductor.config
torch._inductor.config.max_autotune_gemm_backends = "ATEN,TRITON,CUTLASS"
#torch._inductor.config.max_autotune = False

if(not os.path.exists(save_path_words)): # if the save path isnt present, exit
    os._exit()
faulthandler.enable()
#torch._inductor.config.max_autotune_gemm_backends = ["aten", "triton",'cutlass']

# set to cuda or cpu
if torch.cuda.is_available():
    torch.set_default_device('cuda')
    print("Default device set to CUDA")
else:
    torch.set_default_device('cpu')
    print("CUDA not available, default device set to CPU")
torch.serialization.add_safe_globals([torch.nn.modules.container.Sequential])   # sequential copntainer as safe
#torch.set_float32_matmul_precision('high')

checkpoint = torch.load(save_path_words)

iword = [[]]*len(checkpoint['iword'])
for i,word in checkpoint['iword'].items():
    iword[i] = word

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed = checkpoint['embed'][:]
        self.embedding_dimensions = checkpoint['embedding_dimensions'] # number of integers to repsent in n dimension space
        self.dictionary_size = checkpoint['dictionary_size']
        #print(checkpoint['iword'])
        #self.iword = checkpoint['iword']
        #self.wordi = checkpoint['wordi'] # dont need wordi , only need iword for outputting the output indeces
        self.iword = iword[:]
        self.hidden_layer_size = checkpoint['hidden_layer_size']
        self.block_size = checkpoint['block_size']
        self.stop_wordi = checkpoint['stop_wordi']
        self.stop_word=ord(checkpoint['iword'][self.stop_wordi])

        self.model = nn.Sequential(
            nn.Linear(self.embedding_dimensions*self.block_size, self.hidden_layer_size, bias=False), nn.BatchNorm1d(self.hidden_layer_size), nn.Tanh(),
            nn.Linear(self.hidden_layer_size, self.hidden_layer_size, bias=False), nn.BatchNorm1d(self.hidden_layer_size), nn.Tanh(),
            nn.Linear(self.hidden_layer_size, self.dictionary_size, bias=False), nn.BatchNorm1d(self.dictionary_size)
        )
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.eval().cuda()
        print("Model and related variables loaded successfully.")
        #self.model.eval()
        
        
    #@torch.no_grad()
    def forward(self,x):
        #print(out)
        #print(embedx.shape)
        with torch.no_grad():
            logits = self.model(embedx)
            return logits
        
model = NeuralNetwork()

if __name__ == '__main__': # convert to scriptmodule
    test = 'software'
    out = test[-model.block_size:].rjust(model.block_size,chr(model.stop_word))
    out = [ord(c) for c in out]
    print(model.named_parameters())
    
    embedx = model.embed[torch.tensor(out,dtype=torch.long)]
    embedx = embedx.view(1,model.embedding_dimensions*model.block_size)
    logits = model.forward(embedx)
    probs = torch.softmax(logits, dim=1)
    _, top_indices = torch.topk(probs, k=10, dim=1)
    print([model.iword[i] for i in top_indices.tolist()[0]])
    print()    
    
    
    print("available backends:",MODEL_NAME_TO_MODEL.keys()) 
    print(model)
    
        # Configure edge compilation
    edge_compile_config = EdgeCompileConfig(
        _check_ir_validity=False,
        _skip_dim_order=True,
    )
    
    inputs = (torch.randint(ord('a'),ord('z')+1,(model.block_size,)),)
    print(inputs)
    MODEL_NAME = 'libtorch_optimized_model'
    
    # Export the model using torch.export
    # Create the CUDA partitioner
    #print(CudaBackend.generate_method_name_compile_spec("forward"))

    partitioner = CudaPartitioner(compile_spec= [
        CudaBackend.generate_method_name_compile_spec('forward'),
        CompileSpec(key="triton_kernel_mode", value=b"OFF"),
        # CompileSpec("device", "cuda:0"),
        # CompileSpec("optimization_level", "5"),
        # CompileSpec("precision", "fp16"),
        # CompileSpec("use_tf32", "True"),
        # CompileSpec("max_workspace_size", str(1<<30)) # 1GB for autotuning
        ])# compilespecs 
    
    exported_program = torch.export.export(model, inputs)
    
    edge_program=to_edge_transform_and_lower(
        exported_program,
        partitioner=[partitioner],
        compile_config=edge_compile_config)

    
    executorch_program = edge_program.to_executorch()
    
    
    # Convert to executable program and save
    filename = MODEL_NAME+'.pte'
    #with open(filename, "wb") as file:
    #    file.write(executorch_program.buffer)
    #exec_program.get_etrecord().save(MODEL_NAME+"_etrecord.bin")

    print('optimized model saved to', filename)
    # MODEL_NAME = 'libtorch_ac_model.pt'
    # MODEL_TRACED_NAME = 'libtorch_traced_model.pt'
    
    # model_scripted = torch.jit.script(model)
    # traced_model = torch.jit.trace(model.forward, out)
    
    # model_scripted.save(MODEL_NAME)
    # print('script model saved to',MODEL_NAME)
    # traced_model.save(MODEL_TRACED_NAME)
    # print('traced model saved to',MODEL_TRACED_NAME)

Versions

--2026-02-04 02:16:30-- https://raw.githubusercontent.com/pytorch/pytorch/main/torch/utils/collect_env.py
Loaded CA certificate '/etc/ssl/certs/ca-certificates.crt'
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.108.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 31107 (30K) [text/plain]
Saving to: ‘collect_env.py’

collect_env.py 100%[===========>] 30.38K --.-KB/s in 0.01s

2026-02-04 02:16:30 (2.04 MB/s) - ‘collect_env.py’ saved [31107/31107]

Collecting environment information...
PyTorch version: 2.10.0+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A

OS: Arch Linux (x86_64)
GCC version: (GCC) 15.2.1 20260103
Clang version: 21.1.6
CMake version: version 4.2.3
Libc version: glibc-2.42

Python version: 3.13.11 (main, Feb 3 2026, 22:12:53) [GCC 15.2.1 20260103] (64-bit runtime)
Python platform: Linux-6.18.7-arch1-1-x86_64-with-glibc2.42
Is CUDA available: True
CUDA runtime version: 13.1.115
CUDA_MODULE_LOADING set to:
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 5060
Nvidia driver version: 590.48.01
cuDNN version: Probably one of the following:
/usr/lib/libcudnn.so.9.18.1
/usr/lib/libcudnn_adv.so.9.18.1
/usr/lib/libcudnn_cnn.so.9.18.1
/usr/lib/libcudnn_engines_precompiled.so.9.18.1
/usr/lib/libcudnn_engines_runtime_compiled.so.9.18.1
/usr/lib/libcudnn_graph.so.9.18.1
/usr/lib/libcudnn_heuristic.so.9.18.1
/usr/lib/libcudnn_ops.so.9.18.1
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 39 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 20
On-line CPU(s) list: 0-19
Vendor ID: GenuineIntel
Model name: Intel(R) Core(TM) i9-10900KF CPU @ 3.70GHz
CPU family: 6
Model: 165
Thread(s) per core: 2
Core(s) per socket: 10
Socket(s): 1
Stepping: 5
CPU(s) scaling MHz: 69%
CPU max MHz: 5300.0000
CPU min MHz: 800.0000
BogoMIPS: 7399.70
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx rdseed adx smap clflushopt intel_pt xsaveopt xsavec xgetbv1 xsaves dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp vnmi pku ospke md_clear flush_l1d arch_capabilities
Virtualization: VT-x
L1d cache: 320 KiB (10 instances)
L1i cache: 320 KiB (10 instances)
L2 cache: 2.5 MiB (10 instances)
L3 cache: 20 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-19
Vulnerability Gather data sampling: Mitigation; Microcode
Vulnerability Ghostwrite: Not affected
Vulnerability Indirect target selection: Mitigation; Aligned branch/return thunks
Vulnerability Itlb multihit: KVM: Mitigation: Split huge pages
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Old microcode: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds: Mitigation; Microcode
Vulnerability Tsa: Not affected
Vulnerability Tsx async abort: Not affected
Vulnerability Vmscape: Mitigation; IBPB before exit to userspace

Versions of relevant libraries:
[pip3] executorch==1.1.0
[pip3] numpy==2.4.2
[pip3] nvidia-cublas-cu12==12.8.4.1
[pip3] nvidia-cuda-cupti-cu12==12.8.90
[pip3] nvidia-cuda-nvrtc-cu12==12.8.93
[pip3] nvidia-cuda-runtime-cu12==12.8.90
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cufft-cu12==11.3.3.83
[pip3] nvidia-curand-cu12==10.3.9.90
[pip3] nvidia-cusolver-cu12==11.7.3.90
[pip3] nvidia-cusparse-cu12==12.5.8.93
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.27.5
[pip3] nvidia-nvjitlink-cu12==12.8.93
[pip3] nvidia-nvtx-cu12==12.8.90
[pip3] pytorch_tokenizers==1.1.0
[pip3] torch==2.10.0
[pip3] torchao==0.15.0
[pip3] triton==3.6.0
[conda] Could not collect

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions