Skip to content

A PyTorch implementation of the paper 'Selective Amnesia: A Continual Learning Approach to Forgetting in Deep Generative Models' (arXiv:2305.10120) on the MNIST dataset.

Notifications You must be signed in to change notification settings

Amir-rfz/vae-concept-forgetting

Repository files navigation

VAE Concept Forgetting: An Implementation of Selective Amnesia

This repository contains a PyTorch implementation of the paper "Selective Amnesia: A Continual Learning Approach to Forgetting in Deep Generative Models" presented at NeurIPS 2023. The code is structured as a step-by-step Google Colab notebook that demonstrates how to make a Variational Autoencoder (VAE) forget a specific concept (in this case, an MNIST digit).

Original Paper: arXiv:2305.10120

Table of Contents

How It Works

The project is broken down into a series of steps that first replicate the core findings of the paper and then explore the robustness and limitations of the unlearning technique.

Initial CVAE Training

First, a standard Conditional Variational Autoencoder (CVAE) is trained on the complete MNIST dataset. This model learns to generate high-quality images of all 10 digits.

VAE Learning Process

Calculate Fisher Information Matrix (FIM)

Next, we analyze the trained VAE to determine the importance of each weight in its network by calculating the Fisher Information Matrix (FIM). The FIM tells us which connections are crucial for remembering the learned digits. This information is key to protecting the model's memory during the forgetting process.

Forgetting Training with Selective Amnesia

This is the core of the paper's technique. We take the fully trained VAE and the FIM and begin a second phase of training to make the model forget a single digit (e.g., '1'). The model is taught to:

  1. Forget: Associate the label of the forgotten class with random noise.

  2. Remember: Continue to generate clear images for all other classes.

  3. Protect: Use the FIM as a penalty to prevent large changes to important weights, thus preserving the knowledge of the other digits.

Forgetting Process

Quantitative Evaluation

To prove that the forgetting was successful, we train an independent classifier and use it to evaluate samples generated by the original and amnesiac VAEs.

Model Avg. Probability of seeing a '1' Classifier Entropy
Original Model (Before) 0.9704 0.0999
Amnesiac Model (After) 0.2045 2.2175

As expected, the probability of the target class drops significantly, while the classifier's "confusion" (entropy) increases to near its maximum possible value. The visual results below clearly show the effect of the Selective Amnesia algorithm. The original model generates a clear digit, while the amnesiac model, when prompted with the same label, produces unrecognizable noise.

final comparison

Advanced Experiments

The following steps go beyond the basic implementation to test the limits of the Selective Amnesia technique.

Sequential Forgetting and Generalization Check

This experiment investigates whether repeatedly applying Selective Amnesia degrades the model's overall performance. We sequentially force the model to forget four different digits (0, 1, 2, 3, and 4) and then evaluate its ability to generate the remaining digits.

Sequential Forgetting Comparison

The results are quite impressive. Even after five rounds of forgetting, the model continues to generate high-quality samples for the retained digits. This indicates that the EWC penalty is highly effective at protecting the knowledge of the concepts we want to keep, and the unlearning process is well-targeted.

Latent Space Exploration of the Forgotten Class

This section documents a series of increasingly sophisticated attempts to recover the "ghost" of the forgotten digit from the amnesiac model's latent space.

  1. Random Latent Sampling: Probing the model by providing a random latent vector while using the condition for the forgotten digit.

Sequential Forgetting Comparison

  1. Reconstruction with Noisy Conditions: Finding the "true" latent coordinates of a forgotten digit (using the original VAE) and feeding them to the amnesiac decoder along with a "fuzzy" or noisy condition vector.

Sequential Forgetting Comparison

  1. Latent Walk with a Deceptive Condition: Walking the latent path from a known digit to a forgotten one, but keeping the condition vector fixed to the known digit to try and "trick" the decoder.

Sequential Forgetting Comparison

  1. Double Interpolation: Simultaneously walking the path in both the latent space and the condition space from a known digit to a forgotten one.

