Skip to content

The number of channels is weird #296

@optstats

Description

@optstats

I run the unet.py on cifar-10 dataset, and I print the structure of UNet as follows, It seems not right ?

Input shape: torch.Size([128, 3, 32, 32])
Time embedding shape: torch.Size([128, 256])
After image_proj shape: torch.Size([128, 64, 32, 32])

Downsampling process:
Down block 1 output shape: torch.Size([128, 64, 32, 32])
Down block 2 output shape: torch.Size([128, 64, 32, 32])
Down block 3 output shape: torch.Size([128, 64, 16, 16])
Down block 4 output shape: torch.Size([128, 128, 16, 16])
Down block 5 output shape: torch.Size([128, 128, 16, 16])
Down block 6 output shape: torch.Size([128, 128, 8, 8])
Down block 7 output shape: torch.Size([128, 256, 8, 8])
Down block 8 output shape: torch.Size([128, 256, 8, 8])
Down block 9 output shape: torch.Size([128, 256, 4, 4])
Down block 10 output shape: torch.Size([128, 1024, 4, 4])
Down block 11 output shape: torch.Size([128, 1024, 4, 4])

Middle block:
Middle block output shape: torch.Size([128, 1024, 4, 4])

Upsampling process:
Concatenated input shape before up block 1: torch.Size([128, 2048, 4, 4])
Up block 1 output shape: torch.Size([128, 1024, 4, 4])
Concatenated input shape before up block 2: torch.Size([128, 2048, 4, 4])
Up block 2 output shape: torch.Size([128, 1024, 4, 4])
Concatenated input shape before up block 3: torch.Size([128, 1280, 4, 4])
Up block 3 output shape: torch.Size([128, 256, 4, 4])
Upsample 4 output shape: torch.Size([128, 256, 8, 8])
Concatenated input shape before up block 5: torch.Size([128, 512, 8, 8])
Up block 5 output shape: torch.Size([128, 256, 8, 8])
Concatenated input shape before up block 6: torch.Size([128, 512, 8, 8])
Up block 6 output shape: torch.Size([128, 256, 8, 8])
Concatenated input shape before up block 7: torch.Size([128, 384, 8, 8])
Up block 7 output shape: torch.Size([128, 128, 8, 8])
Upsample 8 output shape: torch.Size([128, 128, 16, 16])
Concatenated input shape before up block 9: torch.Size([128, 256, 16, 16])
Up block 9 output shape: torch.Size([128, 128, 16, 16])
Concatenated input shape before up block 10: torch.Size([128, 256, 16, 16])
Up block 10 output shape: torch.Size([128, 128, 16, 16])
Concatenated input shape before up block 11: torch.Size([128, 192, 16, 16])
Up block 11 output shape: torch.Size([128, 64, 16, 16])
Upsample 12 output shape: torch.Size([128, 64, 32, 32])
Concatenated input shape before up block 13: torch.Size([128, 128, 32, 32])
Up block 13 output shape: torch.Size([128, 64, 32, 32])
Concatenated input shape before up block 14: torch.Size([128, 128, 32, 32])
Up block 14 output shape: torch.Size([128, 64, 32, 32])

Final output shape: torch.Size([128, 3, 32, 32])

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions