Skip to content

Commit 5bf1aeb

Browse files
authored
Do not export sampler for device==mps && dtype==bfloat16 (#208)
As titled. Seeing this error: ``` E 00:00:13.059103 executorch:et_metal.mm:246] ETMetalShaderLibrary: Failed to compile shader library: program_source:3813:29: error: assigning to 'bfloat' from incompatible type 'float' tmp_acc_2 = tmp0; ^~~~ E 00:00:13.059124 executorch:et_metal.mm:263] ETMetalShaderLibrary: Library not compiled E 00:00:13.059126 executorch:et_metal.mm:301] ETMetalShaderLibrary::getKernelFunction: Failed to get pipeline state for 'generated_kernel' E 00:00:13.059127 executorch:shim_mps.mm:105] aoti_torch_mps_get_kernel_function: Failed to get kernel function 'generated_kernel' E 00:00:13.059129 executorch:shim_mps.mm:517] aoti_torch_mps_run_command_block: null function handle ``` When running metal delegated argmax model. This PR disable the sampler export code path, let runner fallback to C++ CPU sampler.
1 parent f8aa919 commit 5bf1aeb

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

optimum/exporters/executorch/integrations.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -919,15 +919,27 @@ def export(
919919
example_cache_position,
920920
)
921921

922-
self.exported_sampler = self._export_sampler(
923-
torch.randn((1, 1, self.config.vocab_size), dtype=self.model.dtype, device=self.model.device)
924-
)
922+
# Skip sampler export for MPS + bfloat16 due to Metal shader compilation error
923+
# (assigning float to bfloat in generated shader code)
924+
is_mps_bfloat16 = str(self.model.device).startswith("mps") and self.model.dtype == torch.bfloat16
925+
if is_mps_bfloat16:
926+
logging.warning(
927+
"Skipping sampler export for MPS + bfloat16 due to Metal shader compilation issues. "
928+
"The runner will use CPU-based sampling instead."
929+
)
930+
self.exported_sampler = None
931+
else:
932+
self.exported_sampler = self._export_sampler(
933+
torch.randn((1, 1, self.config.vocab_size), dtype=self.model.dtype, device=self.model.device)
934+
)
925935

926-
return {
936+
result = {
927937
"encoder": self.exported_encoder, # Not called "text_encoder" because the encoder could be non-text too, e.g. Whisper.
928938
"text_decoder": self.exported_decoder,
929-
"sampler": self.exported_sampler,
930939
}
940+
if self.exported_sampler is not None:
941+
result["sampler"] = self.exported_sampler
942+
return result
931943

932944
def generate(self, prompt_token_ids, max_new_tokens):
933945
with torch.no_grad():

0 commit comments

Comments
 (0)