Skip to content

Commit f5436af

Browse files
mgarrardfacebook-github-bot
authored andcommitted
Fix read-only numpy arrays in cross-validation (facebook#4794)
Summary: Prompted after some failures in exports not related to my changes: https://github.com/facebook/Ax/actions/runs/21225324302/job/61070638075?fbclid=IwY2xjawPeXMFleHRuA2FlbQIxMQBicmlkETFRTkR6WlE4NHVrd3IyQXNlc3J0YwZhcHBfaWQBMAABHjTAiZi71n24w95hvzEewrKNPKOGzJisgR7t4qJ3APRMYlusgFC-gu7RLiSb_aem_Zk3pmTDonCFsJvZCTkpeMA Add `.copy()` after `.numpy()` calls to ensure arrays are writeable. PyTorch tensors converted via `.detach().cpu().numpy()` return read-only arrays in some cases. The subsequent squeeze operations create read-only views, and the in-place assignment `loo_covs[:, diag_idx, diag_idx] = loo_vars` fails with "assignment destination is read-only" error. Differential Revision: D91185467
1 parent 6045c31 commit f5436af

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

ax/adapter/cross_validation.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -240,12 +240,13 @@ def _efficient_loo_cross_validate(
240240
# This is a known limitation for fully Bayesian models and models with
241241
# non-Gaussian posteriors (such as PFNs).
242242
# Shape: n x 1 x m
243-
loo_means = posterior.mixture_mean.detach().cpu().numpy()
244-
loo_vars = posterior.mixture_variance.detach().cpu().numpy()
243+
# Use .copy() to ensure arrays are writeable (numpy returns read-only views)
244+
loo_means = posterior.mixture_mean.detach().cpu().numpy().copy()
245+
loo_vars = posterior.mixture_variance.detach().cpu().numpy().copy()
245246
else:
246247
# Shape: n x 1 x m
247-
loo_means = posterior.mean.detach().cpu().numpy()
248-
loo_vars = posterior.variance.detach().cpu().numpy()
248+
loo_means = posterior.mean.detach().cpu().numpy().copy()
249+
loo_vars = posterior.variance.detach().cpu().numpy().copy()
249250

250251
# Squeeze out the q dimension: n x 1 x m -> n x m
251252
loo_means = loo_means.squeeze(1)

0 commit comments

Comments
 (0)