-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Description
I run the unet.py on cifar-10 dataset, and I print the structure of UNet as follows, It seems not right ?
Input shape: torch.Size([128, 3, 32, 32])
Time embedding shape: torch.Size([128, 256])
After image_proj shape: torch.Size([128, 64, 32, 32])
Downsampling process:
Down block 1 output shape: torch.Size([128, 64, 32, 32])
Down block 2 output shape: torch.Size([128, 64, 32, 32])
Down block 3 output shape: torch.Size([128, 64, 16, 16])
Down block 4 output shape: torch.Size([128, 128, 16, 16])
Down block 5 output shape: torch.Size([128, 128, 16, 16])
Down block 6 output shape: torch.Size([128, 128, 8, 8])
Down block 7 output shape: torch.Size([128, 256, 8, 8])
Down block 8 output shape: torch.Size([128, 256, 8, 8])
Down block 9 output shape: torch.Size([128, 256, 4, 4])
Down block 10 output shape: torch.Size([128, 1024, 4, 4])
Down block 11 output shape: torch.Size([128, 1024, 4, 4])
Middle block:
Middle block output shape: torch.Size([128, 1024, 4, 4])
Upsampling process:
Concatenated input shape before up block 1: torch.Size([128, 2048, 4, 4])
Up block 1 output shape: torch.Size([128, 1024, 4, 4])
Concatenated input shape before up block 2: torch.Size([128, 2048, 4, 4])
Up block 2 output shape: torch.Size([128, 1024, 4, 4])
Concatenated input shape before up block 3: torch.Size([128, 1280, 4, 4])
Up block 3 output shape: torch.Size([128, 256, 4, 4])
Upsample 4 output shape: torch.Size([128, 256, 8, 8])
Concatenated input shape before up block 5: torch.Size([128, 512, 8, 8])
Up block 5 output shape: torch.Size([128, 256, 8, 8])
Concatenated input shape before up block 6: torch.Size([128, 512, 8, 8])
Up block 6 output shape: torch.Size([128, 256, 8, 8])
Concatenated input shape before up block 7: torch.Size([128, 384, 8, 8])
Up block 7 output shape: torch.Size([128, 128, 8, 8])
Upsample 8 output shape: torch.Size([128, 128, 16, 16])
Concatenated input shape before up block 9: torch.Size([128, 256, 16, 16])
Up block 9 output shape: torch.Size([128, 128, 16, 16])
Concatenated input shape before up block 10: torch.Size([128, 256, 16, 16])
Up block 10 output shape: torch.Size([128, 128, 16, 16])
Concatenated input shape before up block 11: torch.Size([128, 192, 16, 16])
Up block 11 output shape: torch.Size([128, 64, 16, 16])
Upsample 12 output shape: torch.Size([128, 64, 32, 32])
Concatenated input shape before up block 13: torch.Size([128, 128, 32, 32])
Up block 13 output shape: torch.Size([128, 64, 32, 32])
Concatenated input shape before up block 14: torch.Size([128, 128, 32, 32])
Up block 14 output shape: torch.Size([128, 64, 32, 32])
Final output shape: torch.Size([128, 3, 32, 32])