Skip to content

Commit 3f414a8

Browse files
authored
Fix lime CUDA OOM (#811)
* Potential fix for lime * vectorize lime * Revert "vectorize lime" This reverts commit fd08666.
1 parent d027604 commit 3f414a8

File tree

1 file changed

+39
-27
lines changed

1 file changed

+39
-27
lines changed

pyhealth/interpret/methods/lime.py

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,8 @@ def _compute_lime(
389389
"""
390390
device = input_emb[key].device
391391
batch_size = input_emb[key].shape[0] if input_emb[key].dim() >= 2 else 1
392+
# Keep large intermediate tensors off the GPU to avoid OOM
393+
storage_device = torch.device("cpu")
392394

393395
# Storage for samples and predictions
394396
interpretable_samples = [] # Binary vectors
@@ -438,12 +440,18 @@ def _compute_lime(
438440
perturbed_predictions.append(torch.stack(batch_preds, dim=0))
439441
similarity_weights.append(torch.stack(batch_similarities, dim=0))
440442

443+
# Move small summaries to CPU to free GPU memory
444+
interpretable_samples[-1] = interpretable_samples[-1].float().to(storage_device)
445+
perturbed_predictions[-1] = perturbed_predictions[-1].detach().to(storage_device)
446+
similarity_weights[-1] = similarity_weights[-1].detach().to(storage_device)
447+
441448
# Train weighted linear regression
442449
return self._train_interpretable_model(
443450
interpretable_samples,
444451
perturbed_predictions,
445452
similarity_weights,
446-
device,
453+
compute_device=storage_device,
454+
target_device=device,
447455
)
448456

449457
def _create_perturbed_sample(
@@ -581,23 +589,24 @@ def _compute_similarity(
581589
Returns:
582590
Similarity weight (scalar tensor).
583591
"""
584-
# Flatten embeddings for distance computation
585-
orig_flat = original_emb.reshape(-1).float()
586-
pert_flat = perturbed_emb.reshape(-1).float()
587-
588-
# Compute distance
589-
if self.distance_mode == "cosine":
590-
cos_sim = CosineSimilarity(dim=0)
591-
distance = 1 - cos_sim(orig_flat, pert_flat)
592-
elif self.distance_mode == "euclidean":
593-
distance = torch.norm(orig_flat - pert_flat)
594-
else:
595-
raise ValueError("Invalid distance_mode")
592+
with torch.no_grad():
593+
# Flatten embeddings for distance computation
594+
orig_flat = original_emb.reshape(-1).float()
595+
pert_flat = perturbed_emb.reshape(-1).float()
596+
597+
# Compute distance
598+
if self.distance_mode == "cosine":
599+
cos_sim = CosineSimilarity(dim=0)
600+
distance = 1 - cos_sim(orig_flat, pert_flat)
601+
elif self.distance_mode == "euclidean":
602+
distance = torch.norm(orig_flat - pert_flat)
603+
else:
604+
raise ValueError("Invalid distance_mode")
596605

597-
# Apply exponential kernel
598-
similarity = torch.exp(
599-
-1 * (distance ** 2) / (2 * (self.kernel_width ** 2))
600-
)
606+
# Apply exponential kernel
607+
similarity = torch.exp(
608+
-1 * (distance ** 2) / (2 * (self.kernel_width ** 2))
609+
)
601610

602611
return similarity
603612

@@ -606,7 +615,8 @@ def _train_interpretable_model(
606615
interpretable_samples: list,
607616
predictions: list,
608617
weights: list,
609-
device: torch.device,
618+
compute_device: torch.device,
619+
target_device: torch.device,
610620
) -> torch.Tensor:
611621
"""Train weighted linear regression model.
612622
@@ -619,15 +629,16 @@ def _train_interpretable_model(
619629
interpretable_samples: List of binary vectors.
620630
predictions: List of model predictions.
621631
weights: List of similarity weights.
622-
device: Device for computation.
632+
compute_device: Device for regression solve (CPU to save GPU memory).
633+
target_device: Device to place the returned coefficients.
623634
624635
Returns:
625636
Linear model coefficients (batch_size, n_features).
626637
"""
627638
# Stack collected data
628-
X = torch.stack(interpretable_samples, dim=0).to(device) # (n_samples, n_features)
629-
Y = torch.stack(predictions, dim=0).to(device) # (n_samples, batch_size)
630-
W = torch.stack(weights, dim=0).to(device) # (n_samples, batch_size)
639+
X = torch.stack(interpretable_samples, dim=0).to(compute_device) # (n_samples, n_features)
640+
Y = torch.stack(predictions, dim=0).to(compute_device) # (n_samples, batch_size)
641+
W = torch.stack(weights, dim=0).to(compute_device) # (n_samples, batch_size)
631642

632643
# Solve for each batch item independently
633644
batch_size = Y.shape[1]
@@ -647,18 +658,18 @@ def _train_interpretable_model(
647658
# Solve based on feature selection method
648659
if self.feature_selection == "lasso":
649660
# L1 regularization (approximated with iterative reweighted least squares)
650-
coef = self._solve_lasso(Xw, yw, device)
661+
coef = self._solve_lasso(Xw, yw, compute_device)
651662
elif self.feature_selection == "ridge":
652663
# L2 regularization
653-
coef = self._solve_ridge(Xw, yw, device)
664+
coef = self._solve_ridge(Xw, yw, compute_device)
654665
else: # "none"
655666
# No regularization
656-
coef = self._solve_ols(Xw, yw, device)
667+
coef = self._solve_ols(Xw, yw, compute_device)
657668

658669
coefficients.append(coef)
659670

660671
# Stack into (batch_size, n_features)
661-
return torch.stack(coefficients, dim=0)
672+
return torch.stack(coefficients, dim=0).to(target_device)
662673

663674
def _solve_lasso(
664675
self,
@@ -874,7 +885,8 @@ def _forward_from_inputs(
874885
(perturbed_inputs.shape[0], 1), device=perturbed_inputs.device
875886
)
876887

877-
output = self.model(**model_inputs)
888+
with torch.no_grad():
889+
output = self.model(**model_inputs)
878890
return self._extract_logits(output)
879891

880892
def _prepare_time_info(

0 commit comments

Comments
 (0)