From a0205c99b52725d1cc2e122eaf052941d9d5a533 Mon Sep 17 00:00:00 2001 From: Taylor Howell Date: Fri, 23 Jan 2026 15:27:53 +0000 Subject: [PATCH 1/7] Add tree structure fields to Model for island discovery Model fields: - ntree: number of kinematic trees - tree_dofadr: start address of tree's DOFs (ntree,) - tree_dofnum: number of DOFs in tree (ntree,) - tree_bodynum: number of bodies in tree (ntree,) - body_treeid: id of body's tree; -1 for static (nbody,) - dof_treeid: id of dof's tree (nv,) All fields are imported directly from mjModel via the dataclass field matching pattern in put_model(). Test: test_tree_structure_fields verifies all fields match MuJoCo. --- mujoco_warp/_src/io_test.py | 12 ++++++++++++ mujoco_warp/_src/types.py | 12 ++++++++++++ 2 files changed, 24 insertions(+) diff --git a/mujoco_warp/_src/io_test.py b/mujoco_warp/_src/io_test.py index 91f81611e..4885b6e36 100644 --- a/mujoco_warp/_src/io_test.py +++ b/mujoco_warp/_src/io_test.py @@ -645,6 +645,18 @@ def test_eq_active(self, active, make_data): _assert_eq(d.eq_active.numpy()[0], mjd.eq_active, "eq_active") + def test_tree_structure_fields(self): + """Tests that tree structure fields match between types.Model and mjModel.""" + mjm, _, m, _ = test_data.fixture("pendula.xml") + + # verify fields match MuJoCo + for field in ["ntree", "tree_dofadr", "tree_dofnum", "tree_bodynum", "body_treeid", "dof_treeid"]: + m_val = getattr(m, field) + mjm_val = getattr(mjm, field) + if isinstance(m_val, wp.array): + m_val = m_val.numpy() + np.testing.assert_array_equal(m_val, mjm_val, err_msg=f"mismatch: {field}") + if __name__ == "__main__": wp.init() diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py index 78aef0b8a..dde674e58 100644 --- a/mujoco_warp/_src/types.py +++ b/mujoco_warp/_src/types.py @@ -749,6 +749,7 @@ class Model: nbody: number of bodies noct: number of total octree cells in all meshes njnt: number of joints + ntree: number of kinematic trees nM: number of non-zeros in sparse inertia matrix nC: number of non-zeros in sparse body-dof matrix ngeom: number of geoms @@ -794,6 +795,7 @@ class Model: body_jntadr: start addr of joints; -1: no joints (nbody,) body_dofnum: number of motion degrees of freedom (nbody,) body_dofadr: start addr of dofs; -1: no dofs (nbody,) + body_treeid: id of body's tree; -1: static (nbody,) body_geomnum: number of geoms (nbody,) body_geomadr: start addr of geoms; -1: no geoms (nbody,) body_pos: position offset rel. to parent body (*, nbody, 3) @@ -828,6 +830,7 @@ class Model: dof_bodyid: id of dof's body (nv,) dof_jntid: id of dof's joint (nv,) dof_parentid: id of dof's parent; -1: none (nv,) + dof_treeid: id of dof's tree (nv,) dof_Madr: dof address in M-diagonal (nv,) dof_solref: constraint solver reference: frictionloss (*, nv, NREF) dof_solimp: constraint solver impedance: frictionloss (*, nv, NIMP) @@ -835,6 +838,9 @@ class Model: dof_armature: dof armature inertia/mass (*, nv) dof_damping: damping coefficient (*, nv) dof_invweight0: diag. inverse inertia in qpos0 (*, nv) + tree_bodynum: number of bodies in tree (incl. root) (ntree,) + tree_dofadr: start address of tree's dofs (ntree,) + tree_dofnum: number of dofs in tree (ntree,) geom_type: geometric type (GeomType) (ngeom,) geom_contype: geom contact type (ngeom,) geom_conaffinity: geom contact affinity (ngeom,) @@ -1100,6 +1106,7 @@ class Model: nbody: int noct: int njnt: int + ntree: int nM: int nC: int ngeom: int @@ -1145,6 +1152,7 @@ class Model: body_jntadr: array("nbody", int) body_dofnum: array("nbody", int) body_dofadr: array("nbody", int) + body_treeid: array("nbody", int) body_geomnum: array("nbody", int) body_geomadr: array("nbody", int) body_pos: array("*", "nbody", wp.vec3) @@ -1179,6 +1187,7 @@ class Model: dof_bodyid: array("nv", int) dof_jntid: array("nv", int) dof_parentid: array("nv", int) + dof_treeid: array("nv", int) dof_Madr: array("nv", int) dof_solref: array("*", "nv", wp.vec2) dof_solimp: array("*", "nv", vec5) @@ -1186,6 +1195,9 @@ class Model: dof_armature: array("*", "nv", float) dof_damping: array("*", "nv", float) dof_invweight0: array("*", "nv", float) + tree_bodynum: array("ntree", int) + tree_dofadr: array("ntree", int) + tree_dofnum: array("ntree", int) geom_type: array("ngeom", int) geom_contype: array("ngeom", int) geom_conaffinity: array("ngeom", int) From 15ac55d75eb6b76bce155e2838ab33dae9492517 Mon Sep 17 00:00:00 2001 From: Taylor Howell Date: Fri, 23 Jan 2026 16:57:25 +0000 Subject: [PATCH 2/7] Add edge discovery kernel for constraint islands - _find_tree_edges kernel scans Jacobian for tree-tree connections - On-device deduplication using radix sort and prefix sum compaction - Helper kernels: _compute_edge_keys, _init_indices, _mark_unique_edges, _compact_edges - find_tree_edges host function orchestrates the full pipeline - 5 tests: single edge, self-edge, chain, deduplication, no constraints --- mujoco_warp/_src/island.py | 259 ++++++++++++++++++++++++++++++++ mujoco_warp/_src/island_test.py | 184 +++++++++++++++++++++++ 2 files changed, 443 insertions(+) create mode 100644 mujoco_warp/_src/island.py create mode 100644 mujoco_warp/_src/island_test.py diff --git a/mujoco_warp/_src/island.py b/mujoco_warp/_src/island.py new file mode 100644 index 000000000..442602993 --- /dev/null +++ b/mujoco_warp/_src/island.py @@ -0,0 +1,259 @@ +# Copyright 2026 The Newton Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import warp as wp + +from mujoco_warp._src import types + + +@wp.kernel +def _find_tree_edges( + # Model: + nv: int, + dof_treeid: wp.array(dtype=int), + # Data in: + nefc_in: wp.array(dtype=int), + efc_type_in: wp.array2d(dtype=int), + efc_id_in: wp.array2d(dtype=int), + efc_J_in: wp.array3d(dtype=float), + njmax_in: int, + # Out: + edges_out: wp.array2d(dtype=int), + nedge_out: wp.array(dtype=int), +): + worldid, efcid = wp.tid() + + # skip if beyond active constraints + if efcid >= wp.min(njmax_in, nefc_in[worldid]): + return + + # skip continuation rows (same constraint type and id as previous) + # this avoids duplicate edges from multi-row constraints (e.g., 3D contacts) + if efcid > 0: + if efc_type_in[worldid, efcid] == efc_type_in[worldid, efcid - 1]: + if efc_id_in[worldid, efcid] == efc_id_in[worldid, efcid - 1]: + return + + # collect trees touched by this constraint row + prev_tree = int(-2) # -1 is valid (static), so use -2 as sentinel + first_tree = int(-2) + + for dof in range(nv): + J_val = efc_J_in[worldid, efcid, dof] + + if J_val != 0.0: + tree = dof_treeid[dof] + + if first_tree == -2: + first_tree = tree + prev_tree = tree + elif tree != prev_tree and tree >= 0: + # found a new tree, add edge between prev_tree and tree + if prev_tree >= 0: + idx = wp.atomic_add(nedge_out, 0, 1) + if idx < njmax_in: + # store as ordered pair for deduplication + t1 = wp.min(prev_tree, tree) + t2 = wp.max(prev_tree, tree) + edges_out[idx, 0] = t1 + edges_out[idx, 1] = t2 + prev_tree = tree + + # add self-edge if only one tree found (constrained to itself) + if first_tree >= 0 and prev_tree == first_tree: + idx = wp.atomic_add(nedge_out, 0, 1) + if idx < njmax_in: + edges_out[idx, 0] = first_tree + edges_out[idx, 1] = first_tree + + +@wp.kernel +def _compute_edge_keys( + # Model: + ntree: int, + # In: + nedge_in: wp.array(dtype=int), + edges_in: wp.array2d(dtype=int), + # Out: + keys_out: wp.array(dtype=int), +): + """Compute sort keys for edges: t1 * ntree + t2.""" + i = wp.tid() + if i >= nedge_in[0]: + keys_out[i] = 2147483647 # max int, sort to end + return + keys_out[i] = edges_in[i, 0] * ntree + edges_in[i, 1] + + +@wp.kernel +def _init_indices( + # Out: + indices_out: wp.array(dtype=int), +): + """Initialize indices to [0, 1, 2, ...].""" + i = wp.tid() + indices_out[i] = i + + +@wp.kernel +def _mark_unique_edges( + # In: + nedge_in: wp.array(dtype=int), + sorted_indices_in: wp.array(dtype=int), + edges_in: wp.array2d(dtype=int), + # Out: + unique_mask_out: wp.array(dtype=int), +): + """Mark unique edges after sorting.""" + i = wp.tid() + n = nedge_in[0] + if i >= n: + unique_mask_out[i] = 0 + return + + if i == 0: + unique_mask_out[i] = 1 + return + + idx = sorted_indices_in[i] + prev_idx = sorted_indices_in[i - 1] + + # check if different from previous + if edges_in[idx, 0] != edges_in[prev_idx, 0] or edges_in[idx, 1] != edges_in[prev_idx, 1]: + unique_mask_out[i] = 1 + else: + unique_mask_out[i] = 0 + + +@wp.kernel +def _compact_edges( + # In: + nedge_in: wp.array(dtype=int), + sorted_indices_in: wp.array(dtype=int), + unique_prefix_in: wp.array(dtype=int), + unique_mask_in: wp.array(dtype=int), + edges_in: wp.array2d(dtype=int), + # Out: + edges_out: wp.array2d(dtype=int), + nedge_out: wp.array(dtype=int), +): + """Compact edges using prefix sum of unique mask.""" + i = wp.tid() + n = nedge_in[0] + if i >= n: + return + + if unique_mask_in[i] == 1: + # exclusive prefix sum gives destination index + dst = unique_prefix_in[i] + src = sorted_indices_in[i] + edges_out[dst, 0] = edges_in[src, 0] + edges_out[dst, 1] = edges_in[src, 1] + + # last unique element writes the count + if i == n - 1 or unique_mask_in[i + 1] == 0: + # count is prefix + 1 for this element + pass + + # the last thread writes the unique count + if i == n - 1: + nedge_out[0] = unique_prefix_in[n - 1] + unique_mask_in[n - 1] + + +def find_tree_edges( + m: types.Model, + d: types.Data, +) -> tuple[wp.array, wp.array]: + """Find tree-tree edges from the constraint Jacobian. + + Args: + m: The model containing kinematic and dynamic information. + d: The data object containing the current state and output arrays. + + Returns: + Tuple of (edges, nedge) arrays on device. + edges has shape (njmax, 2) where each row is an ordered (t1, t2) pair. + nedge is a (1,) array with the number of unique edges. + """ + # allocate outputs + edges = wp.zeros((d.njmax, 2), dtype=int) + nedge = wp.zeros(1, dtype=int) + + # find edges + wp.launch( + kernel=_find_tree_edges, + dim=(d.nworld, d.njmax), + inputs=[ + m.nv, + m.dof_treeid, + d.nefc, + d.efc.type, + d.efc.id, + d.efc.J, + d.njmax, + edges, + nedge, + ], + ) + + # compute sort keys (need 2x size for radix sort scratch space) + keys = wp.zeros(2 * d.njmax, dtype=int) + wp.launch( + kernel=_compute_edge_keys, + dim=d.njmax, + inputs=[m.ntree, nedge, edges, keys], + ) + + # sort by keys using Warp's radix sort + # radix_sort_pairs sorts (keys, values) pairs - values must be initialized to indices + sorted_indices = wp.empty(2 * d.njmax, dtype=int) + wp.launch( + kernel=_init_indices, + dim=2 * d.njmax, + inputs=[sorted_indices], + ) + wp.utils.radix_sort_pairs(keys, sorted_indices, count=d.njmax) + del keys + + # mark unique edges + unique_mask = wp.zeros(d.njmax, dtype=int) + wp.launch( + kernel=_mark_unique_edges, + dim=d.njmax, + inputs=[nedge, sorted_indices, edges, unique_mask], + ) + + # prefix sum for compaction addresses + unique_prefix = wp.zeros(d.njmax, dtype=int) + wp.utils.array_scan(unique_mask, unique_prefix, inclusive=False) + + # compact edges + edges_unique = wp.zeros((d.njmax, 2), dtype=int) + nedge_unique = wp.zeros(1, dtype=int) + wp.launch( + kernel=_compact_edges, + dim=d.njmax, + inputs=[ + nedge, + sorted_indices, + unique_prefix, + unique_mask, + edges, + edges_unique, + nedge_unique, + ], + ) + + return edges_unique, nedge_unique diff --git a/mujoco_warp/_src/island_test.py b/mujoco_warp/_src/island_test.py new file mode 100644 index 000000000..e85a68a96 --- /dev/null +++ b/mujoco_warp/_src/island_test.py @@ -0,0 +1,184 @@ +# Copyright 2026 The Newton Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for island discovery.""" + +import warp as wp +from absl.testing import absltest + +import mujoco_warp as mjwarp +from mujoco_warp import test_data +from mujoco_warp._src import island + + +class IslandEdgeDiscoveryTest(absltest.TestCase): + """Tests for edge discovery from constraint Jacobian.""" + + def test_single_constraint_two_trees(self): + """A single weld constraint between two bodies creates one edge.""" + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + + + + + + + + """ + ) + + # run forward to populate constraints + mjwarp.forward(m, d) + + # find edges + edges, nedge = island.find_tree_edges(m, d) + + # should have exactly 1 edge between tree 0 and tree 1 + self.assertEqual(nedge.numpy()[0], 1) + edge_np = edges.numpy() + self.assertEqual(edge_np[0, 0], 0) # tree 0 + self.assertEqual(edge_np[0, 1], 1) # tree 1 + + def test_constraint_within_single_tree_creates_self_edge(self): + """A constraint within a single tree creates a self-edge.""" + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + + + + + + + + """ + ) + + mjwarp.forward(m, d) + edges, nedge = island.find_tree_edges(m, d) + + # should have exactly 1 self-edge for tree 0 + self.assertEqual(nedge.numpy()[0], 1) + edge_np = edges.numpy() + self.assertEqual(edge_np[0, 0], 0) # tree 0 + self.assertEqual(edge_np[0, 1], 0) # tree 0 (self-edge) + + def test_three_bodies_chain(self): + """Three bodies with constraints A-B and B-C should have 2 edges.""" + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + + + + + + + + + + + + + """ + ) + + mjwarp.forward(m, d) + edges, nedge = island.find_tree_edges(m, d) + + # should have 2 edges: (0,1) and (1,2) + n = nedge.numpy()[0] + self.assertEqual(n, 2) + edge_np = edges.numpy()[:n] + edges_set = set(tuple(e) for e in edge_np) + self.assertIn((0, 1), edges_set) + self.assertIn((1, 2), edges_set) + + def test_deduplication(self): + """Repeated constraints between same trees should be deduplicated.""" + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + + + + + + + + + """ + ) + + mjwarp.forward(m, d) + edges, nedge = island.find_tree_edges(m, d) + + # should have 1 unique edge (0,1) despite 2 constraints + self.assertEqual(nedge.numpy()[0], 1) + + def test_no_constraints(self): + """No constraints should produce no edges.""" + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + """ + ) + + mjwarp.forward(m, d) + edges, nedge = island.find_tree_edges(m, d) + + self.assertEqual(nedge.numpy()[0], 0) + + +if __name__ == "__main__": + wp.init() + absltest.main() From 5ccd02eb05be6e5548e79c7f62df98ef12cf9661 Mon Sep 17 00:00:00 2001 From: Taylor Howell Date: Tue, 3 Feb 2026 09:57:08 +0000 Subject: [PATCH 3/7] kernel fusion --- mujoco_warp/_src/island.py | 132 ++++++++++--------------------------- 1 file changed, 36 insertions(+), 96 deletions(-) diff --git a/mujoco_warp/_src/island.py b/mujoco_warp/_src/island.py index 442602993..ee6008460 100644 --- a/mujoco_warp/_src/island.py +++ b/mujoco_warp/_src/island.py @@ -80,97 +80,63 @@ def _find_tree_edges( @wp.kernel -def _compute_edge_keys( - # Model: +def _compute_keys_and_indices( ntree: int, - # In: nedge_in: wp.array(dtype=int), edges_in: wp.array2d(dtype=int), - # Out: + njmax: int, keys_out: wp.array(dtype=int), -): - """Compute sort keys for edges: t1 * ntree + t2.""" - i = wp.tid() - if i >= nedge_in[0]: - keys_out[i] = 2147483647 # max int, sort to end - return - keys_out[i] = edges_in[i, 0] * ntree + edges_in[i, 1] - - -@wp.kernel -def _init_indices( - # Out: indices_out: wp.array(dtype=int), ): - """Initialize indices to [0, 1, 2, ...].""" + """Compute sort keys and initialize indices. Launch: 2 * njmax.""" i = wp.tid() - indices_out[i] = i - -@wp.kernel -def _mark_unique_edges( - # In: - nedge_in: wp.array(dtype=int), - sorted_indices_in: wp.array(dtype=int), - edges_in: wp.array2d(dtype=int), - # Out: - unique_mask_out: wp.array(dtype=int), -): - """Mark unique edges after sorting.""" - i = wp.tid() - n = nedge_in[0] - if i >= n: - unique_mask_out[i] = 0 - return + # Always init index + indices_out[i] = i - if i == 0: - unique_mask_out[i] = 1 + # Only compute keys for first njmax elements + if i >= njmax: return - - idx = sorted_indices_in[i] - prev_idx = sorted_indices_in[i - 1] - - # check if different from previous - if edges_in[idx, 0] != edges_in[prev_idx, 0] or edges_in[idx, 1] != edges_in[prev_idx, 1]: - unique_mask_out[i] = 1 + if i >= nedge_in[0]: + keys_out[i] = 2147483647 # sort to end else: - unique_mask_out[i] = 0 + keys_out[i] = edges_in[i, 0] * ntree + edges_in[i, 1] @wp.kernel -def _compact_edges( - # In: +def _deduplicate_edges( nedge_in: wp.array(dtype=int), sorted_indices_in: wp.array(dtype=int), - unique_prefix_in: wp.array(dtype=int), - unique_mask_in: wp.array(dtype=int), edges_in: wp.array2d(dtype=int), - # Out: edges_out: wp.array2d(dtype=int), nedge_out: wp.array(dtype=int), ): - """Compact edges using prefix sum of unique mask.""" + """Mark unique and compact in one pass using atomics. Launch: njmax.""" i = wp.tid() n = nedge_in[0] + if i >= n: return - if unique_mask_in[i] == 1: - # exclusive prefix sum gives destination index - dst = unique_prefix_in[i] + # Check if this edge is unique (different from previous) + is_unique = int(0) + if i == 0: + is_unique = 1 + else: + idx = sorted_indices_in[i] + prev_idx = sorted_indices_in[i - 1] + if edges_in[idx, 0] != edges_in[prev_idx, 0]: + is_unique = 1 + elif edges_in[idx, 1] != edges_in[prev_idx, 1]: + is_unique = 1 + + if is_unique == 1: + # Atomic add to get output index + dst = wp.atomic_add(nedge_out, 0, 1) src = sorted_indices_in[i] edges_out[dst, 0] = edges_in[src, 0] edges_out[dst, 1] = edges_in[src, 1] - # last unique element writes the count - if i == n - 1 or unique_mask_in[i + 1] == 0: - # count is prefix + 1 for this element - pass - - # the last thread writes the unique count - if i == n - 1: - nedge_out[0] = unique_prefix_in[n - 1] + unique_mask_in[n - 1] - def find_tree_edges( m: types.Model, @@ -208,52 +174,26 @@ def find_tree_edges( ], ) - # compute sort keys (need 2x size for radix sort scratch space) + # compute sort keys and init indices (fused) keys = wp.zeros(2 * d.njmax, dtype=int) - wp.launch( - kernel=_compute_edge_keys, - dim=d.njmax, - inputs=[m.ntree, nedge, edges, keys], - ) - - # sort by keys using Warp's radix sort - # radix_sort_pairs sorts (keys, values) pairs - values must be initialized to indices sorted_indices = wp.empty(2 * d.njmax, dtype=int) wp.launch( - kernel=_init_indices, + kernel=_compute_keys_and_indices, dim=2 * d.njmax, - inputs=[sorted_indices], + inputs=[m.ntree, nedge, edges, d.njmax, keys, sorted_indices], ) + + # sort by keys using Warp's radix sort wp.utils.radix_sort_pairs(keys, sorted_indices, count=d.njmax) del keys - # mark unique edges - unique_mask = wp.zeros(d.njmax, dtype=int) - wp.launch( - kernel=_mark_unique_edges, - dim=d.njmax, - inputs=[nedge, sorted_indices, edges, unique_mask], - ) - - # prefix sum for compaction addresses - unique_prefix = wp.zeros(d.njmax, dtype=int) - wp.utils.array_scan(unique_mask, unique_prefix, inclusive=False) - - # compact edges + # deduplicate edges (fused, no prefix sum) edges_unique = wp.zeros((d.njmax, 2), dtype=int) nedge_unique = wp.zeros(1, dtype=int) wp.launch( - kernel=_compact_edges, + kernel=_deduplicate_edges, dim=d.njmax, - inputs=[ - nedge, - sorted_indices, - unique_prefix, - unique_mask, - edges, - edges_unique, - nedge_unique, - ], + inputs=[nedge, sorted_indices, edges, edges_unique, nedge_unique], ) return edges_unique, nedge_unique From 2b78ca9fa305a4d61b63bab4aafa1112ed8b4b26 Mon Sep 17 00:00:00 2001 From: Taylor Howell Date: Tue, 3 Feb 2026 10:53:23 +0000 Subject: [PATCH 4/7] per-world fix --- mujoco_warp/_src/island.py | 106 +++++++++++++++++++------------------ 1 file changed, 55 insertions(+), 51 deletions(-) diff --git a/mujoco_warp/_src/island.py b/mujoco_warp/_src/island.py index ee6008460..7f64b9272 100644 --- a/mujoco_warp/_src/island.py +++ b/mujoco_warp/_src/island.py @@ -29,9 +29,9 @@ def _find_tree_edges( efc_id_in: wp.array2d(dtype=int), efc_J_in: wp.array3d(dtype=float), njmax_in: int, - # Out: - edges_out: wp.array2d(dtype=int), - nedge_out: wp.array(dtype=int), + # Out (per-world): + edges_out: wp.array3d(dtype=int), # (nworld, njmax, 2) + nedge_out: wp.array(dtype=int), # (nworld,) ): worldid, efcid = wp.tid() @@ -62,58 +62,58 @@ def _find_tree_edges( elif tree != prev_tree and tree >= 0: # found a new tree, add edge between prev_tree and tree if prev_tree >= 0: - idx = wp.atomic_add(nedge_out, 0, 1) + idx = wp.atomic_add(nedge_out, worldid, 1) if idx < njmax_in: # store as ordered pair for deduplication t1 = wp.min(prev_tree, tree) t2 = wp.max(prev_tree, tree) - edges_out[idx, 0] = t1 - edges_out[idx, 1] = t2 + edges_out[worldid, idx, 0] = t1 + edges_out[worldid, idx, 1] = t2 prev_tree = tree # add self-edge if only one tree found (constrained to itself) if first_tree >= 0 and prev_tree == first_tree: - idx = wp.atomic_add(nedge_out, 0, 1) + idx = wp.atomic_add(nedge_out, worldid, 1) if idx < njmax_in: - edges_out[idx, 0] = first_tree - edges_out[idx, 1] = first_tree + edges_out[worldid, idx, 0] = first_tree + edges_out[worldid, idx, 1] = first_tree @wp.kernel def _compute_keys_and_indices( ntree: int, - nedge_in: wp.array(dtype=int), - edges_in: wp.array2d(dtype=int), + nedge_in: wp.array(dtype=int), # (nworld,) + edges_in: wp.array3d(dtype=int), # (nworld, njmax, 2) njmax: int, - keys_out: wp.array(dtype=int), - indices_out: wp.array(dtype=int), + keys_out: wp.array2d(dtype=int), # (nworld, 2*njmax) + indices_out: wp.array2d(dtype=int), # (nworld, 2*njmax) ): - """Compute sort keys and initialize indices. Launch: 2 * njmax.""" - i = wp.tid() + """Compute sort keys and initialize indices per world. Launch: (nworld, 2*njmax).""" + worldid, i = wp.tid() # Always init index - indices_out[i] = i + indices_out[worldid, i] = i # Only compute keys for first njmax elements if i >= njmax: return - if i >= nedge_in[0]: - keys_out[i] = 2147483647 # sort to end + if i >= nedge_in[worldid]: + keys_out[worldid, i] = 2147483647 # sort to end else: - keys_out[i] = edges_in[i, 0] * ntree + edges_in[i, 1] + keys_out[worldid, i] = edges_in[worldid, i, 0] * ntree + edges_in[worldid, i, 1] @wp.kernel def _deduplicate_edges( - nedge_in: wp.array(dtype=int), - sorted_indices_in: wp.array(dtype=int), - edges_in: wp.array2d(dtype=int), - edges_out: wp.array2d(dtype=int), - nedge_out: wp.array(dtype=int), + nedge_in: wp.array(dtype=int), # (nworld,) + sorted_indices_in: wp.array2d(dtype=int), # (nworld, 2*njmax) + edges_in: wp.array3d(dtype=int), # (nworld, njmax, 2) + edges_out: wp.array3d(dtype=int), # (nworld, njmax, 2) + nedge_out: wp.array(dtype=int), # (nworld,) ): - """Mark unique and compact in one pass using atomics. Launch: njmax.""" - i = wp.tid() - n = nedge_in[0] + """Mark unique and compact in one pass using atomics. Launch: (nworld, njmax).""" + worldid, i = wp.tid() + n = nedge_in[worldid] if i >= n: return @@ -123,19 +123,19 @@ def _deduplicate_edges( if i == 0: is_unique = 1 else: - idx = sorted_indices_in[i] - prev_idx = sorted_indices_in[i - 1] - if edges_in[idx, 0] != edges_in[prev_idx, 0]: + idx = sorted_indices_in[worldid, i] + prev_idx = sorted_indices_in[worldid, i - 1] + if edges_in[worldid, idx, 0] != edges_in[worldid, prev_idx, 0]: is_unique = 1 - elif edges_in[idx, 1] != edges_in[prev_idx, 1]: + elif edges_in[worldid, idx, 1] != edges_in[worldid, prev_idx, 1]: is_unique = 1 if is_unique == 1: # Atomic add to get output index - dst = wp.atomic_add(nedge_out, 0, 1) - src = sorted_indices_in[i] - edges_out[dst, 0] = edges_in[src, 0] - edges_out[dst, 1] = edges_in[src, 1] + dst = wp.atomic_add(nedge_out, worldid, 1) + src = sorted_indices_in[worldid, i] + edges_out[worldid, dst, 0] = edges_in[worldid, src, 0] + edges_out[worldid, dst, 1] = edges_in[worldid, src, 1] def find_tree_edges( @@ -150,14 +150,14 @@ def find_tree_edges( Returns: Tuple of (edges, nedge) arrays on device. - edges has shape (njmax, 2) where each row is an ordered (t1, t2) pair. - nedge is a (1,) array with the number of unique edges. + edges has shape (nworld, njmax, 2) where each row is an ordered (t1, t2) pair. + nedge is a (nworld,) array with the number of unique edges per world. """ - # allocate outputs - edges = wp.zeros((d.njmax, 2), dtype=int) - nedge = wp.zeros(1, dtype=int) + # allocate per-world outputs + edges = wp.zeros((d.nworld, d.njmax, 2), dtype=int) + nedge = wp.zeros(d.nworld, dtype=int) - # find edges + # find edges (per-world) wp.launch( kernel=_find_tree_edges, dim=(d.nworld, d.njmax), @@ -174,25 +174,29 @@ def find_tree_edges( ], ) - # compute sort keys and init indices (fused) - keys = wp.zeros(2 * d.njmax, dtype=int) - sorted_indices = wp.empty(2 * d.njmax, dtype=int) + # compute sort keys and init indices (per-world, fused) + keys = wp.zeros((d.nworld, 2 * d.njmax), dtype=int) + sorted_indices = wp.empty((d.nworld, 2 * d.njmax), dtype=int) wp.launch( kernel=_compute_keys_and_indices, - dim=2 * d.njmax, + dim=(d.nworld, 2 * d.njmax), inputs=[m.ntree, nedge, edges, d.njmax, keys, sorted_indices], ) - # sort by keys using Warp's radix sort - wp.utils.radix_sort_pairs(keys, sorted_indices, count=d.njmax) + # sort by keys using Warp's radix sort (per-world) + # radix_sort_pairs requires 1D arrays, so we need to sort each world separately + for w in range(d.nworld): + keys_w = keys[w] + indices_w = sorted_indices[w] + wp.utils.radix_sort_pairs(keys_w, indices_w, count=d.njmax) del keys - # deduplicate edges (fused, no prefix sum) - edges_unique = wp.zeros((d.njmax, 2), dtype=int) - nedge_unique = wp.zeros(1, dtype=int) + # deduplicate edges (per-world, fused, no prefix sum) + edges_unique = wp.zeros((d.nworld, d.njmax, 2), dtype=int) + nedge_unique = wp.zeros(d.nworld, dtype=int) wp.launch( kernel=_deduplicate_edges, - dim=d.njmax, + dim=(d.nworld, d.njmax), inputs=[nedge, sorted_indices, edges, edges_unique, nedge_unique], ) From aff7f4a02493268db377b240cf4cd4c1e13d318e Mon Sep 17 00:00:00 2001 From: Taylor Howell Date: Tue, 3 Feb 2026 11:12:38 +0000 Subject: [PATCH 5/7] constraint parallel edge discovery --- mujoco_warp/_src/island.py | 263 +++++++++++++++----------------- mujoco_warp/_src/island_test.py | 255 +++++++++++++++++++++++++++---- 2 files changed, 347 insertions(+), 171 deletions(-) diff --git a/mujoco_warp/_src/island.py b/mujoco_warp/_src/island.py index 7f64b9272..e51a6c402 100644 --- a/mujoco_warp/_src/island.py +++ b/mujoco_warp/_src/island.py @@ -16,188 +16,163 @@ import warp as wp from mujoco_warp._src import types +from mujoco_warp._src.types import ConstraintType +from mujoco_warp._src.types import EqType +from mujoco_warp._src.types import ObjType @wp.kernel -def _find_tree_edges( +def _tree_edges( # Model: nv: int, + body_treeid: wp.array(dtype=int), + jnt_dofadr: wp.array(dtype=int), dof_treeid: wp.array(dtype=int), + geom_bodyid: wp.array(dtype=int), + site_bodyid: wp.array(dtype=int), + eq_type: wp.array(dtype=int), + eq_obj1id: wp.array(dtype=int), + eq_obj2id: wp.array(dtype=int), + eq_objtype: wp.array(dtype=int), # Data in: nefc_in: wp.array(dtype=int), + contact_geom_in: wp.array(dtype=wp.vec2i), efc_type_in: wp.array2d(dtype=int), efc_id_in: wp.array2d(dtype=int), efc_J_in: wp.array3d(dtype=float), njmax_in: int, - # Out (per-world): - edges_out: wp.array3d(dtype=int), # (nworld, njmax, 2) - nedge_out: wp.array(dtype=int), # (nworld,) + # Out: + tree_tree: wp.array3d(dtype=int), # kernel_analyzer: off ): + """Find tree edges. Launch: (nworld, njmax).""" worldid, efcid = wp.tid() # skip if beyond active constraints if efcid >= wp.min(njmax_in, nefc_in[worldid]): return - # skip continuation rows (same constraint type and id as previous) - # this avoids duplicate edges from multi-row constraints (e.g., 3D contacts) - if efcid > 0: - if efc_type_in[worldid, efcid] == efc_type_in[worldid, efcid - 1]: - if efc_id_in[worldid, efcid] == efc_id_in[worldid, efcid - 1]: - return + efc_type = efc_type_in[worldid, efcid] + efc_id = efc_id_in[worldid, efcid] + + tree0 = int(-1) + tree1 = int(-1) + use_generic = int(0) + + # equality (connect/weld) + if efc_type == ConstraintType.EQUALITY: + eq_t = eq_type[efc_id] + + if eq_t == EqType.CONNECT or eq_t == EqType.WELD: + b1 = eq_obj1id[efc_id] + b2 = eq_obj2id[efc_id] + + # site semantics + if eq_objtype[efc_id] == ObjType.SITE: + b1 = site_bodyid[b1] + b2 = site_bodyid[b2] + + tree0 = body_treeid[b1] + tree1 = body_treeid[b2] + else: + # JOINT, TENDON, FLEX + use_generic = 1 + + # joint friction + elif efc_type == ConstraintType.FRICTION_DOF: + tree0 = dof_treeid[efc_id] + + # joint limit + elif efc_type == ConstraintType.LIMIT_JOINT: + tree0 = dof_treeid[jnt_dofadr[efc_id]] + + # contact + elif ( + efc_type == ConstraintType.CONTACT_FRICTIONLESS + or efc_type == ConstraintType.CONTACT_PYRAMIDAL + or efc_type == ConstraintType.CONTACT_ELLIPTIC + ): + geom_pair = contact_geom_in[efc_id] + g1 = geom_pair[0] + g2 = geom_pair[1] + + # flex contacts have negative geom ids + if g1 >= 0 and g2 >= 0: + tree0 = body_treeid[geom_bodyid[g1]] + tree1 = body_treeid[geom_bodyid[g2]] + else: + use_generic = 1 + + # generic + else: + use_generic = 1 + + # handle static bodies + if use_generic == 0: + # swap so tree0 is non-negative if possible + if tree0 < 0 and tree1 >= 0: + tree0 = tree1 + tree1 = -1 + + # mark the edge + if tree0 >= 0: + if tree1 < 0 or tree0 == tree1: + # self-edge + wp.atomic_max(tree_tree, worldid, tree0, tree0, 1) + else: + # cross-tree edge + t1 = wp.min(tree0, tree1) + t2 = wp.max(tree0, tree1) + wp.atomic_max(tree_tree, worldid, t1, t2, 1) + wp.atomic_max(tree_tree, worldid, t2, t1, 1) + return - # collect trees touched by this constraint row - prev_tree = int(-2) # -1 is valid (static), so use -2 as sentinel - first_tree = int(-2) + # generic: scan Jacobian row + first_tree = int(-1) + has_cross_edge = int(0) for dof in range(nv): + # TODO(team): sparse efc_J + # TODO(team): tree dof skip J_val = efc_J_in[worldid, efcid, dof] - if J_val != 0.0: tree = dof_treeid[dof] - - if first_tree == -2: + if tree < 0: + continue + if first_tree == -1: first_tree = tree - prev_tree = tree - elif tree != prev_tree and tree >= 0: - # found a new tree, add edge between prev_tree and tree - if prev_tree >= 0: - idx = wp.atomic_add(nedge_out, worldid, 1) - if idx < njmax_in: - # store as ordered pair for deduplication - t1 = wp.min(prev_tree, tree) - t2 = wp.max(prev_tree, tree) - edges_out[worldid, idx, 0] = t1 - edges_out[worldid, idx, 1] = t2 - prev_tree = tree - - # add self-edge if only one tree found (constrained to itself) - if first_tree >= 0 and prev_tree == first_tree: - idx = wp.atomic_add(nedge_out, worldid, 1) - if idx < njmax_in: - edges_out[worldid, idx, 0] = first_tree - edges_out[worldid, idx, 1] = first_tree + elif tree != first_tree: + t1 = wp.min(first_tree, tree) + t2 = wp.max(first_tree, tree) + wp.atomic_max(tree_tree, worldid, t1, t2, 1) + has_cross_edge = 1 + if first_tree >= 0 and has_cross_edge == 0: + wp.atomic_max(tree_tree, worldid, first_tree, first_tree, 1) -@wp.kernel -def _compute_keys_and_indices( - ntree: int, - nedge_in: wp.array(dtype=int), # (nworld,) - edges_in: wp.array3d(dtype=int), # (nworld, njmax, 2) - njmax: int, - keys_out: wp.array2d(dtype=int), # (nworld, 2*njmax) - indices_out: wp.array2d(dtype=int), # (nworld, 2*njmax) -): - """Compute sort keys and initialize indices per world. Launch: (nworld, 2*njmax).""" - worldid, i = wp.tid() - - # Always init index - indices_out[worldid, i] = i - # Only compute keys for first njmax elements - if i >= njmax: - return - if i >= nedge_in[worldid]: - keys_out[worldid, i] = 2147483647 # sort to end - else: - keys_out[worldid, i] = edges_in[worldid, i, 0] * ntree + edges_in[worldid, i, 1] - - -@wp.kernel -def _deduplicate_edges( - nedge_in: wp.array(dtype=int), # (nworld,) - sorted_indices_in: wp.array2d(dtype=int), # (nworld, 2*njmax) - edges_in: wp.array3d(dtype=int), # (nworld, njmax, 2) - edges_out: wp.array3d(dtype=int), # (nworld, njmax, 2) - nedge_out: wp.array(dtype=int), # (nworld,) -): - """Mark unique and compact in one pass using atomics. Launch: (nworld, njmax).""" - worldid, i = wp.tid() - n = nedge_in[worldid] - - if i >= n: - return - - # Check if this edge is unique (different from previous) - is_unique = int(0) - if i == 0: - is_unique = 1 - else: - idx = sorted_indices_in[worldid, i] - prev_idx = sorted_indices_in[worldid, i - 1] - if edges_in[worldid, idx, 0] != edges_in[worldid, prev_idx, 0]: - is_unique = 1 - elif edges_in[worldid, idx, 1] != edges_in[worldid, prev_idx, 1]: - is_unique = 1 - - if is_unique == 1: - # Atomic add to get output index - dst = wp.atomic_add(nedge_out, worldid, 1) - src = sorted_indices_in[worldid, i] - edges_out[worldid, dst, 0] = edges_in[worldid, src, 0] - edges_out[worldid, dst, 1] = edges_in[worldid, src, 1] - - -def find_tree_edges( - m: types.Model, - d: types.Data, -) -> tuple[wp.array, wp.array]: - """Find tree-tree edges from the constraint Jacobian. - - Args: - m: The model containing kinematic and dynamic information. - d: The data object containing the current state and output arrays. - - Returns: - Tuple of (edges, nedge) arrays on device. - edges has shape (nworld, njmax, 2) where each row is an ordered (t1, t2) pair. - nedge is a (nworld,) array with the number of unique edges per world. - """ - # allocate per-world outputs - edges = wp.zeros((d.nworld, d.njmax, 2), dtype=int) - nedge = wp.zeros(d.nworld, dtype=int) - - # find edges (per-world) +def tree_edges(m: types.Model, d: types.Data, tree_tree: wp.array3d(dtype=int)): + """Compute tree-tree adjacency matrix.""" + tree_tree.zero_() wp.launch( - kernel=_find_tree_edges, + kernel=_tree_edges, dim=(d.nworld, d.njmax), inputs=[ m.nv, + m.body_treeid, + m.jnt_dofadr, m.dof_treeid, + m.geom_bodyid, + m.site_bodyid, + m.eq_type, + m.eq_obj1id, + m.eq_obj2id, + m.eq_objtype, d.nefc, + d.contact.geom, d.efc.type, d.efc.id, d.efc.J, d.njmax, - edges, - nedge, ], + outputs=[tree_tree], ) - - # compute sort keys and init indices (per-world, fused) - keys = wp.zeros((d.nworld, 2 * d.njmax), dtype=int) - sorted_indices = wp.empty((d.nworld, 2 * d.njmax), dtype=int) - wp.launch( - kernel=_compute_keys_and_indices, - dim=(d.nworld, 2 * d.njmax), - inputs=[m.ntree, nedge, edges, d.njmax, keys, sorted_indices], - ) - - # sort by keys using Warp's radix sort (per-world) - # radix_sort_pairs requires 1D arrays, so we need to sort each world separately - for w in range(d.nworld): - keys_w = keys[w] - indices_w = sorted_indices[w] - wp.utils.radix_sort_pairs(keys_w, indices_w, count=d.njmax) - del keys - - # deduplicate edges (per-world, fused, no prefix sum) - edges_unique = wp.zeros((d.nworld, d.njmax, 2), dtype=int) - nedge_unique = wp.zeros(d.nworld, dtype=int) - wp.launch( - kernel=_deduplicate_edges, - dim=(d.nworld, d.njmax), - inputs=[nedge, sorted_indices, edges, edges_unique, nedge_unique], - ) - - return edges_unique, nedge_unique diff --git a/mujoco_warp/_src/island_test.py b/mujoco_warp/_src/island_test.py index e85a68a96..00a082eaf 100644 --- a/mujoco_warp/_src/island_test.py +++ b/mujoco_warp/_src/island_test.py @@ -15,6 +15,7 @@ """Tests for island discovery.""" +import numpy as np import warp as wp from absl.testing import absltest @@ -26,6 +27,8 @@ class IslandEdgeDiscoveryTest(absltest.TestCase): """Tests for edge discovery from constraint Jacobian.""" + # TODO(team): add test for additional constraint types to test special cases + def test_single_constraint_two_trees(self): """A single weld constraint between two bodies creates one edge.""" mjm, mjd, m, d = test_data.fixture( @@ -48,17 +51,14 @@ def test_single_constraint_two_trees(self): """ ) - # run forward to populate constraints mjwarp.forward(m, d) - # find edges - edges, nedge = island.find_tree_edges(m, d) + treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int) + island.tree_edges(m, d, treetree) - # should have exactly 1 edge between tree 0 and tree 1 - self.assertEqual(nedge.numpy()[0], 1) - edge_np = edges.numpy() - self.assertEqual(edge_np[0, 0], 0) # tree 0 - self.assertEqual(edge_np[0, 1], 1) # tree 1 + tt = treetree.numpy() + self.assertEqual(tt[0, 0, 1], 1) + self.assertEqual(tt[0, 1, 0], 1) def test_constraint_within_single_tree_creates_self_edge(self): """A constraint within a single tree creates a self-edge.""" @@ -83,13 +83,12 @@ def test_constraint_within_single_tree_creates_self_edge(self): ) mjwarp.forward(m, d) - edges, nedge = island.find_tree_edges(m, d) - # should have exactly 1 self-edge for tree 0 - self.assertEqual(nedge.numpy()[0], 1) - edge_np = edges.numpy() - self.assertEqual(edge_np[0, 0], 0) # tree 0 - self.assertEqual(edge_np[0, 1], 0) # tree 0 (self-edge) + treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int) + island.tree_edges(m, d, treetree) + + tt = treetree.numpy() + self.assertEqual(tt[0, 0, 0], 1) # self-edge for tree 0 def test_three_bodies_chain(self): """Three bodies with constraints A-B and B-C should have 2 edges.""" @@ -119,15 +118,15 @@ def test_three_bodies_chain(self): ) mjwarp.forward(m, d) - edges, nedge = island.find_tree_edges(m, d) - # should have 2 edges: (0,1) and (1,2) - n = nedge.numpy()[0] - self.assertEqual(n, 2) - edge_np = edges.numpy()[:n] - edges_set = set(tuple(e) for e in edge_np) - self.assertIn((0, 1), edges_set) - self.assertIn((1, 2), edges_set) + treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int) + island.tree_edges(m, d, treetree) + + tt = treetree.numpy() + self.assertEqual(tt[0, 0, 1], 1) + self.assertEqual(tt[0, 1, 0], 1) + self.assertEqual(tt[0, 1, 2], 1) + self.assertEqual(tt[0, 2, 1], 1) def test_deduplication(self): """Repeated constraints between same trees should be deduplicated.""" @@ -153,10 +152,14 @@ def test_deduplication(self): ) mjwarp.forward(m, d) - edges, nedge = island.find_tree_edges(m, d) - # should have 1 unique edge (0,1) despite 2 constraints - self.assertEqual(nedge.numpy()[0], 1) + treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int) + island.tree_edges(m, d, treetree) + + tt = treetree.numpy() + self.assertEqual(tt[0, 0, 1], 1) + self.assertEqual(tt[0, 1, 0], 1) + self.assertEqual(np.sum(tt[0]), 2) def test_no_constraints(self): """No constraints should produce no edges.""" @@ -174,9 +177,207 @@ def test_no_constraints(self): ) mjwarp.forward(m, d) - edges, nedge = island.find_tree_edges(m, d) - self.assertEqual(nedge.numpy()[0], 0) + treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int) + island.tree_edges(m, d, treetree) + + tt = treetree.numpy() + self.assertEqual(np.sum(tt[0]), 0) + + def test_multi_world_parallel(self): + """Each world's edges should be computed independently.""" + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + + + + + + + + """, + nworld=2, + ) + + mjwarp.forward(m, d) + + treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int) + island.tree_edges(m, d, treetree) + + tt = treetree.numpy() + self.assertEqual(tt[0, 0, 1], 1) + self.assertEqual(tt[0, 1, 0], 1) + self.assertEqual(tt[1, 0, 1], 1) + self.assertEqual(tt[1, 1, 0], 1) + + def test_contact_constraint_edges(self): + """Contact constraints between geoms should create edges.""" + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + + + + + """, + nworld=2, + ) + + mjwarp.forward(m, d) + + nefc = d.nefc.numpy() + if nefc[0] > 0: + treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int) + island.tree_edges(m, d, treetree) + + tt = treetree.numpy() + self.assertEqual(tt[0, 0, 1], 1) + self.assertEqual(tt[0, 1, 0], 1) + + def test_isolated_tree_no_edge(self): + """A floating body with no constraints should produce no edges.""" + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + """, + nworld=2, + ) + + mjwarp.forward(m, d) + + treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int) + island.tree_edges(m, d, treetree) + + tt = treetree.numpy() + self.assertEqual(np.sum(tt[0]), 0) + self.assertEqual(np.sum(tt[1]), 0) + + def test_mixed_equality_and_contact(self): + """Both equality and contact constraints should contribute to edges.""" + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + + + + + + + + + + + + """, + nworld=2, + ) + + mjwarp.forward(m, d) + + treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int) + island.tree_edges(m, d, treetree) + + tt = treetree.numpy() + self.assertEqual(tt[0, 1, 2], 1) + self.assertEqual(tt[0, 2, 1], 1) + + def test_worldbody_dofs_ignored(self): + """Constraints involving worldbody (tree < 0) should not cause spurious edges.""" + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + + + + + + + """, + nworld=2, + ) + + mjwarp.forward(m, d) + + treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int) + island.tree_edges(m, d, treetree) + + tt = treetree.numpy() + self.assertEqual(tt[0, 0, 0], 1) # self-edge for floating tree + + def test_constraint_touches_three_trees(self): + """Multiple constraints sharing a body create a star topology.""" + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + + + + + + + + + + + + + """, + nworld=2, + ) + + mjwarp.forward(m, d) + + treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int) + island.tree_edges(m, d, treetree) + + tt = treetree.numpy() + self.assertEqual(tt[0, 0, 1], 1) + self.assertEqual(tt[0, 1, 0], 1) + self.assertEqual(tt[0, 0, 2], 1) + self.assertEqual(tt[0, 2, 0], 1) if __name__ == "__main__": From 9f4a7251c46077354c633d38c6f474e34d4f9d14 Mon Sep 17 00:00:00 2001 From: Taylor Howell Date: Sun, 25 Jan 2026 14:48:42 +0000 Subject: [PATCH 6/7] island discovery with flood fill algorithm --- mujoco_warp/_src/forward.py | 3 + mujoco_warp/_src/io.py | 19 +++ mujoco_warp/_src/io_test.py | 2 + mujoco_warp/_src/island.py | 92 ++++++++++- mujoco_warp/_src/island_test.py | 284 ++++++++++++++++++++++++++++++-- mujoco_warp/_src/types.py | 8 +- 6 files changed, 395 insertions(+), 13 deletions(-) diff --git a/mujoco_warp/_src/forward.py b/mujoco_warp/_src/forward.py index 8facf4995..b4ce36700 100644 --- a/mujoco_warp/_src/forward.py +++ b/mujoco_warp/_src/forward.py @@ -20,6 +20,7 @@ from mujoco_warp._src import collision_driver from mujoco_warp._src import constraint from mujoco_warp._src import derivative +from mujoco_warp._src import island from mujoco_warp._src import math from mujoco_warp._src import passive from mujoco_warp._src import sensor @@ -517,6 +518,8 @@ def fwd_position(m: Model, d: Data, factorize: bool = True): if m.opt.run_collision_detection: collision_driver.collision(m, d) constraint.make_constraint(m, d) + if not (m.opt.disableflags & DisableBit.ISLAND): + island.island(m, d) smooth.transmission(m, d) diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index 6dec37f8a..f6a463fe4 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -738,6 +738,9 @@ def make_data( "eq_active": wp.array(np.tile(mjm.eq_active0.astype(bool), (nworld, 1)), shape=(nworld, mjm.neq), dtype=bool), # flexedge "flexedge_J": None, + # island arrays + "nisland": None, + "tree_island": None, } for f in dataclasses.fields(types.Data): if f.name in d_kwargs: @@ -755,6 +758,10 @@ def make_data( d.flexedge_J = wp.zeros((nworld, 1, mjd.flexedge_J.size), dtype=float) + # island discovery arrays + d.nisland = wp.zeros((nworld,), dtype=int) + d.tree_island = wp.zeros((nworld, mjm.ntree), dtype=int) + return d @@ -889,6 +896,9 @@ def put_data( "actuator_moment": None, "flexedge_J": None, "nacon": None, + # island arrays + "nisland": None, + "tree_island": None, } for f in dataclasses.fields(types.Data): if f.name in d_kwargs: @@ -916,6 +926,10 @@ def put_data( d.flexedge_J = wp.array(np.tile(mjd.flexedge_J.reshape(-1), (nworld, 1)).reshape((nworld, 1, -1)), dtype=float) + # island arrays + d.nisland = wp.zeros((nworld,), dtype=int) + d.tree_island = wp.zeros((nworld, mjm.ntree), dtype=int) + if mujoco.mj_isSparse(mjm): ten_J = np.zeros((mjm.ntendon, mjm.nv)) mujoco.mju_sparse2dense(ten_J, mjd.ten_J.reshape(-1), mjd.ten_J_rownnz, mjd.ten_J_rowadr, mjd.ten_J_colind.reshape(-1)) @@ -1127,6 +1141,11 @@ def get_data_into( # sensors result.sensordata[:] = d.sensordata.numpy()[world_id] + # islands + if not (mjm.opt.disableflags & mujoco.mjtDisableBit.mjDSBL_ISLAND) and mjm.ntree: + result.nisland = d.nisland.numpy()[world_id] + result.tree_island[:] = d.tree_island.numpy()[world_id] + def reset_data(m: types.Model, d: types.Data, reset: Optional[wp.array] = None): """Clear data, set defaults; optionally by world. diff --git a/mujoco_warp/_src/io_test.py b/mujoco_warp/_src/io_test.py index 11753a3f9..35e9b3673 100644 --- a/mujoco_warp/_src/io_test.py +++ b/mujoco_warp/_src/io_test.py @@ -126,6 +126,7 @@ def test_get_data_into(self, nworld, world_id): self.assertEqual(d.ne.numpy()[world_id], mjd.ne) self.assertEqual(d.nf.numpy()[world_id], mjd.nf) self.assertEqual(d.nl.numpy()[world_id], mjd.nl) + self.assertEqual(d.nisland.numpy()[world_id], mjd.nisland) for field in [ "energy", @@ -196,6 +197,7 @@ def test_get_data_into(self, nworld, world_id): "wrap_obj", "wrap_xpos", "sensordata", + "tree_island", ]: _assert_eq( getattr(d, field).numpy()[world_id].reshape(-1), diff --git a/mujoco_warp/_src/island.py b/mujoco_warp/_src/island.py index e51a6c402..3d1d064b4 100644 --- a/mujoco_warp/_src/island.py +++ b/mujoco_warp/_src/island.py @@ -19,6 +19,7 @@ from mujoco_warp._src.types import ConstraintType from mujoco_warp._src.types import EqType from mujoco_warp._src.types import ObjType +from mujoco_warp._src.warp_util import event_scope @wp.kernel @@ -44,7 +45,7 @@ def _tree_edges( # Out: tree_tree: wp.array3d(dtype=int), # kernel_analyzer: off ): - """Find tree edges. Launch: (nworld, njmax).""" + """Find tree edges.""" worldid, efcid = wp.tid() # skip if beyond active constraints @@ -176,3 +177,92 @@ def tree_edges(m: types.Model, d: types.Data, tree_tree: wp.array3d(dtype=int)): ], outputs=[tree_tree], ) + + +@wp.kernel +def _flood_fill( + # Model: + ntree: int, + # In: + tree_tree_in: wp.array3d(dtype=int), + labels_in: wp.array2d(dtype=int), + stack_in: wp.array2d(dtype=int), + # Data out: + nisland_out: wp.array(dtype=int), + # Out: + labels_out: wp.array2d(dtype=int), + stack_out: wp.array2d(dtype=int), +): + """DFS flood fill to discover islands using tree_tree matrix.""" + worldid = wp.tid() + nisland = int(0) + + # iterate over trees + for i in range(ntree): + # already assigned + if labels_in[worldid, i] != -1: + continue + + # check if tree has any edges + has_edge = int(0) + for j in range(ntree): + if tree_tree_in[worldid, i, j] != 0: + has_edge = 1 + break + if has_edge == 0: + continue + + # DFS: push i onto stack + nstack = int(0) + stack_out[worldid, nstack] = i + nstack = nstack + 1 + + while nstack > 0: + # pop v from stack + nstack = nstack - 1 + v = stack_in[worldid, nstack] + + # already assigned + if labels_in[worldid, v] != -1: + continue + + # assign to current island + labels_out[worldid, v] = nisland + + # push neighbors + for neighbor in range(ntree): + if tree_tree_in[worldid, v, neighbor] != 0: + if labels_in[worldid, neighbor] == -1 and nstack < ntree: + stack_out[worldid, nstack] = neighbor + nstack = nstack + 1 + + # island filled + nisland = nisland + 1 + + nisland_out[worldid] = nisland + + +@event_scope +def island( + m: types.Model, + d: types.Data, +) -> None: + """Discover constraint islands.""" + if m.ntree == 0: + d.nisland.zero_() + return + + # Step 1: Find tree edges + tree_tree = wp.zeros((d.nworld, m.ntree, m.ntree), dtype=int) + tree_edges(m, d, tree_tree) + + # Step 2: DFS flood fill + d.tree_island.fill_(-1) + stack_scratch = wp.zeros((d.nworld, m.ntree), dtype=int) + + wp.launch( + _flood_fill, + dim=d.nworld, + inputs=[m.ntree, tree_tree, d.tree_island, stack_scratch], + outputs=[d.nisland, d.tree_island, stack_scratch], + ) diff --git a/mujoco_warp/_src/island_test.py b/mujoco_warp/_src/island_test.py index 00a082eaf..fb95d498f 100644 --- a/mujoco_warp/_src/island_test.py +++ b/mujoco_warp/_src/island_test.py @@ -22,6 +22,7 @@ import mujoco_warp as mjwarp from mujoco_warp import test_data from mujoco_warp._src import island +from mujoco_warp._src.types import DisableBit class IslandEdgeDiscoveryTest(absltest.TestCase): @@ -51,7 +52,7 @@ def test_single_constraint_two_trees(self): """ ) - mjwarp.forward(m, d) + mjwarp.fwd_position(m, d) treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int) island.tree_edges(m, d, treetree) @@ -82,7 +83,7 @@ def test_constraint_within_single_tree_creates_self_edge(self): """ ) - mjwarp.forward(m, d) + mjwarp.fwd_position(m, d) treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int) island.tree_edges(m, d, treetree) @@ -117,7 +118,7 @@ def test_three_bodies_chain(self): """ ) - mjwarp.forward(m, d) + mjwarp.fwd_position(m, d) treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int) island.tree_edges(m, d, treetree) @@ -151,7 +152,7 @@ def test_deduplication(self): """ ) - mjwarp.forward(m, d) + mjwarp.fwd_position(m, d) treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int) island.tree_edges(m, d, treetree) @@ -176,7 +177,7 @@ def test_no_constraints(self): """ ) - mjwarp.forward(m, d) + mjwarp.fwd_position(m, d) treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int) island.tree_edges(m, d, treetree) @@ -207,7 +208,7 @@ def test_multi_world_parallel(self): nworld=2, ) - mjwarp.forward(m, d) + mjwarp.fwd_position(m, d) treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int) island.tree_edges(m, d, treetree) @@ -238,7 +239,7 @@ def test_contact_constraint_edges(self): nworld=2, ) - mjwarp.forward(m, d) + mjwarp.fwd_position(m, d) nefc = d.nefc.numpy() if nefc[0] > 0: @@ -265,7 +266,7 @@ def test_isolated_tree_no_edge(self): nworld=2, ) - mjwarp.forward(m, d) + mjwarp.fwd_position(m, d) treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int) island.tree_edges(m, d, treetree) @@ -301,7 +302,7 @@ def test_mixed_equality_and_contact(self): nworld=2, ) - mjwarp.forward(m, d) + mjwarp.fwd_position(m, d) treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int) island.tree_edges(m, d, treetree) @@ -332,7 +333,7 @@ def test_worldbody_dofs_ignored(self): nworld=2, ) - mjwarp.forward(m, d) + mjwarp.fwd_position(m, d) treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int) island.tree_edges(m, d, treetree) @@ -368,7 +369,7 @@ def test_constraint_touches_three_trees(self): nworld=2, ) - mjwarp.forward(m, d) + mjwarp.fwd_position(m, d) treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int) island.tree_edges(m, d, treetree) @@ -380,6 +381,267 @@ def test_constraint_touches_three_trees(self): self.assertEqual(tt[0, 2, 0], 1) +class IslandDiscoveryTest(absltest.TestCase): + """Tests for full island discovery including label propagation.""" + + def test_two_trees_one_constraint_one_island(self): + """Two trees connected by one constraint form one island. + + topology: + [[0, 1], + [1, 0]] + """ + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + + + + + + + + """, + overrides={"opt.disableflags": DisableBit.ISLAND}, + ) + + d.nisland.fill_(-1) + d.tree_island.fill_(-1) + mjwarp.fwd_position(m, d) + island.island(m, d) + + # should have exactly 1 island + self.assertEqual(d.nisland.numpy()[0], 1) + # both trees should be in island 0 + tree_island = d.tree_island.numpy()[0] + self.assertEqual(tree_island[0], tree_island[1]) + self.assertEqual(tree_island[0], 0) + + def test_three_trees_chain_one_island(self): + """Three trees in a chain form one island. + + topology: + [[0, 1, 0], + [1, 0, 1], + [0, 1, 0]] + """ + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + + + + + + + + + + + + + """, + overrides={"opt.disableflags": DisableBit.ISLAND}, + ) + + d.nisland.fill_(-1) + d.tree_island.fill_(-1) + mjwarp.fwd_position(m, d) + island.island(m, d) + + # should have exactly 1 island + self.assertEqual(d.nisland.numpy()[0], 1) + # all trees should be in the same island + tree_island = d.tree_island.numpy()[0] + self.assertEqual(tree_island[0], tree_island[1]) + self.assertEqual(tree_island[1], tree_island[2]) + + def test_two_disconnected_pairs_two_islands(self): + """Two pairs of disconnected trees form two islands. + + topology: + [[0, 1, 0, 0], + [1, 0, 0, 0], + [0, 0, 0, 1], + [0, 0, 1, 0]] + """ + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + + + + + + + + + + + + + + + + + """, + overrides={"opt.disableflags": DisableBit.ISLAND}, + ) + + d.nisland.fill_(-1) + d.tree_island.fill_(-1) + mjwarp.fwd_position(m, d) + island.island(m, d) + + # should have exactly 2 islands + self.assertEqual(d.nisland.numpy()[0], 2) + # trees 0,1 should be in one island, trees 2,3 in another + tree_island = d.tree_island.numpy()[0] + self.assertEqual(tree_island[0], tree_island[1]) + self.assertEqual(tree_island[2], tree_island[3]) + self.assertNotEqual(tree_island[0], tree_island[2]) + + def test_no_constraints_no_islands(self): + """No constraints means no constrained islands. + + topology: + [[0]] (no edges) + """ + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + """, + overrides={"opt.disableflags": DisableBit.ISLAND}, + ) + + d.nisland.fill_(-1) + d.tree_island.fill_(-1) + mjwarp.fwd_position(m, d) + island.island(m, d) + + # should have 0 islands (unconstrained tree is not an island) + self.assertEqual(d.nisland.numpy()[0], 0) + + def test_multiple_worlds(self): + """Test island discovery with nworld=2. + + topology: + [[0, 1], + [1, 0]] + """ + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + + + + + + + + """, + nworld=2, + overrides={"opt.disableflags": DisableBit.ISLAND}, + ) + + d.nisland.fill_(-1) + d.tree_island.fill_(-1) + mjwarp.fwd_position(m, d) + island.island(m, d) + + # both worlds should have exactly 1 island + nisland = d.nisland.numpy() + self.assertEqual(nisland[0], 1) + self.assertEqual(nisland[1], 1) + + # both trees in both worlds should be in island 0 + tree_island = d.tree_island.numpy() + for worldid in range(2): + self.assertEqual(tree_island[worldid, 0], 0) + self.assertEqual(tree_island[worldid, 1], 0) + + def test_three_trees_star_hub_at_end(self): + """Three trees with tree 2 as hub connecting trees 0 and 1. + + topology: + [[0, 0, 1], + [0, 0, 1], + [1, 1, 0]] + """ + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + + + + + + + + + + + + + """, + overrides={"opt.disableflags": DisableBit.ISLAND}, + ) + + d.nisland.fill_(-1) + d.tree_island.fill_(-1) + mjwarp.fwd_position(m, d) + island.island(m, d) + + # should have exactly 1 island + self.assertEqual(d.nisland.numpy()[0], 1) + # all trees should be in the same island + tree_island = d.tree_island.numpy()[0] + self.assertEqual(tree_island[0], tree_island[1]) + self.assertEqual(tree_island[1], tree_island[2]) + + if __name__ == "__main__": wp.init() absltest.main() diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py index 1462b7603..b0498cf2d 100644 --- a/mujoco_warp/_src/types.py +++ b/mujoco_warp/_src/types.py @@ -147,6 +147,7 @@ class DisableBit(enum.IntFlag): SENSOR: sensors EULERDAMP: implicit damping for Euler integration NATIVECCD: native convex collision detection (ignored in MJWarp) + ISLAND: constraint islands """ CONSTRAINT = mujoco.mjtDisableBit.mjDSBL_CONSTRAINT @@ -165,7 +166,8 @@ class DisableBit(enum.IntFlag): SENSOR = mujoco.mjtDisableBit.mjDSBL_SENSOR EULERDAMP = mujoco.mjtDisableBit.mjDSBL_EULERDAMP NATIVECCD = mujoco.mjtDisableBit.mjDSBL_NATIVECCD - # unsupported: MIDPHASE, AUTORESET, ISLAND + ISLAND = mujoco.mjtDisableBit.mjDSBL_ISLAND + # unsupported: MIDPHASE, AUTORESET class EnableBit(enum.IntFlag): @@ -1561,6 +1563,7 @@ class Data: nf: number of friction constraints (nworld,) nl: number of limit constraints (nworld,) nefc: number of constraints (nworld,) + nisland: number of constraint islands (nworld,) time: simulation time (nworld,) energy: potential, kinetic energy (nworld, 2) qpos: position (nworld, nq) @@ -1637,6 +1640,7 @@ class Data: cfrc_ext: com-based external force on body (nworld, nbody, 6) contact: contact data efc: constraint data + tree_island: island ID per tree (-1 if unconstrained) (nworld, ntree) warp only fields: nworld: number of worlds @@ -1654,6 +1658,7 @@ class Data: nf: array("nworld", int) nl: array("nworld", int) nefc: array("nworld", int) + nisland: array("nworld", int) time: array("nworld", float) energy: array("nworld", wp.vec2) qpos: array("nworld", "nq", float) @@ -1726,6 +1731,7 @@ class Data: cfrc_ext: array("nworld", "nbody", wp.spatial_vector) contact: Contact efc: Constraint + tree_island: array("nworld", "ntree", int) # warp only fields: nworld: int From b3246af2d7573d0250ba4dedd8a7ef11e84795c2 Mon Sep 17 00:00:00 2001 From: Taylor Howell Date: Mon, 9 Feb 2026 17:00:28 +0000 Subject: [PATCH 7/7] update --- mujoco_warp/__init__.py | 1 + mujoco_warp/_src/io.py | 9 +++++---- mujoco_warp/_src/io_test.py | 4 +++- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/mujoco_warp/__init__.py b/mujoco_warp/__init__.py index 798f97414..679a08861 100644 --- a/mujoco_warp/__init__.py +++ b/mujoco_warp/__init__.py @@ -56,6 +56,7 @@ from mujoco_warp._src.io import set_const as set_const from mujoco_warp._src.io import set_const_0 as set_const_0 from mujoco_warp._src.io import set_const_fixed as set_const_fixed +from mujoco_warp._src.island import island as island from mujoco_warp._src.passive import passive as passive from mujoco_warp._src.ray import ray as ray from mujoco_warp._src.ray import rays as rays diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index c1be8ebaf..a44ccdc14 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -929,8 +929,8 @@ def put_data( d.flexedge_J = wp.array(np.tile(mjd.flexedge_J.reshape(-1), (nworld, 1)).reshape((nworld, 1, -1)), dtype=float) # island arrays - d.nisland = wp.zeros((nworld,), dtype=int) - d.tree_island = wp.zeros((nworld, mjm.ntree), dtype=int) + d.nisland = wp.array(np.full(nworld, mjd.nisland), dtype=int) + d.tree_island = wp.array(np.tile(mjd.tree_island, (nworld, 1)), dtype=int) if mujoco.mj_isSparse(mjm): ten_J = np.zeros((mjm.ntendon, mjm.nv)) @@ -1153,8 +1153,9 @@ def get_data_into( result.sensordata[:] = d.sensordata.numpy()[world_id] # islands - if not (mjm.opt.disableflags & mujoco.mjtDisableBit.mjDSBL_ISLAND) and mjm.ntree: - result.nisland = d.nisland.numpy()[world_id] + nisland = d.nisland.numpy()[world_id] + result.nisland = nisland + if nisland: result.tree_island[:] = d.tree_island.numpy()[world_id] diff --git a/mujoco_warp/_src/io_test.py b/mujoco_warp/_src/io_test.py index a58db8534..5cb55c8ea 100644 --- a/mujoco_warp/_src/io_test.py +++ b/mujoco_warp/_src/io_test.py @@ -212,7 +212,6 @@ def test_get_data_into(self, nworld, world_id): "wrap_obj", "wrap_xpos", "sensordata", - "tree_island", ]: _assert_eq( getattr(d, field).numpy()[world_id].reshape(-1), @@ -298,7 +297,10 @@ def test_get_data_into_io_test_models(self, xml): "xpos", "xquat", "geom_xpos", + "tree_island", ]: + if field == "tree_island" and d.nisland.numpy()[0] == 0: + continue if getattr(mjd, field).size > 0: _assert_eq( getattr(mjd_result, field).reshape(-1),