Skip to content

Comments

Refactor EL models with standardized training loop#108

Merged
ferzcam merged 3 commits intomainfrom
el
Jan 15, 2026
Merged

Refactor EL models with standardized training loop#108
ferzcam merged 3 commits intomainfrom
el

Conversation

@ferzcam
Copy link
Contributor

@ferzcam ferzcam commented Jan 15, 2026

Move generic training logic to EmbeddingELModel base class, allowing subclasses to customize behavior through hooks instead of duplicating the entire training loop.

New customization hooks in EmbeddingELModel:

  • get_negative_sampling_config(): Configure which GCIs need negatives
  • generate_negatives(): Custom negative sampling strategy
  • compute_loss(): Custom loss computation (e.g., MSE loss for ELBE)
  • get_regularization_loss(): Add regularization from module
  • get_optimizer(): Use different optimizer

API changes:

  • epochs parameter moved from init() to train()
  • Subclasses now only need to implement init_module() and override hooks as needed

Models simplified:

  • ELBE: Only overrides compute_loss() for MSE loss
  • ELEmbeddings: Uses base class entirely (has regularization)
  • BoxSquaredEL: Uses base class entirely (has regularization)

Example models updated to use new API:

  • ELBEPPI, ELBEGDA
  • ELEmPPI, ELEmGDA
  • BoxSquaredELPPI, BoxSquaredELGDA

Code reduction: ~640 lines removed, ~450 lines added (net -190 lines)

ferzcam and others added 3 commits January 15, 2026 13:51
Move generic training logic to EmbeddingELModel base class, allowing
subclasses to customize behavior through hooks instead of duplicating
the entire training loop.

New customization hooks in EmbeddingELModel:
- get_negative_sampling_config(): Configure which GCIs need negatives
- generate_negatives(): Custom negative sampling strategy
- compute_loss(): Custom loss computation (e.g., MSE loss for ELBE)
- get_regularization_loss(): Add regularization from module
- get_optimizer(): Use different optimizer

API changes:
- epochs parameter moved from __init__() to train()
- Subclasses now only need to implement init_module() and override
  hooks as needed

Models simplified:
- ELBE: Only overrides compute_loss() for MSE loss
- ELEmbeddings: Uses base class entirely (has regularization)
- BoxSquaredEL: Uses base class entirely (has regularization)

Example models updated to use new API:
- ELBEPPI, ELBEGDA
- ELEmPPI, ELEmGDA
- BoxSquaredELPPI, BoxSquaredELGDA

Code reduction: ~640 lines removed, ~450 lines added (net -190 lines)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Remove gci2 as the implicit default evaluation method in EmbeddingELModel.
Users must now explicitly set model.eval_gci_name before training with
validation or calling eval_method(), enabling evaluation on any GCI type
(gci0, gci1, gci2, gci3, etc.).

- Add eval_gci_name property with validation in EmbeddingELModel
- Update train() to require eval_gci_name when validation is used
- Update eval_method() to use configured GCI type instead of hardcoded gci2
- Update all example models (ELEmbeddings, ELBE, BoxSquaredEL) to set
  eval_gci_name="gci2" in their __init__ methods
- Update tests to explicitly set eval_gci_name

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@ferzcam ferzcam merged commit b61fc44 into main Jan 15, 2026
4 checks passed
@ferzcam ferzcam deleted the el branch January 15, 2026 12:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant