diff --git a/mujoco_warp/__init__.py b/mujoco_warp/__init__.py
index 798f97414..679a08861 100644
--- a/mujoco_warp/__init__.py
+++ b/mujoco_warp/__init__.py
@@ -56,6 +56,7 @@
from mujoco_warp._src.io import set_const as set_const
from mujoco_warp._src.io import set_const_0 as set_const_0
from mujoco_warp._src.io import set_const_fixed as set_const_fixed
+from mujoco_warp._src.island import island as island
from mujoco_warp._src.passive import passive as passive
from mujoco_warp._src.ray import ray as ray
from mujoco_warp._src.ray import rays as rays
diff --git a/mujoco_warp/_src/forward.py b/mujoco_warp/_src/forward.py
index 032070be8..b6ec5c62c 100644
--- a/mujoco_warp/_src/forward.py
+++ b/mujoco_warp/_src/forward.py
@@ -20,6 +20,7 @@
from mujoco_warp._src import collision_driver
from mujoco_warp._src import constraint
from mujoco_warp._src import derivative
+from mujoco_warp._src import island
from mujoco_warp._src import math
from mujoco_warp._src import passive
from mujoco_warp._src import sensor
@@ -517,6 +518,8 @@ def fwd_position(m: Model, d: Data, factorize: bool = True):
if m.opt.run_collision_detection:
collision_driver.collision(m, d)
constraint.make_constraint(m, d)
+ if not (m.opt.disableflags & DisableBit.ISLAND):
+ island.island(m, d)
smooth.transmission(m, d)
diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py
index b96ea5406..a44ccdc14 100644
--- a/mujoco_warp/_src/io.py
+++ b/mujoco_warp/_src/io.py
@@ -740,6 +740,9 @@ def make_data(
"eq_active": wp.array(np.tile(mjm.eq_active0.astype(bool), (nworld, 1)), shape=(nworld, mjm.neq), dtype=bool),
# flexedge
"flexedge_J": None,
+ # island arrays
+ "nisland": None,
+ "tree_island": None,
}
for f in dataclasses.fields(types.Data):
if f.name in d_kwargs:
@@ -757,6 +760,10 @@ def make_data(
d.flexedge_J = wp.zeros((nworld, 1, mjd.flexedge_J.size), dtype=float)
+ # island discovery arrays
+ d.nisland = wp.zeros((nworld,), dtype=int)
+ d.tree_island = wp.zeros((nworld, mjm.ntree), dtype=int)
+
return d
@@ -891,6 +898,9 @@ def put_data(
"actuator_moment": None,
"flexedge_J": None,
"nacon": None,
+ # island arrays
+ "nisland": None,
+ "tree_island": None,
}
for f in dataclasses.fields(types.Data):
if f.name in d_kwargs:
@@ -918,6 +928,10 @@ def put_data(
d.flexedge_J = wp.array(np.tile(mjd.flexedge_J.reshape(-1), (nworld, 1)).reshape((nworld, 1, -1)), dtype=float)
+ # island arrays
+ d.nisland = wp.array(np.full(nworld, mjd.nisland), dtype=int)
+ d.tree_island = wp.array(np.tile(mjd.tree_island, (nworld, 1)), dtype=int)
+
if mujoco.mj_isSparse(mjm):
ten_J = np.zeros((mjm.ntendon, mjm.nv))
mujoco.mju_sparse2dense(ten_J, mjd.ten_J.reshape(-1), mjd.ten_J_rownnz, mjd.ten_J_rowadr, mjd.ten_J_colind.reshape(-1))
@@ -1138,6 +1152,12 @@ def get_data_into(
# sensors
result.sensordata[:] = d.sensordata.numpy()[world_id]
+ # islands
+ nisland = d.nisland.numpy()[world_id]
+ result.nisland = nisland
+ if nisland:
+ result.tree_island[:] = d.tree_island.numpy()[world_id]
+
def reset_data(m: types.Model, d: types.Data, reset: Optional[wp.array] = None):
"""Clear data, set defaults; optionally by world.
diff --git a/mujoco_warp/_src/io_test.py b/mujoco_warp/_src/io_test.py
index 4ecf590e5..5cb55c8ea 100644
--- a/mujoco_warp/_src/io_test.py
+++ b/mujoco_warp/_src/io_test.py
@@ -141,6 +141,7 @@ def test_get_data_into(self, nworld, world_id):
self.assertEqual(d.ne.numpy()[world_id], mjd.ne)
self.assertEqual(d.nf.numpy()[world_id], mjd.nf)
self.assertEqual(d.nl.numpy()[world_id], mjd.nl)
+ self.assertEqual(d.nisland.numpy()[world_id], mjd.nisland)
_assert_eq(d.time.numpy()[world_id], mjd.time, "time")
for field in [
@@ -296,7 +297,10 @@ def test_get_data_into_io_test_models(self, xml):
"xpos",
"xquat",
"geom_xpos",
+ "tree_island",
]:
+ if field == "tree_island" and d.nisland.numpy()[0] == 0:
+ continue
if getattr(mjd, field).size > 0:
_assert_eq(
getattr(mjd_result, field).reshape(-1),
diff --git a/mujoco_warp/_src/island.py b/mujoco_warp/_src/island.py
index e51a6c402..3d1d064b4 100644
--- a/mujoco_warp/_src/island.py
+++ b/mujoco_warp/_src/island.py
@@ -19,6 +19,7 @@
from mujoco_warp._src.types import ConstraintType
from mujoco_warp._src.types import EqType
from mujoco_warp._src.types import ObjType
+from mujoco_warp._src.warp_util import event_scope
@wp.kernel
@@ -44,7 +45,7 @@ def _tree_edges(
# Out:
tree_tree: wp.array3d(dtype=int), # kernel_analyzer: off
):
- """Find tree edges. Launch: (nworld, njmax)."""
+ """Find tree edges."""
worldid, efcid = wp.tid()
# skip if beyond active constraints
@@ -176,3 +177,92 @@ def tree_edges(m: types.Model, d: types.Data, tree_tree: wp.array3d(dtype=int)):
],
outputs=[tree_tree],
)
+
+
+@wp.kernel
+def _flood_fill(
+ # Model:
+ ntree: int,
+ # In:
+ tree_tree_in: wp.array3d(dtype=int),
+ labels_in: wp.array2d(dtype=int),
+ stack_in: wp.array2d(dtype=int),
+ # Data out:
+ nisland_out: wp.array(dtype=int),
+ # Out:
+ labels_out: wp.array2d(dtype=int),
+ stack_out: wp.array2d(dtype=int),
+):
+ """DFS flood fill to discover islands using tree_tree matrix."""
+ worldid = wp.tid()
+ nisland = int(0)
+
+ # iterate over trees
+ for i in range(ntree):
+ # already assigned
+ if labels_in[worldid, i] != -1:
+ continue
+
+ # check if tree has any edges
+ has_edge = int(0)
+ for j in range(ntree):
+ if tree_tree_in[worldid, i, j] != 0:
+ has_edge = 1
+ break
+ if has_edge == 0:
+ continue
+
+ # DFS: push i onto stack
+ nstack = int(0)
+ stack_out[worldid, nstack] = i
+ nstack = nstack + 1
+
+ while nstack > 0:
+ # pop v from stack
+ nstack = nstack - 1
+ v = stack_in[worldid, nstack]
+
+ # already assigned
+ if labels_in[worldid, v] != -1:
+ continue
+
+ # assign to current island
+ labels_out[worldid, v] = nisland
+
+ # push neighbors
+ for neighbor in range(ntree):
+ if tree_tree_in[worldid, v, neighbor] != 0:
+ if labels_in[worldid, neighbor] == -1 and nstack < ntree:
+ stack_out[worldid, nstack] = neighbor
+ nstack = nstack + 1
+
+ # island filled
+ nisland = nisland + 1
+
+ nisland_out[worldid] = nisland
+
+
+@event_scope
+def island(
+ m: types.Model,
+ d: types.Data,
+) -> None:
+ """Discover constraint islands."""
+ if m.ntree == 0:
+ d.nisland.zero_()
+ return
+
+ # Step 1: Find tree edges
+ tree_tree = wp.zeros((d.nworld, m.ntree, m.ntree), dtype=int)
+ tree_edges(m, d, tree_tree)
+
+ # Step 2: DFS flood fill
+ d.tree_island.fill_(-1)
+ stack_scratch = wp.zeros((d.nworld, m.ntree), dtype=int)
+
+ wp.launch(
+ _flood_fill,
+ dim=d.nworld,
+ inputs=[m.ntree, tree_tree, d.tree_island, stack_scratch],
+ outputs=[d.nisland, d.tree_island, stack_scratch],
+ )
diff --git a/mujoco_warp/_src/island_test.py b/mujoco_warp/_src/island_test.py
index 00a082eaf..fb95d498f 100644
--- a/mujoco_warp/_src/island_test.py
+++ b/mujoco_warp/_src/island_test.py
@@ -22,6 +22,7 @@
import mujoco_warp as mjwarp
from mujoco_warp import test_data
from mujoco_warp._src import island
+from mujoco_warp._src.types import DisableBit
class IslandEdgeDiscoveryTest(absltest.TestCase):
@@ -51,7 +52,7 @@ def test_single_constraint_two_trees(self):
"""
)
- mjwarp.forward(m, d)
+ mjwarp.fwd_position(m, d)
treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int)
island.tree_edges(m, d, treetree)
@@ -82,7 +83,7 @@ def test_constraint_within_single_tree_creates_self_edge(self):
"""
)
- mjwarp.forward(m, d)
+ mjwarp.fwd_position(m, d)
treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int)
island.tree_edges(m, d, treetree)
@@ -117,7 +118,7 @@ def test_three_bodies_chain(self):
"""
)
- mjwarp.forward(m, d)
+ mjwarp.fwd_position(m, d)
treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int)
island.tree_edges(m, d, treetree)
@@ -151,7 +152,7 @@ def test_deduplication(self):
"""
)
- mjwarp.forward(m, d)
+ mjwarp.fwd_position(m, d)
treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int)
island.tree_edges(m, d, treetree)
@@ -176,7 +177,7 @@ def test_no_constraints(self):
"""
)
- mjwarp.forward(m, d)
+ mjwarp.fwd_position(m, d)
treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int)
island.tree_edges(m, d, treetree)
@@ -207,7 +208,7 @@ def test_multi_world_parallel(self):
nworld=2,
)
- mjwarp.forward(m, d)
+ mjwarp.fwd_position(m, d)
treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int)
island.tree_edges(m, d, treetree)
@@ -238,7 +239,7 @@ def test_contact_constraint_edges(self):
nworld=2,
)
- mjwarp.forward(m, d)
+ mjwarp.fwd_position(m, d)
nefc = d.nefc.numpy()
if nefc[0] > 0:
@@ -265,7 +266,7 @@ def test_isolated_tree_no_edge(self):
nworld=2,
)
- mjwarp.forward(m, d)
+ mjwarp.fwd_position(m, d)
treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int)
island.tree_edges(m, d, treetree)
@@ -301,7 +302,7 @@ def test_mixed_equality_and_contact(self):
nworld=2,
)
- mjwarp.forward(m, d)
+ mjwarp.fwd_position(m, d)
treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int)
island.tree_edges(m, d, treetree)
@@ -332,7 +333,7 @@ def test_worldbody_dofs_ignored(self):
nworld=2,
)
- mjwarp.forward(m, d)
+ mjwarp.fwd_position(m, d)
treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int)
island.tree_edges(m, d, treetree)
@@ -368,7 +369,7 @@ def test_constraint_touches_three_trees(self):
nworld=2,
)
- mjwarp.forward(m, d)
+ mjwarp.fwd_position(m, d)
treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int)
island.tree_edges(m, d, treetree)
@@ -380,6 +381,267 @@ def test_constraint_touches_three_trees(self):
self.assertEqual(tt[0, 2, 0], 1)
+class IslandDiscoveryTest(absltest.TestCase):
+ """Tests for full island discovery including label propagation."""
+
+ def test_two_trees_one_constraint_one_island(self):
+ """Two trees connected by one constraint form one island.
+
+ topology:
+ [[0, 1],
+ [1, 0]]
+ """
+ mjm, mjd, m, d = test_data.fixture(
+ xml="""
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ """,
+ overrides={"opt.disableflags": DisableBit.ISLAND},
+ )
+
+ d.nisland.fill_(-1)
+ d.tree_island.fill_(-1)
+ mjwarp.fwd_position(m, d)
+ island.island(m, d)
+
+ # should have exactly 1 island
+ self.assertEqual(d.nisland.numpy()[0], 1)
+ # both trees should be in island 0
+ tree_island = d.tree_island.numpy()[0]
+ self.assertEqual(tree_island[0], tree_island[1])
+ self.assertEqual(tree_island[0], 0)
+
+ def test_three_trees_chain_one_island(self):
+ """Three trees in a chain form one island.
+
+ topology:
+ [[0, 1, 0],
+ [1, 0, 1],
+ [0, 1, 0]]
+ """
+ mjm, mjd, m, d = test_data.fixture(
+ xml="""
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ """,
+ overrides={"opt.disableflags": DisableBit.ISLAND},
+ )
+
+ d.nisland.fill_(-1)
+ d.tree_island.fill_(-1)
+ mjwarp.fwd_position(m, d)
+ island.island(m, d)
+
+ # should have exactly 1 island
+ self.assertEqual(d.nisland.numpy()[0], 1)
+ # all trees should be in the same island
+ tree_island = d.tree_island.numpy()[0]
+ self.assertEqual(tree_island[0], tree_island[1])
+ self.assertEqual(tree_island[1], tree_island[2])
+
+ def test_two_disconnected_pairs_two_islands(self):
+ """Two pairs of disconnected trees form two islands.
+
+ topology:
+ [[0, 1, 0, 0],
+ [1, 0, 0, 0],
+ [0, 0, 0, 1],
+ [0, 0, 1, 0]]
+ """
+ mjm, mjd, m, d = test_data.fixture(
+ xml="""
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ """,
+ overrides={"opt.disableflags": DisableBit.ISLAND},
+ )
+
+ d.nisland.fill_(-1)
+ d.tree_island.fill_(-1)
+ mjwarp.fwd_position(m, d)
+ island.island(m, d)
+
+ # should have exactly 2 islands
+ self.assertEqual(d.nisland.numpy()[0], 2)
+ # trees 0,1 should be in one island, trees 2,3 in another
+ tree_island = d.tree_island.numpy()[0]
+ self.assertEqual(tree_island[0], tree_island[1])
+ self.assertEqual(tree_island[2], tree_island[3])
+ self.assertNotEqual(tree_island[0], tree_island[2])
+
+ def test_no_constraints_no_islands(self):
+ """No constraints means no constrained islands.
+
+ topology:
+ [[0]] (no edges)
+ """
+ mjm, mjd, m, d = test_data.fixture(
+ xml="""
+
+
+
+
+
+
+
+
+ """,
+ overrides={"opt.disableflags": DisableBit.ISLAND},
+ )
+
+ d.nisland.fill_(-1)
+ d.tree_island.fill_(-1)
+ mjwarp.fwd_position(m, d)
+ island.island(m, d)
+
+ # should have 0 islands (unconstrained tree is not an island)
+ self.assertEqual(d.nisland.numpy()[0], 0)
+
+ def test_multiple_worlds(self):
+ """Test island discovery with nworld=2.
+
+ topology:
+ [[0, 1],
+ [1, 0]]
+ """
+ mjm, mjd, m, d = test_data.fixture(
+ xml="""
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ """,
+ nworld=2,
+ overrides={"opt.disableflags": DisableBit.ISLAND},
+ )
+
+ d.nisland.fill_(-1)
+ d.tree_island.fill_(-1)
+ mjwarp.fwd_position(m, d)
+ island.island(m, d)
+
+ # both worlds should have exactly 1 island
+ nisland = d.nisland.numpy()
+ self.assertEqual(nisland[0], 1)
+ self.assertEqual(nisland[1], 1)
+
+ # both trees in both worlds should be in island 0
+ tree_island = d.tree_island.numpy()
+ for worldid in range(2):
+ self.assertEqual(tree_island[worldid, 0], 0)
+ self.assertEqual(tree_island[worldid, 1], 0)
+
+ def test_three_trees_star_hub_at_end(self):
+ """Three trees with tree 2 as hub connecting trees 0 and 1.
+
+ topology:
+ [[0, 0, 1],
+ [0, 0, 1],
+ [1, 1, 0]]
+ """
+ mjm, mjd, m, d = test_data.fixture(
+ xml="""
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ """,
+ overrides={"opt.disableflags": DisableBit.ISLAND},
+ )
+
+ d.nisland.fill_(-1)
+ d.tree_island.fill_(-1)
+ mjwarp.fwd_position(m, d)
+ island.island(m, d)
+
+ # should have exactly 1 island
+ self.assertEqual(d.nisland.numpy()[0], 1)
+ # all trees should be in the same island
+ tree_island = d.tree_island.numpy()[0]
+ self.assertEqual(tree_island[0], tree_island[1])
+ self.assertEqual(tree_island[1], tree_island[2])
+
+
if __name__ == "__main__":
wp.init()
absltest.main()
diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py
index 41fbd4b3e..8ffdd44db 100644
--- a/mujoco_warp/_src/types.py
+++ b/mujoco_warp/_src/types.py
@@ -167,6 +167,7 @@ class DisableBit(enum.IntFlag):
SENSOR: sensors
EULERDAMP: implicit damping for Euler integration
NATIVECCD: native convex collision detection (ignored in MJWarp)
+ ISLAND: constraint islands
"""
CONSTRAINT = mujoco.mjtDisableBit.mjDSBL_CONSTRAINT
@@ -185,7 +186,8 @@ class DisableBit(enum.IntFlag):
SENSOR = mujoco.mjtDisableBit.mjDSBL_SENSOR
EULERDAMP = mujoco.mjtDisableBit.mjDSBL_EULERDAMP
NATIVECCD = mujoco.mjtDisableBit.mjDSBL_NATIVECCD
- # unsupported: MIDPHASE, AUTORESET, ISLAND
+ ISLAND = mujoco.mjtDisableBit.mjDSBL_ISLAND
+ # unsupported: MIDPHASE, AUTORESET
class EnableBit(enum.IntFlag):
@@ -1583,6 +1585,7 @@ class Data:
nf: number of friction constraints (nworld,)
nl: number of limit constraints (nworld,)
nefc: number of constraints (nworld,)
+ nisland: number of constraint islands (nworld,)
time: simulation time (nworld,)
energy: potential, kinetic energy (nworld, 2)
qpos: position (nworld, nq)
@@ -1659,6 +1662,7 @@ class Data:
cfrc_ext: com-based external force on body (nworld, nbody, 6)
contact: contact data
efc: constraint data
+ tree_island: island ID per tree (-1 if unconstrained) (nworld, ntree)
warp only fields:
nworld: number of worlds
@@ -1676,6 +1680,7 @@ class Data:
nf: array("nworld", int)
nl: array("nworld", int)
nefc: array("nworld", int)
+ nisland: array("nworld", int)
time: array("nworld", float)
energy: array("nworld", wp.vec2)
qpos: array("nworld", "nq", float)
@@ -1748,6 +1753,7 @@ class Data:
cfrc_ext: array("nworld", "nbody", wp.spatial_vector)
contact: Contact
efc: Constraint
+ tree_island: array("nworld", "ntree", int)
# warp only fields:
nworld: int