Julia implementation of Support Prototypical part network (ST-ProtoPNet) for interpretable classification using support and trivial prototypes. Integrates with Flux.jl for efficient optimization and training. Features customizable loss functions, data preparation utilities, and visualization tools (for 2D data).
.
├── src/ # Core implementation
├── examples/ # Usage examples
├── tests/ # Basic tests
└── Project.toml # Dependencies
data_prep.jl: Functions to split data into training/test sets and create data loaders
network.jl: The main network structure that does the classification
prototype_layer.jl: Layer that compares input data with learned prototypes
prototype_loss.jl: Loss functions to train the network
training.jl: Functions to train the model in different stages
2d_data_plot.jl: Tools to visualize the data and model results
The network has 3 main parts:
-
Backbone
- Takes in 2D data
- Converts input into features
-
Prototype Layer
- Compares input features with learned prototypes
- Each prototype represents a pattern the model learns
- Prototypes are devided equally between all classes
-
Classifier
- Makes final class prediction
- Uses a simple Dense layer
- Train backbone and prototypes
- Match prototypes to real data points
- Fine-tune classifier
You can find Jupyter notebooks with examples and tutorials in the examples/ folder.
using ST_ProtoPNet
# Create model
model = SimpleNet(10, 2, num_classes=2)
# Prepare data
train_loader, test_loader = get_dataloader_flux(X_train, y_train, X_test, y_test)
# Train model
opt_state = Flux.setup(Adam(), model)
loss = construct_loss()
standart_train!(model, train_loader, test_loader, loss, opt_state)