@@ -389,6 +389,8 @@ def _compute_lime(
389389 """
390390 device = input_emb [key ].device
391391 batch_size = input_emb [key ].shape [0 ] if input_emb [key ].dim () >= 2 else 1
392+ # Keep large intermediate tensors off the GPU to avoid OOM
393+ storage_device = torch .device ("cpu" )
392394
393395 # Storage for samples and predictions
394396 interpretable_samples = [] # Binary vectors
@@ -438,12 +440,18 @@ def _compute_lime(
438440 perturbed_predictions .append (torch .stack (batch_preds , dim = 0 ))
439441 similarity_weights .append (torch .stack (batch_similarities , dim = 0 ))
440442
443+ # Move small summaries to CPU to free GPU memory
444+ interpretable_samples [- 1 ] = interpretable_samples [- 1 ].float ().to (storage_device )
445+ perturbed_predictions [- 1 ] = perturbed_predictions [- 1 ].detach ().to (storage_device )
446+ similarity_weights [- 1 ] = similarity_weights [- 1 ].detach ().to (storage_device )
447+
441448 # Train weighted linear regression
442449 return self ._train_interpretable_model (
443450 interpretable_samples ,
444451 perturbed_predictions ,
445452 similarity_weights ,
446- device ,
453+ compute_device = storage_device ,
454+ target_device = device ,
447455 )
448456
449457 def _create_perturbed_sample (
@@ -581,23 +589,24 @@ def _compute_similarity(
581589 Returns:
582590 Similarity weight (scalar tensor).
583591 """
584- # Flatten embeddings for distance computation
585- orig_flat = original_emb .reshape (- 1 ).float ()
586- pert_flat = perturbed_emb .reshape (- 1 ).float ()
587-
588- # Compute distance
589- if self .distance_mode == "cosine" :
590- cos_sim = CosineSimilarity (dim = 0 )
591- distance = 1 - cos_sim (orig_flat , pert_flat )
592- elif self .distance_mode == "euclidean" :
593- distance = torch .norm (orig_flat - pert_flat )
594- else :
595- raise ValueError ("Invalid distance_mode" )
592+ with torch .no_grad ():
593+ # Flatten embeddings for distance computation
594+ orig_flat = original_emb .reshape (- 1 ).float ()
595+ pert_flat = perturbed_emb .reshape (- 1 ).float ()
596+
597+ # Compute distance
598+ if self .distance_mode == "cosine" :
599+ cos_sim = CosineSimilarity (dim = 0 )
600+ distance = 1 - cos_sim (orig_flat , pert_flat )
601+ elif self .distance_mode == "euclidean" :
602+ distance = torch .norm (orig_flat - pert_flat )
603+ else :
604+ raise ValueError ("Invalid distance_mode" )
596605
597- # Apply exponential kernel
598- similarity = torch .exp (
599- - 1 * (distance ** 2 ) / (2 * (self .kernel_width ** 2 ))
600- )
606+ # Apply exponential kernel
607+ similarity = torch .exp (
608+ - 1 * (distance ** 2 ) / (2 * (self .kernel_width ** 2 ))
609+ )
601610
602611 return similarity
603612
@@ -606,7 +615,8 @@ def _train_interpretable_model(
606615 interpretable_samples : list ,
607616 predictions : list ,
608617 weights : list ,
609- device : torch .device ,
618+ compute_device : torch .device ,
619+ target_device : torch .device ,
610620 ) -> torch .Tensor :
611621 """Train weighted linear regression model.
612622
@@ -619,15 +629,16 @@ def _train_interpretable_model(
619629 interpretable_samples: List of binary vectors.
620630 predictions: List of model predictions.
621631 weights: List of similarity weights.
622- device: Device for computation.
632+ compute_device: Device for regression solve (CPU to save GPU memory).
633+ target_device: Device to place the returned coefficients.
623634
624635 Returns:
625636 Linear model coefficients (batch_size, n_features).
626637 """
627638 # Stack collected data
628- X = torch .stack (interpretable_samples , dim = 0 ).to (device ) # (n_samples, n_features)
629- Y = torch .stack (predictions , dim = 0 ).to (device ) # (n_samples, batch_size)
630- W = torch .stack (weights , dim = 0 ).to (device ) # (n_samples, batch_size)
639+ X = torch .stack (interpretable_samples , dim = 0 ).to (compute_device ) # (n_samples, n_features)
640+ Y = torch .stack (predictions , dim = 0 ).to (compute_device ) # (n_samples, batch_size)
641+ W = torch .stack (weights , dim = 0 ).to (compute_device ) # (n_samples, batch_size)
631642
632643 # Solve for each batch item independently
633644 batch_size = Y .shape [1 ]
@@ -647,18 +658,18 @@ def _train_interpretable_model(
647658 # Solve based on feature selection method
648659 if self .feature_selection == "lasso" :
649660 # L1 regularization (approximated with iterative reweighted least squares)
650- coef = self ._solve_lasso (Xw , yw , device )
661+ coef = self ._solve_lasso (Xw , yw , compute_device )
651662 elif self .feature_selection == "ridge" :
652663 # L2 regularization
653- coef = self ._solve_ridge (Xw , yw , device )
664+ coef = self ._solve_ridge (Xw , yw , compute_device )
654665 else : # "none"
655666 # No regularization
656- coef = self ._solve_ols (Xw , yw , device )
667+ coef = self ._solve_ols (Xw , yw , compute_device )
657668
658669 coefficients .append (coef )
659670
660671 # Stack into (batch_size, n_features)
661- return torch .stack (coefficients , dim = 0 )
672+ return torch .stack (coefficients , dim = 0 ). to ( target_device )
662673
663674 def _solve_lasso (
664675 self ,
@@ -874,7 +885,8 @@ def _forward_from_inputs(
874885 (perturbed_inputs .shape [0 ], 1 ), device = perturbed_inputs .device
875886 )
876887
877- output = self .model (** model_inputs )
888+ with torch .no_grad ():
889+ output = self .model (** model_inputs )
878890 return self ._extract_logits (output )
879891
880892 def _prepare_time_info (
0 commit comments