Skip to content

Commit d7eb939

Browse files
authored
Reinsert broadcasted mask (#581)
* 1. Moved the dimension subset check into broadcast_mask 2. Added a brief docstring to broadcast_mask * Add tests for superset dims
1 parent 27a2a96 commit d7eb939

File tree

4 files changed

+20
-6
lines changed

4 files changed

+20
-6
lines changed

linopy/common.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,16 @@ def as_dataarray(
287287

288288

289289
def broadcast_mask(mask: DataArray, labels: DataArray) -> DataArray:
290+
"""
291+
Broadcast a boolean mask to match the shape of labels.
292+
293+
Ensures that mask dimensions are a subset of labels dimensions, broadcasts
294+
the mask accordingly, and fills any NaN values (from missing coordinates)
295+
with False while emitting a FutureWarning.
296+
"""
297+
assert set(mask.dims).issubset(labels.dims), (
298+
"Dimensions of mask not a subset of resulting labels dimensions."
299+
)
290300
mask = mask.broadcast_like(labels)
291301
if mask.isnull().any():
292302
warn(

linopy/model.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -552,9 +552,6 @@ def add_variables(
552552

553553
if mask is not None:
554554
mask = as_dataarray(mask, coords=data.coords, dims=data.dims).astype(bool)
555-
assert set(mask.dims).issubset(data.dims), (
556-
"Dimensions of mask not a subset of resulting labels dimensions."
557-
)
558555
mask = broadcast_mask(mask, data.labels)
559556

560557
# Auto-mask based on NaN in bounds (use numpy for speed)
@@ -750,9 +747,6 @@ def add_constraints(
750747

751748
if mask is not None:
752749
mask = as_dataarray(mask).astype(bool)
753-
assert set(mask.dims).issubset(data.dims), (
754-
"Dimensions of mask not a subset of resulting labels dimensions."
755-
)
756750
mask = broadcast_mask(mask, data.labels)
757751

758752
# Auto-mask based on null expressions or NaN RHS (use numpy for speed)

test/test_constraints.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,11 @@ def test_masked_constraints_broadcast() -> None:
191191
assert (m.constraints.labels.bc3[2:5, :] == -1).all()
192192
assert (m.constraints.labels.bc3[5:10, :] == -1).all()
193193

194+
# Mask with extra dimension not in data should raise
195+
mask4 = xr.DataArray([True, False], dims=["extra_dim"])
196+
with pytest.raises(AssertionError, match="not a subset"):
197+
m.add_constraints(1 * x + 10 * y, EQUAL, 0, name="bc4", mask=mask4)
198+
194199

195200
def test_non_aligned_constraints() -> None:
196201
m: Model = Model()

test/test_variables.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,11 @@ def test_variables_mask_broadcast() -> None:
134134
assert (z.labels[2:5, :] == -1).all()
135135
assert (z.labels[5:10, :] == -1).all()
136136

137+
# Mask with extra dimension not in data should raise
138+
mask4 = xr.DataArray([True, False], dims=["extra_dim"])
139+
with pytest.raises(AssertionError, match="not a subset"):
140+
m.add_variables(lower, upper, name="w", mask=mask4)
141+
137142

138143
def test_variables_get_name_by_label(m: Model) -> None:
139144
assert m.variables.get_name_by_label(4) == "x"

0 commit comments

Comments
 (0)