Skip to content

Commit d2a3b2e

Browse files
authored
Merge pull request #1188 from IAHispano/formatter/main
chore(format): run black on main
2 parents d08c5b4 + b5166e2 commit d2a3b2e

File tree

3 files changed

+67
-14
lines changed

3 files changed

+67
-14
lines changed

rvc/lib/algorithm/discriminators.py

Lines changed: 63 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,14 @@ class MultiPeriodDiscriminator(torch.nn.Module):
2121
Defaults to False.
2222
"""
2323

24-
def __init__(self, use_spectral_norm: bool = False, checkpointing: bool = False, version: str = "v2"):
24+
def __init__(
25+
self,
26+
use_spectral_norm: bool = False,
27+
checkpointing: bool = False,
28+
version: str = "v2",
29+
):
2530
super().__init__()
26-
31+
2732
if version == "v1":
2833
periods = [2, 3, 5, 7, 11, 17]
2934
resolutions = []
@@ -33,12 +38,15 @@ def __init__(self, use_spectral_norm: bool = False, checkpointing: bool = False,
3338
elif version == "v3":
3439
periods = [2, 3, 5, 7, 11]
3540
resolutions = [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]]
36-
41+
3742
self.checkpointing = checkpointing
3843
self.discriminators = torch.nn.ModuleList(
3944
[DiscriminatorS(use_spectral_norm=use_spectral_norm)]
4045
+ [DiscriminatorP(p, use_spectral_norm=use_spectral_norm) for p in periods]
41-
+ [DiscriminatorR(r, use_spectral_norm=use_spectral_norm) for r in resolutions]
46+
+ [
47+
DiscriminatorR(r, use_spectral_norm=use_spectral_norm)
48+
for r in resolutions
49+
]
4250
)
4351

4452
def forward(self, y, y_hat):
@@ -160,6 +168,7 @@ def forward(self, x):
160168
x = torch.flatten(x, 1, -1)
161169
return x, fmap
162170

171+
163172
class DiscriminatorR(torch.nn.Module):
164173
def __init__(self, resolution, use_spectral_norm=False):
165174
super().__init__()
@@ -170,11 +179,49 @@ def __init__(self, resolution, use_spectral_norm=False):
170179

171180
self.convs = torch.nn.ModuleList(
172181
[
173-
norm_f(torch.nn.Conv2d( 1, 32, (3, 9), padding=(1, 4),)),
174-
norm_f(torch.nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4),)),
175-
norm_f(torch.nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4),)),
176-
norm_f(torch.nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4),)),
177-
norm_f(torch.nn.Conv2d(32, 32, (3, 3), padding=(1, 1),)),
182+
norm_f(
183+
torch.nn.Conv2d(
184+
1,
185+
32,
186+
(3, 9),
187+
padding=(1, 4),
188+
)
189+
),
190+
norm_f(
191+
torch.nn.Conv2d(
192+
32,
193+
32,
194+
(3, 9),
195+
stride=(1, 2),
196+
padding=(1, 4),
197+
)
198+
),
199+
norm_f(
200+
torch.nn.Conv2d(
201+
32,
202+
32,
203+
(3, 9),
204+
stride=(1, 2),
205+
padding=(1, 4),
206+
)
207+
),
208+
norm_f(
209+
torch.nn.Conv2d(
210+
32,
211+
32,
212+
(3, 9),
213+
stride=(1, 2),
214+
padding=(1, 4),
215+
)
216+
),
217+
norm_f(
218+
torch.nn.Conv2d(
219+
32,
220+
32,
221+
(3, 3),
222+
padding=(1, 1),
223+
)
224+
),
178225
]
179226
)
180227
self.conv_post = norm_f(torch.nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
@@ -183,7 +230,7 @@ def forward(self, x):
183230
fmap = []
184231

185232
x = self.spectrogram(x).unsqueeze(1)
186-
233+
187234
for layer in self.convs:
188235
x = F.leaky_relu(layer(x), self.lrelu_slope)
189236
fmap.append(x)
@@ -195,13 +242,17 @@ def forward(self, x):
195242
def spectrogram(self, x):
196243
n_fft, hop_length, win_length = self.resolution
197244
pad = int((n_fft - hop_length) / 2)
198-
x = F.pad(x, (pad, pad), mode="reflect",).squeeze(1)
245+
x = F.pad(
246+
x,
247+
(pad, pad),
248+
mode="reflect",
249+
).squeeze(1)
199250
x = torch.stft(
200251
x,
201252
n_fft=n_fft,
202253
hop_length=hop_length,
203254
win_length=win_length,
204-
window=torch.ones(win_length, device=x.device),
255+
window=torch.ones(win_length, device=x.device),
205256
center=False,
206257
return_complex=True,
207258
)

rvc/train/train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,9 @@ def run(
429429
)
430430

431431
net_d = MultiPeriodDiscriminator(
432-
config.model.use_spectral_norm, checkpointing=checkpointing, version=disc_version,
432+
config.model.use_spectral_norm,
433+
checkpointing=checkpointing,
434+
version=disc_version,
433435
)
434436

435437
if torch.cuda.is_available():

tabs/train/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def train_tab():
343343
info=i18n(
344344
"Choose the vocoder for audio synthesis:\n- **HiFi-GAN**: Default option, compatible with all clients.\n- **MRF HiFi-GAN**: Higher fidelity, Applio-only.\n- **RefineGAN**: Superior audio quality, Applio-only."
345345
),
346-
choices=["HiFi-GAN", "RefineGAN"], #"MRF HiFi-GAN", ],
346+
choices=["HiFi-GAN", "RefineGAN"], # "MRF HiFi-GAN", ],
347347
value="HiFi-GAN",
348348
interactive=True,
349349
visible=True,

0 commit comments

Comments
 (0)