Skip to content

Commit 31b4899

Browse files
committed
Speedup neural source filter
1 parent a59a056 commit 31b4899

File tree

2 files changed

+54
-168
lines changed

2 files changed

+54
-168
lines changed

modules/nsf_hifigan/env.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ def __init__(self, *args, **kwargs):
77
super(AttrDict, self).__init__(*args, **kwargs)
88
self.__dict__ = self
99

10+
def __getattr__(self, item):
11+
return self[item]
12+
1013

1114
def build_env(config, config_name, path):
1215
t_path = os.path.join(path, config_name)

modules/nsf_hifigan/models.py

Lines changed: 51 additions & 168 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,12 @@ def load_model(model_path, device='cuda'):
1717
with open(config_file) as f:
1818
data = f.read()
1919

20-
global h
2120
json_config = json.loads(data)
2221
h = AttrDict(json_config)
2322

2423
generator = Generator(h).to(device)
2524

26-
if torch.cuda.is_available():
27-
cp_dict = torch.load(model_path)
28-
else:
29-
cp_dict = torch.load(model_path, map_location="cpu")
25+
cp_dict = torch.load(model_path, map_location=device)
3026
generator.load_state_dict(cp_dict['generator'])
3127
generator.eval()
3228
generator.remove_weight_norm()
@@ -98,59 +94,6 @@ def remove_weight_norm(self):
9894
remove_weight_norm(l)
9995

10096

