-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
74 lines (52 loc) · 2.19 KB
/
utils.py
File metadata and controls
74 lines (52 loc) · 2.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import json
import numpy as np
import matplotlib.pyplot as plt
from types import SimpleNamespace
from IPython import display
import torch
import torchvision.transforms as T
import torchvision as tv
from torch.distributions.bernoulli import Bernoulli
augs = SimpleNamespace()
augs.normalize = T.Normalize(mean=[.485, .456, .406], std=[.229, .224, .225])
augs.normalize_invert = T.Normalize(mean=[-.485 / .229, -.456 / .224, -.406 / .225], std=[1 / .229, 1 / .224, 1 / .225])
here = os.path.abspath(os.path.dirname(__file__))
with open(os.path.join(here, 'imagenet_class_index.json'), 'r') as f:
imagenet_json = json.load(f)
imagenet_label = [imagenet_json[str(k)][1] for k in range(len(imagenet_json))]
def get_image_from_input_tensor(inp_image, ix=0):
return augs.normalize_invert(inp_image)[ix].permute(1, 2, 0).detach().cpu().numpy()
def get_input_tensor_from_image(image):
assert len(image.shape) == 3, 'Image should have 3 axis'
return augs.normalize(torch.Tensor(image).permute(2, 0, 1)[None])
def get_tensor_deciles(x, n_round=0, intervals=torch.linspace(0, 1, 11)):
torch_quants = torch.quantile(x, q=intervals.to(x.device))
return torch_quants.detach().cpu().numpy().round(n_round)
def normalize_minmax(x):
return (x - x.min()) / (x.max() - x.min())
def normalize_minmax_gentle(x, upper=0.98, lower=0.02):
x = (x - x.quantile(upper)) / (x.quantile(upper) - x.quantile(lower))
return x.clip(min=0, max=1)
def plot_grid(grid, use_display=False, is_grid=False, nrows=1, ncols=None, figsize=None, subtitles=None):
if ncols is None:
if isinstance(grid, list):
ncols = len(grid)
else:
ncols = grid.shape[0]
if isinstance(grid, np.ndarray):
grid = torch.Tensor(grid)
if not is_grid:
grid = tv.utils.make_grid(grid, nrow=ncols)
if figsize is None:
plt.figure(figsize=(4 * ncols, 5 * nrows))
else:
plt.figure(figsize=figsize)
plt.imshow(grid.cpu().permute(1, 2, 0))
plt.axis('off')
if subtitles is not None:
plt.title(subtitles)
if use_display:
display.display(plt.gcf())
def decide_randomly(p, thrs=.5):
return Bernoulli(p).sample((1,)).item() > thrs