Skip to content

Commit ec3c7a8

Browse files
committed
ONNX export with NSF speedup
1 parent 31b4899 commit ec3c7a8

File tree

1 file changed

+60
-81
lines changed

1 file changed

+60
-81
lines changed

onnx/export/export_nsf_hifigan.py

Lines changed: 60 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
import onnxsim
88
import torch
99
import torch.nn as nn
10-
import torch.nn.functional as functional
11-
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
10+
import torch.nn.functional as F
11+
from torch.nn import Conv1d, ConvTranspose1d
1212
from torch.nn.utils import weight_norm, remove_weight_norm
1313

1414
from modules.nsf_hifigan.env import AttrDict
@@ -45,71 +45,63 @@ def __init__(self, samp_rate, harmonic_num=0,
4545
self.dim = self.harmonic_num + 1
4646
self.sampling_rate = samp_rate
4747
self.voiced_threshold = voiced_threshold
48-
self.diff = Conv2d(
49-
in_channels=1,
50-
out_channels=1,
51-
kernel_size=(2, 1),
52-
stride=(1, 1),
53-
padding=0,
54-
dilation=(1, 1),
55-
bias=False
56-
)
57-
self.diff.weight = nn.Parameter(torch.FloatTensor([[[[-1.], [1.]]]]))
5848

59-
def _f02sine(self, f0_values):
49+
def _f02sine(self, f0_values, upp):
6050
""" f0_values: (batchsize, length, dim)
6151
where dim indicates fundamental tone and overtones
6252
"""
63-
# convert to F0 in rad. The integer part n can be ignored
64-
# because 2 * np.pi * n doesn't affect phase
65-
rad_values = (f0_values / self.sampling_rate).fmod(1.)
66-
67-
# initial phase noise (no noise for fundamental component)
53+
rad_values = (f0_values / self.sampling_rate).fmod(1.) ###%1意味着n_har的乘积无法后处理优化
6854
rand_ini = torch.rand(1, self.dim, device=f0_values.device)
6955
rand_ini[:, 0] = 0
7056
rad_values[:, 0, :] += rand_ini
71-
72-
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
73-
74-
# To prevent torch.cumsum numerical overflow,
75-
# it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
76-
# Buffer tmp_over_one_idx indicates the time step to add -1.
77-
# This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
78-
tmp_over_one = torch.cumsum(rad_values, dim=1).fmod(1.)
79-
80-
diff = self.diff(tmp_over_one.unsqueeze(1)).squeeze(1) # Equivalent to torch.diff, but able to export ONNX
81-
cumsum_shift = (diff < 0).float()
82-
cumsum_shift = torch.cat((torch.zeros((1, 1, self.dim)).to(f0_values.device), cumsum_shift), dim=1)
83-
sines = torch.sin(torch.cumsum(rad_values - cumsum_shift, dim=1) * (2 * np.pi))
57+
is_half = rad_values.dtype is not torch.float32
58+
tmp_over_one = torch.cumsum(rad_values.double(), 1) # % 1 #####%1意味着后面的cumsum无法再优化
59+
if is_half:
60+
tmp_over_one = tmp_over_one.half()
61+
else:
62+
tmp_over_one = tmp_over_one.float()
63+
tmp_over_one *= upp
64+
tmp_over_one = F.interpolate(
65+
tmp_over_one.transpose(2, 1), scale_factor=upp,
66+
mode='linear', align_corners=True
67+
).transpose(2, 1)
68+
rad_values = F.interpolate(rad_values.transpose(2, 1), scale_factor=upp, mode='nearest').transpose(2, 1)
69+
tmp_over_one = tmp_over_one.fmod(1.)
70+
diff = F.conv2d(
71+
tmp_over_one.unsqueeze(1), torch.FloatTensor([[[[-1.], [1.]]]]).to(tmp_over_one.device),
72+
stride=(1, 1), padding=0, dilation=(1, 1)
73+
).squeeze(1) # Equivalent to torch.diff, but able to export ONNX
74+
cumsum_shift = (diff < 0).double()
75+
cumsum_shift = torch.cat((
76+
torch.zeros((1, 1, self.dim), dtype=torch.double).to(f0_values.device),
77+
cumsum_shift
78+
), dim=1)
79+
sines = torch.sin(torch.cumsum(rad_values.double() + cumsum_shift, dim=1) * 2 * np.pi)
80+
if is_half:
81+
sines = sines.half()
82+
else:
83+
sines = sines.float()
8484
return sines
8585

86-
def forward(self, f0):
86+
87+
@torch.no_grad()
88+
def forward(self, f0, upp):
8789
""" sine_tensor, uv = forward(f0)
8890
input F0: tensor(batchsize=1, length, dim=1)
8991
f0 for unvoiced steps should be 0
9092
output sine_tensor: tensor(batchsize=1, length, dim)
9193
output uv: tensor(batchsize=1, length, 1)
9294
"""
93-
with torch.no_grad():
94-
# fundamental component
95-
fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
96-
97-
# generate sine waveforms
98-
sine_waves = self._f02sine(fn) * self.sine_amp
99-
100-
# generate uv signal
101-
uv = (f0 > self.voiced_threshold).float()
102-
103-
# noise: for unvoiced should be similar to sine_amp
104-
# std = self.sine_amp/3 -> max value ~ self.sine_amp
105-
# . for voiced regions is self.noise_std
106-
noise_amp = uv * self.noise_std + (1 - uv) * (self.sine_amp / 3)
107-
noise = noise_amp * torch.randn_like(sine_waves)
108-
109-
# first: set the unvoiced part to 0 by uv
110-
# then: additive noise
111-
sine_waves = sine_waves * uv + noise
112-
return sine_waves, uv, noise
95+
f0 = f0.unsqueeze(-1)
96+
fn = torch.multiply(f0, torch.arange(1, self.dim + 1, device=f0.device).reshape((1, 1, -1)))
97+
sine_waves = self._f02sine(fn, upp)
98+
sine_waves = sine_waves * self.sine_amp
99+
uv = (f0 > self.voiced_threshold).float()
100+
uv = F.interpolate(uv.transpose(2, 1), scale_factor=upp, mode='nearest').transpose(2, 1)
101+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
102+
noise = noise_amp * torch.randn_like(sine_waves)
103+
sine_waves = sine_waves * uv + noise
104+
return sine_waves
113105

114106

115107
class SourceModuleHnNSF(torch.nn.Module):
@@ -144,20 +136,10 @@ def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1,
144136
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
145137
self.l_tanh = torch.nn.Tanh()
146138

147-
def forward(self, x):
148-
"""
149-
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
150-
F0_sampled (batchsize, length, 1)
151-
Sine_source (batchsize, length, 1)
152-
noise_source (batchsize, length 1)
153-
"""
154-
# source for harmonic branch
155-
sine_wavs, uv, _ = self.l_sin_gen(x)
139+
def forward(self, x, upp):
140+
sine_wavs = self.l_sin_gen(x, upp)
156141
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
157-
158-
# source for noise branch, in the same shape as uv
159-
noise = torch.randn_like(uv) * (self.sine_amp / 3)
160-
return sine_merge, noise, uv
142+
return sine_merge
161143

162144

163145
class Generator(torch.nn.Module):
@@ -166,10 +148,10 @@ def __init__(self, h):
166148
self.h = h
167149
self.num_kernels = len(h.resblock_kernel_sizes)
168150
self.num_upsamples = len(h.upsample_rates)
169-
self.f0_upsamp = torch.nn.Upsample(scale_factor=float(np.prod(h.upsample_rates)))
170151
self.m_source = SourceModuleHnNSF(
171152
sampling_rate=h.sampling_rate,
172-
harmonic_num=8)
153+
harmonic_num=8
154+
)
173155
self.noise_convs = nn.ModuleList()
174156
self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
175157
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
@@ -187,24 +169,22 @@ def __init__(self, h):
187169
else:
188170
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
189171
self.resblocks = nn.ModuleList()
190-
ch = None
172+
ch = h.upsample_initial_channel
191173
for i in range(len(self.ups)):
192-
ch = h.upsample_initial_channel // (2 ** (i + 1))
174+
ch //= 2
193175
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
194176
self.resblocks.append(resblock(h, ch, k, d))
195177

