diff --git a/pyhealth/interpret/methods/lime.py b/pyhealth/interpret/methods/lime.py index 8bdc1b67..d2c59a3c 100644 --- a/pyhealth/interpret/methods/lime.py +++ b/pyhealth/interpret/methods/lime.py @@ -389,6 +389,8 @@ def _compute_lime( """ device = input_emb[key].device batch_size = input_emb[key].shape[0] if input_emb[key].dim() >= 2 else 1 + # Keep large intermediate tensors off the GPU to avoid OOM + storage_device = torch.device("cpu") # Storage for samples and predictions interpretable_samples = [] # Binary vectors @@ -438,12 +440,18 @@ def _compute_lime( perturbed_predictions.append(torch.stack(batch_preds, dim=0)) similarity_weights.append(torch.stack(batch_similarities, dim=0)) + # Move small summaries to CPU to free GPU memory + interpretable_samples[-1] = interpretable_samples[-1].float().to(storage_device) + perturbed_predictions[-1] = perturbed_predictions[-1].detach().to(storage_device) + similarity_weights[-1] = similarity_weights[-1].detach().to(storage_device) + # Train weighted linear regression return self._train_interpretable_model( interpretable_samples, perturbed_predictions, similarity_weights, - device, + compute_device=storage_device, + target_device=device, ) def _create_perturbed_sample( @@ -581,23 +589,24 @@ def _compute_similarity( Returns: Similarity weight (scalar tensor). """ - # Flatten embeddings for distance computation - orig_flat = original_emb.reshape(-1).float() - pert_flat = perturbed_emb.reshape(-1).float() - - # Compute distance - if self.distance_mode == "cosine": - cos_sim = CosineSimilarity(dim=0) - distance = 1 - cos_sim(orig_flat, pert_flat) - elif self.distance_mode == "euclidean": - distance = torch.norm(orig_flat - pert_flat) - else: - raise ValueError("Invalid distance_mode") + with torch.no_grad(): + # Flatten embeddings for distance computation + orig_flat = original_emb.reshape(-1).float() + pert_flat = perturbed_emb.reshape(-1).float() + + # Compute distance + if self.distance_mode == "cosine": + cos_sim = CosineSimilarity(dim=0) + distance = 1 - cos_sim(orig_flat, pert_flat) + elif self.distance_mode == "euclidean": + distance = torch.norm(orig_flat - pert_flat) + else: + raise ValueError("Invalid distance_mode") - # Apply exponential kernel - similarity = torch.exp( - -1 * (distance ** 2) / (2 * (self.kernel_width ** 2)) - ) + # Apply exponential kernel + similarity = torch.exp( + -1 * (distance ** 2) / (2 * (self.kernel_width ** 2)) + ) return similarity @@ -606,7 +615,8 @@ def _train_interpretable_model( interpretable_samples: list, predictions: list, weights: list, - device: torch.device, + compute_device: torch.device, + target_device: torch.device, ) -> torch.Tensor: """Train weighted linear regression model. @@ -619,15 +629,16 @@ def _train_interpretable_model( interpretable_samples: List of binary vectors. predictions: List of model predictions. weights: List of similarity weights. - device: Device for computation. + compute_device: Device for regression solve (CPU to save GPU memory). + target_device: Device to place the returned coefficients. Returns: Linear model coefficients (batch_size, n_features). """ # Stack collected data - X = torch.stack(interpretable_samples, dim=0).to(device) # (n_samples, n_features) - Y = torch.stack(predictions, dim=0).to(device) # (n_samples, batch_size) - W = torch.stack(weights, dim=0).to(device) # (n_samples, batch_size) + X = torch.stack(interpretable_samples, dim=0).to(compute_device) # (n_samples, n_features) + Y = torch.stack(predictions, dim=0).to(compute_device) # (n_samples, batch_size) + W = torch.stack(weights, dim=0).to(compute_device) # (n_samples, batch_size) # Solve for each batch item independently batch_size = Y.shape[1] @@ -647,18 +658,18 @@ def _train_interpretable_model( # Solve based on feature selection method if self.feature_selection == "lasso": # L1 regularization (approximated with iterative reweighted least squares) - coef = self._solve_lasso(Xw, yw, device) + coef = self._solve_lasso(Xw, yw, compute_device) elif self.feature_selection == "ridge": # L2 regularization - coef = self._solve_ridge(Xw, yw, device) + coef = self._solve_ridge(Xw, yw, compute_device) else: # "none" # No regularization - coef = self._solve_ols(Xw, yw, device) + coef = self._solve_ols(Xw, yw, compute_device) coefficients.append(coef) # Stack into (batch_size, n_features) - return torch.stack(coefficients, dim=0) + return torch.stack(coefficients, dim=0).to(target_device) def _solve_lasso( self, @@ -874,7 +885,8 @@ def _forward_from_inputs( (perturbed_inputs.shape[0], 1), device=perturbed_inputs.device ) - output = self.model(**model_inputs) + with torch.no_grad(): + output = self.model(**model_inputs) return self._extract_logits(output) def _prepare_time_info(