From 196afb493ac3e378c5f17bb04d14b14e06dddf89 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Tue, 3 Feb 2026 18:24:03 -0600 Subject: [PATCH 1/3] Potential fix for lime --- pyhealth/interpret/methods/lime.py | 66 ++++++++++++++++++------------ 1 file changed, 39 insertions(+), 27 deletions(-) diff --git a/pyhealth/interpret/methods/lime.py b/pyhealth/interpret/methods/lime.py index 8bdc1b673..d2c59a3c3 100644 --- a/pyhealth/interpret/methods/lime.py +++ b/pyhealth/interpret/methods/lime.py @@ -389,6 +389,8 @@ def _compute_lime( """ device = input_emb[key].device batch_size = input_emb[key].shape[0] if input_emb[key].dim() >= 2 else 1 + # Keep large intermediate tensors off the GPU to avoid OOM + storage_device = torch.device("cpu") # Storage for samples and predictions interpretable_samples = [] # Binary vectors @@ -438,12 +440,18 @@ def _compute_lime( perturbed_predictions.append(torch.stack(batch_preds, dim=0)) similarity_weights.append(torch.stack(batch_similarities, dim=0)) + # Move small summaries to CPU to free GPU memory + interpretable_samples[-1] = interpretable_samples[-1].float().to(storage_device) + perturbed_predictions[-1] = perturbed_predictions[-1].detach().to(storage_device) + similarity_weights[-1] = similarity_weights[-1].detach().to(storage_device) + # Train weighted linear regression return self._train_interpretable_model( interpretable_samples, perturbed_predictions, similarity_weights, - device, + compute_device=storage_device, + target_device=device, ) def _create_perturbed_sample( @@ -581,23 +589,24 @@ def _compute_similarity( Returns: Similarity weight (scalar tensor). """ - # Flatten embeddings for distance computation - orig_flat = original_emb.reshape(-1).float() - pert_flat = perturbed_emb.reshape(-1).float() - - # Compute distance - if self.distance_mode == "cosine": - cos_sim = CosineSimilarity(dim=0) - distance = 1 - cos_sim(orig_flat, pert_flat) - elif self.distance_mode == "euclidean": - distance = torch.norm(orig_flat - pert_flat) - else: - raise ValueError("Invalid distance_mode") + with torch.no_grad(): + # Flatten embeddings for distance computation + orig_flat = original_emb.reshape(-1).float() + pert_flat = perturbed_emb.reshape(-1).float() + + # Compute distance + if self.distance_mode == "cosine": + cos_sim = CosineSimilarity(dim=0) + distance = 1 - cos_sim(orig_flat, pert_flat) + elif self.distance_mode == "euclidean": + distance = torch.norm(orig_flat - pert_flat) + else: + raise ValueError("Invalid distance_mode") - # Apply exponential kernel - similarity = torch.exp( - -1 * (distance ** 2) / (2 * (self.kernel_width ** 2)) - ) + # Apply exponential kernel + similarity = torch.exp( + -1 * (distance ** 2) / (2 * (self.kernel_width ** 2)) + ) return similarity @@ -606,7 +615,8 @@ def _train_interpretable_model( interpretable_samples: list, predictions: list, weights: list, - device: torch.device, + compute_device: torch.device, + target_device: torch.device, ) -> torch.Tensor: """Train weighted linear regression model. @@ -619,15 +629,16 @@ def _train_interpretable_model( interpretable_samples: List of binary vectors. predictions: List of model predictions. weights: List of similarity weights. - device: Device for computation. + compute_device: Device for regression solve (CPU to save GPU memory). + target_device: Device to place the returned coefficients. Returns: Linear model coefficients (batch_size, n_features). """ # Stack collected data - X = torch.stack(interpretable_samples, dim=0).to(device) # (n_samples, n_features) - Y = torch.stack(predictions, dim=0).to(device) # (n_samples, batch_size) - W = torch.stack(weights, dim=0).to(device) # (n_samples, batch_size) + X = torch.stack(interpretable_samples, dim=0).to(compute_device) # (n_samples, n_features) + Y = torch.stack(predictions, dim=0).to(compute_device) # (n_samples, batch_size) + W = torch.stack(weights, dim=0).to(compute_device) # (n_samples, batch_size) # Solve for each batch item independently batch_size = Y.shape[1] @@ -647,18 +658,18 @@ def _train_interpretable_model( # Solve based on feature selection method if self.feature_selection == "lasso": # L1 regularization (approximated with iterative reweighted least squares) - coef = self._solve_lasso(Xw, yw, device) + coef = self._solve_lasso(Xw, yw, compute_device) elif self.feature_selection == "ridge": # L2 regularization - coef = self._solve_ridge(Xw, yw, device) + coef = self._solve_ridge(Xw, yw, compute_device) else: # "none" # No regularization - coef = self._solve_ols(Xw, yw, device) + coef = self._solve_ols(Xw, yw, compute_device) coefficients.append(coef) # Stack into (batch_size, n_features) - return torch.stack(coefficients, dim=0) + return torch.stack(coefficients, dim=0).to(target_device) def _solve_lasso( self, @@ -874,7 +885,8 @@ def _forward_from_inputs( (perturbed_inputs.shape[0], 1), device=perturbed_inputs.device ) - output = self.model(**model_inputs) + with torch.no_grad(): + output = self.model(**model_inputs) return self._extract_logits(output) def _prepare_time_info( From fd0866633b073dbc65555ecc3ec8e4c46d952963 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Tue, 3 Feb 2026 20:24:23 -0600 Subject: [PATCH 2/3] vectorize lime --- pyhealth/interpret/methods/lime.py | 101 ++++++++++++++++++++++------- 1 file changed, 76 insertions(+), 25 deletions(-) diff --git a/pyhealth/interpret/methods/lime.py b/pyhealth/interpret/methods/lime.py index d2c59a3c3..7918a5aae 100644 --- a/pyhealth/interpret/methods/lime.py +++ b/pyhealth/interpret/methods/lime.py @@ -410,35 +410,33 @@ def _compute_lime( torch.ones(n_features, device=device) * 0.5 ) - # Create perturbed sample for each batch item - batch_preds = [] - batch_similarities = [] - - for b_idx in range(batch_size): - # Create perturbed embedding by mixing input and baseline - perturbed_emb = self._create_perturbed_sample( - key, binary_vector, input_emb, baseline_emb, b_idx - ) - - # Get model prediction for perturbed sample - pred = self._evaluate_sample( - key, perturbed_emb, baseline_emb, - target_class_idx, time_info, label_data + # Vectorized perturbed batch + perturbed_batch = self._create_perturbed_sample_batch( + key, binary_vector, input_emb, baseline_emb + ) # (batch, ...) + + # Forward pass for the whole batch + if self.use_embeddings: + logits = self._forward_from_embeddings( + key, perturbed_batch, baseline_emb, time_info, label_data ) - batch_preds.append(pred) - - # Compute similarity weight - similarity = self._compute_similarity( - input_emb[key][b_idx:b_idx+1] if batch_size > 1 else input_emb[key], - perturbed_emb, - binary_vector, + else: + logits = self._forward_from_inputs( + key, perturbed_batch, baseline_emb, time_info, label_data ) - batch_similarities.append(similarity) - + + preds = self._extract_target_prediction(logits, target_class_idx) + if preds.dim() == 0: + preds = preds.unsqueeze(0) + + similarities = self._compute_similarity_batch( + input_emb[key], perturbed_batch, binary_vector + ) + # Store sample information interpretable_samples.append(binary_vector.float()) - perturbed_predictions.append(torch.stack(batch_preds, dim=0)) - similarity_weights.append(torch.stack(batch_similarities, dim=0)) + perturbed_predictions.append(preds.detach()) + similarity_weights.append(similarities.detach()) # Move small summaries to CPU to free GPU memory interpretable_samples[-1] = interpretable_samples[-1].float().to(storage_device) @@ -496,6 +494,32 @@ def _create_perturbed_sample( return perturbed + def _create_perturbed_sample_batch( + self, + key: str, + binary_vector: torch.Tensor, + input_emb: Dict[str, torch.Tensor], + baseline_emb: Dict[str, torch.Tensor], + ) -> torch.Tensor: + """Vectorized mix of input and baseline for a single mask across the batch.""" + dim = input_emb[key].dim() + batch_size = input_emb[key].shape[0] + n_features = binary_vector.shape[0] + + if dim == 4: + mask_view = binary_vector.view(1, n_features, 1, 1) + elif dim == 3: + mask_view = binary_vector.view(1, n_features, 1) + else: + mask_view = binary_vector.view(1, n_features) + + mask_view = mask_view.expand(batch_size, *mask_view.shape[1:]) + base = baseline_emb[key] + if base.shape[0] != batch_size: + base = base.expand(batch_size, *base.shape[1:]) + + return torch.where(mask_view.bool(), input_emb[key], base) + def _evaluate_sample( self, key: str, @@ -610,6 +634,33 @@ def _compute_similarity( return similarity + def _compute_similarity_batch( + self, + original_emb: torch.Tensor, + perturbed_batch: torch.Tensor, + binary_vector: torch.Tensor, + ) -> torch.Tensor: + """Vectorized similarity for a batch with a single mask. + + Returns: (batch,) similarities. + """ + with torch.no_grad(): + orig_flat = original_emb.reshape(original_emb.shape[0], -1).float() + pert_flat = perturbed_batch.reshape(perturbed_batch.shape[0], -1).float() + + if self.distance_mode == "cosine": + distance = 1 - F.cosine_similarity(orig_flat, pert_flat, dim=-1) + elif self.distance_mode == "euclidean": + distance = torch.norm(orig_flat - pert_flat, dim=-1) + else: + raise ValueError("Invalid distance_mode") + + similarity = torch.exp( + -1 * (distance ** 2) / (2 * (self.kernel_width ** 2)) + ) + + return similarity + def _train_interpretable_model( self, interpretable_samples: list, From 6e7106cfac182ea338fc822c40104dbeaaf839f8 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Tue, 3 Feb 2026 21:29:47 -0500 Subject: [PATCH 3/3] Revert "vectorize lime" This reverts commit fd0866633b073dbc65555ecc3ec8e4c46d952963. --- pyhealth/interpret/methods/lime.py | 101 +++++++---------------------- 1 file changed, 25 insertions(+), 76 deletions(-) diff --git a/pyhealth/interpret/methods/lime.py b/pyhealth/interpret/methods/lime.py index 7918a5aae..d2c59a3c3 100644 --- a/pyhealth/interpret/methods/lime.py +++ b/pyhealth/interpret/methods/lime.py @@ -410,33 +410,35 @@ def _compute_lime( torch.ones(n_features, device=device) * 0.5 ) - # Vectorized perturbed batch - perturbed_batch = self._create_perturbed_sample_batch( - key, binary_vector, input_emb, baseline_emb - ) # (batch, ...) - - # Forward pass for the whole batch - if self.use_embeddings: - logits = self._forward_from_embeddings( - key, perturbed_batch, baseline_emb, time_info, label_data + # Create perturbed sample for each batch item + batch_preds = [] + batch_similarities = [] + + for b_idx in range(batch_size): + # Create perturbed embedding by mixing input and baseline + perturbed_emb = self._create_perturbed_sample( + key, binary_vector, input_emb, baseline_emb, b_idx ) - else: - logits = self._forward_from_inputs( - key, perturbed_batch, baseline_emb, time_info, label_data + + # Get model prediction for perturbed sample + pred = self._evaluate_sample( + key, perturbed_emb, baseline_emb, + target_class_idx, time_info, label_data ) - - preds = self._extract_target_prediction(logits, target_class_idx) - if preds.dim() == 0: - preds = preds.unsqueeze(0) - - similarities = self._compute_similarity_batch( - input_emb[key], perturbed_batch, binary_vector - ) - + batch_preds.append(pred) + + # Compute similarity weight + similarity = self._compute_similarity( + input_emb[key][b_idx:b_idx+1] if batch_size > 1 else input_emb[key], + perturbed_emb, + binary_vector, + ) + batch_similarities.append(similarity) + # Store sample information interpretable_samples.append(binary_vector.float()) - perturbed_predictions.append(preds.detach()) - similarity_weights.append(similarities.detach()) + perturbed_predictions.append(torch.stack(batch_preds, dim=0)) + similarity_weights.append(torch.stack(batch_similarities, dim=0)) # Move small summaries to CPU to free GPU memory interpretable_samples[-1] = interpretable_samples[-1].float().to(storage_device) @@ -494,32 +496,6 @@ def _create_perturbed_sample( return perturbed - def _create_perturbed_sample_batch( - self, - key: str, - binary_vector: torch.Tensor, - input_emb: Dict[str, torch.Tensor], - baseline_emb: Dict[str, torch.Tensor], - ) -> torch.Tensor: - """Vectorized mix of input and baseline for a single mask across the batch.""" - dim = input_emb[key].dim() - batch_size = input_emb[key].shape[0] - n_features = binary_vector.shape[0] - - if dim == 4: - mask_view = binary_vector.view(1, n_features, 1, 1) - elif dim == 3: - mask_view = binary_vector.view(1, n_features, 1) - else: - mask_view = binary_vector.view(1, n_features) - - mask_view = mask_view.expand(batch_size, *mask_view.shape[1:]) - base = baseline_emb[key] - if base.shape[0] != batch_size: - base = base.expand(batch_size, *base.shape[1:]) - - return torch.where(mask_view.bool(), input_emb[key], base) - def _evaluate_sample( self, key: str, @@ -634,33 +610,6 @@ def _compute_similarity( return similarity - def _compute_similarity_batch( - self, - original_emb: torch.Tensor, - perturbed_batch: torch.Tensor, - binary_vector: torch.Tensor, - ) -> torch.Tensor: - """Vectorized similarity for a batch with a single mask. - - Returns: (batch,) similarities. - """ - with torch.no_grad(): - orig_flat = original_emb.reshape(original_emb.shape[0], -1).float() - pert_flat = perturbed_batch.reshape(perturbed_batch.shape[0], -1).float() - - if self.distance_mode == "cosine": - distance = 1 - F.cosine_similarity(orig_flat, pert_flat, dim=-1) - elif self.distance_mode == "euclidean": - distance = torch.norm(orig_flat - pert_flat, dim=-1) - else: - raise ValueError("Invalid distance_mode") - - similarity = torch.exp( - -1 * (distance ** 2) / (2 * (self.kernel_width ** 2)) - ) - - return similarity - def _train_interpretable_model( self, interpretable_samples: list,