Skip to content

Add quantized fusion patterns for dense, QKV, and SDPA#60

Open
ajroetker wants to merge 11 commits intogomlx:mainfrom
ajroetker:quantized-fused-ops
Open

Add quantized fusion patterns for dense, QKV, and SDPA#60
ajroetker wants to merge 11 commits intogomlx:mainfrom
ajroetker:quantized-fused-ops

Conversation

@ajroetker
Copy link
Contributor

Summary

Adds ONNX fusion pattern detectors for quantized operations, built on the fusion framework from the add-fusion-support branch:

  • Quantized Dense: Detects DynamicQuantizeLinear → MatMulInteger → DequantizeLinear chains and emits nn.QuantizedDense with Int8 format (score: 40)
  • Quantized QKV: Merges three quantized dense Q/K/V projections into a single batched QuantizedDense call (score: 70)
  • Quantized SDPA: Detects full quantized attention pattern (DQL → MatMulInteger(Q,K^T) → Cast → Scale → Softmax → MatMulInteger(.,V)) and emits BackendFusedQuantizedScaledDotProductAttention (score: 90)
  • Fixes GQA head mismatch in convertGroupQueryAttention — replicates KV heads when kvNumHeads < numHeads before calling attention.Core
  • Adds isZeroInitializer helper and updates convertDequantizeLinear to take Model receiver

Dependencies

  • Requires the fusion framework from add-fusion-support branch (included in this PR)
  • Requires gomlx#350 (adds QuantFormat, FusedQuantizedDense, FusedQuantizedScaledDotProductAttention to gomlx backends)

Test plan

  • All existing fusion tests pass (SDPA, DenseGelu, QKVDense detection + integration)
  • All existing onnx op tests pass
  • TestGroupQueryAttention/GQA-basic now passes (was failing due to head count mismatch)
  • Full go build ./... compiles cleanly

…U patterns

Introduces a graph fusion framework that detects common subgraph patterns
(scaled dot-product attention, QKV dense projections, dense+GELU activations)
and replaces them with fused ops when the backend supports them. Adds a
capability-gated fused SDPA fast path in the MultiHeadAttention op converter.
# Conflicts:
#	go.mod
#	go.sum
#	onnx/ops.go
Bump github.com/gomlx/gomlx from v0.26.1-0.20260211111746-dd3d906b02a6
to v0.26.1-0.20260215082710-429182c8560c.
Replace FusionType enum, FusionGroup struct, and switch-based dispatch
with a FusionCandidate interface and RegisterFusionDetector registration
pattern. Each fusion (SDPA, QKVDense, DenseGelu) is now self-contained
with its own candidate type, detector init(), and emit method.

Detection uses score-based greedy selection for non-overlapping fusions.
Backend capability checks are removed since the GoMLX wrapper functions
(attention.Core, attention.QKVProjection, nn.Dense) handle fused-vs-
decomposed fallback internally via InternalFusedOpCaller.

Fixes 4 compiler errors from undefined symbols:
- FusedMultiHeadSDPA → attention.Core
- FusedQKVDense → attention.QKVProjection
- backends.OpTypeFusedMultiHeadSDPA → removed
- backends.OpTypeFusedQKVDense → removed
Add ONNX fusion pattern detectors and emitters for quantized operations:

- fusion_quantized_dense.go: Detects DynamicQuantizeLinear + MatMulInteger
  chains and emits nn.QuantizedDense with Int8 format
- fusion_quantized_qkv.go: Merges three quantized dense projections (Q, K, V)
  into a single batched QuantizedDense call
- fusion_quantized_sdpa.go: Detects quantized scaled dot-product attention
  patterns and emits BackendFusedQuantizedScaledDotProductAttention

Also fixes:
- convertDequantizeLinear now takes Model receiver for fusion-aware access
- Add isZeroInitializer helper for detecting zero-valued ONNX tensors
- Fix GQA (Grouped Query Attention) head mismatch: replicate KV heads
  before calling attention.Core when kvNumHeads < numHeads
- Create internal/onnxgraph package for graph helpers (BuildConsumerMap,
  SoleConsumer, OtherBinaryOpInput, HasExternalConsumers)
- Remove redundant graph/consumers params; store consumers on Model
- Consolidate shape helpers into Model.ShapeForName (new shapes.go)
- Rename DenseGeluParams → DenseActivationParams
- Prefix all SDPA-specific helpers with sdpa, add model-family docs
- Move TensorProtoToScalar/ConstantNodeToScalar to tensor.go, use
  float16 package instead of manual half-float conversion
- Pre-concatenate QKV weights during fusion detection, add
  FreeUnusedVariables method
- Add tests for tensorProtoRawBytes, concatenateTensorProtos,
  TensorProtoToScalar, ConstantNodeToScalar, FreeUnusedVariables,
  and Mul-scaled SDPA pattern
Resolve go.sum conflict and update quantized fusion files to use the new
interface-based fusion API:
- FusionDetector signature: func(m *Model) []FusionCandidate
- Helper functions moved to onnxgraph package (SoleConsumer,
  OtherBinaryOpInput, HasExternalConsumers)
- Drop sdpa prefix from shared utility methods (tryGetConstantScalar,
  isMaskRankAcceptable, extractHeadCounts, matchKTranspose, etc.)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant