Skip to content

Optimize JoinDims and SplitDims by canonicalizing to simpler operations (Partial fixes #1843)#1847

Open
mengxingbw wants to merge 7 commits intopymc-devs:mainfrom
mengxingbw:fix/1843-rewrite-optimizations
Open

Optimize JoinDims and SplitDims by canonicalizing to simpler operations (Partial fixes #1843)#1847
mengxingbw wants to merge 7 commits intopymc-devs:mainfrom
mengxingbw:fix/1843-rewrite-optimizations

Conversation

@mengxingbw
Copy link

@mengxingbw mengxingbw commented Jan 13, 2026

Description

This PR implements the 3 out of 4 canonicalization rewrites suggested in #1843:

  • join_dims(x, axis=axis, n_axes=1) → identity (no-op)
  • join_dims(x, axis=axis, n_axes=0)expand_dims(x, axis)
  • split_dims(x, axis=axis, shape=())squeeze(x, axis)
  • split_dims(x, axis=axis, shape=(dim,))specify_shape(...) (see Block section)

Questions

I tried to work on the last requested change:

split_dims(x, axis=axis, shape=(dim,)) -> specify_shape(x, (*[None] * axis, dim, ...))

The issue: specify_shape preserves the input's known shape when it's already concrete, so it doesn't match SplitDims's output type. If the input already has a known shape at a dimension, it uses that shape; and it only uses the specified shape when the input shape is None. This has caused the function to fail.
For this rewrite to work even when the input shape is known, I'd need to use reshape instead of specify_shape, but that defeats the purpose of using specify_shape for shape assertion.

⚠️ Question: How should I proceed?

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

Optimize JoinDims and SplitDims by canonicalizing to simpler
operations (identity, expand_dims, squeeze).

Partial fixes pymc-devs#1843
@mengxingbw mengxingbw changed the title Add canonicalization rewrites for JoinDims/SplitDims Optimize JoinDims and SplitDims by canonicalizing to simpler operations (Partial fixes #1843) Jan 13, 2026
@jessegrabowski
Copy link
Member

My guess is that ricardo meant reshape, not literally specify_shape (which you're right, just adds metadata but doesn't do any computation)

@ricardoV94
Copy link
Member

I meant split dims, when the shape argument has just one entry shape.type.shape==(1,) is equivalent to specify shape on that axis.

That's what the syntax shape=(d,) was supposed to convey. As opposed to empy shape=() which is just a squeeze

x, shape = node.inputs
axis = node.op.axis

if isinstance(shape, Constant) and shape.data.size == 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't need to be constant just static shape of zero shape.type.shape == (0,)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I would merge this with the split-to-reshape rewrite so we don't accidentally run that before this

@mbaldourw
Copy link

Thank you @jessegrabowski and @ricardoV94 for clarifying - so it sounds like we don't need split_dims(x, axis=axis, shape=(dim,)) → specify_shape(...) this function since it will fall into reshape anyways?

I have made the changes according to the comment above.

@ricardoV94
Copy link
Member

reshape should be our last resort, everything we can avoid as reshape we should

@ricardoV94
Copy link
Member

To clarify, none of the changes in this PR were strictly needed, they are an improvement over simple reshape

@mbaldourw
Copy link

To clarify, none of the changes in this PR were strictly needed, they are an improvement over simple reshape

understood. is there anything else to do with the last function:
split_dims(x, axis=axis, shape=(dim,)) → specify_shape(...)?

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making progress, needs a few more tweaks

Comment on lines 26 to 30
# Special case: empty shape -> squeeze
if shape.type.shape == (0,):
squeezed_x = squeeze(x, axis=axis)
copy_stack_trace(x, squeezed_x)
return [squeezed_x]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is duplicated, you meant the case with shape.type.shape == (1,) I presume?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed the redundant block; im not sure how to treat the shape == 1 case without calling reshape, since specify_shape won't help?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean won't help. neither split_dims nor reshape do anything in that case, that's why it's functionally equivalent to a specify shape.

Try to run some cases of such split_dims to get acquainted with the behavior.


@register_canonicalize
@node_rewriter([JoinDims])
def local_join_dims_noop(fgraph, node):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merge these join dims rewrites in a single one, like we did with SplitDims

@@ -9,11 +11,24 @@
def local_split_dims_to_reshape(fgraph, node):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we don't do only reshape, we should have a more generic name. Same for the join_dims when we merge the special cases

Suggested change
def local_split_dims_to_reshape(fgraph, node):
def local_lower_split_dims(fgraph, node):

Comment on lines 51 to 57
# After rewrite: should have 0 JoinDims nodes
assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 0
# Output should be equivalent to input (identity rewrite)
# The rewrite returns the input variable, so output should match input shape/type
assert fg.outputs[0].type.shape == x.type.shape
assert fg.outputs[0].type.dtype == x.type.dtype
assert fg.outputs[0].type.ndim == x.type.ndim
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use utt.assert_equal_computations to check we have the specific graph that we expect, not just anything without JoinDims

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This recommendation applies to all new tests

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This recommendation applies to all new tests

I can't seem to get it to pass for the first 2 tests. when i looked it up, i got "assert_equal_computations is better suited for cases where the canonical form is a specific operation (like expand_dims, squeeze, or identity) where graph structures match. For basic reshape cases, the rewrite produces a different but equivalent graph structure, so structural checks are sufficient"
Please let me know how to proceed!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest you look at the generated graph, the utility prints it when the assert fails. It shouldn't have anything too strange in it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest you look at the generated graph, the utility prints it when the assert fails. It shouldn't have anything too strange in it

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after excluding:

--------------------------------------- Captured stdout call ----------------------------------------
rewriting: rewrite local_split_dims replaces SplitDims{axis=1}.0 of SplitDims{axis=1}(x, [2 5 1]) with Reshape{5}.0 of Reshape{5}(x, MakeVector{dtype='int64'}.0)
rewriting: rewrite MergeOptimizer replaces 2 of None with 2 of None
rewriting: rewrite MergeOptimizer replaces 0 of None with 0 of None
rewriting: rewrite MergeOptimizer replaces 1 of None with 1 of None
rewriting: rewrite MergeOptimizer replaces Shape.0 of Shape(x) with Shape.0 of Shape(x)
rewriting: rewrite MergeOptimizer replaces 0 of None with 0 of None
rewriting: rewrite MergeOptimizer replaces 2 of None with 2 of None
rewriting: rewrite local_subtensor_remove_broadcastable_index replaces Subtensor{i}.0 of Subtensor{i}(Subtensor{:stop}.0, 0) with Squeeze{axis=0}.0 of Squeeze{axis=0}(Subtensor{:stop}.0)
rewriting: rewrite local_subtensor_remove_broadcastable_index replaces Subtensor{i}.0 of Subtensor{i}(Subtensor{start:}.0, 0) with Squeeze{axis=0}.0 of Squeeze{axis=0}(Subtensor{start:}.0)
rewriting: rewrite constant_folding replaces Subtensor{i}.0 of Subtensor{i}([2 5 1], 2) with 1 of None
rewriting: rewrite constant_folding replaces Subtensor{i}.0 of Subtensor{i}([2 5 1], 1) with 5 of None
rewriting: rewrite constant_folding replaces Subtensor{i}.0 of Subtensor{i}([2 5 1], 0) with 2 of None
rewriting: rewrite local_reshape_to_dimshuffle replaces Reshape{5}.0 of Reshape{5}(x, MakeVector{dtype='int64'}.0) with ExpandDims{axis=3}.0 of ExpandDims{axis=3}(Reshape{4}.0)```

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And how does the rewritten graph look like now (vs the expected)?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E       
E       Rewritten:
E       SpecifyShape [id A] <Tensor5(float64, shape=(2, 2, 5, 1, 3))>
E        ├─ ExpandDims{axis=3} [id B] <Tensor5(float64, shape=(?, 2, 5, 1, ?))>
E        │  └─ Reshape{4} [id C] <Tensor4(float64, shape=(?, 2, 5, ?))>
E        │     ├─ x [id D] <Tensor3(float64, shape=(2, 10, 3))>
E        │     └─ MakeVector{dtype='int64'} [id E] <Vector(int64, shape=(4,))>
E        │        ├─ Squeeze{axis=0} [id F] <Scalar(int64, shape=())>
E        │        │  └─ Subtensor{:stop} [id G] <Vector(int64, shape=(1,))>
E        │        │     ├─ Shape [id H] <Vector(int64, shape=(3,))>
E        │        │     │  └─ x [id D] <Tensor3(float64, shape=(2, 10, 3))>
E        │        │     └─ 1 [id I] <int64>
E        │        ├─ 2 [id J] <Scalar(int64, shape=())>
E        │        ├─ 5 [id K] <Scalar(int64, shape=())>
E        │        └─ Squeeze{axis=0} [id L] <Scalar(int64, shape=())>
E        │           └─ Subtensor{start:} [id M] <Vector(int64, shape=(1,))>
E        │              ├─ Shape [id H] <Vector(int64, shape=(3,))>
E        │              │  └─ ···
E        │              └─ 2 [id N] <int64>
E        ├─ 2 [id O] <Scalar(int8, shape=())>
E        ├─ 2 [id O] <Scalar(int8, shape=())>
E        ├─ 5 [id P] <Scalar(int8, shape=())>
E        ├─ 1 [id Q] <Scalar(int8, shape=())>
E        └─ 3 [id R] <Scalar(int8, shape=())>
E       
E       Expected:
E       ExpandDims{axis=3} [id A] <Tensor5(float64, shape=(2, 2, 5, 1, 3))>
E        └─ Reshape{4} [id B] <Tensor4(float64, shape=(2, 2, 5, 3))>
E           ├─ x [id C] <Tensor3(float64, shape=(2, 10, 3))>
E           └─ MakeVector{dtype='int64'} [id D] <Vector(int64, shape=(4,))>
E              ├─ Subtensor{i} [id E] <Scalar(int64, shape=())>
E              │  ├─ Shape [id F] <Vector(int64, shape=(3,))>
E              │  │  └─ x [id C] <Tensor3(float64, shape=(2, 10, 3))>
E              │  └─ 0 [id G] <int64>
E              ├─ Cast{int64} [id H] <Scalar(int64, shape=())>
E              │  └─ 2 [id I] <Scalar(int8, shape=())>
E              ├─ Cast{int64} [id J] <Scalar(int64, shape=())>
E              │  └─ 5 [id K] <Scalar(int8, shape=())>
E              └─ Subtensor{i} [id L] <Scalar(int64, shape=())>
E                 ├─ Shape [id F] <Vector(int64, shape=(3,))>
E                 │  └─ ···
E                 └─ 2 [id M] <int64>```

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removing local_subtensor_remove_broadcastable_index should bring you closer, and using np.int64(2|5) for the expected shape. That will get rid of the Cast thing, which comes from #1073

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

       
E       Rewritten:
E       SpecifyShape [id A] <Tensor5(float64, shape=(2, 2, 5, 1, 3))>
E        ├─ ExpandDims{axis=3} [id B] <Tensor5(float64, shape=(?, 2, 5, 1, ?))>
E        │  └─ Reshape{4} [id C] <Tensor4(float64, shape=(?, 2, 5, ?))>
E        │     ├─ x [id D] <Tensor3(float64, shape=(2, 10, 3))>
E        │     └─ MakeVector{dtype='int64'} [id E] <Vector(int64, shape=(4,))>
E        │        ├─ Subtensor{i} [id F] <Scalar(int64, shape=())>
E        │        │  ├─ Subtensor{:stop} [id G] <Vector(int64, shape=(1,))>
E        │        │  │  ├─ Shape [id H] <Vector(int64, shape=(3,))>
E        │        │  │  │  └─ x [id D] <Tensor3(float64, shape=(2, 10, 3))>
E        │        │  │  └─ 1 [id I] <int64>
E        │        │  └─ 0 [id J] <int64>
E        │        ├─ 2 [id K] <Scalar(int64, shape=())>
E        │        ├─ 5 [id L] <Scalar(int64, shape=())>
E        │        └─ Subtensor{i} [id M] <Scalar(int64, shape=())>
E        │           ├─ Subtensor{start:} [id N] <Vector(int64, shape=(1,))>
E        │           │  ├─ Shape [id H] <Vector(int64, shape=(3,))>
E        │           │  │  └─ ···
E        │           │  └─ 2 [id O] <int64>
E        │           └─ 0 [id J] <int64>
E        ├─ 2 [id P] <Scalar(int8, shape=())>
E        ├─ 2 [id P] <Scalar(int8, shape=())>
E        ├─ 5 [id Q] <Scalar(int8, shape=())>
E        ├─ 1 [id R] <Scalar(int8, shape=())>
E        └─ 3 [id S] <Scalar(int8, shape=())>
E       
E       Expected:
E       ExpandDims{axis=3} [id A] <Tensor5(float64, shape=(2, 2, 5, 1, 3))>
E        └─ Reshape{4} [id B] <Tensor4(float64, shape=(2, 2, 5, 3))>
E           ├─ x [id C] <Tensor3(float64, shape=(2, 10, 3))>
E           └─ MakeVector{dtype='int64'} [id D] <Vector(int64, shape=(4,))>
E              ├─ Subtensor{i} [id E] <Scalar(int64, shape=())>
E              │  ├─ Shape [id F] <Vector(int64, shape=(3,))>
E              │  │  └─ x [id C] <Tensor3(float64, shape=(2, 10, 3))>
E              │  └─ 0 [id G] <int64>
E              ├─ 2 [id H] <Scalar(int64, shape=())>
E              ├─ 5 [id I] <Scalar(int64, shape=())>
E              └─ Subtensor{i} [id J] <Scalar(int64, shape=())>
E                 ├─ Shape [id F] <Vector(int64, shape=(3,))>
E                 │  └─ ···
E                 └─ 2 [id K] <int64>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Specialize (join|split)_dims rewrite

4 participants