77import onnxsim
88import torch
99import torch .nn as nn
10- import torch .nn .functional as functional
11- from torch .nn import Conv1d , ConvTranspose1d , Conv2d
10+ import torch .nn .functional as F
11+ from torch .nn import Conv1d , ConvTranspose1d
1212from torch .nn .utils import weight_norm , remove_weight_norm
1313
1414from modules .nsf_hifigan .env import AttrDict
@@ -45,71 +45,63 @@ def __init__(self, samp_rate, harmonic_num=0,
4545 self .dim = self .harmonic_num + 1
4646 self .sampling_rate = samp_rate
4747 self .voiced_threshold = voiced_threshold
48- self .diff = Conv2d (
49- in_channels = 1 ,
50- out_channels = 1 ,
51- kernel_size = (2 , 1 ),
52- stride = (1 , 1 ),
53- padding = 0 ,
54- dilation = (1 , 1 ),
55- bias = False
56- )
57- self .diff .weight = nn .Parameter (torch .FloatTensor ([[[[- 1. ], [1. ]]]]))
5848
59- def _f02sine (self , f0_values ):
49+ def _f02sine (self , f0_values , upp ):
6050 """ f0_values: (batchsize, length, dim)
6151 where dim indicates fundamental tone and overtones
6252 """
63- # convert to F0 in rad. The integer part n can be ignored
64- # because 2 * np.pi * n doesn't affect phase
65- rad_values = (f0_values / self .sampling_rate ).fmod (1. )
66-
67- # initial phase noise (no noise for fundamental component)
53+ rad_values = (f0_values / self .sampling_rate ).fmod (1. ) ###%1意味着n_har的乘积无法后处理优化
6854 rand_ini = torch .rand (1 , self .dim , device = f0_values .device )
6955 rand_ini [:, 0 ] = 0
7056 rad_values [:, 0 , :] += rand_ini
71-
72- # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
73-
74- # To prevent torch.cumsum numerical overflow,
75- # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
76- # Buffer tmp_over_one_idx indicates the time step to add -1.
77- # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
78- tmp_over_one = torch .cumsum (rad_values , dim = 1 ).fmod (1. )
79-
80- diff = self .diff (tmp_over_one .unsqueeze (1 )).squeeze (1 ) # Equivalent to torch.diff, but able to export ONNX
81- cumsum_shift = (diff < 0 ).float ()
82- cumsum_shift = torch .cat ((torch .zeros ((1 , 1 , self .dim )).to (f0_values .device ), cumsum_shift ), dim = 1 )
83- sines = torch .sin (torch .cumsum (rad_values - cumsum_shift , dim = 1 ) * (2 * np .pi ))
57+ is_half = rad_values .dtype is not torch .float32
58+ tmp_over_one = torch .cumsum (rad_values .double (), 1 ) # % 1 #####%1意味着后面的cumsum无法再优化
59+ if is_half :
60+ tmp_over_one = tmp_over_one .half ()
61+ else :
62+ tmp_over_one = tmp_over_one .float ()
63+ tmp_over_one *= upp
64+ tmp_over_one = F .interpolate (
65+ tmp_over_one .transpose (2 , 1 ), scale_factor = upp ,
66+ mode = 'linear' , align_corners = True
67+ ).transpose (2 , 1 )
68+ rad_values = F .interpolate (rad_values .transpose (2 , 1 ), scale_factor = upp , mode = 'nearest' ).transpose (2 , 1 )
69+ tmp_over_one = tmp_over_one .fmod (1. )
70+ diff = F .conv2d (
71+ tmp_over_one .unsqueeze (1 ), torch .FloatTensor ([[[[- 1. ], [1. ]]]]).to (tmp_over_one .device ),
72+ stride = (1 , 1 ), padding = 0 , dilation = (1 , 1 )
73+ ).squeeze (1 ) # Equivalent to torch.diff, but able to export ONNX
74+ cumsum_shift = (diff < 0 ).double ()
75+ cumsum_shift = torch .cat ((
76+ torch .zeros ((1 , 1 , self .dim ), dtype = torch .double ).to (f0_values .device ),
77+ cumsum_shift
78+ ), dim = 1 )
79+ sines = torch .sin (torch .cumsum (rad_values .double () + cumsum_shift , dim = 1 ) * 2 * np .pi )
80+ if is_half :
81+ sines = sines .half ()
82+ else :
83+ sines = sines .float ()
8484 return sines
8585
86- def forward (self , f0 ):
86+
87+ @torch .no_grad ()
88+ def forward (self , f0 , upp ):
8789 """ sine_tensor, uv = forward(f0)
8890 input F0: tensor(batchsize=1, length, dim=1)
8991 f0 for unvoiced steps should be 0
9092 output sine_tensor: tensor(batchsize=1, length, dim)
9193 output uv: tensor(batchsize=1, length, 1)
9294 """
93- with torch .no_grad ():
94- # fundamental component
95- fn = torch .multiply (f0 , torch .FloatTensor ([[range (1 , self .harmonic_num + 2 )]]).to (f0 .device ))
96-
97- # generate sine waveforms
98- sine_waves = self ._f02sine (fn ) * self .sine_amp
99-
100- # generate uv signal
101- uv = (f0 > self .voiced_threshold ).float ()
102-
103- # noise: for unvoiced should be similar to sine_amp
104- # std = self.sine_amp/3 -> max value ~ self.sine_amp
105- # . for voiced regions is self.noise_std
106- noise_amp = uv * self .noise_std + (1 - uv ) * (self .sine_amp / 3 )
107- noise = noise_amp * torch .randn_like (sine_waves )
108-
109- # first: set the unvoiced part to 0 by uv
110- # then: additive noise
111- sine_waves = sine_waves * uv + noise
112- return sine_waves , uv , noise
95+ f0 = f0 .unsqueeze (- 1 )
96+ fn = torch .multiply (f0 , torch .arange (1 , self .dim + 1 , device = f0 .device ).reshape ((1 , 1 , - 1 )))
97+ sine_waves = self ._f02sine (fn , upp )
98+ sine_waves = sine_waves * self .sine_amp
99+ uv = (f0 > self .voiced_threshold ).float ()
100+ uv = F .interpolate (uv .transpose (2 , 1 ), scale_factor = upp , mode = 'nearest' ).transpose (2 , 1 )
101+ noise_amp = uv * self .noise_std + (1 - uv ) * self .sine_amp / 3
102+ noise = noise_amp * torch .randn_like (sine_waves )
103+ sine_waves = sine_waves * uv + noise
104+ return sine_waves
113105
114106
115107class SourceModuleHnNSF (torch .nn .Module ):
@@ -144,20 +136,10 @@ def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1,
144136 self .l_linear = torch .nn .Linear (harmonic_num + 1 , 1 )
145137 self .l_tanh = torch .nn .Tanh ()
146138
147- def forward (self , x ):
148- """
149- Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
150- F0_sampled (batchsize, length, 1)
151- Sine_source (batchsize, length, 1)
152- noise_source (batchsize, length 1)
153- """
154- # source for harmonic branch
155- sine_wavs , uv , _ = self .l_sin_gen (x )
139+ def forward (self , x , upp ):
140+ sine_wavs = self .l_sin_gen (x , upp )
156141 sine_merge = self .l_tanh (self .l_linear (sine_wavs ))
157-
158- # source for noise branch, in the same shape as uv
159- noise = torch .randn_like (uv ) * (self .sine_amp / 3 )
160- return sine_merge , noise , uv
142+ return sine_merge
161143
162144
163145class Generator (torch .nn .Module ):
@@ -166,10 +148,10 @@ def __init__(self, h):
166148 self .h = h
167149 self .num_kernels = len (h .resblock_kernel_sizes )
168150 self .num_upsamples = len (h .upsample_rates )
169- self .f0_upsamp = torch .nn .Upsample (scale_factor = float (np .prod (h .upsample_rates )))
170151 self .m_source = SourceModuleHnNSF (
171152 sampling_rate = h .sampling_rate ,
172- harmonic_num = 8 )
153+ harmonic_num = 8
154+ )
173155 self .noise_convs = nn .ModuleList ()
174156 self .conv_pre = weight_norm (Conv1d (h .num_mels , h .upsample_initial_channel , 7 , 1 , padding = 3 ))
175157 resblock = ResBlock1 if h .resblock == '1' else ResBlock2
@@ -187,24 +169,22 @@ def __init__(self, h):
187169 else :
188170 self .noise_convs .append (Conv1d (1 , c_cur , kernel_size = 1 ))
189171 self .resblocks = nn .ModuleList ()
190- ch = None
172+ ch = h . upsample_initial_channel
191173 for i in range (len (self .ups )):
192- ch = h . upsample_initial_channel // ( 2 ** ( i + 1 ))
174+ ch //= 2
193175 for j , (k , d ) in enumerate (zip (h .resblock_kernel_sizes , h .resblock_dilation_sizes )):
194176 self .resblocks .append (resblock (h , ch , k , d ))
195177
196178 self .conv_post = weight_norm (Conv1d (ch , 1 , 7 , 1 , padding = 3 ))
197179 self .ups .apply (init_weights )
198180 self .conv_post .apply (init_weights )
181+ self .upp = int (np .prod (h .upsample_rates ))
199182
200183 def forward (self , x , f0 ):
201- f0 = self .f0_upsamp (f0 .unsqueeze (1 )).transpose (1 , 2 ) # bs,n,t
202- har_source , noi_source , uv = self .m_source (f0 )
203- har_source = har_source .transpose (1 , 2 )
184+ har_source = self .m_source (f0 , self .upp ).transpose (1 , 2 )
204185 x = self .conv_pre (x )
205-
206186 for i in range (self .num_upsamples ):
207- x = functional .leaky_relu (x , LRELU_SLOPE )
187+ x = F .leaky_relu (x , LRELU_SLOPE )
208188
209189 x = self .ups [i ](x )
210190 x_source = self .noise_convs [i ](har_source )
@@ -217,10 +197,9 @@ def forward(self, x, f0):
217197 else :
218198 xs += self .resblocks [i * self .num_kernels + j ](x )
219199 x = xs / self .num_kernels
220- x = functional .leaky_relu (x )
200+ x = F .leaky_relu (x )
221201 x = self .conv_post (x )
222202 x = torch .tanh (x )
223-
224203 x = x .squeeze (1 )
225204 return x
226205
@@ -261,7 +240,7 @@ def load_model(model_path, device):
261240 generator = Generator (h ).to (device )
262241
263242 cp_dict = torch .load (model_path )
264- generator .load_state_dict (cp_dict ['generator' ], strict = False )
243+ generator .load_state_dict (cp_dict ['generator' ])
265244 generator .eval ()
266245 generator .remove_weight_norm ()
267246 del cp_dict
@@ -323,18 +302,18 @@ def export(model_path):
323302 1 : 'n_frames'
324303 }
325304 },
326- opset_version = 11
305+ opset_version = 13
327306 )
328307 print ('PyTorch ONNX export finished.' )
329308
330309
331310if __name__ == '__main__' :
332311 sys .argv = [
333- 'inference/ds_e2e .py' ,
312+ 'inference/ds_cascade .py' ,
334313 '--config' ,
335- 'configs/midi/cascade/opencs/test .yaml' ,
314+ 'configs/acoustic/nomidi .yaml' ,
336315 ]
337- path = 'onnx/assets/nsf_hifigan .onnx'
316+ path = 'onnx/assets/nsf_hifigan2 .onnx'
338317 export (path )
339318 simplify (path , path )
340319 print (f'| export \' NSF-HiFiGAN\' to \' { path } \' .' )
0 commit comments