feat: add Mod op, dtype promotion, and If improvements for Florence-2#62
Open
ajroetker wants to merge 8 commits intogomlx:mainfrom
Open
feat: add Mod op, dtype promotion, and If improvements for Florence-2#62ajroetker wants to merge 8 commits intogomlx:mainfrom
ajroetker wants to merge 8 commits intogomlx:mainfrom
Conversation
Add ONNX fusion pattern detectors and emitters for quantized operations: - fusion_quantized_dense.go: Detects DynamicQuantizeLinear + MatMulInteger chains and emits nn.QuantizedDense with Int8 format - fusion_quantized_qkv.go: Merges three quantized dense projections (Q, K, V) into a single batched QuantizedDense call - fusion_quantized_sdpa.go: Detects quantized scaled dot-product attention patterns and emits BackendFusedQuantizedScaledDotProductAttention Also fixes: - convertDequantizeLinear now takes Model receiver for fusion-aware access - Add isZeroInitializer helper for detecting zero-valued ONNX tensors - Fix GQA (Grouped Query Attention) head mismatch: replicate KV heads before calling attention.Core when kvNumHeads < numHeads
# Conflicts: # go.mod # go.sum
- Add Mod operator with fmod=0 (Python-style) and fmod=1 (C-style) - Add ensureFloat() to promote integers for unary math ops (Sqrt, Exp, Log, etc.) - Add dtype alignment in Concat when dtype promotion is enabled - Loosen isVariableConstant to accept float variables with "const" in name - Fix sub-graph name shadowing: save/restore parent entries instead of deleting - Rework convertIf to use GoMLX native If with closures instead of Where - Add static condition resolution in convertIf via tryMaterializeBool - Isolate branch closure convertedOutputs maps to prevent cross-contamination
…ction - Always promote integer inputs to float for float-only ops (Sqrt, Exp, etc.) regardless of allowDTypePromotion setting - Cast Pow(x, ±0.5) results back to original integer dtype - Guard Ceil/Floor to be identity on integer inputs per ONNX spec - Check both Mul inputs for scalar constant (commutative) - Bump dependencies (go 1.26, tokenizers, onnxruntime_go, x/ packages)
Resolve go.sum conflict and update quantized fusion files to use the new interface-based fusion API: - FusionDetector signature: func(m *Model) []FusionCandidate - Helper functions moved to onnxgraph package (SoleConsumer, OtherBinaryOpInput, HasExternalConsumers) - Drop sdpa prefix from shared utility methods (tryGetConstantScalar, isMaskRankAcceptable, extractHeadCounts, matchKTranspose, etc.)
Resolve conflicts in go.mod (keep newer deps from florence2-ops, add float16 as direct dependency), go.sum (regenerated via go mod tidy), and fusion_sdpa.go (keep non-prefixed extractScaleFromMul).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds several ops and fixes needed to run Florence-2 models.
fmod=1(C-style, sign follows dividend) andfmod=0(Python-style, sign follows divisor) with broadcasting and dtype promotion.Sqrt,Exp,Log,Erf,Tanh,Sin,Cos,Sigmoid, andPowspecial cases now promote integer inputs to Float32 when dtype promotion is enabled, matching ONNX Runtime behavior.isVariableConstantloosening: Float variables with "const" in the name are now accepted as materializable constants (needed when Concat dtype promotion casts Float32 constants to Int64).convertSubGraphnow saves and restores parent entries innodeOutputToNode/variableNameToValueinstead of unconditionally deleting them on cleanup.convertIfrework: Uses GoMLX's nativeIfwith closures instead of theWhere-based approach, so only the taken branch executes at runtime. Includes static condition resolution viatryMaterializeBoolto skip dead branches entirely when the condition is a compile-time constant. Each branch gets an isolated copy ofconvertedOutputsto prevent cross-contamination during graph tracing.Test plan
TestModcovers fmod=0, fmod=1, int, float, and broadcasting casesgo test ./onnx/...)