Skip to content

feat: add Einsum op, WithConstantVariables, and bug fixes#64

Open
ajroetker wants to merge 6 commits intogomlx:mainfrom
ajroetker:feature/new-ops-and-fixes
Open

feat: add Einsum op, WithConstantVariables, and bug fixes#64
ajroetker wants to merge 6 commits intogomlx:mainfrom
ajroetker:feature/new-ops-and-fixes

Conversation

@ajroetker
Copy link
Contributor

@ajroetker ajroetker commented Mar 2, 2026

Summary

New Features

  • Einsum op: Add support for ONNX Einsum with 2 operands
  • WithConstantVariables: New option to embed all model variables as graph constants, required for models that use variables in ops needing compile-time values (e.g., Range for rotary embeddings)
  • DynamicDim support: Propagate DynamicDim through broadcast and slice ops — detect DynamicDim axes, preserve axis names for the specialization system, and fall through to existing concrete-dimension logic when all dims are static

Bug Fixes

  • CumSum: Use adjustedAxis consistently (CumSum supports negative axes but other calls in the same function were using the raw axis)
  • RotaryEmbedding: Add bounds check requiring at least 4 inputs before accessing inputs[2] and inputs[3]
  • GroupQueryAttention: Add rank check before accessing Dimensions[2] on past KV inputs to prevent index-out-of-range on malformed tensors
  • LayerNormalization: Only allocate biasShape when bias is non-nil, and use bias's own dimensions instead of scale dimensions
  • ScatterND: Fix error message printing data rank instead of indices rank
  • ConstantOfShape: Make value attribute optional per ONNX spec, defaulting to scalar float32 zero when absent
  • LSTM: Fix copy-paste bug reading activation_alpha for activationBeta instead of activation_beta; fix activaitons typo
  • MultiHeadAttention: Fix 2D mask reshape producing (1,batch,kv_seq,1) instead of (batch,1,1,kv_seq) by using Reshape directly
  • convertIf: Use m.onnxWhere for dtype promotion between then/else branches instead of raw Where which panics on mismatched dtypes; panic with descriptive message when If node produces no outputs instead of returning nil
  • QLinearMatMul: Fix nil pointer dereference when yZeroPoint is nil (.DType() was called before the nil check); collapse redundant if/else-if branches for zero-point subtraction
  • convertSubGraph: Fix overwriting shared model maps (nodeOutputToNode, variableNameToValue) without saving/restoring originals — sub-graph names colliding with main graph names could corrupt state; thread *context.Context through so sub-graph nodes can resolve model variables correctly
  • materialize.go: Fix varDesc = append(opsDesc, ...) copy-paste bug that left opsDesc always empty and overwrote varDesc in error messages
  • prettyprint: Fix format verb for shapes.Shape (%d%v)

Simplifications

  • Remove redundant savedVariableExists/savedNodeExists maps — use Go's comma-ok idiom on the saved maps directly since nil is never a valid value

Test plan

  • Existing tests pass
  • Test Einsum with 2-operand models
  • Test WithConstantVariables with rotary embedding models
  • Test models that exercise convertIf/convertSubGraph paths

Add Einsum op support (2-operand), WithConstantVariables option for
inference-only workloads where variables must be compile-time constants,
and fix several bugs found during code review:

- Fix nil pointer dereference in onnxQLinearMatMul when yZeroPoint is nil
- Collapse redundant if/else-if branches for zero-point subtraction
- Fix convertIf returning nil instead of panicking on zero outputs
- Fix convertSubGraph overwriting shared model maps without restoring originals
- Fix prettyprint format verb for shapes.Shape (%d -> %v)
…ertSubGraph

- Fix opsDesc assignment bug in materialize.go: `varDesc = append(opsDesc, ...)`
  was assigning to the wrong variable, leaving opsDesc always empty and
  overwriting varDesc in error messages.
- Thread *context.Context through convertIf -> convertSubGraph -> convertNode
  so sub-graph nodes that depend on model variables can resolve them correctly
  instead of panicking on nil context.
Use Go's comma-ok idiom on the saved maps directly instead of maintaining
separate boolean maps to track key existence. Nil is never a valid value
in variableNameToValue or nodeOutputToNode, so checking the saved map
entry suffices. Also adds a concurrency-safety note.
- ScatterND: fix error message printing data rank (r) instead of indices rank (q)
- ConstantOfShape: make 'value' attribute optional per ONNX spec, defaulting
  to scalar float32 zero when absent
- LSTM: fix copy-paste bug reading "activation_alpha" for activationBeta
  instead of "activation_beta", and fix "activaitons" typo
- MultiHeadAttention: fix 2D mask reshape producing (1,batch,kv_seq,1)
  instead of (batch,1,1,kv_seq) by using Reshape directly
- LayerNormalization: use bias tensor's own dimensions for biasShape
  instead of reusing scale dimensions
- convertIf: use m.onnxWhere for dtype promotion between then/else branches
  instead of raw Where which panics on mismatched dtypes
… LayerNorm bias

- CumSum: use adjustedAxis consistently (CumSum supports negative axes
  but other calls in the same function use adjustedAxis)
- RotaryEmbedding: add bounds check requiring at least 4 inputs before
  accessing inputs[2] and inputs[3]
- GroupQueryAttention: add rank check before accessing Dimensions[2]
  on past KV inputs to prevent index-out-of-range on malformed tensors
- LayerNormalization: only allocate biasShape when bias is non-nil,
  and use bias's own dimensions instead of scale dimensions
Prepare onnxBroadcastToCommonShape and convertSlice for future dynamic
dimension support in gomlx. Both functions now detect DynamicDim axes,
preserve axis names for the specialization system, and fall through to
existing concrete-dimension logic when all dims are static.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant