Skip to content

summary does not work with the torch.device class #199

@Skaifai

Description

@Skaifai

Pytorch summary does not work with the torch.device class.
Code to reproduce the error.

import torch
import torch.nn as nn
from torchvision import models
from torchsummary import summary

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Using ", device)

class CNN(nn.Module):
    def __init__(self, train_CNN=False, num_classes=2):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = CNN().to(device)
summary(model, (3, 28, 28), device=device)

Error message:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
~\AppData\Local\Temp\ipykernel_17928\2345870344.py in <module>
     27 
     28 model = CNN().to(device)
---> 29 summary(model, (3, 28, 28), device=device)

~\anaconda3\lib\site-packages\torchsummary\torchsummary.py in summary(model, input_size, batch_size, device)
     42             hooks.append(module.register_forward_hook(hook))
     43 
---> 44     device = device.lower()
     45     assert device in [
     46         "cuda",

AttributeError: 'torch.device' object has no attribute 'lower'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions