diff --git a/mujoco_warp/__init__.py b/mujoco_warp/__init__.py index 798f97414..679a08861 100644 --- a/mujoco_warp/__init__.py +++ b/mujoco_warp/__init__.py @@ -56,6 +56,7 @@ from mujoco_warp._src.io import set_const as set_const from mujoco_warp._src.io import set_const_0 as set_const_0 from mujoco_warp._src.io import set_const_fixed as set_const_fixed +from mujoco_warp._src.island import island as island from mujoco_warp._src.passive import passive as passive from mujoco_warp._src.ray import ray as ray from mujoco_warp._src.ray import rays as rays diff --git a/mujoco_warp/_src/forward.py b/mujoco_warp/_src/forward.py index 032070be8..b6ec5c62c 100644 --- a/mujoco_warp/_src/forward.py +++ b/mujoco_warp/_src/forward.py @@ -20,6 +20,7 @@ from mujoco_warp._src import collision_driver from mujoco_warp._src import constraint from mujoco_warp._src import derivative +from mujoco_warp._src import island from mujoco_warp._src import math from mujoco_warp._src import passive from mujoco_warp._src import sensor @@ -517,6 +518,8 @@ def fwd_position(m: Model, d: Data, factorize: bool = True): if m.opt.run_collision_detection: collision_driver.collision(m, d) constraint.make_constraint(m, d) + if not (m.opt.disableflags & DisableBit.ISLAND): + island.island(m, d) smooth.transmission(m, d) diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index b96ea5406..a44ccdc14 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -740,6 +740,9 @@ def make_data( "eq_active": wp.array(np.tile(mjm.eq_active0.astype(bool), (nworld, 1)), shape=(nworld, mjm.neq), dtype=bool), # flexedge "flexedge_J": None, + # island arrays + "nisland": None, + "tree_island": None, } for f in dataclasses.fields(types.Data): if f.name in d_kwargs: @@ -757,6 +760,10 @@ def make_data( d.flexedge_J = wp.zeros((nworld, 1, mjd.flexedge_J.size), dtype=float) + # island discovery arrays + d.nisland = wp.zeros((nworld,), dtype=int) + d.tree_island = wp.zeros((nworld, mjm.ntree), dtype=int) + return d @@ -891,6 +898,9 @@ def put_data( "actuator_moment": None, "flexedge_J": None, "nacon": None, + # island arrays + "nisland": None, + "tree_island": None, } for f in dataclasses.fields(types.Data): if f.name in d_kwargs: @@ -918,6 +928,10 @@ def put_data( d.flexedge_J = wp.array(np.tile(mjd.flexedge_J.reshape(-1), (nworld, 1)).reshape((nworld, 1, -1)), dtype=float) + # island arrays + d.nisland = wp.array(np.full(nworld, mjd.nisland), dtype=int) + d.tree_island = wp.array(np.tile(mjd.tree_island, (nworld, 1)), dtype=int) + if mujoco.mj_isSparse(mjm): ten_J = np.zeros((mjm.ntendon, mjm.nv)) mujoco.mju_sparse2dense(ten_J, mjd.ten_J.reshape(-1), mjd.ten_J_rownnz, mjd.ten_J_rowadr, mjd.ten_J_colind.reshape(-1)) @@ -1138,6 +1152,12 @@ def get_data_into( # sensors result.sensordata[:] = d.sensordata.numpy()[world_id] + # islands + nisland = d.nisland.numpy()[world_id] + result.nisland = nisland + if nisland: + result.tree_island[:] = d.tree_island.numpy()[world_id] + def reset_data(m: types.Model, d: types.Data, reset: Optional[wp.array] = None): """Clear data, set defaults; optionally by world. diff --git a/mujoco_warp/_src/io_test.py b/mujoco_warp/_src/io_test.py index 4ecf590e5..5cb55c8ea 100644 --- a/mujoco_warp/_src/io_test.py +++ b/mujoco_warp/_src/io_test.py @@ -141,6 +141,7 @@ def test_get_data_into(self, nworld, world_id): self.assertEqual(d.ne.numpy()[world_id], mjd.ne) self.assertEqual(d.nf.numpy()[world_id], mjd.nf) self.assertEqual(d.nl.numpy()[world_id], mjd.nl) + self.assertEqual(d.nisland.numpy()[world_id], mjd.nisland) _assert_eq(d.time.numpy()[world_id], mjd.time, "time") for field in [ @@ -296,7 +297,10 @@ def test_get_data_into_io_test_models(self, xml): "xpos", "xquat", "geom_xpos", + "tree_island", ]: + if field == "tree_island" and d.nisland.numpy()[0] == 0: + continue if getattr(mjd, field).size > 0: _assert_eq( getattr(mjd_result, field).reshape(-1), diff --git a/mujoco_warp/_src/island.py b/mujoco_warp/_src/island.py index e51a6c402..3d1d064b4 100644 --- a/mujoco_warp/_src/island.py +++ b/mujoco_warp/_src/island.py @@ -19,6 +19,7 @@ from mujoco_warp._src.types import ConstraintType from mujoco_warp._src.types import EqType from mujoco_warp._src.types import ObjType +from mujoco_warp._src.warp_util import event_scope @wp.kernel @@ -44,7 +45,7 @@ def _tree_edges( # Out: tree_tree: wp.array3d(dtype=int), # kernel_analyzer: off ): - """Find tree edges. Launch: (nworld, njmax).""" + """Find tree edges.""" worldid, efcid = wp.tid() # skip if beyond active constraints @@ -176,3 +177,92 @@ def tree_edges(m: types.Model, d: types.Data, tree_tree: wp.array3d(dtype=int)): ], outputs=[tree_tree], ) + + +@wp.kernel +def _flood_fill( + # Model: + ntree: int, + # In: + tree_tree_in: wp.array3d(dtype=int), + labels_in: wp.array2d(dtype=int), + stack_in: wp.array2d(dtype=int), + # Data out: + nisland_out: wp.array(dtype=int), + # Out: + labels_out: wp.array2d(dtype=int), + stack_out: wp.array2d(dtype=int), +): + """DFS flood fill to discover islands using tree_tree matrix.""" + worldid = wp.tid() + nisland = int(0) + + # iterate over trees + for i in range(ntree): + # already assigned + if labels_in[worldid, i] != -1: + continue + + # check if tree has any edges + has_edge = int(0) + for j in range(ntree): + if tree_tree_in[worldid, i, j] != 0: + has_edge = 1 + break + if has_edge == 0: + continue + + # DFS: push i onto stack + nstack = int(0) + stack_out[worldid, nstack] = i + nstack = nstack + 1 + + while nstack > 0: + # pop v from stack + nstack = nstack - 1 + v = stack_in[worldid, nstack] + + # already assigned + if labels_in[worldid, v] != -1: + continue + + # assign to current island + labels_out[worldid, v] = nisland + + # push neighbors + for neighbor in range(ntree): + if tree_tree_in[worldid, v, neighbor] != 0: + if labels_in[worldid, neighbor] == -1 and nstack < ntree: + stack_out[worldid, nstack] = neighbor + nstack = nstack + 1 + + # island filled + nisland = nisland + 1 + + nisland_out[worldid] = nisland + + +@event_scope +def island( + m: types.Model, + d: types.Data, +) -> None: + """Discover constraint islands.""" + if m.ntree == 0: + d.nisland.zero_() + return + + # Step 1: Find tree edges + tree_tree = wp.zeros((d.nworld, m.ntree, m.ntree), dtype=int) + tree_edges(m, d, tree_tree) + + # Step 2: DFS flood fill + d.tree_island.fill_(-1) + stack_scratch = wp.zeros((d.nworld, m.ntree), dtype=int) + + wp.launch( + _flood_fill, + dim=d.nworld, + inputs=[m.ntree, tree_tree, d.tree_island, stack_scratch], + outputs=[d.nisland, d.tree_island, stack_scratch], + ) diff --git a/mujoco_warp/_src/island_test.py b/mujoco_warp/_src/island_test.py index 00a082eaf..fb95d498f 100644 --- a/mujoco_warp/_src/island_test.py +++ b/mujoco_warp/_src/island_test.py @@ -22,6 +22,7 @@ import mujoco_warp as mjwarp from mujoco_warp import test_data from mujoco_warp._src import island +from mujoco_warp._src.types import DisableBit class IslandEdgeDiscoveryTest(absltest.TestCase): @@ -51,7 +52,7 @@ def test_single_constraint_two_trees(self): """ ) - mjwarp.forward(m, d) + mjwarp.fwd_position(m, d) treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int) island.tree_edges(m, d, treetree) @@ -82,7 +83,7 @@ def test_constraint_within_single_tree_creates_self_edge(self): """ ) - mjwarp.forward(m, d) + mjwarp.fwd_position(m, d) treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int) island.tree_edges(m, d, treetree) @@ -117,7 +118,7 @@ def test_three_bodies_chain(self): """ ) - mjwarp.forward(m, d) + mjwarp.fwd_position(m, d) treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int) island.tree_edges(m, d, treetree) @@ -151,7 +152,7 @@ def test_deduplication(self): """ ) - mjwarp.forward(m, d) + mjwarp.fwd_position(m, d) treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int) island.tree_edges(m, d, treetree) @@ -176,7 +177,7 @@ def test_no_constraints(self): """ ) - mjwarp.forward(m, d) + mjwarp.fwd_position(m, d) treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int) island.tree_edges(m, d, treetree) @@ -207,7 +208,7 @@ def test_multi_world_parallel(self): nworld=2, ) - mjwarp.forward(m, d) + mjwarp.fwd_position(m, d) treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int) island.tree_edges(m, d, treetree) @@ -238,7 +239,7 @@ def test_contact_constraint_edges(self): nworld=2, ) - mjwarp.forward(m, d) + mjwarp.fwd_position(m, d) nefc = d.nefc.numpy() if nefc[0] > 0: @@ -265,7 +266,7 @@ def test_isolated_tree_no_edge(self): nworld=2, ) - mjwarp.forward(m, d) + mjwarp.fwd_position(m, d) treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int) island.tree_edges(m, d, treetree) @@ -301,7 +302,7 @@ def test_mixed_equality_and_contact(self): nworld=2, ) - mjwarp.forward(m, d) + mjwarp.fwd_position(m, d) treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int) island.tree_edges(m, d, treetree) @@ -332,7 +333,7 @@ def test_worldbody_dofs_ignored(self): nworld=2, ) - mjwarp.forward(m, d) + mjwarp.fwd_position(m, d) treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int) island.tree_edges(m, d, treetree) @@ -368,7 +369,7 @@ def test_constraint_touches_three_trees(self): nworld=2, ) - mjwarp.forward(m, d) + mjwarp.fwd_position(m, d) treetree = wp.empty((d.nworld, m.ntree, m.ntree), dtype=int) island.tree_edges(m, d, treetree) @@ -380,6 +381,267 @@ def test_constraint_touches_three_trees(self): self.assertEqual(tt[0, 2, 0], 1) +class IslandDiscoveryTest(absltest.TestCase): + """Tests for full island discovery including label propagation.""" + + def test_two_trees_one_constraint_one_island(self): + """Two trees connected by one constraint form one island. + + topology: + [[0, 1], + [1, 0]] + """ + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + + + + + + + + """, + overrides={"opt.disableflags": DisableBit.ISLAND}, + ) + + d.nisland.fill_(-1) + d.tree_island.fill_(-1) + mjwarp.fwd_position(m, d) + island.island(m, d) + + # should have exactly 1 island + self.assertEqual(d.nisland.numpy()[0], 1) + # both trees should be in island 0 + tree_island = d.tree_island.numpy()[0] + self.assertEqual(tree_island[0], tree_island[1]) + self.assertEqual(tree_island[0], 0) + + def test_three_trees_chain_one_island(self): + """Three trees in a chain form one island. + + topology: + [[0, 1, 0], + [1, 0, 1], + [0, 1, 0]] + """ + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + + + + + + + + + + + + + """, + overrides={"opt.disableflags": DisableBit.ISLAND}, + ) + + d.nisland.fill_(-1) + d.tree_island.fill_(-1) + mjwarp.fwd_position(m, d) + island.island(m, d) + + # should have exactly 1 island + self.assertEqual(d.nisland.numpy()[0], 1) + # all trees should be in the same island + tree_island = d.tree_island.numpy()[0] + self.assertEqual(tree_island[0], tree_island[1]) + self.assertEqual(tree_island[1], tree_island[2]) + + def test_two_disconnected_pairs_two_islands(self): + """Two pairs of disconnected trees form two islands. + + topology: + [[0, 1, 0, 0], + [1, 0, 0, 0], + [0, 0, 0, 1], + [0, 0, 1, 0]] + """ + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + + + + + + + + + + + + + + + + + """, + overrides={"opt.disableflags": DisableBit.ISLAND}, + ) + + d.nisland.fill_(-1) + d.tree_island.fill_(-1) + mjwarp.fwd_position(m, d) + island.island(m, d) + + # should have exactly 2 islands + self.assertEqual(d.nisland.numpy()[0], 2) + # trees 0,1 should be in one island, trees 2,3 in another + tree_island = d.tree_island.numpy()[0] + self.assertEqual(tree_island[0], tree_island[1]) + self.assertEqual(tree_island[2], tree_island[3]) + self.assertNotEqual(tree_island[0], tree_island[2]) + + def test_no_constraints_no_islands(self): + """No constraints means no constrained islands. + + topology: + [[0]] (no edges) + """ + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + """, + overrides={"opt.disableflags": DisableBit.ISLAND}, + ) + + d.nisland.fill_(-1) + d.tree_island.fill_(-1) + mjwarp.fwd_position(m, d) + island.island(m, d) + + # should have 0 islands (unconstrained tree is not an island) + self.assertEqual(d.nisland.numpy()[0], 0) + + def test_multiple_worlds(self): + """Test island discovery with nworld=2. + + topology: + [[0, 1], + [1, 0]] + """ + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + + + + + + + + """, + nworld=2, + overrides={"opt.disableflags": DisableBit.ISLAND}, + ) + + d.nisland.fill_(-1) + d.tree_island.fill_(-1) + mjwarp.fwd_position(m, d) + island.island(m, d) + + # both worlds should have exactly 1 island + nisland = d.nisland.numpy() + self.assertEqual(nisland[0], 1) + self.assertEqual(nisland[1], 1) + + # both trees in both worlds should be in island 0 + tree_island = d.tree_island.numpy() + for worldid in range(2): + self.assertEqual(tree_island[worldid, 0], 0) + self.assertEqual(tree_island[worldid, 1], 0) + + def test_three_trees_star_hub_at_end(self): + """Three trees with tree 2 as hub connecting trees 0 and 1. + + topology: + [[0, 0, 1], + [0, 0, 1], + [1, 1, 0]] + """ + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + + + + + + + + + + + + + """, + overrides={"opt.disableflags": DisableBit.ISLAND}, + ) + + d.nisland.fill_(-1) + d.tree_island.fill_(-1) + mjwarp.fwd_position(m, d) + island.island(m, d) + + # should have exactly 1 island + self.assertEqual(d.nisland.numpy()[0], 1) + # all trees should be in the same island + tree_island = d.tree_island.numpy()[0] + self.assertEqual(tree_island[0], tree_island[1]) + self.assertEqual(tree_island[1], tree_island[2]) + + if __name__ == "__main__": wp.init() absltest.main() diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py index 41fbd4b3e..8ffdd44db 100644 --- a/mujoco_warp/_src/types.py +++ b/mujoco_warp/_src/types.py @@ -167,6 +167,7 @@ class DisableBit(enum.IntFlag): SENSOR: sensors EULERDAMP: implicit damping for Euler integration NATIVECCD: native convex collision detection (ignored in MJWarp) + ISLAND: constraint islands """ CONSTRAINT = mujoco.mjtDisableBit.mjDSBL_CONSTRAINT @@ -185,7 +186,8 @@ class DisableBit(enum.IntFlag): SENSOR = mujoco.mjtDisableBit.mjDSBL_SENSOR EULERDAMP = mujoco.mjtDisableBit.mjDSBL_EULERDAMP NATIVECCD = mujoco.mjtDisableBit.mjDSBL_NATIVECCD - # unsupported: MIDPHASE, AUTORESET, ISLAND + ISLAND = mujoco.mjtDisableBit.mjDSBL_ISLAND + # unsupported: MIDPHASE, AUTORESET class EnableBit(enum.IntFlag): @@ -1583,6 +1585,7 @@ class Data: nf: number of friction constraints (nworld,) nl: number of limit constraints (nworld,) nefc: number of constraints (nworld,) + nisland: number of constraint islands (nworld,) time: simulation time (nworld,) energy: potential, kinetic energy (nworld, 2) qpos: position (nworld, nq) @@ -1659,6 +1662,7 @@ class Data: cfrc_ext: com-based external force on body (nworld, nbody, 6) contact: contact data efc: constraint data + tree_island: island ID per tree (-1 if unconstrained) (nworld, ntree) warp only fields: nworld: number of worlds @@ -1676,6 +1680,7 @@ class Data: nf: array("nworld", int) nl: array("nworld", int) nefc: array("nworld", int) + nisland: array("nworld", int) time: array("nworld", float) energy: array("nworld", wp.vec2) qpos: array("nworld", "nq", float) @@ -1748,6 +1753,7 @@ class Data: cfrc_ext: array("nworld", "nbody", wp.spatial_vector) contact: Contact efc: Constraint + tree_island: array("nworld", "ntree", int) # warp only fields: nworld: int