Sequential Forgetting Comparison

  1. Triangular Latent Walk: Walking from a known digit, through the latent location of the forgotten digit, and on to a second known digit, to see if the model's need for a smooth transition reveals the forgotten concept.

Sequential Forgetting Comparison

  1. Latent & Condition Optimization (Adversarial Attack): The final, successful approach. An optimizer is used to simultaneously modify both a latent vector z and a condition vector c, using an expert classifier as a guide to find a "secret key" that forces the amnesiac decoder to regenerate the forgotten digit.

Sequential Forgetting Comparison

Visualizing the Latent Space

To understand how the model forgets, we can visualize its latent space using several techniques.

Original VAE 8D Latent Space: This plot shows a well-structured latent space where every region decodes into a clear digit.We use t-SNE to create a 2D "map" of the model's internal representation of digit styles. The manifold plot shows the actual images generated from points across this map.

t-SNE visualization latent space manifold

Amnesiac VAE 8D Latent Space: Comparing this to the original reveals the core pattern: the forgetting process doesn't erase a region of the t-SNE map. Instead, it teaches the decoder to output garbage when it is asked to decode a point from the "style" region of a forgotten digit while also being given the forgotten digit's label. This creates "dead zones" on the manifold, which are the footprint of the forgotten concept.

t-SNE visualization latent space manifold

2D Latent Space: To get a clearer, more direct view, we can train a VAE with a z_dim of 2. This allows us to plot the encoder's output directly without dimensionality reduction. This experiment reveals a fascinating "quarantine" strategy: in a constrained 2D space, the easiest way for the model to forget a concept is to first isolate it into its own cluster and then apply the forgetting rule only to that area.

2D Latent Space

Latent Space Centroid Analysis: This is the most powerful visualization for understanding the structural changes in the latent space. We calculate the "center of gravity" (centroid) for each digit's cluster and analyze how these centers move. The results show that the forgetting process doesn't just erase a concept; it actively pushes its representation into a remote, isolated corner of the latent space, making it an "outcast" from the other digits.

Sequential Forgetting Comparison

Testing Robustness in Low-Probability Latent Regions

This experiment tests a key hypothesis: does the unlearning process accidentally damage the model's ability to generate rare, "outlier" styles of the remembered digits? We identify the latent vectors for the most unusual examples of the forgotten digit from the entire MNIST dataset and compare their reconstructions before and after the forgetting process. The results show that the unlearning is remarkably robust, effectively suppressing even the rarest styles of the forgotten concept.

Outlier Robustness Test

Selective Amnesia with Gradient Projection

This final experiment explores alternative unlearning strategies inspired by methods in continual learning. We compare the standard Selective Amnesia method against two custom techniques:

  1. Symmetric Gradient Projection: A method that purifies both the "forget" and "remember" gradients before combining them.

  2. Alternating Gradient Projection: A method that alternates between applying the purified "forget" and "remember" gradients.

The visual results show that while the custom methods are effective at forgetting, they cause more significant degradation to the quality of the remembered digits compared to the standard SA method, which is stabilized by the EWC penalty.

Gradient Projection Visual Comparison

A numerical comparison of the gradient norms during training confirms that the different methods have distinct update behaviors.

Gradient Norm Comparison

Final Results

The quantitative and visual results together confirm the effectiveness of the Selective Amnesia algorithm. The model successfully learns to forget a specific concept on command, and this forgetting is robust enough to withstand simple probes.

However, the advanced exploration shows that while the unlearning is strong, a determined adversarial attack using optimization can still recover a "ghost" of the unlearned concept, indicating that the information is suppressed rather than completely erased.

How to Run

The code for this project is contained within a single Google Colab notebook.

  1. Open the notebook in Google Colab.

  2. Ensure the runtime is set to use a GPU for faster training.

  3. Run each cell in order, from top to bottom. The notebook is designed to save all model checkpoints and results to your Google Drive to prevent data loss from disconnections.

About

A PyTorch implementation of the paper 'Selective Amnesia: A Continual Learning Approach to Forgetting in Deep Generative Models' (arXiv:2305.10120) on the MNIST dataset.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published