Skip to content

Commit defc997

Browse files
authored
Merge pull request #53 from openvpi/stretch-v2
Support time stretching and velocity control
2 parents f63bc47 + a017147 commit defc997

File tree

15 files changed

+322
-73
lines changed

15 files changed

+322
-73
lines changed

augmentation/pitch_shift.py

Lines changed: 0 additions & 26 deletions
This file was deleted.

augmentation/spec_stretch.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from copy import deepcopy
2+
3+
import numpy as np
4+
import torch
5+
6+
from basics.base_augmentation import BaseAugmentation
7+
from data_gen.data_gen_utils import get_pitch_parselmouth
8+
from modules.fastspeech.tts_modules import LengthRegulator
9+
from src.vocoders.base_vocoder import VOCODERS
10+
from utils.hparams import hparams
11+
from utils.pitch_utils import f0_to_coarse
12+
13+
14+
class SpectrogramStretchAugmentation(BaseAugmentation):
15+
"""
16+
This class contains methods for frequency-domain and time-domain stretching augmentation.
17+
"""
18+
def __init__(self, data_dirs: list, augmentation_args: dict):
19+
super().__init__(data_dirs, augmentation_args)
20+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
21+
self.lr = LengthRegulator().to(self.device)
22+
23+
def process_item(self, item: dict, key_shift=0., speed=1., replace_spk_id=None) -> dict:
24+
aug_item = deepcopy(item)
25+
if hparams['vocoder'] in VOCODERS:
26+
wav, mel = VOCODERS[hparams['vocoder']].wav2spec(
27+
aug_item['wav_fn'], keyshift=key_shift, speed=speed
28+
)
29+
else:
30+
wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(
31+
aug_item['wav_fn'], keyshift=key_shift, speed=speed
32+
)
33+
34+
aug_item['mel'] = mel
35+
36+
if speed != 1. or hparams.get('use_speed_embed', False):
37+
aug_item['len'] = len(mel)
38+
aug_item['speed'] = int(np.round(hparams['hop_size'] * speed)) / hparams['hop_size'] # real speed
39+
aug_item['sec'] /= aug_item['speed']
40+
aug_item['ph_durs'] /= aug_item['speed']
41+
aug_item['mel2ph'] = self.get_mel2ph(aug_item['ph_durs'], aug_item['len'])
42+
aug_item['f0'], aug_item['pitch'] = get_pitch_parselmouth(wav, mel, hparams, speed=speed)
43+
44+
if key_shift != 0. or hparams.get('use_key_shift_embed', False):
45+
aug_item['key_shift'] = key_shift
46+
aug_item['f0'] *= 2 ** (key_shift / 12)
47+
aug_item['pitch'] = f0_to_coarse(aug_item['f0'])
48+
49+
if replace_spk_id is not None:
50+
aug_item['spk_id'] = replace_spk_id
51+
52+
return aug_item
53+
54+
@torch.no_grad()
55+
def get_mel2ph(self, durs, length):
56+
ph_acc = np.around(
57+
np.add.accumulate(durs) * hparams['audio_sample_rate'] / hparams['hop_size'] + 0.5
58+
).astype('int')
59+
ph_dur = np.diff(ph_acc, prepend=0)
60+
ph_dur = torch.LongTensor(ph_dur)[None].to(self.device)
61+
mel2ph = self.lr(ph_dur).cpu().numpy()[0]
62+
num_frames = len(mel2ph)
63+
if num_frames < length:
64+
mel2ph = np.concatenate((mel2ph, np.full((length - num_frames, mel2ph[-1]))), axis=0)
65+
elif num_frames > length:
66+
mel2ph = mel2ph[:length]
67+
return mel2ph

basics/base_binarizer.py

