@@ -21,9 +21,14 @@ class MultiPeriodDiscriminator(torch.nn.Module):
2121 Defaults to False.
2222 """
2323
24- def __init__ (self , use_spectral_norm : bool = False , checkpointing : bool = False , version : str = "v2" ):
24+ def __init__ (
25+ self ,
26+ use_spectral_norm : bool = False ,
27+ checkpointing : bool = False ,
28+ version : str = "v2" ,
29+ ):
2530 super ().__init__ ()
26-
31+
2732 if version == "v1" :
2833 periods = [2 , 3 , 5 , 7 , 11 , 17 ]
2934 resolutions = []
@@ -33,12 +38,15 @@ def __init__(self, use_spectral_norm: bool = False, checkpointing: bool = False,
3338 elif version == "v3" :
3439 periods = [2 , 3 , 5 , 7 , 11 ]
3540 resolutions = [[1024 , 120 , 600 ], [2048 , 240 , 1200 ], [512 , 50 , 240 ]]
36-
41+
3742 self .checkpointing = checkpointing
3843 self .discriminators = torch .nn .ModuleList (
3944 [DiscriminatorS (use_spectral_norm = use_spectral_norm )]
4045 + [DiscriminatorP (p , use_spectral_norm = use_spectral_norm ) for p in periods ]
41- + [DiscriminatorR (r , use_spectral_norm = use_spectral_norm ) for r in resolutions ]
46+ + [
47+ DiscriminatorR (r , use_spectral_norm = use_spectral_norm )
48+ for r in resolutions
49+ ]
4250 )
4351
4452 def forward (self , y , y_hat ):
@@ -160,6 +168,7 @@ def forward(self, x):
160168 x = torch .flatten (x , 1 , - 1 )
161169 return x , fmap
162170
171+
163172class DiscriminatorR (torch .nn .Module ):
164173 def __init__ (self , resolution , use_spectral_norm = False ):
165174 super ().__init__ ()
@@ -170,11 +179,49 @@ def __init__(self, resolution, use_spectral_norm=False):
170179
171180 self .convs = torch .nn .ModuleList (
172181 [
173- norm_f (torch .nn .Conv2d ( 1 , 32 , (3 , 9 ), padding = (1 , 4 ),)),
174- norm_f (torch .nn .Conv2d (32 , 32 , (3 , 9 ), stride = (1 , 2 ), padding = (1 , 4 ),)),
175- norm_f (torch .nn .Conv2d (32 , 32 , (3 , 9 ), stride = (1 , 2 ), padding = (1 , 4 ),)),
176- norm_f (torch .nn .Conv2d (32 , 32 , (3 , 9 ), stride = (1 , 2 ), padding = (1 , 4 ),)),
177- norm_f (torch .nn .Conv2d (32 , 32 , (3 , 3 ), padding = (1 , 1 ),)),
182+ norm_f (
183+ torch .nn .Conv2d (
184+ 1 ,
185+ 32 ,
186+ (3 , 9 ),
187+ padding = (1 , 4 ),
188+ )
189+ ),
190+ norm_f (
191+ torch .nn .Conv2d (
192+ 32 ,
193+ 32 ,
194+ (3 , 9 ),
195+ stride = (1 , 2 ),
196+ padding = (1 , 4 ),
197+ )
198+ ),
199+ norm_f (
200+ torch .nn .Conv2d (
201+ 32 ,
202+ 32 ,
203+ (3 , 9 ),
204+ stride = (1 , 2 ),
205+ padding = (1 , 4 ),
206+ )
207+ ),
208+ norm_f (
209+ torch .nn .Conv2d (
210+ 32 ,
211+ 32 ,
212+ (3 , 9 ),
213+ stride = (1 , 2 ),
214+ padding = (1 , 4 ),
215+ )
216+ ),
217+ norm_f (
218+ torch .nn .Conv2d (
219+ 32 ,
220+ 32 ,
221+ (3 , 3 ),
222+ padding = (1 , 1 ),
223+ )
224+ ),
178225 ]
179226 )
180227 self .conv_post = norm_f (torch .nn .Conv2d (32 , 1 , (3 , 3 ), padding = (1 , 1 )))
@@ -183,7 +230,7 @@ def forward(self, x):
183230 fmap = []
184231
185232 x = self .spectrogram (x ).unsqueeze (1 )
186-
233+
187234 for layer in self .convs :
188235 x = F .leaky_relu (layer (x ), self .lrelu_slope )
189236 fmap .append (x )
@@ -195,13 +242,17 @@ def forward(self, x):
195242 def spectrogram (self , x ):
196243 n_fft , hop_length , win_length = self .resolution
197244 pad = int ((n_fft - hop_length ) / 2 )
198- x = F .pad (x , (pad , pad ), mode = "reflect" ,).squeeze (1 )
245+ x = F .pad (
246+ x ,
247+ (pad , pad ),
248+ mode = "reflect" ,
249+ ).squeeze (1 )
199250 x = torch .stft (
200251 x ,
201252 n_fft = n_fft ,
202253 hop_length = hop_length ,
203254 win_length = win_length ,
204- window = torch .ones (win_length , device = x .device ),
255+ window = torch .ones (win_length , device = x .device ),
205256 center = False ,
206257 return_complex = True ,
207258 )
0 commit comments