101-
class Generator(torch.nn.Module):
102-
def __init__(self, h):
103-
super(Generator, self).__init__()
104-
self.h = h
105-
self.num_kernels = len(h.resblock_kernel_sizes)
106-
self.num_upsamples = len(h.upsample_rates)
107-
self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
108-
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
109-
110-
self.ups = nn.ModuleList()
111-
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
112-
self.ups.append(weight_norm(
113-
ConvTranspose1d(h.upsample_initial_channel // (2 ** i), h.upsample_initial_channel // (2 ** (i + 1)),
114-
k, u, padding=(k - u) // 2)))
115-
116-
self.resblocks = nn.ModuleList()
117-
for i in range(len(self.ups)):
118-
ch = h.upsample_initial_channel // (2 ** (i + 1))
119-
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
120-
self.resblocks.append(resblock(h, ch, k, d))
121-
122-
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
123-
self.ups.apply(init_weights)
124-
self.conv_post.apply(init_weights)
125-
126-
def forward(self, x):
127-
x = self.conv_pre(x)
128-
for i in range(self.num_upsamples):
129-
x = F.leaky_relu(x, LRELU_SLOPE)
130-
x = self.ups[i](x)
131-
xs = None
132-
for j in range(self.num_kernels):
133-
if xs is None:
134-
xs = self.resblocks[i * self.num_kernels + j](x)
135-
else:
136-
xs += self.resblocks[i * self.num_kernels + j](x)
137-
x = xs / self.num_kernels
138-
x = F.leaky_relu(x)
139-
x = self.conv_post(x)
140-
x = torch.tanh(x)
141-
142-
return x
143-
144-
def remove_weight_norm(self):
145-
print('Removing weight norm...')
146-
for l in self.ups:
147-
remove_weight_norm(l)
148-
for l in self.resblocks:
149-
l.remove_weight_norm()
150-
remove_weight_norm(self.conv_pre)
151-
remove_weight_norm(self.conv_post)
152-
153-
15497
class SineGen(torch.nn.Module):
15598
""" Definition of sine generator
15699
SineGen(samp_rate, harmonic_num = 0,
@@ -169,109 +112,64 @@ class SineGen(torch.nn.Module):
169112

170113
def __init__(self, samp_rate, harmonic_num=0,
171114
sine_amp=0.1, noise_std=0.003,
172-
voiced_threshold=0,
173-
flag_for_pulse=False):
115+
voiced_threshold=0):
174116
super(SineGen, self).__init__()
175117
self.sine_amp = sine_amp
176118
self.noise_std = noise_std
177119
self.harmonic_num = harmonic_num
178120
self.dim = self.harmonic_num + 1
179121
self.sampling_rate = samp_rate
180122
self.voiced_threshold = voiced_threshold
181-
self.flag_for_pulse = flag_for_pulse
182123

183124
def _f02uv(self, f0):
184125
# generate uv signal
185-
uv = (f0 > self.voiced_threshold).type(torch.float32)
126+
uv = torch.ones_like(f0)
127+
uv = uv * (f0 > self.voiced_threshold)
186128
return uv
187129

188-
def _f02sine(self, f0_values):
189-
""" f0_values: (batchsize, length, dim)
190-
where dim indicates fundamental tone and overtones
191-
"""
192-
# convert to F0 in rad. The interger part n can be ignored
193-
# because 2 * np.pi * n doesn't affect phase
194-
rad_values = (f0_values / self.sampling_rate) % 1
195-
196-
# initial phase noise (no noise for fundamental component)
197-
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
198-
device=f0_values.device)
199-
rand_ini[:, 0] = 0
200-
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
201-
202-
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
203-
if not self.flag_for_pulse:
204-
# for normal case
205-
206-
# To prevent torch.cumsum numerical overflow,
207-
# it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
208-
# Buffer tmp_over_one_idx indicates the time step to add -1.
209-
# This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
210-
tmp_over_one = torch.cumsum(rad_values, 1) % 1
211-
tmp_over_one_idx = (torch.diff(tmp_over_one, dim=1)) < 0
212-
cumsum_shift = torch.zeros_like(rad_values)
213-
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
214-
215-
sines = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1)
216-
* 2 * np.pi)
217-
else:
218-
# If necessary, make sure that the first time step of every
219-
# voiced segments is sin(pi) or cos(0)
220-
# This is used for pulse-train generation
221-
222-
# identify the last time step in unvoiced segments
223-
uv = self._f02uv(f0_values)
224-
uv_1 = torch.roll(uv, shifts=-1, dims=1)
225-
uv_1[:, -1, :] = 1
226-
u_loc = (uv < 1) * (uv_1 > 0)
227-
228-
# get the instantanouse phase
229-
tmp_cumsum = torch.cumsum(rad_values, dim=1)
230-
# different batch needs to be processed differently
231-
for idx in range(f0_values.shape[0]):
232-
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
233-
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
234-
# stores the accumulation of i.phase within
235-
# each voiced segments
236-
tmp_cumsum[idx, :, :] = 0
237-
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
238-
239-
# rad_values - tmp_cumsum: remove the accumulation of i.phase
240-
# within the previous voiced segment.
241-
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
242-
243-
# get the sines
244-
sines = torch.cos(i_phase * 2 * np.pi)
245-
return sines
246-
247-
def forward(self, f0):
130+
@torch.no_grad()
131+
def forward(self, f0, upp):
248132
""" sine_tensor, uv = forward(f0)
249133
input F0: tensor(batchsize=1, length, dim=1)
250134
f0 for unvoiced steps should be 0
251135
output sine_tensor: tensor(batchsize=1, length, dim)
252136
output uv: tensor(batchsize=1, length, 1)
253137
"""
254-
with torch.no_grad():
255-
# fundamental component
256-
fn = torch.multiply(f0, torch.arange(1, self.harmonic_num + 2, device=f0.device))
257-
258-
# generate sine waveforms
259-
sine_waves = self._f02sine(fn) * self.sine_amp
260-
261-
# generate uv signal
262-
# uv = torch.ones(f0.shape)
263-
# uv = uv * (f0 > self.voiced_threshold)
264-
uv = self._f02uv(f0)
265-
266-
# noise: for unvoiced should be similar to sine_amp
267-
# std = self.sine_amp/3 -> max value ~ self.sine_amp
268-
# . for voiced regions is self.noise_std
269-
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
270-
noise = noise_amp * torch.randn_like(sine_waves)
271-
272-
# first: set the unvoiced part to 0 by uv
273-
# then: additive noise
274-
sine_waves = sine_waves * uv + noise
138+
f0 = f0.unsqueeze(-1)
139+
fn = torch.multiply(f0, torch.arange(1, self.dim + 1, device=f0.device).reshape((1, 1, -1)))
140+
rad_values = (fn / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化
141+
rand_ini = torch.rand(fn.shape[0], fn.shape[2], device=fn.device)
142+
rand_ini[:, 0] = 0
143+
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
144+
is_half = rad_values.dtype is not torch.float32
145+
tmp_over_one = torch.cumsum(rad_values.double(), 1) # % 1 #####%1意味着后面的cumsum无法再优化
146+
if is_half:
147+
tmp_over_one = tmp_over_one.half()
148+
else:
149+
tmp_over_one = tmp_over_one.float()
150+
tmp_over_one *= upp
151+
tmp_over_one = F.interpolate(
152+
tmp_over_one.transpose(2, 1), scale_factor=upp,
153+
mode='linear', align_corners=True
154+
).transpose(2, 1)
155+
rad_values = F.interpolate(rad_values.transpose(2, 1), scale_factor=upp, mode='nearest').transpose(2, 1)
156+
tmp_over_one %= 1
157+
tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
158+
cumsum_shift = torch.zeros_like(rad_values)
159+
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
160+
rad_values = rad_values.double()
161+
cumsum_shift = cumsum_shift.double()
162+
sine_waves = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi)
163+
if is_half:
164+
sine_waves = sine_waves.half()
165+
else:
166+
sine_waves = sine_waves.float()
167+
sine_waves = sine_waves * self.sine_amp
168+
uv = self._f02uv(f0)
169+
uv = F.interpolate(uv.transpose(2, 1), scale_factor=upp, mode='nearest').transpose(2, 1)
170+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
171+
noise = noise_amp * torch.randn_like(sine_waves)
172+
sine_waves = sine_waves * uv + noise
275173
return sine_waves, uv, noise
276174

277175

@@ -308,20 +206,10 @@ def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1,
308206
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
309207
self.l_tanh = torch.nn.Tanh()
310208

311-
def forward(self, x):
312-
"""
313-
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
314-
F0_sampled (batchsize, length, 1)
315-
Sine_source (batchsize, length, 1)
316-
noise_source (batchsize, length 1)
317-
"""
318-
# source for harmonic branch
319-
sine_wavs, uv, _ = self.l_sin_gen(x)
209+
def forward(self, x, upp):
210+
sine_wavs, uv, _ = self.l_sin_gen(x, upp)
320211
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
321-
322-
# source for noise branch, in the same shape as uv
323-
noise = torch.randn_like(uv) * self.sine_amp / 3
324-
return sine_merge, noise, uv
212+
return sine_merge
325213

326214

327215
class Generator(torch.nn.Module):
@@ -330,10 +218,10 @@ def __init__(self, h):
330218
self.h = h
331219
self.num_kernels = len(h.resblock_kernel_sizes)
332220
self.num_upsamples = len(h.upsample_rates)
333-
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h.upsample_rates))
334221
self.m_source = SourceModuleHnNSF(
335222
sampling_rate=h.sampling_rate,
336-
harmonic_num=8)
223+
harmonic_num=8
224+
)
337225
self.noise_convs = nn.ModuleList()
338226
self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
339227
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
@@ -345,35 +233,30 @@ def __init__(self, h):
345233
ConvTranspose1d(h.upsample_initial_channel // (2 ** i), h.upsample_initial_channel // (2 ** (i + 1)),
346234
k, u, padding=(k - u) // 2)))
347235
if i + 1 < len(h.upsample_rates): #
348-
stride_f0 = np.prod(h.upsample_rates[i + 1:])
236+
stride_f0 = int(np.prod(h.upsample_rates[i + 1:]))
349237
self.noise_convs.append(Conv1d(
350238
1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2))
351239
else:
352240
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
353241
self.resblocks = nn.ModuleList()
242+
ch = h.upsample_initial_channel
354243
for i in range(len(self.ups)):
355-
ch = h.upsample_initial_channel // (2 ** (i + 1))
244+
ch //= 2
356245
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
357246
self.resblocks.append(resblock(h, ch, k, d))
358247

359248
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
360249
self.ups.apply(init_weights)
361250
self.conv_post.apply(init_weights)
251+
self.upp = int(np.prod(h.upsample_rates))
362252

363253
def forward(self, x, f0):
364-
# print(1,x.shape,f0.shape,f0[:, None].shape)
365-
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
366-
# print(2,f0.shape)
367-
har_source, noi_source, uv = self.m_source(f0)
368-
har_source = har_source.transpose(1, 2)
254+
har_source = self.m_source(f0, self.upp).transpose(1, 2)
369255
x = self.conv_pre(x)
370-
# print(124,x.shape,har_source.shape)
371256
for i in range(self.num_upsamples):
372257
x = F.leaky_relu(x, LRELU_SLOPE)
373-
# print(3,x.shape)
374258
x = self.ups[i](x)
375259
x_source = self.noise_convs[i](har_source)
376-
# print(4,x_source.shape,har_source.shape,x.shape)
377260
x = x + x_source
378261
xs = None
379262
for j in range(self.num_kernels):

0 commit comments

Comments
 (0)