196178
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
197179
self.ups.apply(init_weights)
198180
self.conv_post.apply(init_weights)
181+
self.upp = int(np.prod(h.upsample_rates))
199182

200183
def forward(self, x, f0):
201-
f0 = self.f0_upsamp(f0.unsqueeze(1)).transpose(1, 2) # bs,n,t
202-
har_source, noi_source, uv = self.m_source(f0)
203-
har_source = har_source.transpose(1, 2)
184+
har_source = self.m_source(f0, self.upp).transpose(1, 2)
204185
x = self.conv_pre(x)
205-
206186
for i in range(self.num_upsamples):
207-
x = functional.leaky_relu(x, LRELU_SLOPE)
187+
x = F.leaky_relu(x, LRELU_SLOPE)
208188

209189
x = self.ups[i](x)
210190
x_source = self.noise_convs[i](har_source)
@@ -217,10 +197,9 @@ def forward(self, x, f0):
217197
else:
218198
xs += self.resblocks[i * self.num_kernels + j](x)
219199
x = xs / self.num_kernels
220-
x = functional.leaky_relu(x)
200+
x = F.leaky_relu(x)
221201
x = self.conv_post(x)
222202
x = torch.tanh(x)
223-
224203
x = x.squeeze(1)
225204
return x
226205

@@ -261,7 +240,7 @@ def load_model(model_path, device):
261240
generator = Generator(h).to(device)
262241

263242
cp_dict = torch.load(model_path)
264-
generator.load_state_dict(cp_dict['generator'], strict=False)
243+
generator.load_state_dict(cp_dict['generator'])
265244
generator.eval()
266245
generator.remove_weight_norm()
267246
del cp_dict
@@ -323,18 +302,18 @@ def export(model_path):
323302
1: 'n_frames'
324303
}
325304
},
326-
opset_version=11
305+
opset_version=13
327306
)
328307
print('PyTorch ONNX export finished.')
329308

330309

331310
if __name__ == '__main__':
332311
sys.argv = [
333-
'inference/ds_e2e.py',
312+
'inference/ds_cascade.py',
334313
'--config',
335-
'configs/midi/cascade/opencs/test.yaml',
314+
'configs/acoustic/nomidi.yaml',
336315
]
337-
path = 'onnx/assets/nsf_hifigan.onnx'
316+
path = 'onnx/assets/nsf_hifigan2.onnx'
338317
export(path)
339318
simplify(path, path)
340319
print(f'| export \'NSF-HiFiGAN\' to \'{path}\'.')

0 commit comments

Comments
 (0)