Add quantized fusion patterns for dense, QKV, and SDPA#60
Open
ajroetker wants to merge 11 commits intogomlx:mainfrom
Open
Add quantized fusion patterns for dense, QKV, and SDPA#60ajroetker wants to merge 11 commits intogomlx:mainfrom
ajroetker wants to merge 11 commits intogomlx:mainfrom
Conversation
…U patterns Introduces a graph fusion framework that detects common subgraph patterns (scaled dot-product attention, QKV dense projections, dense+GELU activations) and replaces them with fused ops when the backend supports them. Adds a capability-gated fused SDPA fast path in the MultiHeadAttention op converter.
# Conflicts: # go.mod # go.sum # onnx/ops.go
Bump github.com/gomlx/gomlx from v0.26.1-0.20260211111746-dd3d906b02a6 to v0.26.1-0.20260215082710-429182c8560c.
Replace FusionType enum, FusionGroup struct, and switch-based dispatch with a FusionCandidate interface and RegisterFusionDetector registration pattern. Each fusion (SDPA, QKVDense, DenseGelu) is now self-contained with its own candidate type, detector init(), and emit method. Detection uses score-based greedy selection for non-overlapping fusions. Backend capability checks are removed since the GoMLX wrapper functions (attention.Core, attention.QKVProjection, nn.Dense) handle fused-vs- decomposed fallback internally via InternalFusedOpCaller. Fixes 4 compiler errors from undefined symbols: - FusedMultiHeadSDPA → attention.Core - FusedQKVDense → attention.QKVProjection - backends.OpTypeFusedMultiHeadSDPA → removed - backends.OpTypeFusedQKVDense → removed
Add ONNX fusion pattern detectors and emitters for quantized operations: - fusion_quantized_dense.go: Detects DynamicQuantizeLinear + MatMulInteger chains and emits nn.QuantizedDense with Int8 format - fusion_quantized_qkv.go: Merges three quantized dense projections (Q, K, V) into a single batched QuantizedDense call - fusion_quantized_sdpa.go: Detects quantized scaled dot-product attention patterns and emits BackendFusedQuantizedScaledDotProductAttention Also fixes: - convertDequantizeLinear now takes Model receiver for fusion-aware access - Add isZeroInitializer helper for detecting zero-valued ONNX tensors - Fix GQA (Grouped Query Attention) head mismatch: replicate KV heads before calling attention.Core when kvNumHeads < numHeads
- Create internal/onnxgraph package for graph helpers (BuildConsumerMap, SoleConsumer, OtherBinaryOpInput, HasExternalConsumers) - Remove redundant graph/consumers params; store consumers on Model - Consolidate shape helpers into Model.ShapeForName (new shapes.go) - Rename DenseGeluParams → DenseActivationParams - Prefix all SDPA-specific helpers with sdpa, add model-family docs - Move TensorProtoToScalar/ConstantNodeToScalar to tensor.go, use float16 package instead of manual half-float conversion - Pre-concatenate QKV weights during fusion detection, add FreeUnusedVariables method - Add tests for tensorProtoRawBytes, concatenateTensorProtos, TensorProtoToScalar, ConstantNodeToScalar, FreeUnusedVariables, and Mul-scaled SDPA pattern
Resolve go.sum conflict and update quantized fusion files to use the new interface-based fusion API: - FusionDetector signature: func(m *Model) []FusionCandidate - Helper functions moved to onnxgraph package (SoleConsumer, OtherBinaryOpInput, HasExternalConsumers) - Drop sdpa prefix from shared utility methods (tryGetConstantScalar, isMaskRankAcceptable, extractHeadCounts, matchKTranspose, etc.)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds ONNX fusion pattern detectors for quantized operations, built on the fusion framework from the
add-fusion-supportbranch:DynamicQuantizeLinear → MatMulInteger → DequantizeLinearchains and emitsnn.QuantizedDensewith Int8 format (score: 40)QuantizedDensecall (score: 70)DQL → MatMulInteger(Q,K^T) → Cast → Scale → Softmax → MatMulInteger(.,V)) and emitsBackendFusedQuantizedScaledDotProductAttention(score: 90)convertGroupQueryAttention— replicates KV heads whenkvNumHeads < numHeadsbefore callingattention.CoreisZeroInitializerhelper and updatesconvertDequantizeLinearto take Model receiverDependencies
add-fusion-supportbranch (included in this PR)QuantFormat,FusedQuantizedDense,FusedQuantizedScaledDotProductAttentionto gomlx backends)Test plan
TestGroupQueryAttention/GQA-basicnow passes (was failing due to head count mismatch)go build ./...compiles cleanly