Skip to content

Julia implementation of ST-ProtoPNet for interpretable classification with support and trivial prototypes. Built on Flux.jl with custom losses, data prep tools, and 2D visualization.

License

Notifications You must be signed in to change notification settings

anokhver/ProtoPNet_julia

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

21 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ST-ProtoPNet.jl

Overview

Julia implementation of Support Prototypical part network (ST-ProtoPNet) for interpretable classification using support and trivial prototypes. Integrates with Flux.jl for efficient optimization and training. Features customizable loss functions, data preparation utilities, and visualization tools (for 2D data).

Wang et al. (2023)

Project Structure

.
├── src/              # Core implementation
├── examples/         # Usage examples
├── tests/            # Basic tests
└── Project.toml      # Dependencies

data_prep.jl: Functions to split data into training/test sets and create data loaders
network.jl: The main network structure that does the classification
prototype_layer.jl: Layer that compares input data with learned prototypes
prototype_loss.jl: Loss functions to train the network
training.jl: Functions to train the model in different stages
2d_data_plot.jl: Tools to visualize the data and model results

Model Structure

The network has 3 main parts:

  1. Backbone

    • Takes in 2D data
    • Converts input into features
  2. Prototype Layer

    • Compares input features with learned prototypes
    • Each prototype represents a pattern the model learns
    • Prototypes are devided equally between all classes
  3. Classifier

    • Makes final class prediction
    • Uses a simple Dense layer

Training Steps

  1. Train backbone and prototypes
  2. Match prototypes to real data points
  3. Fine-tune classifier

Basic Usage

You can find Jupyter notebooks with examples and tutorials in the examples/ folder.

using ST_ProtoPNet

# Create model  
model = SimpleNet(10, 2, num_classes=2)  

# Prepare data
train_loader, test_loader = get_dataloader_flux(X_train, y_train, X_test, y_test)

# Train model
opt_state = Flux.setup(Adam(), model)
loss = construct_loss()
standart_train!(model, train_loader, test_loader, loss, opt_state)

About

Julia implementation of ST-ProtoPNet for interpretable classification with support and trivial prototypes. Built on Flux.jl with custom losses, data prep tools, and 2D visualization.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Julia 100.0%