-
Notifications
You must be signed in to change notification settings - Fork 414
Open
Description
Pytorch summary does not work with the torch.device class.
Code to reproduce the error.
import torch
import torch.nn as nn
from torchvision import models
from torchsummary import summary
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Using ", device)
class CNN(nn.Module):
def __init__(self, train_CNN=False, num_classes=2):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = CNN().to(device)
summary(model, (3, 28, 28), device=device)Error message:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
~\AppData\Local\Temp\ipykernel_17928\2345870344.py in <module>
27
28 model = CNN().to(device)
---> 29 summary(model, (3, 28, 28), device=device)
~\anaconda3\lib\site-packages\torchsummary\torchsummary.py in summary(model, input_size, batch_size, device)
42 hooks.append(module.register_forward_hook(hook))
43
---> 44 device = device.lower()
45 assert device in [
46 "cuda",
AttributeError: 'torch.device' object has no attribute 'lower'Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels