-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsampler.py
More file actions
87 lines (68 loc) · 2.35 KB
/
sampler.py
File metadata and controls
87 lines (68 loc) · 2.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
from data import NamesDataset
import sys
import random
from pprint import pprint
import pycountry
model = torch.load("generator.pt")
model.eval()
dataset = NamesDataset("data/names/")
max_length = 20
def add_sos_to_input(input):
sos = torch.LongTensor([dataset.letter_to_index('<SOS>')])
res = torch.cat((sos, input), 0)
return res
def add_eos_to_target(target):
eos = torch.LongTensor([dataset.letter_to_index('<EOS>')])
res = torch.cat((target, eos), 0)
return res
def sample(lang, gender, var=1):
with torch.no_grad():
input = dataset.name_to_tensor("")
input = add_sos_to_input(input)
hidden = model.initHidden()
lang_tensor = dataset.lang_to_tensor(lang)
gender = dataset.gender_to_tensor(gender)
output_name = ""
for i in range(max_length):
output, hidden = model(input[0], lang_tensor[0], gender[0], hidden)
if i == 0:
topv, topi = output.topk(30)
topi = topi[0][random.randint(0, 20)]
else:
topv, topi = output.topk(10)
topi = topi[0][random.randint(0, var)]
if topi == dataset.n_letters - 1:
break
else:
letter = dataset.all_letters[topi]
output_name += letter
input = dataset.name_to_tensor(letter)
return output_name
def main():
# lang = sys.argv[1]
# gender = sys.argv[2]
# var = int(sys.argv[3])
for o in dataset.languages:
names_m = set()
while len(names_m) < 3000:
first = sample(o, "M", 1)
last = ""
if len(first.split(" ")) < 2:
last = sample(o, "M", 1)
names_m.add(first.title() + " " + last.title())
filename = "generated/new/" + o + "-M.txt"
with open(filename, 'wt') as out:
print(*names_m, sep="\n", file=out)
names_f = set()
while len(names_f) < 3000:
first = sample(o, "F", 1)
last = ""
if len(first.split(" ")) < 2:
last = sample(o, "F", 1)
names_f.add(first.title() + " " + last.title())
filename = "generated/new/" + o + "-F.txt"
with open(filename, 'wt') as out:
print(*names_f, sep="\n", file=out)
if __name__ == "__main__":
main()