Lines changed: 69 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import shutil
23
import os
34
os.environ["OMP_NUM_THREADS"] = "1"
@@ -218,38 +219,43 @@ def arrange_data_augmentation(self, prefix):
218219
Code for all types of data augmentation should be added here.
219220
"""
220221
aug_map = {}
222+
aug_list = []
221223
all_item_names = [item_name for item_name, _ in self.meta_data_iterator(prefix)]
224+
total_scale = 0
222225
if self.augmentation_args.get('random_pitch_shifting') is not None:
223-
from augmentation.pitch_shift import PitchShiftAugmentation
226+
from augmentation.spec_stretch import SpectrogramStretchAugmentation
224227
aug_args = self.augmentation_args['random_pitch_shifting']
225228
key_shift_min, key_shift_max = aug_args['range']
226229
assert hparams.get('use_key_shift_embed', False), \
227230
'Random pitch shifting augmentation requires use_key_shift_embed == True.'
228231
assert key_shift_min < 0 < key_shift_max, \
229232
'Random pitch shifting augmentation must have a range where min < 0 < max.'
230233

231-
aug_ins = PitchShiftAugmentation(self.raw_data_dirs, aug_args)
234+
aug_ins = SpectrogramStretchAugmentation(self.raw_data_dirs, aug_args)
232235
scale = aug_args['scale']
233-
aug_item_names = all_item_names * int(scale) \
234-
+ random.sample(all_item_names, int(len(all_item_names) * (scale - int(scale))))
236+
aug_item_names = random.choices(all_item_names, k=int(scale * len(all_item_names)))
235237

236238
for aug_item_name in aug_item_names:
237-
rand = random.random() * 2 - 1
239+
rand = random.uniform(-1, 1)
238240
if rand < 0:
239241
key_shift = key_shift_min * abs(rand)
240242
else:
241243
key_shift = key_shift_max * rand
242244
aug_task = {
245+
'name': aug_item_name,
243246
'func': aug_ins.process_item,
244247
'kwargs': {'key_shift': key_shift}
245248
}
246249
if aug_item_name in aug_map:
247250
aug_map[aug_item_name].append(aug_task)
248251
else:
249252
aug_map[aug_item_name] = [aug_task]
253+
aug_list.append(aug_task)
254+
255+
total_scale += scale
250256

251257
if self.augmentation_args.get('fixed_pitch_shifting') is not None:
252-
from augmentation.pitch_shift import PitchShiftAugmentation
258+
from augmentation.spec_stretch import SpectrogramStretchAugmentation
253259
aug_args = self.augmentation_args['fixed_pitch_shifting']
254260
targets = aug_args['targets']
255261
scale = aug_args['scale']
@@ -262,19 +268,74 @@ def arrange_data_augmentation(self, prefix):
262268
'Fixed pitch shifting augmentation requires num_spk >= (1 + len(targets)) * len(speakers).'
263269
assert scale < 1, 'Fixed pitch shifting augmentation requires scale < 1.'
264270

265-
aug_ins = PitchShiftAugmentation(self.raw_data_dirs, aug_args)
271+
aug_ins = SpectrogramStretchAugmentation(self.raw_data_dirs, aug_args)
266272
for i, target in enumerate(targets):
267-
aug_item_names = random.sample(all_item_names, int(len(all_item_names) * scale))
273+
aug_item_names = random.choices(all_item_names, k=int(scale * len(all_item_names)))
268274
for aug_item_name in aug_item_names:
269275
replace_spk_id = int(aug_item_name.split(':', maxsplit=1)[0]) + (i + 1) * len(self.spk_map)
270276
aug_task = {
277+
'name': aug_item_name,
271278
'func': aug_ins.process_item,
272279
'kwargs': {'key_shift': target, 'replace_spk_id': replace_spk_id}
273280
}
274281
if aug_item_name in aug_map:
275282
aug_map[aug_item_name].append(aug_task)
276283
else:
277284
aug_map[aug_item_name] = [aug_task]
285+
aug_list.append(aug_task)
286+
287+
total_scale += scale * len(targets)
288+
289+
if self.augmentation_args.get('random_time_stretching') is not None:
290+
from augmentation.spec_stretch import SpectrogramStretchAugmentation
291+
aug_args = self.augmentation_args['random_time_stretching']
292+
speed_min, speed_max = aug_args['range']
293+
domain = aug_args['domain']
294+
assert hparams.get('use_speed_embed', False), \
295+
'Random time stretching augmentation requires use_speed_embed == True.'
296+
assert 0 < speed_min < 1 < speed_max, \
297+
'Random time stretching augmentation must have a range where 0 < min < 1 < max.'
298+
assert domain in ['log', 'linear'], 'domain must be \'log\' or \'linear\'.'
299+
300+
aug_ins = SpectrogramStretchAugmentation(self.raw_data_dirs, aug_args)
301+
scale = aug_args['scale']
302+
k_from_raw = int(scale / (1 + total_scale) * len(all_item_names))
303+
k_from_aug = int(total_scale * scale / (1 + total_scale) * len(all_item_names))
304+
k_mutate = int(total_scale * scale / (1 + scale) * len(all_item_names))
305+
aug_types = [0] * k_from_raw + [1] * k_from_aug + [2] * k_mutate
306+
aug_items = random.choices(all_item_names, k=k_from_raw) + random.choices(aug_list, k=k_from_aug + k_mutate)
307+
308+
for aug_type, aug_item in zip(aug_types, aug_items):
309+
if domain == 'log':
310+
# Uniform distribution in log domain
311+
speed = speed_min * (speed_max / speed_min) ** random.random()
312+
else:
313+
# Uniform distribution in linear domain
314+
rand = random.uniform(-1, 1)
315+
speed = 1 + (speed_max - 1) * rand if rand >= 0 else 1 + (1 - speed_min) * rand
316+
if aug_type == 0:
317+
aug_task = {
318+
'name': aug_item,
319+
'func': aug_ins.process_item,
320+
'kwargs': {'speed': speed}
321+
}
322+
if aug_item in aug_map:
323+
aug_map[aug_item].append(aug_task)
324+
else:
325+
aug_map[aug_item] = [aug_task]
326+
aug_list.append(aug_task)
327+
elif aug_type == 1:
328+
aug_task = copy.deepcopy(aug_item)
329+
aug_item['kwargs']['speed'] = speed
330+
if aug_item['name'] in aug_map:
331+
aug_map[aug_item['name']].append(aug_task)
332+
else:
333+
aug_map[aug_item['name']] = [aug_task]
334+
aug_list.append(aug_task)
335+
elif aug_type == 2:
336+
aug_item['kwargs']['speed'] = speed
337+
338+
total_scale += scale
278339

279340
return aug_map
280341

configs/acoustic/nomidi.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ binarization_args:
4040
# fixed_pitch_shifting:
4141
# targets: [-5., 5.]
4242
# scale: 0.75
43+
# random_time_stretching:
44+
# range: [0.5, 2.]
45+
# domain: log # or linear
46+
# scale: 2.0
4347

4448
raw_data_dir: 'data/opencpop/raw'
4549
processed_data_dir: ''
@@ -66,7 +70,8 @@ use_uv: false
6670
use_midi: false
6771
use_spk_embed: false
6872
use_spk_id: false
69-
#use_key_shift_embed: true
73+
use_key_shift_embed: false
74+
use_speed_embed: false
7075
use_gt_f0: false # for midi exp
7176
use_gt_dur: false # for further midi exp
7277
f0_embed_type: continuous

data_gen/data_gen_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,22 +147,24 @@ def process_utterance(wav_path,
147147
return wav, mel, spc
148148

149149

150-
def get_pitch_parselmouth(wav_data, mel, hparams):
150+
def get_pitch_parselmouth(wav_data, mel, hparams, speed=1):
151151
"""
152152
153153
:param wav_data: [T]
154-
:param mel: [T, 80]
154+
:param mel: [T, mel_bins]
155155
:param hparams:
156156
:return:
157157
"""
158-
time_step = hparams['hop_size'] / hparams['audio_sample_rate'] * 1000
158+
hop_size = int(np.round(hparams['hop_size'] * speed))
159+
160+
time_step = hop_size / hparams['audio_sample_rate'] * 1000
159161
f0_min = 65
160162
f0_max = 800
161163

162164
f0 = parselmouth.Sound(wav_data, hparams['audio_sample_rate']).to_pitch_ac(
163165
time_step=time_step / 1000, voicing_threshold=0.6,
164166
pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']
165-
pad_size=(int(len(wav_data) // hparams['hop_size']) - len(f0) + 1) // 2
167+
pad_size=(int(len(wav_data) // hop_size) - len(f0) + 1) // 2
166168
f0 = np.pad(f0,[[pad_size,len(mel) - len(f0) - pad_size]], mode='constant')
167169
pitch_coarse = f0_to_coarse(f0)
168170
return f0, pitch_coarse

inference/ds_cascade.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ def preprocess_phoneme_level_input(self, inp):
4444
gender = np.array(inp['gender'].split(), 'float')
4545
else:
4646
gender = float(inp['gender'])
47+
velocity_timestep = None
48+
velocity = None
49+
if inp.get('velocity') is not None:
50+
velocity_timestep = float(inp['velocity_timestep'])
51+
velocity = np.array(inp['velocity'].split(), 'float')
4752
ph_seq_lst = ph_seq.split()
4853
if inp['ph_dur'] is not None:
4954
ph_dur = np.array(inp['ph_dur'].split(), 'float')
@@ -58,7 +63,8 @@ def preprocess_phoneme_level_input(self, inp):
5863
f'{len(note_lst)} {len(ph_seq.split())} {len(midi_dur_lst)}')
5964
print(f'Processed {len(ph_seq_lst)} tokens: {" ".join(ph_seq_lst)}')
6065

61-
return ph_seq, note_lst, midi_dur_lst, is_slur, ph_dur, f0_timestep, f0_seq, gender_timestep, gender
66+
return ph_seq, note_lst, midi_dur_lst, is_slur, ph_dur, \
67+
f0_timestep, f0_seq, gender_timestep, gender, velocity_timestep, velocity
6268

6369
def preprocess_input(self, inp, input_type='word'):
6470
"""
@@ -90,9 +96,10 @@ def preprocess_input(self, inp, input_type='word'):
9096
# get ph seq, note lst, midi dur lst, is slur lst.
9197
if input_type == 'word':
9298
ph_seq, note_lst, midi_dur_lst, is_slur = self.preprocess_word_level_input(inp)
93-
ph_dur = f0_timestep = f0_seq = gender_timestep = gender = None
99+
ph_dur = f0_timestep = f0_seq = gender_timestep = gender = velocity_timestep = velocity = None
94100
elif input_type == 'phoneme': # like transcriptions.txt in Opencpop dataset.
95-
ph_seq, note_lst, midi_dur_lst, is_slur, ph_dur, f0_timestep, f0_seq, gender_timestep, gender = \
101+
ph_seq, note_lst, midi_dur_lst, is_slur, ph_dur, \
102+
f0_timestep, f0_seq, gender_timestep, gender, velocity_timestep, velocity = \
96103
self.preprocess_phoneme_level_input(inp)
97104
else:
98105
raise ValueError('Invalid input type. Must be \'word\' or \'phoneme\'.')
@@ -118,6 +125,8 @@ def preprocess_input(self, inp, input_type='word'):
118125
item['f0_seq'] = f0_seq
119126
item['gender_timestep'] = gender_timestep
120127
item['gender'] = gender
128+
item['velocity_timestep'] = velocity_timestep
129+
item['velocity'] = velocity
121130
item['spk_mix_timestep'] = inp.get('spk_mix_timestep')
122131
return item
123132

