This repository contains a PyTorch implementation of the paper "Selective Amnesia: A Continual Learning Approach to Forgetting in Deep Generative Models" presented at NeurIPS 2023. The code is structured as a step-by-step Google Colab notebook that demonstrates how to make a Variational Autoencoder (VAE) forget a specific concept (in this case, an MNIST digit).
Original Paper: arXiv:2305.10120
The project is broken down into a series of steps that first replicate the core findings of the paper and then explore the robustness and limitations of the unlearning technique.
First, a standard Conditional Variational Autoencoder (CVAE) is trained on the complete MNIST dataset. This model learns to generate high-quality images of all 10 digits.
Next, we analyze the trained VAE to determine the importance of each weight in its network by calculating the Fisher Information Matrix (FIM). The FIM tells us which connections are crucial for remembering the learned digits. This information is key to protecting the model's memory during the forgetting process.
This is the core of the paper's technique. We take the fully trained VAE and the FIM and begin a second phase of training to make the model forget a single digit (e.g., '1'). The model is taught to:
-
Forget: Associate the label of the forgotten class with random noise.
-
Remember: Continue to generate clear images for all other classes.
-
Protect: Use the FIM as a penalty to prevent large changes to important weights, thus preserving the knowledge of the other digits.
To prove that the forgetting was successful, we train an independent classifier and use it to evaluate samples generated by the original and amnesiac VAEs.
| Model | Avg. Probability of seeing a '1' | Classifier Entropy |
|---|---|---|
| Original Model (Before) | 0.9704 | 0.0999 |
| Amnesiac Model (After) | 0.2045 | 2.2175 |
As expected, the probability of the target class drops significantly, while the classifier's "confusion" (entropy) increases to near its maximum possible value. The visual results below clearly show the effect of the Selective Amnesia algorithm. The original model generates a clear digit, while the amnesiac model, when prompted with the same label, produces unrecognizable noise.
The following steps go beyond the basic implementation to test the limits of the Selective Amnesia technique.
This experiment investigates whether repeatedly applying Selective Amnesia degrades the model's overall performance. We sequentially force the model to forget four different digits (0, 1, 2, 3, and 4) and then evaluate its ability to generate the remaining digits.
The results are quite impressive. Even after five rounds of forgetting, the model continues to generate high-quality samples for the retained digits. This indicates that the EWC penalty is highly effective at protecting the knowledge of the concepts we want to keep, and the unlearning process is well-targeted.
This section documents a series of increasingly sophisticated attempts to recover the "ghost" of the forgotten digit from the amnesiac model's latent space.
- Random Latent Sampling: Probing the model by providing a random latent vector while using the condition for the forgotten digit.
- Reconstruction with Noisy Conditions: Finding the "true" latent coordinates of a forgotten digit (using the original VAE) and feeding them to the amnesiac decoder along with a "fuzzy" or noisy condition vector.
- Latent Walk with a Deceptive Condition: Walking the latent path from a known digit to a forgotten one, but keeping the condition vector fixed to the known digit to try and "trick" the decoder.
- Double Interpolation: Simultaneously walking the path in both the latent space and the condition space from a known digit to a forgotten one.
- Triangular Latent Walk: Walking from a known digit, through the latent location of the forgotten digit, and on to a second known digit, to see if the model's need for a smooth transition reveals the forgotten concept.
- Latent & Condition Optimization (Adversarial Attack): The final, successful approach. An optimizer is used to simultaneously modify both a latent vector
zand a condition vectorc, using an expert classifier as a guide to find a "secret key" that forces the amnesiac decoder to regenerate the forgotten digit.
To understand how the model forgets, we can visualize its latent space using several techniques.
Original VAE 8D Latent Space: This plot shows a well-structured latent space where every region decodes into a clear digit.We use t-SNE to create a 2D "map" of the model's internal representation of digit styles. The manifold plot shows the actual images generated from points across this map.
|
|
Amnesiac VAE 8D Latent Space: Comparing this to the original reveals the core pattern: the forgetting process doesn't erase a region of the t-SNE map. Instead, it teaches the decoder to output garbage when it is asked to decode a point from the "style" region of a forgotten digit while also being given the forgotten digit's label. This creates "dead zones" on the manifold, which are the footprint of the forgotten concept.
|
|
2D Latent Space: To get a clearer, more direct view, we can train a VAE with a z_dim of 2. This allows us to plot the encoder's output directly without dimensionality reduction. This experiment reveals a fascinating "quarantine" strategy: in a constrained 2D space, the easiest way for the model to forget a concept is to first isolate it into its own cluster and then apply the forgetting rule only to that area.
Latent Space Centroid Analysis: This is the most powerful visualization for understanding the structural changes in the latent space. We calculate the "center of gravity" (centroid) for each digit's cluster and analyze how these centers move. The results show that the forgetting process doesn't just erase a concept; it actively pushes its representation into a remote, isolated corner of the latent space, making it an "outcast" from the other digits.
This experiment tests a key hypothesis: does the unlearning process accidentally damage the model's ability to generate rare, "outlier" styles of the remembered digits? We identify the latent vectors for the most unusual examples of the forgotten digit from the entire MNIST dataset and compare their reconstructions before and after the forgetting process. The results show that the unlearning is remarkably robust, effectively suppressing even the rarest styles of the forgotten concept.
This final experiment explores alternative unlearning strategies inspired by methods in continual learning. We compare the standard Selective Amnesia method against two custom techniques:
-
Symmetric Gradient Projection: A method that purifies both the "forget" and "remember" gradients before combining them.
-
Alternating Gradient Projection: A method that alternates between applying the purified "forget" and "remember" gradients.
The visual results show that while the custom methods are effective at forgetting, they cause more significant degradation to the quality of the remembered digits compared to the standard SA method, which is stabilized by the EWC penalty.
A numerical comparison of the gradient norms during training confirms that the different methods have distinct update behaviors.
The quantitative and visual results together confirm the effectiveness of the Selective Amnesia algorithm. The model successfully learns to forget a specific concept on command, and this forgetting is robust enough to withstand simple probes.
However, the advanced exploration shows that while the unlearning is strong, a determined adversarial attack using optimization can still recover a "ghost" of the unlearned concept, indicating that the information is suppressed rather than completely erased.
The code for this project is contained within a single Google Colab notebook.
-
Open the notebook in Google Colab.
-
Ensure the runtime is set to use a GPU for faster training.
-
Run each cell in order, from top to bottom. The notebook is designed to save all model checkpoints and results to your Google Drive to prevent data loss from disconnections.


















