Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 39 additions & 27 deletions pyhealth/interpret/methods/lime.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,8 @@ def _compute_lime(
"""
device = input_emb[key].device
batch_size = input_emb[key].shape[0] if input_emb[key].dim() >= 2 else 1
# Keep large intermediate tensors off the GPU to avoid OOM
storage_device = torch.device("cpu")

# Storage for samples and predictions
interpretable_samples = [] # Binary vectors
Expand Down Expand Up @@ -438,12 +440,18 @@ def _compute_lime(
perturbed_predictions.append(torch.stack(batch_preds, dim=0))
similarity_weights.append(torch.stack(batch_similarities, dim=0))

# Move small summaries to CPU to free GPU memory
interpretable_samples[-1] = interpretable_samples[-1].float().to(storage_device)
perturbed_predictions[-1] = perturbed_predictions[-1].detach().to(storage_device)
similarity_weights[-1] = similarity_weights[-1].detach().to(storage_device)

# Train weighted linear regression
return self._train_interpretable_model(
interpretable_samples,
perturbed_predictions,
similarity_weights,
device,
compute_device=storage_device,
target_device=device,
)

def _create_perturbed_sample(
Expand Down Expand Up @@ -581,23 +589,24 @@ def _compute_similarity(
Returns:
Similarity weight (scalar tensor).
"""
# Flatten embeddings for distance computation
orig_flat = original_emb.reshape(-1).float()
pert_flat = perturbed_emb.reshape(-1).float()

# Compute distance
if self.distance_mode == "cosine":
cos_sim = CosineSimilarity(dim=0)
distance = 1 - cos_sim(orig_flat, pert_flat)
elif self.distance_mode == "euclidean":
distance = torch.norm(orig_flat - pert_flat)
else:
raise ValueError("Invalid distance_mode")
with torch.no_grad():
# Flatten embeddings for distance computation
orig_flat = original_emb.reshape(-1).float()
pert_flat = perturbed_emb.reshape(-1).float()

# Compute distance
if self.distance_mode == "cosine":
cos_sim = CosineSimilarity(dim=0)
distance = 1 - cos_sim(orig_flat, pert_flat)
elif self.distance_mode == "euclidean":
distance = torch.norm(orig_flat - pert_flat)
else:
raise ValueError("Invalid distance_mode")

# Apply exponential kernel
similarity = torch.exp(
-1 * (distance ** 2) / (2 * (self.kernel_width ** 2))
)
# Apply exponential kernel
similarity = torch.exp(
-1 * (distance ** 2) / (2 * (self.kernel_width ** 2))
)

return similarity

Expand All @@ -606,7 +615,8 @@ def _train_interpretable_model(
interpretable_samples: list,
predictions: list,
weights: list,
device: torch.device,
compute_device: torch.device,
target_device: torch.device,
) -> torch.Tensor:
"""Train weighted linear regression model.

Expand All @@ -619,15 +629,16 @@ def _train_interpretable_model(
interpretable_samples: List of binary vectors.
predictions: List of model predictions.
weights: List of similarity weights.
device: Device for computation.
compute_device: Device for regression solve (CPU to save GPU memory).
target_device: Device to place the returned coefficients.

Returns:
Linear model coefficients (batch_size, n_features).
"""
# Stack collected data
X = torch.stack(interpretable_samples, dim=0).to(device) # (n_samples, n_features)
Y = torch.stack(predictions, dim=0).to(device) # (n_samples, batch_size)
W = torch.stack(weights, dim=0).to(device) # (n_samples, batch_size)
X = torch.stack(interpretable_samples, dim=0).to(compute_device) # (n_samples, n_features)
Y = torch.stack(predictions, dim=0).to(compute_device) # (n_samples, batch_size)
W = torch.stack(weights, dim=0).to(compute_device) # (n_samples, batch_size)

# Solve for each batch item independently
batch_size = Y.shape[1]
Expand All @@ -647,18 +658,18 @@ def _train_interpretable_model(
# Solve based on feature selection method
if self.feature_selection == "lasso":
# L1 regularization (approximated with iterative reweighted least squares)
coef = self._solve_lasso(Xw, yw, device)
coef = self._solve_lasso(Xw, yw, compute_device)
elif self.feature_selection == "ridge":
# L2 regularization
coef = self._solve_ridge(Xw, yw, device)
coef = self._solve_ridge(Xw, yw, compute_device)
else: # "none"
# No regularization
coef = self._solve_ols(Xw, yw, device)
coef = self._solve_ols(Xw, yw, compute_device)

coefficients.append(coef)

# Stack into (batch_size, n_features)
return torch.stack(coefficients, dim=0)
return torch.stack(coefficients, dim=0).to(target_device)

def _solve_lasso(
self,
Expand Down Expand Up @@ -874,7 +885,8 @@ def _forward_from_inputs(
(perturbed_inputs.shape[0], 1), device=perturbed_inputs.device
)

output = self.model(**model_inputs)
with torch.no_grad():
output = self.model(**model_inputs)
return self._extract_logits(output)

def _prepare_time_info(
Expand Down