Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
__pycache__
resources/*.pt
resources/*.npy
.idea
.idea
/.vscode/
*.log
35 changes: 18 additions & 17 deletions models/dymn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
from torch import nn, Tensor
import torch.nn.functional as F
from torchvision.ops.misc import ConvNormActivation
from torchvision.ops.misc import Conv2dNormActivation
from torch.hub import load_state_dict_from_url
import urllib.parse

Expand Down Expand Up @@ -77,13 +77,13 @@ def __init__(

# building first layer
firstconv_output_channels = inverted_residual_setting[0].input_channels
self.in_c = ConvNormActivation(
in_channels,
firstconv_output_channels,
kernel_size=in_conv_kernel,
stride=in_conv_stride,
norm_layer=norm_layer,
activation_layer=nn.Hardswish,
self.in_c = Conv2dNormActivation(
in_channels,
firstconv_output_channels,
kernel_size=in_conv_kernel,
stride=in_conv_stride,
norm_layer=norm_layer,
activation_layer=nn.Hardswish,
)

for cnf in inverted_residual_setting:
Expand All @@ -107,7 +107,7 @@ def __init__(
# building last several layers
lastconv_input_channels = inverted_residual_setting[-1].out_channels
lastconv_output_channels = 6 * lastconv_input_channels
self.out_c = ConvNormActivation(
self.out_c = Conv2dNormActivation(
lastconv_input_channels,
lastconv_output_channels,
kernel_size=1,
Expand Down Expand Up @@ -174,31 +174,32 @@ def _feature_forward(
return x, fmaps
return x

def _clf_forward(self, x: Tensor):
embed = F.adaptive_avg_pool2d(x, (1, 1)).view(x.size(0), -1)
def _clf_forward(self, x: Tensor, frame: bool = False):
B, D, _, T = x.shape
embed = F.adaptive_avg_pool2d(x, (1, 1)).squeeze(
) if not frame else F.adaptive_avg_pool2d(x, (1, T)).view(B, D, -1).transpose(-2, -1)
x = self.classifier(x).squeeze()
if x.dim() == 1:
# squeezed batch dimension
x = x.unsqueeze(0)
return x, embed

def _forward_impl(
self, x: Tensor, return_fmaps: bool = False
self, x: Tensor, return_fmaps: bool = False, frame: bool = False
) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, List[Tensor]]]:
if return_fmaps:
x, fmaps = self._feature_forward(x, return_fmaps=True)
x, _ = self._clf_forward(x)
x, _ = self._clf_forward(x, frame=frame)
return x, fmaps
else:
x = self._feature_forward(x)
x, embed = self._clf_forward(x)
x, embed = self._clf_forward(x, frame=frame)
return x, embed

def forward(
self, x: Tensor, return_fmaps: bool = False
self, x: Tensor, return_fmaps: bool = False, frame: bool = False
) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, List[Tensor]]]:
return self._forward_impl(x, return_fmaps)

return self._forward_impl(x, return_fmaps, frame=frame)

def update_params(self, epoch):
for module in self.modules():
Expand Down
8 changes: 4 additions & 4 deletions models/mn/block_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import torch.nn as nn
from torch import Tensor
from torchvision.ops.misc import ConvNormActivation
from torchvision.ops.misc import Conv2dNormActivation

from models.mn.utils import make_divisible, cnn_out_size

Expand Down Expand Up @@ -137,7 +137,7 @@ def __init__(
# expand
if cnf.expanded_channels != cnf.input_channels:
layers.append(
ConvNormActivation(
Conv2dNormActivation(
cnf.input_channels,
cnf.expanded_channels,
kernel_size=1,
Expand All @@ -149,7 +149,7 @@ def __init__(
# depthwise
stride = 1 if cnf.dilation > 1 else cnf.stride
layers.append(
ConvNormActivation(
Conv2dNormActivation(
cnf.expanded_channels,
cnf.expanded_channels,
kernel_size=cnf.kernel,
Expand All @@ -165,7 +165,7 @@ def __init__(

# project
layers.append(
ConvNormActivation(
Conv2dNormActivation(
cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
)
)
Expand Down
24 changes: 13 additions & 11 deletions models/mn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
from torch import nn, Tensor
import torch.nn.functional as F
from torchvision.ops.misc import ConvNormActivation
from torchvision.ops.misc import Conv2dNormActivation
from torch.hub import load_state_dict_from_url
import urllib.parse

Expand Down Expand Up @@ -122,7 +122,7 @@ def __init__(
# building first layer
firstconv_output_channels = inverted_residual_setting[0].input_channels
layers.append(
ConvNormActivation(
Conv2dNormActivation(
in_channels,
firstconv_output_channels,
kernel_size=in_conv_kernel,
Expand Down Expand Up @@ -157,7 +157,7 @@ def __init__(
lastconv_input_channels = inverted_residual_setting[-1].out_channels
lastconv_output_channels = 6 * lastconv_input_channels
layers.append(
ConvNormActivation(
Conv2dNormActivation(
lastconv_input_channels,
lastconv_output_channels,
kernel_size=1,
Expand Down Expand Up @@ -209,29 +209,31 @@ def __init__(
if m.bias is not None:
nn.init.zeros_(m.bias)

def _forward_impl(self, x: Tensor, return_fmaps: bool = False) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, List[Tensor]]]:
def _forward_impl(self, x: Tensor, return_fmaps: bool = False, frame: bool = False) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, List[Tensor]]]:
fmaps = []

for i, layer in enumerate(self.features):
x = layer(x)
if return_fmaps:
fmaps.append(x)

features = F.adaptive_avg_pool2d(x, (1, 1)).squeeze()

B, C, _, T = x.shape
features = F.adaptive_avg_pool2d(x, (1, 1)).squeeze(
) if not frame else F.adaptive_avg_pool2d(x, (1, T)).view(B, C, -1).transpose(-2, -1)
x = self.classifier(x).squeeze()

if features.dim() == 1 and x.dim() == 1:
# squeezed batch dimension
features = features.unsqueeze(0)
x = x.unsqueeze(0)

if return_fmaps:
return x, fmaps
else:
return x, features

def forward(self, x: Tensor) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, List[Tensor]]]:
return self._forward_impl(x)
def forward(self, x: Tensor, return_fmaps: bool = False, frame: bool = False) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, List[Tensor]]]:
return self._forward_impl(x, return_fmaps=return_fmaps, frame=frame)


def _mobilenet_v3_conf(
Expand Down
8 changes: 4 additions & 4 deletions models/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024
self.fmin = fmin
if fmax is None:
fmax = sr // 2 - fmax_aug_range // 2
print(f"Warning: FMAX is None setting to {fmax} ")
print(f"INFO: FMAX is None, setting to {fmax} ")
self.fmax = fmax
self.hopsize = hopsize
self.register_buffer('window',
Expand All @@ -40,8 +40,8 @@ def __init__(self, n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024
def forward(self, x):
x = nn.functional.conv1d(x.unsqueeze(1), self.preemphasis_coefficient).squeeze(1)
x = torch.stft(x, self.n_fft, hop_length=self.hopsize, win_length=self.win_length,
center=True, normalized=False, window=self.window, return_complex=False)
x = (x ** 2).sum(dim=-1) # power mag
center=True, normalized=False, window=self.window, return_complex=True)
x = torch.square(torch.abs(x)) # power mag
fmin = self.fmin + torch.randint(self.fmin_aug_range, (1,)).item()
fmax = self.fmax + self.fmax_aug_range // 2 - torch.randint(self.fmax_aug_range, (1,)).item()
# don't augment eval data
Expand All @@ -53,7 +53,7 @@ def forward(self, x):
fmin, fmax, vtln_low=100.0, vtln_high=-500., vtln_warp_factor=1.0)
mel_basis = torch.as_tensor(torch.nn.functional.pad(mel_basis, (0, 1), mode='constant', value=0),
device=x.device)
with torch.cuda.amp.autocast(enabled=False):
with torch.amp.autocast(enabled=False, device_type=x.device.type):
melspec = torch.matmul(mel_basis, x)

melspec = (melspec + 0.00001).log()
Expand Down