Skip to content

srb-cv/AttentionClassification

Repository files navigation

LearnToPayAttention

AUR

PyTorch implementation of ICLR 2018 paper Learn To Pay Attention


Most Recent Updates


My implementation is based on "(VGG-att3)-concat-pc" in the paper, and I trained the model on CIFAR-100 DATASET.
I implemented two version of the model, the only difference is whether to insert the attention module before or after the corresponding max-pooling layer.

(New!) Pre-trained models

Google drive link
Alternative link(Baidu Cloud Disk)

Dependences

NOTE If you are using PyTorch < 0.4.1, then replace torch.nn.functional.interpolate by torch.nn.Upsample. (Modify the code in utilities.py).

Training

  1. Pay attention before max-pooling layers
python train.py --attn_mode before --outf logs_before --normalize_attn --log_images
  1. Pay attention after max-pooling layers
python train.py --attn_mode after --outf logs_after --normalize_attn --log_images

Results

Training curve - loss

The x-axis is # iter

  1. Pay attention before max-pooling layers

  2. Pay attention after max-pooling layers

  3. Plot in one figure

Training curve - accuracy on test data

The x-axis is # epoch

  1. Pay attention before max-pooling layers

  2. Pay attention after max-pooling layers

  3. Plot in one figure

Quantitative results (on test data of CIFAR-100)

Method VGG (Simonyan&Zisserman,2014) (VGG-att3)-concat-pc (ICLR 2018) attn-before-pooling (my code) attn-after-pooling (my code)
Top-1 error 30.62 22.97 22.62 22.92

Attention map visualization (on test data of CIFAR-100)

From left to right: L1, L2, L3, original images

  1. Pay attention before max-pooling layers

  2. Pay attention after max-pooling layers

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •  

Languages