@@ -210,6 +219,23 @@ def input_to_batch(self, item):
210219
else:
211220
key_shift = None
212221

222+
if hparams.get('use_speed_embed', False):
223+
if item['velocity'] is None:
224+
print('Using default velocity curve')
225+
speed = torch.FloatTensor([1.]).to(self.device)
226+
else:
227+
print('Using manual velocity curve')
228+
velocity_timestep = item['velocity_timestep']
229+
velocity_seq = item['velocity']
230+
speed_min, speed_max = hparams['augmentation_args']['random_time_stretching']['range']
231+
speed_seq = np.clip(velocity_seq, a_min=speed_min, a_max=speed_max)
232+
t_max = (len(speed_seq) - 1) * velocity_timestep
233+
dt = hparams['hop_size'] / hparams['audio_sample_rate']
234+
speed_interp = np.interp(np.arange(0, t_max, dt), velocity_timestep * np.arange(len(speed_seq)), speed_seq)
235+
speed = torch.FloatTensor(speed_interp)[None, :].to(self.device)
236+
else:
237+
speed = None
238+
213239
batch = {
214240
'item_name': item_names,
215241
'text': text,
@@ -222,7 +248,8 @@ def input_to_batch(self, item):
222248
'is_slur': is_slur,
223249
'mel2ph': mel2ph,
224250
'log2f0': log2f0,
225-
'key_shift': key_shift
251+
'key_shift': key_shift,
252+
'speed': speed
226253
}
227254
return batch
228255

