-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
86 lines (67 loc) · 2.19 KB
/
train.py
File metadata and controls
86 lines (67 loc) · 2.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch.optim as optim
import model
import data
from torch.utils.data.dataloader import DataLoader
import torch.nn as nn
import time
import math
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import torch
dataset = data.NamesDataset("data/names/")
train_loader = DataLoader(
dataset, batch_size=1, shuffle=True
)
model = model.GeneratorModel(
dataset.n_letters, 128, dataset.n_letters, int(dataset.n_letters**0.25), int(len(dataset.languages)**0.25), 1, len(dataset.languages))
criterion = nn.NLLLoss()
learning_rate = 0.0003
optimizer = optim.Adam(model.parameters(), learning_rate)
print_every = 5000
plot_every = 1000
def add_sos_to_input(input):
sos = torch.LongTensor([dataset.letter_to_index('<SOS>')])
res = torch.cat((sos, input), 0)
return res
def add_eos_to_target(target):
eos = torch.LongTensor([dataset.letter_to_index('<EOS>')])
res = torch.cat((target, eos), 0)
return res
def train(input, language, gender, target):
hidden = model.initHidden()
optimizer.zero_grad()
loss = 0
target.unsqueeze_(-1)
for i in range(input.size(0)):
output, hidden = model(input[i], language[0], gender[0], hidden)
loss += criterion(output, target[i])
loss.backward()
optimizer.step()
return output, loss.item() / input.size(0)
def timeSince(since):
now = time.time()
s = now - since
m = math.floor(s / 60)
s -= m * 60
return '%dm %ds' % (m, s)
def run():
all_losses = []
start = time.time()
n_epochs = 3
current_loss = 0
for e in range(n_epochs):
for i, (input, lang, gender) in enumerate(train_loader):
output, loss = train(add_sos_to_input(
input[0]), lang[0], gender[0], add_eos_to_target(input[0]))
current_loss += loss
if i % print_every == 0:
print('%d %d%% (%s) %.4f' % (i, i / dataset.__len__() *
100, timeSince(start), loss))
if i % plot_every == 0:
all_losses.append(current_loss / plot_every)
current_loss = 0
plt.figure()
plt.plot(all_losses)
plt.show()
torch.save(model, "generator.pt")
run()