-
Notifications
You must be signed in to change notification settings - Fork 427
Expand file tree
/
Copy pathdata_pipe.py
More file actions
123 lines (113 loc) · 4.58 KB
/
data_pipe.py
File metadata and controls
123 lines (113 loc) · 4.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from pathlib import Path
from torch.utils.data import Dataset, ConcatDataset, DataLoader
from torchvision import transforms as trans
from torchvision.datasets import ImageFolder
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import numpy as np
import cv2
import bcolz
import pickle
import torch
import mxnet as mx
from tqdm import tqdm
def de_preprocess(tensor):
return tensor*0.5 + 0.5
def get_train_dataset(imgs_folder):
train_transform = trans.Compose([
trans.RandomHorizontalFlip(),
trans.ToTensor(),
trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
ds = ImageFolder(imgs_folder, train_transform)
class_num = ds[-1][1] + 1
return ds, class_num
def get_train_loader(conf):
if conf.data_mode in ['ms1m', 'concat']:
ms1m_ds, ms1m_class_num = get_train_dataset(conf.ms1m_folder/'imgs')
print('ms1m loader generated')
if conf.data_mode in ['vgg', 'concat']:
vgg_ds, vgg_class_num = get_train_dataset(conf.vgg_folder/'imgs')
print('vgg loader generated')
if conf.data_mode == 'vgg':
ds = vgg_ds
class_num = vgg_class_num
elif conf.data_mode == 'ms1m':
ds = ms1m_ds
class_num = ms1m_class_num
elif conf.data_mode == 'concat':
for i,(url,label) in enumerate(vgg_ds.imgs):
vgg_ds.imgs[i] = (url, label + ms1m_class_num)
ds = ConcatDataset([ms1m_ds,vgg_ds])
class_num = vgg_class_num + ms1m_class_num
elif conf.data_mode == 'emore':
ds, class_num = get_train_dataset(conf.emore_folder/'imgs')
loader = DataLoader(ds, batch_size=conf.batch_size, shuffle=True, pin_memory=conf.pin_memory, num_workers=conf.num_workers)
return loader, class_num
def load_bin(path, rootdir, transform, image_size=[112,112]):
if not rootdir.exists():
rootdir.mkdir()
bins, issame_list = pickle.load(open(path, 'rb'), encoding='bytes')
data = bcolz.fill([len(bins), 3, image_size[0], image_size[1]], dtype=np.float32, rootdir=rootdir, mode='w')
for i in range(len(bins)):
_bin = bins[i]
img = mx.image.imdecode(_bin).asnumpy()
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
img = Image.fromarray(img.astype(np.uint8))
data[i, ...] = transform(img)
i += 1
if i % 1000 == 0:
print('loading bin', i)
print(data.shape)
np.save(str(rootdir)+'_list', np.array(issame_list))
return data, issame_list
def get_val_pair(path, name):
carray = bcolz.carray(rootdir = path/name, mode='r')
issame = np.load(path/'{}_list.npy'.format(name))
return carray, issame
def get_val_data(data_path):
agedb_30, agedb_30_issame = get_val_pair(data_path, 'agedb_30')
cfp_fp, cfp_fp_issame = get_val_pair(data_path, 'cfp_fp')
lfw, lfw_issame = get_val_pair(data_path, 'lfw')
return agedb_30, cfp_fp, lfw, agedb_30_issame, cfp_fp_issame, lfw_issame
def load_mx_rec(rec_path):
save_path = rec_path/'imgs'
if not save_path.exists():
save_path.mkdir()
imgrec = mx.recordio.MXIndexedRecordIO(str(rec_path/'train.idx'), str(rec_path/'train.rec'), 'r')
img_info = imgrec.read_idx(0)
header,_ = mx.recordio.unpack(img_info)
max_idx = int(header.label[0])
for idx in tqdm(range(1,max_idx)):
img_info = imgrec.read_idx(idx)
header, img = mx.recordio.unpack_img(img_info)
# label = int(header.label)
label = int(header.label[0])
img = Image.fromarray(img)
label_path = save_path/str(label)
if not label_path.exists():
label_path.mkdir()
img.save(label_path/'{}.jpg'.format(idx), quality=95)
# class train_dataset(Dataset):
# def __init__(self, imgs_bcolz, label_bcolz, h_flip=True):
# self.imgs = bcolz.carray(rootdir = imgs_bcolz)
# self.labels = bcolz.carray(rootdir = label_bcolz)
# self.h_flip = h_flip
# self.length = len(self.imgs) - 1
# if h_flip:
# self.transform = trans.Compose([
# trans.ToPILImage(),
# trans.RandomHorizontalFlip(),
# trans.ToTensor(),
# trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
# ])
# self.class_num = self.labels[-1] + 1
# def __len__(self):
# return self.length
# def __getitem__(self, index):
# img = torch.tensor(self.imgs[index+1], dtype=torch.float)
# label = torch.tensor(self.labels[index+1], dtype=torch.long)
# if self.h_flip:
# img = de_preprocess(img)
# img = self.transform(img)
# return img, label