1+ import copy
12import shutil
23import os
34os .environ ["OMP_NUM_THREADS" ] = "1"
@@ -218,38 +219,43 @@ def arrange_data_augmentation(self, prefix):
218219 Code for all types of data augmentation should be added here.
219220 """
220221 aug_map = {}
222+ aug_list = []
221223 all_item_names = [item_name for item_name , _ in self .meta_data_iterator (prefix )]
224+ total_scale = 0
222225 if self .augmentation_args .get ('random_pitch_shifting' ) is not None :
223- from augmentation .pitch_shift import PitchShiftAugmentation
226+ from augmentation .spec_stretch import SpectrogramStretchAugmentation
224227 aug_args = self .augmentation_args ['random_pitch_shifting' ]
225228 key_shift_min , key_shift_max = aug_args ['range' ]
226229 assert hparams .get ('use_key_shift_embed' , False ), \
227230 'Random pitch shifting augmentation requires use_key_shift_embed == True.'
228231 assert key_shift_min < 0 < key_shift_max , \
229232 'Random pitch shifting augmentation must have a range where min < 0 < max.'
230233
231- aug_ins = PitchShiftAugmentation (self .raw_data_dirs , aug_args )
234+ aug_ins = SpectrogramStretchAugmentation (self .raw_data_dirs , aug_args )
232235 scale = aug_args ['scale' ]
233- aug_item_names = all_item_names * int (scale ) \
234- + random .sample (all_item_names , int (len (all_item_names ) * (scale - int (scale ))))
236+ aug_item_names = random .choices (all_item_names , k = int (scale * len (all_item_names )))
235237
236238 for aug_item_name in aug_item_names :
237- rand = random .random () * 2 - 1
239+ rand = random .uniform ( - 1 , 1 )
238240 if rand < 0 :
239241 key_shift = key_shift_min * abs (rand )
240242 else :
241243 key_shift = key_shift_max * rand
242244 aug_task = {
245+ 'name' : aug_item_name ,
243246 'func' : aug_ins .process_item ,
244247 'kwargs' : {'key_shift' : key_shift }
245248 }
246249 if aug_item_name in aug_map :
247250 aug_map [aug_item_name ].append (aug_task )
248251 else :
249252 aug_map [aug_item_name ] = [aug_task ]
253+ aug_list .append (aug_task )
254+
255+ total_scale += scale
250256
251257 if self .augmentation_args .get ('fixed_pitch_shifting' ) is not None :
252- from augmentation .pitch_shift import PitchShiftAugmentation
258+ from augmentation .spec_stretch import SpectrogramStretchAugmentation
253259 aug_args = self .augmentation_args ['fixed_pitch_shifting' ]
254260 targets = aug_args ['targets' ]
255261 scale = aug_args ['scale' ]
@@ -262,19 +268,74 @@ def arrange_data_augmentation(self, prefix):
262268 'Fixed pitch shifting augmentation requires num_spk >= (1 + len(targets)) * len(speakers).'
263269 assert scale < 1 , 'Fixed pitch shifting augmentation requires scale < 1.'
264270
265- aug_ins = PitchShiftAugmentation (self .raw_data_dirs , aug_args )
271+ aug_ins = SpectrogramStretchAugmentation (self .raw_data_dirs , aug_args )
266272 for i , target in enumerate (targets ):
267- aug_item_names = random .sample (all_item_names , int (len (all_item_names ) * scale ))
273+ aug_item_names = random .choices (all_item_names , k = int (scale * len (all_item_names )))
268274 for aug_item_name in aug_item_names :
269275 replace_spk_id = int (aug_item_name .split (':' , maxsplit = 1 )[0 ]) + (i + 1 ) * len (self .spk_map )
270276 aug_task = {
277+ 'name' : aug_item_name ,
271278 'func' : aug_ins .process_item ,
272279 'kwargs' : {'key_shift' : target , 'replace_spk_id' : replace_spk_id }
273280 }
274281 if aug_item_name in aug_map :
275282 aug_map [aug_item_name ].append (aug_task )
276283 else :
277284 aug_map [aug_item_name ] = [aug_task ]
285+ aug_list .append (aug_task )
286+
287+ total_scale += scale * len (targets )
288+
289+ if self .augmentation_args .get ('random_time_stretching' ) is not None :
290+ from augmentation .spec_stretch import SpectrogramStretchAugmentation
291+ aug_args = self .augmentation_args ['random_time_stretching' ]
292+ speed_min , speed_max = aug_args ['range' ]
293+ domain = aug_args ['domain' ]
294+ assert hparams .get ('use_speed_embed' , False ), \
295+ 'Random time stretching augmentation requires use_speed_embed == True.'
296+ assert 0 < speed_min < 1 < speed_max , \
297+ 'Random time stretching augmentation must have a range where 0 < min < 1 < max.'
298+ assert domain in ['log' , 'linear' ], 'domain must be \' log\' or \' linear\' .'
299+
300+ aug_ins = SpectrogramStretchAugmentation (self .raw_data_dirs , aug_args )
301+ scale = aug_args ['scale' ]
302+ k_from_raw = int (scale / (1 + total_scale ) * len (all_item_names ))
303+ k_from_aug = int (total_scale * scale / (1 + total_scale ) * len (all_item_names ))
304+ k_mutate = int (total_scale * scale / (1 + scale ) * len (all_item_names ))
305+ aug_types = [0 ] * k_from_raw + [1 ] * k_from_aug + [2 ] * k_mutate
306+ aug_items = random .choices (all_item_names , k = k_from_raw ) + random .choices (aug_list , k = k_from_aug + k_mutate )
307+
308+ for aug_type , aug_item in zip (aug_types , aug_items ):
309+ if domain == 'log' :
310+ # Uniform distribution in log domain
311+ speed = speed_min * (speed_max / speed_min ) ** random .random ()
312+ else :
313+ # Uniform distribution in linear domain
314+ rand = random .uniform (- 1 , 1 )
315+ speed = 1 + (speed_max - 1 ) * rand if rand >= 0 else 1 + (1 - speed_min ) * rand
316+ if aug_type == 0 :
317+ aug_task = {
318+ 'name' : aug_item ,
319+ 'func' : aug_ins .process_item ,
320+ 'kwargs' : {'speed' : speed }
321+ }
322+ if aug_item in aug_map :
323+ aug_map [aug_item ].append (aug_task )
324+ else :
325+ aug_map [aug_item ] = [aug_task ]
326+ aug_list .append (aug_task )
327+ elif aug_type == 1 :
328+ aug_task = copy .deepcopy (aug_item )
329+ aug_item ['kwargs' ]['speed' ] = speed
330+ if aug_item ['name' ] in aug_map :
331+ aug_map [aug_item ['name' ]].append (aug_task )
332+ else :
333+ aug_map [aug_item ['name' ]] = [aug_task ]
334+ aug_list .append (aug_task )
335+ elif aug_type == 2 :
336+ aug_item ['kwargs' ]['speed' ] = speed
337+
338+ total_scale += scale
278339
279340 return aug_map
280341
0 commit comments