Skip to content

Commit cc7b56d

Browse files
committed
Support custom lr gamma
1 parent defc997 commit cc7b56d

File tree

3 files changed

+8
-3
lines changed

3 files changed

+8
-3
lines changed

configs/acoustic/nomidi.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ gen_tgt_spk_id: -1
103103
num_sanity_val_steps: 1
104104
lr: 0.0004
105105
decay_steps: 50000
106+
gamma: 0.5
106107
max_tokens: 80000
107108
max_sentences: 48
108109
val_check_interval: 2000

pipelines/no_midi_preparation.ipynb

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,9 +1076,9 @@
10761076
"\n",
10771077
"These two parameters jointly determine the batch size at training time, the former representing maximum number of frames in one batch and the latter limiting the maximum batch size. Larger batches consumes more GPU memory at training time. This value can be adjusted according to your GPU memory. Remember not to set this value too low because the model may not converge with small batches.\n",
10781078
"\n",
1079-
"##### `lr` and `decay_steps`\n",
1079+
"##### `lr`, `decay_steps`, `gamma`\n",
10801080
"\n",
1081-
"These two values refer to the learning rate and number of steps everytime the learning rate decays. If you decreased your batch size, you may consider using a smaller learning rate and more decay steps.\n",
1081+
"The learning rate starts at `lr`, decays with the rate `gamma` at every `decay_steps` during training. If you decreased your batch size, you may consider using a smaller learning rate and more decay steps, or larger gamma.\n",
10821082
"\n",
10831083
"##### `val_check_interval`, `num_ckpt_keep` and `max_updates`\n",
10841084
"\n",
@@ -1137,6 +1137,7 @@
11371137
"\n",
11381138
"lr = 0.0004\n",
11391139
"decay_steps = 50000\n",
1140+
"gamma = 0.5\n",
11401141
"\n",
11411142
"val_check_interval = 2000\n",
11421143
"num_ckpt_keep = 5\n",
@@ -1185,6 +1186,7 @@
11851186
" 'max_sentences': max_sentences,\n",
11861187
" 'lr': lr,\n",
11871188
" 'decay_steps': decay_steps,\n",
1189+
" 'gamma': gamma,\n",
11881190
" 'val_check_interval': val_check_interval,\n",
11891191
" 'num_valid_plots': min(10, len(test_prefixes)),\n",
11901192
" 'num_ckpt_keep': num_ckpt_keep,\n",
@@ -1411,6 +1413,7 @@
14111413
"\n",
14121414
"lr = 0.0004\n",
14131415
"decay_steps = 50000\n",
1416+
"gamma = 0.5\n",
14141417
"\n",
14151418
"val_check_interval = 2000\n",
14161419
"num_ckpt_keep = 5\n",
@@ -1485,6 +1488,7 @@
14851488
" 'max_sentences': max_sentences,\n",
14861489
" 'lr': lr,\n",
14871490
" 'decay_steps': decay_steps,\n",
1491+
" 'gamma': gamma\n",
14881492
" 'val_check_interval': val_check_interval,\n",
14891493
" 'num_valid_plots': min(20, len(test_prefixes)),\n",
14901494
" 'num_ckpt_keep': num_ckpt_keep,\n",

src/task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def validation_step(self, sample, batch_idx):
8383
return outputs
8484

8585
def build_scheduler(self, optimizer):
86-
return torch.optim.lr_scheduler.StepLR(optimizer, hparams['decay_steps'], gamma=0.5)
86+
return torch.optim.lr_scheduler.StepLR(optimizer, hparams['decay_steps'], gamma=hparams.get('gamma', 0.5))
8787

8888
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx):
8989
if optimizer is None:

0 commit comments

Comments
 (0)