Skip to content

Commit 4df2e40

Browse files
committed
heterogeneous meshes
1 parent dc50a56 commit 4df2e40

File tree

10 files changed

+748
-216
lines changed

10 files changed

+748
-216
lines changed

mujoco_warp/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from ._src.inverse import inverse as inverse
4242
from ._src.io import get_data_into as get_data_into
4343
from ._src.io import make_data as make_data
44+
from ._src.io import mesh_batch as mesh_batch
4445
from ._src.io import put_data as put_data
4546
from ._src.io import put_model as put_model
4647
from ._src.io import reset_data as reset_data

mujoco_warp/_src/collision_convex.py

Lines changed: 58 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -223,20 +223,25 @@ def ccd_hfield_kernel(
223223
hfield_ncol: wp.array(dtype=int),
224224
hfield_size: wp.array(dtype=wp.vec4),
225225
hfield_data: wp.array(dtype=float),
226-
mesh_vertadr: wp.array(dtype=int),
227-
mesh_vertnum: wp.array(dtype=int),
228-
mesh_vert: wp.array(dtype=wp.vec3),
229-
mesh_graphadr: wp.array(dtype=int),
230-
mesh_graph: wp.array(dtype=int),
231-
mesh_polynum: wp.array(dtype=int),
232-
mesh_polyadr: wp.array(dtype=int),
233-
mesh_polynormal: wp.array(dtype=wp.vec3),
234-
mesh_polyvertadr: wp.array(dtype=int),
235-
mesh_polyvertnum: wp.array(dtype=int),
236-
mesh_polyvert: wp.array(dtype=int),
237-
mesh_polymapadr: wp.array(dtype=int),
238-
mesh_polymapnum: wp.array(dtype=int),
239-
mesh_polymap: wp.array(dtype=int),
226+
mesh_vertadr_offset: wp.array(dtype=int),
227+
mesh_graphadr_offset: wp.array(dtype=int),
228+
mesh_polyadr_offset: wp.array(dtype=int),
229+
mesh_polyvertadr_offset: wp.array(dtype=int),
230+
mesh_polymapadr_offset: wp.array(dtype=int),
231+
mesh_vertadr: wp.array2d(dtype=int),
232+
mesh_vertnum: wp.array2d(dtype=int),
233+
mesh_vert: wp.array2d(dtype=wp.vec3),
234+
mesh_graphadr: wp.array2d(dtype=int),
235+
mesh_graph: wp.array2d(dtype=int),
236+
mesh_polynum: wp.array2d(dtype=int),
237+
mesh_polyadr: wp.array2d(dtype=int),
238+
mesh_polynormal: wp.array2d(dtype=wp.vec3),
239+
mesh_polyvertadr: wp.array2d(dtype=int),
240+
mesh_polyvertnum: wp.array2d(dtype=int),
241+
mesh_polyvert: wp.array2d(dtype=int),
242+
mesh_polymapadr: wp.array2d(dtype=int),
243+
mesh_polymapnum: wp.array2d(dtype=int),
244+
mesh_polymap: wp.array2d(dtype=int),
240245
pair_dim: wp.array(dtype=int),
241246
pair_solref: wp.array2d(dtype=wp.vec2),
242247
pair_solreffriction: wp.array2d(dtype=wp.vec2),
@@ -340,6 +345,11 @@ def ccd_hfield_kernel(
340345
mesh_polymapadr,
341346
mesh_polymapnum,
342347
mesh_polymap,
348+
mesh_vertadr_offset,
349+
mesh_graphadr_offset,
350+
mesh_polyadr_offset,
351+
mesh_polyvertadr_offset,
352+
mesh_polymapadr_offset,
343353
geom_xpos_in,
344354
geom_xmat_in,
345355
geoms,
@@ -909,20 +919,25 @@ def ccd_kernel(
909919
geom_friction: wp.array2d(dtype=wp.vec3),
910920
geom_margin: wp.array2d(dtype=float),
911921
geom_gap: wp.array2d(dtype=float),
912-
mesh_vertadr: wp.array(dtype=int),
913-
mesh_vertnum: wp.array(dtype=int),
914-
mesh_vert: wp.array(dtype=wp.vec3),
915-
mesh_graphadr: wp.array(dtype=int),
916-
mesh_graph: wp.array(dtype=int),
917-
mesh_polynum: wp.array(dtype=int),
918-
mesh_polyadr: wp.array(dtype=int),
919-
mesh_polynormal: wp.array(dtype=wp.vec3),
920-
mesh_polyvertadr: wp.array(dtype=int),
921-
mesh_polyvertnum: wp.array(dtype=int),
922-
mesh_polyvert: wp.array(dtype=int),
923-
mesh_polymapadr: wp.array(dtype=int),
924-
mesh_polymapnum: wp.array(dtype=int),
925-
mesh_polymap: wp.array(dtype=int),
922+
mesh_vertadr_offset: wp.array(dtype=int),
923+
mesh_graphadr_offset: wp.array(dtype=int),
924+
mesh_polyadr_offset: wp.array(dtype=int),
925+
mesh_polyvertadr_offset: wp.array(dtype=int),
926+
mesh_polymapadr_offset: wp.array(dtype=int),
927+
mesh_vertadr: wp.array2d(dtype=int),
928+
mesh_vertnum: wp.array2d(dtype=int),
929+
mesh_vert: wp.array2d(dtype=wp.vec3),
930+
mesh_graphadr: wp.array2d(dtype=int),
931+
mesh_graph: wp.array2d(dtype=int),
932+
mesh_polynum: wp.array2d(dtype=int),
933+
mesh_polyadr: wp.array2d(dtype=int),
934+
mesh_polynormal: wp.array2d(dtype=wp.vec3),
935+
mesh_polyvertadr: wp.array2d(dtype=int),
936+
mesh_polyvertnum: wp.array2d(dtype=int),
937+
mesh_polyvert: wp.array2d(dtype=int),
938+
mesh_polymapadr: wp.array2d(dtype=int),
939+
mesh_polymapnum: wp.array2d(dtype=int),
940+
mesh_polymap: wp.array2d(dtype=int),
926941
pair_dim: wp.array(dtype=int),
927942
pair_solref: wp.array2d(dtype=wp.vec2),
928943
pair_solreffriction: wp.array2d(dtype=wp.vec2),
@@ -1030,6 +1045,11 @@ def ccd_kernel(
10301045
mesh_polymapadr,
10311046
mesh_polymapnum,
10321047
mesh_polymap,
1048+
mesh_vertadr_offset,
1049+
mesh_graphadr_offset,
1050+
mesh_polyadr_offset,
1051+
mesh_polyvertadr_offset,
1052+
mesh_polymapadr_offset,
10331053
geom_xpos_in,
10341054
geom_xmat_in,
10351055
geoms,
@@ -1226,6 +1246,11 @@ def _pair_count(p1: int, p2: int) -> int:
12261246
m.hfield_ncol,
12271247
m.hfield_size,
12281248
m.hfield_data,
1249+
m.mesh_vertadr_offset,
1250+
m.mesh_graphadr_offset,
1251+
m.mesh_polyadr_offset,
1252+
m.mesh_polyvertadr_offset,
1253+
m.mesh_polymapadr_offset,
12291254
m.mesh_vertadr,
12301255
m.mesh_vertnum,
12311256
m.mesh_vert,
@@ -1314,6 +1339,11 @@ def _pair_count(p1: int, p2: int) -> int:
13141339
m.geom_friction,
13151340
m.geom_margin,
13161341
m.geom_gap,
1342+
m.mesh_vertadr_offset,
1343+
m.mesh_graphadr_offset,
1344+
m.mesh_polyadr_offset,
1345+
m.mesh_polyvertadr_offset,
1346+
m.mesh_polymapadr_offset,
13171347
m.mesh_vertadr,
13181348
m.mesh_vertnum,
13191349
m.mesh_vert,

mujoco_warp/_src/collision_driver_test.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -988,6 +988,146 @@ def test_box_box_face_penetration_depth(self):
988988
msg=f"Contact {i}: Expected penetration {expected_penetration:.4f}, got {mjw_dist:.4f}",
989989
)
990990

991+
def test_batch_mesh(self):
992+
"""Test batched meshes."""
993+
_XML = """
994+
<mujoco>
995+
<option gravity="0 0 0"/>
996+
<asset>
997+
<mesh name="mesh1" builtin="sphere" params="{subdiv1}"/>
998+
<mesh name="mesh2" builtin="sphere" params="{subdiv2}"/>
999+
</asset>
1000+
<worldbody>
1001+
<body name="body1">
1002+
<freejoint/>
1003+
<geom type="mesh" mesh="mesh1" size="0.1"/>
1004+
</body>
1005+
<body name="body2" pos="0 0 {z_offset}">
1006+
<freejoint/>
1007+
<geom type="mesh" mesh="mesh2" size="0.1"/>
1008+
</body>
1009+
</worldbody>
1010+
<keyframe>
1011+
<key name="contact" qpos="0 0 0 1 0 0 0
1012+
0 0 {z_offset} 1 0 0 0"/>
1013+
</keyframe>
1014+
</mujoco>
1015+
"""
1016+
1017+
# different subdivisions give different vertex counts
1018+
# use same z_offset so mesh geometry is the only variable
1019+
configs = [
1020+
{"subdiv1": "0", "subdiv2": "0", "z_offset": "0.15"},
1021+
{"subdiv1": "1", "subdiv2": "1", "z_offset": "0.15"},
1022+
{"subdiv1": "2", "subdiv2": "2", "z_offset": "0.15"},
1023+
]
1024+
1025+
mjms = []
1026+
mjds = []
1027+
for config in configs:
1028+
xml = _XML.format(**config)
1029+
mjm = mujoco.MjModel.from_xml_string(xml)
1030+
mjd = mujoco.MjData(mjm)
1031+
mujoco.mj_resetDataKeyframe(mjm, mjd, 0)
1032+
mujoco.mj_forward(mjm, mjd)
1033+
mujoco.mj_collision(mjm, mjd)
1034+
mjms.append(mjm)
1035+
mjds.append(mjd)
1036+
1037+
# verify different vertex counts
1038+
vertnums = [mjm.nmeshvert for mjm in mjms]
1039+
assert vertnums[0] != vertnums[1] != vertnums[2], f"Expected different vertnums: {vertnums}"
1040+
1041+
# test 1: same mesh for all worlds - verify contacts match mujoco
1042+
for i, (config, mjd) in enumerate(zip(configs, mjds)):
1043+
xml = _XML.format(**config)
1044+
1045+
_, _, m, d = test_data.fixture(xml=xml, keyframe=0, nworld=3)
1046+
mjw.collision(m, d)
1047+
1048+
# group contacts by worldid
1049+
nacon = d.nacon.numpy()[0]
1050+
ncon_per_world = [0, 0, 0]
1051+
dist_per_world = [[], [], []]
1052+
for j in range(nacon):
1053+
wid = d.contact.worldid.numpy()[j]
1054+
ncon_per_world[wid] += 1
1055+
dist_per_world[wid].append(d.contact.dist.numpy()[j])
1056+
1057+
# verify each world matches mujoco
1058+
for worldid in range(3):
1059+
_assert_eq(ncon_per_world[worldid], mjd.ncon, f"Model {i} world {worldid}: ncon")
1060+
if mjd.ncon > 0:
1061+
_assert_eq(dist_per_world[worldid][0], mjd.contact.dist[0], f"Model {i} world {worldid}: dist")
1062+
1063+
# test 2: different mesh per world via mesh_batch
1064+
xml0 = _XML.format(**configs[0])
1065+
_, _, m0, d0 = test_data.fixture(xml=xml0, keyframe=0, nworld=3)
1066+
1067+
# create batched mesh arrays
1068+
m0 = mjw.mesh_batch(m0, mjms)
1069+
1070+
# verify mesh data fields are correctly batched
1071+
for worldid, mjm in enumerate(mjms):
1072+
_assert_eq(m0.mesh_vertnum.numpy()[worldid], mjm.mesh_vertnum, f"world {worldid}: mesh_vertnum")
1073+
_assert_eq(m0.mesh_vertadr.numpy()[worldid], mjm.mesh_vertadr, f"world {worldid}: mesh_vertadr")
1074+
_assert_eq(m0.mesh_faceadr.numpy()[worldid], mjm.mesh_faceadr, f"world {worldid}: mesh_faceadr")
1075+
_assert_eq(m0.mesh_quat.numpy()[worldid], mjm.mesh_quat, f"world {worldid}: mesh_quat")
1076+
_assert_eq(m0.mesh_polynum.numpy()[worldid], mjm.mesh_polynum, f"world {worldid}: mesh_polynum")
1077+
1078+
# verify geom bounds are correctly batched
1079+
for worldid, mjm in enumerate(mjms):
1080+
_assert_eq(m0.geom_size.numpy()[worldid], mjm.geom_size, f"world {worldid}: geom_size")
1081+
_assert_eq(m0.geom_aabb.numpy()[worldid].reshape(mjm.geom_aabb.shape), mjm.geom_aabb, f"world {worldid}: geom_aabb")
1082+
_assert_eq(m0.geom_rbound.numpy()[worldid], mjm.geom_rbound, f"world {worldid}: geom_rbound")
1083+
1084+
# verify body mass/inertia are correctly batched
1085+
for worldid, mjm in enumerate(mjms):
1086+
_assert_eq(m0.body_mass.numpy()[worldid], mjm.body_mass, f"world {worldid}: body_mass")
1087+
_assert_eq(m0.body_inertia.numpy()[worldid], mjm.body_inertia, f"world {worldid}: body_inertia")
1088+
_assert_eq(m0.body_ipos.numpy()[worldid], mjm.body_ipos, f"world {worldid}: body_ipos")
1089+
_assert_eq(m0.dof_invweight0.numpy()[worldid], mjm.dof_invweight0, f"world {worldid}: dof_invweight0")
1090+
1091+
mjw.collision(m0, d0)
1092+
1093+
# verify contact distance and normal for each world
1094+
for worldid in range(3):
1095+
mjd = mjds[worldid]
1096+
1097+
# find contacts for this world
1098+
warp_idx = np.where(d0.contact.worldid.numpy() == worldid)[0]
1099+
self.assertGreater(len(warp_idx), 0, f"world {worldid} has no contacts")
1100+
1101+
# compare first contact distance and normal
1102+
mj_dist = mjd.contact.dist[0]
1103+
mj_frame = mjd.contact.frame[0].reshape(3, 3)
1104+
mj_normal = mj_frame[:, 0]
1105+
1106+
wp_dist = d0.contact.dist.numpy()[warp_idx[0]]
1107+
1108+
_assert_eq(wp_dist, mj_dist, f"world {worldid}: contact.dist")
1109+
1110+
# test 3: 4 worlds with 2 unique mesh configs (configs repeat)
1111+
nworld = 4
1112+
_, _, m1, d1 = test_data.fixture(xml=_XML.format(**configs[0]), keyframe=0, nworld=nworld)
1113+
m1 = mjw.mesh_batch(m1, mjms[:2]) # construct with 2 unique mjModel instances
1114+
1115+
# verify fields match corresponding config
1116+
for worldid in range(nworld):
1117+
_assert_eq(m1.mesh_vertnum.numpy()[worldid % 2], mjms[worldid % 2].mesh_vertnum, f"world {worldid}: mesh_vertnum")
1118+
_assert_eq(m1.geom_size.numpy()[worldid % 2], mjms[worldid % 2].geom_size, f"world {worldid}: geom_size")
1119+
_assert_eq(m1.body_mass.numpy()[worldid % 2], mjms[worldid % 2].body_mass, f"world {worldid}: body_mass")
1120+
1121+
mjw.collision(m1, d1)
1122+
1123+
# verify contacts for each world match corresponding config
1124+
for worldid in range(nworld):
1125+
mjd = mjds[worldid % 2]
1126+
warp_idx = np.where(d1.contact.worldid.numpy() == worldid)[0]
1127+
self.assertGreater(len(warp_idx), 0, f"world {worldid} has no contacts")
1128+
wp_dist = d1.contact.dist.numpy()[warp_idx[0]]
1129+
_assert_eq(wp_dist, mjd.contact.dist[0], f"world {worldid}: contact.dist")
1130+
9911131

9921132
if __name__ == "__main__":
9931133
absltest.main()

mujoco_warp/_src/collision_gjk_test.py

Lines changed: 48 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -74,18 +74,22 @@ def _ccd_kernel(
7474
geom_type: wp.array(dtype=int),
7575
geom_dataid: wp.array(dtype=int),
7676
geom_size: wp.array2d(dtype=wp.vec3),
77-
mesh_vertadr: wp.array(dtype=int),
78-
mesh_vertnum: wp.array(dtype=int),
79-
mesh_vert: wp.array(dtype=wp.vec3),
80-
mesh_polynum: wp.array(dtype=int),
81-
mesh_polyadr: wp.array(dtype=int),
82-
mesh_polynormal: wp.array(dtype=wp.vec3),
83-
mesh_polyvertadr: wp.array(dtype=int),
84-
mesh_polyvertnum: wp.array(dtype=int),
85-
mesh_polyvert: wp.array(dtype=int),
86-
mesh_polymapadr: wp.array(dtype=int),
87-
mesh_polymapnum: wp.array(dtype=int),
88-
mesh_polymap: wp.array(dtype=int),
77+
mesh_vertadr_offset: wp.array(dtype=int),
78+
mesh_polyadr_offset: wp.array(dtype=int),
79+
mesh_polyvertadr_offset: wp.array(dtype=int),
80+
mesh_polymapadr_offset: wp.array(dtype=int),
81+
mesh_vertadr: wp.array2d(dtype=int),
82+
mesh_vertnum: wp.array2d(dtype=int),
83+
mesh_vert: wp.array2d(dtype=wp.vec3),
84+
mesh_polynum: wp.array2d(dtype=int),
85+
mesh_polyadr: wp.array2d(dtype=int),
86+
mesh_polynormal: wp.array2d(dtype=wp.vec3),
87+
mesh_polyvertadr: wp.array2d(dtype=int),
88+
mesh_polyvertnum: wp.array2d(dtype=int),
89+
mesh_polyvert: wp.array2d(dtype=int),
90+
mesh_polymapadr: wp.array2d(dtype=int),
91+
mesh_polymapnum: wp.array2d(dtype=int),
92+
mesh_polymap: wp.array2d(dtype=int),
8993
# Data in:
9094
geom_xpos_in: wp.array2d(dtype=wp.vec3),
9195
geom_xmat_in: wp.array2d(dtype=wp.mat33),
@@ -139,18 +143,20 @@ def _ccd_kernel(
139143

140144
if geom_dataid[gid1] >= 0 and geom_type[gid1] == GeomType.MESH:
141145
dataid = geom_dataid[gid1]
142-
geom1.vertadr = mesh_vertadr[dataid]
143-
geom1.vertnum = mesh_vertnum[dataid]
144-
geom1.mesh_polynum = mesh_polynum[dataid]
145-
geom1.mesh_polyadr = mesh_polyadr[dataid]
146-
geom1.vert = mesh_vert
147-
geom1.mesh_polynormal = mesh_polynormal
148-
geom1.mesh_polyvertadr = mesh_polyvertadr
149-
geom1.mesh_polyvertnum = mesh_polyvertnum
150-
geom1.mesh_polyvert = mesh_polyvert
151-
geom1.mesh_polymapadr = mesh_polymapadr
152-
geom1.mesh_polymapnum = mesh_polymapnum
153-
geom1.mesh_polymap = mesh_polymap
146+
world_vert_offset = mesh_vertadr_offset[0]
147+
world_poly_offset = mesh_polyadr_offset[0]
148+
geom1.vertadr = mesh_vertadr[0, dataid] + world_vert_offset
149+
geom1.vertnum = mesh_vertnum[0, dataid]
150+
geom1.mesh_polynum = mesh_polynum[0, dataid]
151+
geom1.mesh_polyadr = mesh_polyadr[0, dataid] + world_poly_offset
152+
geom1.vert = mesh_vert[0]
153+
geom1.mesh_polynormal = mesh_polynormal[0]
154+
geom1.mesh_polyvertadr = mesh_polyvertadr[0]
155+
geom1.mesh_polyvertnum = mesh_polyvertnum[0]
156+
geom1.mesh_polyvert = mesh_polyvert[0]
157+
geom1.mesh_polymapadr = mesh_polymapadr[0]
158+
geom1.mesh_polymapnum = mesh_polymapnum[0]
159+
geom1.mesh_polymap = mesh_polymap[0]
154160

155161
geom2 = Geom()
156162
geom2.index = -1
@@ -170,18 +176,20 @@ def _ccd_kernel(
170176

171177
if geom_dataid[gid2] >= 0 and geom_type[gid2] == GeomType.MESH:
172178
dataid = geom_dataid[gid2]
173-
geom2.vertadr = mesh_vertadr[dataid]
174-
geom2.vertnum = mesh_vertnum[dataid]
175-
geom2.mesh_polynum = mesh_polynum[dataid]
176-
geom2.mesh_polyadr = mesh_polyadr[dataid]
177-
geom2.vert = mesh_vert
178-
geom2.mesh_polynormal = mesh_polynormal
179-
geom2.mesh_polyvertadr = mesh_polyvertadr
180-
geom2.mesh_polyvertnum = mesh_polyvertnum
181-
geom2.mesh_polyvert = mesh_polyvert
182-
geom2.mesh_polymapadr = mesh_polymapadr
183-
geom2.mesh_polymapnum = mesh_polymapnum
184-
geom2.mesh_polymap = mesh_polymap
179+
world_vert_offset = mesh_vertadr_offset[0]
180+
world_poly_offset = mesh_polyadr_offset[0]
181+
geom2.vertadr = mesh_vertadr[0, dataid] + world_vert_offset
182+
geom2.vertnum = mesh_vertnum[0, dataid]
183+
geom2.mesh_polynum = mesh_polynum[0, dataid]
184+
geom2.mesh_polyadr = mesh_polyadr[0, dataid] + world_poly_offset
185+
geom2.vert = mesh_vert[0]
186+
geom2.mesh_polynormal = mesh_polynormal[0]
187+
geom2.mesh_polyvertadr = mesh_polyvertadr[0]
188+
geom2.mesh_polyvertnum = mesh_polyvertnum[0]
189+
geom2.mesh_polyvert = mesh_polyvert[0]
190+
geom2.mesh_polymapadr = mesh_polymapadr[0]
191+
geom2.mesh_polymapnum = mesh_polymapnum[0]
192+
geom2.mesh_polymap = mesh_polymap[0]
185193

186194
(
187195
dist,
@@ -254,6 +262,10 @@ def _ccd_kernel(
254262
m.geom_type,
255263
m.geom_dataid,
256264
m.geom_size,
265+
m.mesh_vertadr_offset,
266+
m.mesh_polyadr_offset,
267+
m.mesh_polyvertadr_offset,
268+
m.mesh_polymapadr_offset,
257269
m.mesh_vertadr,
258270
m.mesh_vertnum,
259271
m.mesh_vert,

0 commit comments

Comments
 (0)