diff --git a/models/llama4/vision/embedding.py b/models/llama4/vision/embedding.py index f6efd50c6..c04cceff0 100644 --- a/models/llama4/vision/embedding.py +++ b/models/llama4/vision/embedding.py @@ -18,11 +18,11 @@ class PixelShuffle(nn.Module): - def __init__(self, ps_ratio): + def __init__(self, ps_ratio: float): super().__init__() self.ps_ratio = ps_ratio - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: # x: [B, N, C], N = number of patches assert self.ps_ratio is not None, "ps_ratio is required for pixel shuffle" assert x.dim() == 3, "pixel shuffle requires encoded patches [B, N, C]" @@ -33,14 +33,14 @@ def forward(self, x): return pixel_shuffle_patches -def pixel_shuffle_op(input_x, ps_ratio): - n, w, h, c = input_x.size() - input_x = input_x.view(n, w, int(h * ps_ratio), int(c / ps_ratio)) +def pixel_shuffle_op(input_x: torch.Tensor, ps_ratio: float) -> torch.Tensor: + n, h, w, c = input_x.size() + input_x = input_x.view(n, h, int(w * ps_ratio), int(c / ps_ratio)) input_x = input_x.permute(0, 2, 1, 3).contiguous() input_x = input_x.view( n, - int(h * ps_ratio), int(w * ps_ratio), + int(h * ps_ratio), int(c / (ps_ratio * ps_ratio)), ) input_x = input_x.permute(0, 2, 1, 3).contiguous()