Skip to content

feat: add Mod op, dtype promotion, and If improvements for Florence-2#62

Open
ajroetker wants to merge 8 commits intogomlx:mainfrom
ajroetker:feature/florence2-ops
Open

feat: add Mod op, dtype promotion, and If improvements for Florence-2#62
ajroetker wants to merge 8 commits intogomlx:mainfrom
ajroetker:feature/florence2-ops

Conversation

@ajroetker
Copy link
Contributor

Summary

Adds several ops and fixes needed to run Florence-2 models.

Depends on: #56 and #61

  • Mod operator: Supports both fmod=1 (C-style, sign follows dividend) and fmod=0 (Python-style, sign follows divisor) with broadcasting and dtype promotion.
  • Integer-to-float promotion for unary math ops: Sqrt, Exp, Log, Erf, Tanh, Sin, Cos, Sigmoid, and Pow special cases now promote integer inputs to Float32 when dtype promotion is enabled, matching ONNX Runtime behavior.
  • Concat dtype alignment: When dtype promotion is enabled, all Concat operands are cast to the first operand's dtype, preserving Int64 for shape/index tensors.
  • isVariableConstant loosening: Float variables with "const" in the name are now accepted as materializable constants (needed when Concat dtype promotion casts Float32 constants to Int64).
  • Sub-graph name shadowing fix: convertSubGraph now saves and restores parent entries in nodeOutputToNode / variableNameToValue instead of unconditionally deleting them on cleanup.
  • convertIf rework: Uses GoMLX's native If with closures instead of the Where-based approach, so only the taken branch executes at runtime. Includes static condition resolution via tryMaterializeBool to skip dead branches entirely when the condition is a compile-time constant. Each branch gets an isolated copy of convertedOutputs to prevent cross-contamination during graph tracing.

Test plan

  • New TestMod covers fmod=0, fmod=1, int, float, and broadcasting cases
  • Existing tests pass (go test ./onnx/...)

Add ONNX fusion pattern detectors and emitters for quantized operations:

- fusion_quantized_dense.go: Detects DynamicQuantizeLinear + MatMulInteger
  chains and emits nn.QuantizedDense with Int8 format
- fusion_quantized_qkv.go: Merges three quantized dense projections (Q, K, V)
  into a single batched QuantizedDense call
- fusion_quantized_sdpa.go: Detects quantized scaled dot-product attention
  patterns and emits BackendFusedQuantizedScaledDotProductAttention

Also fixes:
- convertDequantizeLinear now takes Model receiver for fusion-aware access
- Add isZeroInitializer helper for detecting zero-valued ONNX tensors
- Fix GQA (Grouped Query Attention) head mismatch: replicate KV heads
  before calling attention.Core when kvNumHeads < numHeads
- Add Mod operator with fmod=0 (Python-style) and fmod=1 (C-style)
- Add ensureFloat() to promote integers for unary math ops (Sqrt, Exp, Log, etc.)
- Add dtype alignment in Concat when dtype promotion is enabled
- Loosen isVariableConstant to accept float variables with "const" in name
- Fix sub-graph name shadowing: save/restore parent entries instead of deleting
- Rework convertIf to use GoMLX native If with closures instead of Where
- Add static condition resolution in convertIf via tryMaterializeBool
- Isolate branch closure convertedOutputs maps to prevent cross-contamination
…ction

- Always promote integer inputs to float for float-only ops (Sqrt, Exp,
  etc.) regardless of allowDTypePromotion setting
- Cast Pow(x, ±0.5) results back to original integer dtype
- Guard Ceil/Floor to be identity on integer inputs per ONNX spec
- Check both Mul inputs for scalar constant (commutative)
- Bump dependencies (go 1.26, tokenizers, onnxruntime_go, x/ packages)
Resolve go.sum conflict and update quantized fusion files to use the new
interface-based fusion API:
- FusionDetector signature: func(m *Model) []FusionCandidate
- Helper functions moved to onnxgraph package (SoleConsumer,
  OtherBinaryOpInput, HasExternalConsumers)
- Drop sdpa prefix from shared utility methods (tryGetConstantScalar,
  isMaskRankAcceptable, extractHeadCounts, matchKTranspose, etc.)
Resolve conflicts in go.mod (keep newer deps from florence2-ops, add
float16 as direct dependency), go.sum (regenerated via go mod tidy),
and fusion_sdpa.go (keep non-prefixed extractScaleFromMul).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant