Skip to content

Commit 910b9ee

Browse files
authored
Bump transformers, improve SDPA, reduce Optimum reliance (#188)
Bump transformers to v5.0.0rc1
1 parent d1ac176 commit 910b9ee

File tree

16 files changed

+106
-54
lines changed

16 files changed

+106
-54
lines changed

.github/workflows/test_models.yml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,17 @@ jobs:
5252
python-version: ${{ matrix.python-version }}
5353
- name: Install dependencies for ExecuTorch
5454
run: |
55+
# Clean up cache to save space
56+
pip cache purge || true
57+
rm -rf ~/.cache/huggingface/hub/* || true
58+
5559
if [ "${{ matrix.executorch-version }}" == "nightly" ]; then
5660
python install_dev.py
5761
else
58-
pip install '.[dev]'
59-
pip install executorch==${{ matrix.executorch-version }}
62+
# Use CPU-only torch to avoid CUDA dependencies (saves ~5GB)
63+
pip install --no-cache-dir '.[dev]' \
64+
--extra-index-url https://download.pytorch.org/whl/cpu
65+
pip install --no-cache-dir executorch==${{ matrix.executorch-version }}
6066
fi
6167
pip list
6268
- name: Run tests

install_dev.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
def install_torch_nightly_deps():
77
"""Install torch related dependencies from pinned nightly"""
8-
EXECUTORCH_NIGHTLY_VERSION = "dev20251003"
8+
EXECUTORCH_NIGHTLY_VERSION = "dev20251104"
99
TORCHAO_NIGHTLY_VERSION = "dev20251104"
1010
# Torch nightly is aligned with pinned nightly in https://github.com/pytorch/executorch/blob/main/torch_pin.py#L2
1111
TORCH_NIGHTLY_VERSION = "dev20251104"
@@ -15,6 +15,7 @@ def install_torch_nightly_deps():
1515
"-m",
1616
"pip",
1717
"install",
18+
"--no-cache-dir", # Prevent cached CUDA packages
1819
f"executorch==1.1.0.{EXECUTORCH_NIGHTLY_VERSION}",
1920
f"torch==2.10.0.{TORCH_NIGHTLY_VERSION}",
2021
f"torchvision==0.25.0.{TORCH_NIGHTLY_VERSION}",
@@ -34,7 +35,7 @@ def install_dep_from_source():
3435
"-m",
3536
"pip",
3637
"install",
37-
"git+https://github.com/huggingface/transformers@91393fe4cc3266a05bc0d129e34ff5f761bb46e2#egg=transformers", # 4.56.1
38+
"git+https://github.com/huggingface/transformers@bdc85cb85c8772d37aa29ce447860b44d7fad6ef#egg=transformers", # v5.0.0rc0
3839
]
3940
)
4041
subprocess.check_call(
@@ -58,13 +59,13 @@ def main():
5859
)
5960
args = parser.parse_args()
6061

61-
# Install package with dev extras
62-
subprocess.check_call([sys.executable, "-m", "pip", "install", ".[dev]"])
63-
64-
# Install nightly dependencies
62+
# Install nightly torch dependencies FIRST to avoid pulling CUDA versions
6563
if not args.skip_override_torch:
6664
install_torch_nightly_deps()
6765

66+
# Install package with dev extras
67+
subprocess.check_call([sys.executable, "-m", "pip", "install", ".[dev]"])
68+
6869
# Install source dependencies
6970
install_dep_from_source()
7071

optimum/commands/export/executorch.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from pathlib import Path
1818
from typing import TYPE_CHECKING
1919

20-
from ...exporters import TasksManager
20+
from transformers.pipelines import get_supported_tasks
21+
2122
from ..base import BaseOptimumCLICommand, CommandInfo
2223

2324

@@ -46,7 +47,7 @@ def parse_args_executorch(parser):
4647
default="text-generation",
4748
help=(
4849
"The task to export the model for. Available tasks depend on the model, but are among:"
49-
f" {str(TasksManager.get_all_tasks())}."
50+
f" {str(get_supported_tasks())}."
5051
),
5152
)
5253
required_group.add_argument(

optimum/commands/register/register_export.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from ..export import ExportCommand
16-
from ..export.executorch import ExecuTorchExportCommand
15+
from optimum.commands.export.base import ExportCommand
16+
from optimum.commands.export.executorch import ExecuTorchExportCommand
1717

1818

1919
REGISTER_COMMANDS = [(ExecuTorchExportCommand, ExportCommand)]

optimum/executorch/attentions/custom_kv_cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ def __init__(
4545
device=device,
4646
dtype=dtype,
4747
)
48-
num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
49-
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
48+
head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
49+
num_heads = getattr(config, "num_key_value_heads", None) or config.num_attention_heads
5050
self.early_initialization(
5151
batch_size=max_batch_size, num_heads=num_heads, head_dim=head_dim, dtype=dtype, device=device
5252
)

optimum/executorch/attentions/custom_sdpa.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,59 @@
1818
from executorch.extension.llm.custom_ops.custom_ops import custom_sdpa # noqa
1919

2020

21+
def sdpa_mask_passthrough(
22+
batch_size: int,
23+
cache_position: torch.Tensor,
24+
kv_length: int,
25+
kv_offset: int = 0,
26+
mask_function: Optional[Callable] = None,
27+
attention_mask: Optional[torch.Tensor] = None,
28+
local_size: Optional[int] = None,
29+
allow_is_causal_skip: bool = True,
30+
allow_torch_fix: bool = True,
31+
**kwargs,
32+
) -> Optional[torch.Tensor]:
33+
"""
34+
Pass-through for attention mask creation since it is never used:
35+
- For regular attention, the custom sdpa op in causal mode creates its own attention mask
36+
- For sliding window attention, the attention mask from the attention mask API is ditched and re-created during the attention API since it needs to know about cache internals
37+
38+
Additionally, there were some vmap export issues with sliding window attention mask creation in Transformers.
39+
40+
Args:
41+
batch_size (`int`):
42+
The batch size of the input sequence.
43+
cache_position (`torch.Tensor`):
44+
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
45+
kv_length (`int`):
46+
The size that the key and value states will have during the attention computation.
47+
kv_offset (`int`, optional):
48+
An optional offset to indicate at which first position the key and values states will refer to.
49+
mask_function (`Callable`):
50+
The mask factory function describing the mask pattern.
51+
attention_mask (`torch.Tensor`, optional):
52+
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
53+
local_size (`int`, optional):
54+
The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True`
55+
to try to skip mask creation if possible.
56+
allow_is_causal_skip (`bool`, optional):
57+
Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in
58+
`torch.sdpa` instead. Default to `True`.
59+
allow_torch_fix (`bool`, optional):
60+
Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
61+
versions. We need an arg to skip it when using eager. By default `True`.
62+
63+
"""
64+
return None
65+
66+
2167
def custom_sdpa_with_start_pos_forward(
2268
module: torch.nn.Module,
2369
query: torch.Tensor,
2470
key: torch.Tensor,
2571
value: torch.Tensor,
2672
attention_mask: Union[torch.Tensor, "BlockMask"], # noqa
73+
position_ids: Optional[torch.Tensor] = None,
2774
scaling: Optional[float] = None,
2875
softcap: Optional[float] = None,
2976
head_mask: Optional[torch.Tensor] = None,
@@ -56,10 +103,10 @@ def custom_sdpa_with_start_pos_forward(
56103
# Calculate the input pos from attention mask.
57104
# Branch out for float vs bool mask
58105
# assert attention_mask.dim() == 2, f"attention_mask must be a 2D matrix."
59-
attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1])
60-
first_row_mask = attention_mask[0, :]
61-
# [0, 0, 0, 0, -inf, -inf, -inf, -inf], start_pos = 3
62-
start_pos = torch.argmin(first_row_mask.to(torch.long)).item() - 1
106+
assert (
107+
position_ids is not None
108+
), "position_ids must be provided to find start position for causal attention"
109+
start_pos = position_ids[0][0].item()
63110
else:
64111
start_pos = 0
65112

@@ -95,6 +142,7 @@ def _custom_sdpa_for_ring_kv_cache(
95142
key: torch.Tensor,
96143
value: torch.Tensor,
97144
attention_mask: Union[torch.Tensor, "BlockMask"], # noqa
145+
position_ids: Optional[torch.Tensor] = None,
98146
scaling: Optional[float] = None,
99147
softcap: Optional[float] = None,
100148
head_mask: Optional[torch.Tensor] = None,
@@ -122,6 +170,7 @@ def _custom_sdpa_for_ring_kv_cache(
122170
key,
123171
value,
124172
attention_mask,
173+
position_ids,
125174
scaling,
126175
softcap,
127176
head_mask,
@@ -134,6 +183,7 @@ def _custom_sdpa_for_ring_kv_cache(
134183
key,
135184
value,
136185
attention_mask,
186+
position_ids,
137187
scaling,
138188
softcap,
139189
head_mask,

optimum/executorch/modeling.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from typing import Dict, List, Optional, Union
2424

2525
import torch
26-
from huggingface_hub import hf_hub_download
26+
from huggingface_hub import hf_hub_download, is_offline_mode
2727
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
2828
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa
2929
from transformers import (
@@ -34,25 +34,22 @@
3434
AutoModelForSeq2SeqLM,
3535
AutoModelForSpeechSeq2Seq,
3636
PreTrainedTokenizer,
37-
add_start_docstrings,
3837
)
3938
from transformers.configuration_utils import PretrainedConfig
39+
from transformers.pipelines import get_task
4040
from transformers.processing_utils import ProcessorMixin
41-
from transformers.utils import is_offline_mode
4241

4342
from executorch.extension.pybindings.portable_lib import (
4443
ExecuTorchModule,
4544
_load_for_executorch,
4645
)
4746
from executorch.kernels import quantized # noqa
4847

49-
from ..exporters import TasksManager
5048
from ..exporters.executorch import main_export
5149
from ..exporters.executorch.utils import (
5250
process_conversation_inputs,
5351
verify_eos_tokens_in_pretrained_tokenizer,
5452
)
55-
from ..modeling_base import FROM_PRETRAINED_START_DOCSTRING, OptimizedModel
5653
from ..utils.file_utils import find_files_matching_pattern
5754
from .stats import Stats
5855

@@ -63,7 +60,7 @@
6360
logger = logging.getLogger(__name__)
6461

6562

66-
class ExecuTorchModelBase(OptimizedModel, ABC):
63+
class ExecuTorchModelBase(ABC):
6764
"""
6865
ExecuTorch model for inference using the ExecuTorch Runtime.
6966
@@ -99,8 +96,6 @@ def __init__(
9996
models: Dict[str, "ExecuTorchModule"],
10097
config: "PretrainedConfig",
10198
):
102-
super().__init__(model=None, config=config)
103-
10499
if self.__class__.auto_model_class is None:
105100
raise ValueError(
106101
f"Class {self.__class__.__name__} must set auto_model_class. "
@@ -268,6 +263,7 @@ def _export(
268263
cls,
269264
model_id: str,
270265
recipe: str,
266+
task: Optional[str] = None,
271267
config: Optional[PretrainedConfig] = None,
272268
token: Optional[Union[bool, str]] = None,
273269
revision: Optional[str] = None,
@@ -278,9 +274,8 @@ def _export(
278274
local_files_only: bool = False,
279275
**kwargs,
280276
) -> Dict[str, "ExecuTorchModule"]:
281-
task = kwargs.pop("task", None)
282-
inferred_task = TasksManager.infer_task_from_model(cls.auto_model_class) if not task else task
283-
logging.info(f"Inferred task from model class: {inferred_task}")
277+
inferred_task = get_task(model_id) if not task else task
278+
logging.info(f"Using task: {inferred_task}")
284279

285280
save_dir = TemporaryDirectory(prefix="executorch_export_")
286281
save_dir_path = Path(save_dir.name)
@@ -316,7 +311,6 @@ def _save_pretrained(self, save_directory):
316311
raise NotImplementedError
317312

318313
@classmethod
319-
@add_start_docstrings(FROM_PRETRAINED_START_DOCSTRING)
320314
def from_pretrained(
321315
cls,
322316
model_id: Union[str, Path],

optimum/exporters/executorch/convert.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919
from pathlib import Path
2020
from typing import Union
2121

22-
from transformers.integrations.executorch import sdpa_mask_without_vmap
23-
from transformers.masking_utils import AttentionMaskInterface
22+
from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS, AttentionMaskInterface
2423
from transformers.modeling_utils import AttentionInterface
2524

2625
from optimum.executorch.attentions.custom_sdpa import custom_sdpa_with_start_pos_forward
@@ -29,7 +28,7 @@
2928

3029

3130
AttentionInterface.register("custom_sdpa", custom_sdpa_with_start_pos_forward)
32-
AttentionMaskInterface.register("custom_sdpa", sdpa_mask_without_vmap)
31+
AttentionMaskInterface.register("custom_sdpa", ALL_MASK_ATTENTION_FUNCTIONS["sdpa"])
3332

3433

3534
def export_to_executorch(

optimum/exporters/executorch/integrations.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,11 @@
3131
)
3232
from transformers.integrations.executorch import (
3333
TorchExportableModuleForDecoderOnlyLM,
34-
sdpa_mask_without_vmap,
3534
)
3635
from transformers.masking_utils import AttentionMaskInterface
3736
from transformers.modeling_utils import AttentionInterface
3837

39-
from optimum.executorch.attentions.custom_sdpa import get_custom_sdpa_for_ring_kv_cache
38+
from optimum.executorch.attentions.custom_sdpa import get_custom_sdpa_for_ring_kv_cache, sdpa_mask_passthrough
4039

4140
from .utils import apply_chat_template_with_fallback, save_config_to_constant_methods
4241

@@ -212,7 +211,7 @@ def __init__(
212211
additional_metadata_kwargs[f"{modality}_token_id"] = getattr(self.config, "image_token_id")
213212
self.metadata = save_config_to_constant_methods(
214213
config=model.config.text_config,
215-
generation_config=model.generation_config,
214+
generation_config=getattr(model, "generation_config", None),
216215
processor_config=processor_config,
217216
get_max_seq_len=max_seq_len,
218217
**additional_metadata_kwargs,
@@ -269,7 +268,7 @@ def _register_custom_attention(self, exportable_module: torch.nn.Module):
269268
if self.use_custom_sdpa:
270269
if self.use_custom_kv_cache:
271270
AttentionInterface.register("custom_sdpa_ring_kv_cache", _custom_sdpa_for_ring_kv_cache)
272-
AttentionMaskInterface.register("custom_sdpa_ring_kv_cache", sdpa_mask_without_vmap)
271+
AttentionMaskInterface.register("custom_sdpa_ring_kv_cache", sdpa_mask_passthrough)
273272
# Manually set the attention implementation to custom_sdpa_ring_kv_cache
274273
# This handles both regular sdpa and one for sliding window/local attention
275274
exportable_module.model.model.config._attn_implementation = "custom_sdpa_ring_kv_cache"
@@ -425,7 +424,7 @@ def __init__(
425424
self.disable_dynamic_shapes = disable_dynamic_shapes
426425
self.metadata = save_config_to_constant_methods(
427426
model.config,
428-
model.generation_config,
427+
generation_config=getattr(model, "generation_config", None),
429428
get_max_seq_len=max_seq_len,
430429
enable_dynamic_shape=not self.disable_dynamic_shapes,
431430
)
@@ -455,7 +454,7 @@ def _prepare_export_inputs(self):
455454

456455
if not self.disable_dynamic_shapes and not is_using_hybrid_cache_wo_custom_sdpa_kv_cache:
457456
# Prepare inputs with dynamic shapes
458-
seq_length = 3 # Sequence length > 1 to avoid specialization issues
457+
seq_length = 3 # Sequence length > 1 to avoid specialization issue
459458
example_input_ids = torch.zeros((1, seq_length), dtype=torch.long, device=self.model.device)
460459
example_cache_position = torch.arange(seq_length, dtype=torch.long, device=self.model.device)
461460
max_seq_len = self.metadata.get("get_max_seq_len")
@@ -471,15 +470,14 @@ def _prepare_export_inputs(self):
471470
return example_input_ids, example_cache_position, dynamic_shapes, strict
472471

473472
def _register_custom_attention(self, exportable_module: torch.nn.Module):
474-
from transformers.integrations.executorch import sdpa_mask_without_vmap
475473
from transformers.masking_utils import AttentionMaskInterface
476474
from transformers.modeling_utils import AttentionInterface
477475

478476
if self.use_custom_sdpa:
479477
if self.use_custom_kv_cache:
480478
_custom_sdpa_for_ring_kv_cache = get_custom_sdpa_for_ring_kv_cache(exportable_module)
481479
AttentionInterface.register("custom_sdpa_ring_kv_cache", _custom_sdpa_for_ring_kv_cache)
482-
AttentionMaskInterface.register("custom_sdpa_ring_kv_cache", sdpa_mask_without_vmap)
480+
AttentionMaskInterface.register("custom_sdpa_ring_kv_cache", sdpa_mask_passthrough)
483481
# Manually set the attention implementation to custom_sdpa_ring_kv_cache
484482
# This handles both regular sdpa and one for sliding window/local attention
485483
exportable_module.model.model.config._attn_implementation = "custom_sdpa_ring_kv_cache"
@@ -554,7 +552,7 @@ def __init__(self, model):
554552
self.model = model
555553
self.config = model.config
556554
# Metadata to be recorded in the pte model file
557-
self.metadata = save_config_to_constant_methods(model.config, model.generation_config)
555+
self.metadata = save_config_to_constant_methods(model.config, getattr(model, "generation_config", None))
558556

559557
def forward(self, pixel_values):
560558
print(f"DEBUG: pixel_values: {pixel_values.shape}")
@@ -593,7 +591,7 @@ def __init__(self, model):
593591
self.model = model
594592
self.config = model.config
595593
# Metadata to be recorded in the pte model file
596-
self.metadata = save_config_to_constant_methods(model.config, model.generation_config)
594+
self.metadata = save_config_to_constant_methods(model.config, getattr(model, "generation_config", None))
597595

598596
def forward(self, input_ids, attention_mask):
599597
return self.model(input_ids, attention_mask)

0 commit comments

Comments
 (0)