@@ -240,8 +267,8 @@ def forward_model(self, inp, return_mel=False):
240267
output = self.model(txt_tokens, spk_mix_embed=spk_mix_embed, ref_mels=None, infer=True,
241268
pitch_midi=sample['pitch_midi'], midi_dur=sample['midi_dur'],
242269
is_slur=sample['is_slur'], mel2ph=sample['mel2ph'], f0=sample['log2f0'],
243-
key_shift=sample['key_shift'])
244-
mel_out = output['mel_out'] # [B, T,80]
270+
key_shift=sample['key_shift'], speed=sample['speed'])
271+
mel_out = output['mel_out'] # [B, T, M]
245272
f0_pred = output['f0_denorm']
246273
if return_mel:
247274
return mel_out.cpu(), f0_pred.cpu()

modules/fastspeech/tts_modules.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,33 @@ def forward(self, dur, dur_padding=None, alpha=1.0):
189189
return mel2ph
190190

191191

192+
class StretchRegulator(torch.nn.Module):
193+
def forward(self, dur, mel2ph):
194+
"""
195+
Example (no batch dim version):
196+
1. dur = [2,4,3]
197+
2. mel2ph = [1,1,2,2,2,2,3,3,3]
198+
3. mel2dur = [2,2,4,4,4,4,3,3,3]
199+
4. bound_mask = [0,1,0,0,0,1,0,0,1]
200+
5. 1 - bound_mask * mel2dur = [1,-1,1,1,1,-3,1,1,-2] => pad => [0,1,-1,1,1,1,-3,1,1]
201+
6. stretch_denorm = [0,1,0,1,2,3,0,1,2]
202+
203+
:param dur: Batch of durations of each frame (B, T_txt)
204+
:param mel2ph: Batch of mel2ph (B, T_speech)
205+
:return:
206+
stretch (B, T_speech)
207+
"""
208+
dur = F.pad(dur, [1, 0], value=1) # Avoid dividing by zero
209+
mel2dur = torch.gather(dur, 1, mel2ph)
210+
bound_mask = torch.gt(mel2ph[:, 1:], mel2ph[:, :-1])
211+
bound_mask = F.pad(bound_mask, [0, 1], mode='constant', value=True)
212+
stretch_delta = 1 - bound_mask * mel2dur
213+
stretch_delta = F.pad(stretch_delta, [1, -1], mode='constant', value=0)
214+
stretch_denorm = torch.cumsum(stretch_delta, dim=1)
215+
stretch = stretch_denorm / mel2dur
216+
return stretch * (mel2ph > 0)
217+
218+
192219
class PitchPredictor(torch.nn.Module):
193220
def __init__(self, idim, n_layers=5, n_chans=384, odim=2, kernel_size=5,
194221
dropout_rate=0.1, padding='SAME'):

0 commit comments

Comments
 (0)