- support Multi-head Attention
| Model | heads | Params (M) | Acc (%) |
|---|---|---|---|
| ResNet50 baseline (ref) | 23.5M | 93.62 | |
| BoTNet-50 | 1 | 18.8M | 95.11% |
| BoTNet-50 | 4 | 18.8M | 95.78% |
| BoTNet-S1-50 | 1 | 18.8M | 95.67% |
| BoTNet-S1-59 | 1 | 27.5M | 95.98% |
| BoTNet-S1-77 | 1 | 44.9M | wip |
- Model
from model import Model
model = ResNet50(num_classes=1000, resolution=(224, 224))
x = torch.randn([2, 3, 224, 224])
print(model(x).size())- Module
from model import MHSA
resolution = 14
mhsa = MHSA(planes, width=resolution, height=resolution)- Paper link
- Author: Aravind Srinivas, Tsung-Yi Lin, Niki Parmar, Jonathon Shlens, Pieter Abbeel, Ashish Vaswani
- Organization: UC Berkeley, Google Research
