Skip to content

Commit f65c3f2

Browse files
committed
first commit
0 parents  commit f65c3f2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+16826
-0
lines changed

README.md

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
2+
# CDR - Interactive Visual Cluster Analysis by Contrastive Dimensionality Reduction
3+
4+
![teaser](teaser.png)
5+
6+
## Environment setup
7+
8+
This project was based on `python 3.6 and pytorch 1.6.0`. See `requirements.txt` for all prerequisites, and you can also install them using the following command.
9+
10+
```bash
11+
pip install -r requirements.txt
12+
```
13+
14+
## Datasets
15+
16+
| | Size | Dimensionality | Clusters | Type | Link |
17+
| :-----------: | :---: | :------------: | :------: | :-----: | :----------------------------------------------------------: |
18+
| Animals | 10000 | 512 | 10 | image | [Kaggle](https://www.kaggle.com/datasets/alessiocorrado99/animals10) |
19+
| Anuran calls | 7195 | 22 | 8 | tabular | [UCI](https://archive.ics.uci.edu/ml/datasets/Anuran+Calls+%28MFCCs%29) |
20+
| Banknote | 1097 | 4 | 2 | text | [UCI](https://archive.ics.uci.edu/ml/datasets/banknote+authentication) |
21+
| Cifar10 | 10000 | 512 | 10 | image | [Alex Krizhevsky](https://www.cs.toronto.edu/~kriz/cifar.html) |
22+
| Cnae9 | 864 | 856 | 9 | text | [UCI](https://archive.ics.uci.edu/ml/datasets/cnae-9) |
23+
| Cats-vs-Dogs | 10000 | 512 | 2 | image | [Kaggle](https://www.kaggle.com/datasets/shaunthesheep/microsoft-catsvsdogs-dataset) |
24+
| Fish | 9000 | 512 | 9 | image | [Kaggle](https://www.kaggle.com/datasets/crowww/a-large-scale-fish-dataset) |
25+
| Food | 3585 | 512 | 11 | image | [Kaggle](https://www.kaggle.com/datasets/anshulmehtakaggl/themassiveindianfooddataset) |
26+
| Har | 8240 | 561 | 6 | tabular | [UCI](https://archive.ics.uci.edu/ml/datasets/human+activity+recognition+using+smartphones) |
27+
| Isolet | 1920 | 617 | 8 | text | [UCI](https://archive.ics.uci.edu/ml/datasets/isolet) |
28+
| ML binary | 1000 | 10 | 2 | tabular | [Kaggle](https://www.kaggle.com/datasets/rhythmcam/ml-binary-classification-study-data) |
29+
| MNIST | 10000 | 784 | 10 | image | [Yann LeCun](http://yann.lecun.com/exdb/mnist/) |
30+
| Pendigits | 8794 | 16 | 10 | tabular | [UCI](https://archive.ics.uci.edu/ml/datasets/pen-based+recognition+of+handwritten+digits) |
31+
| Retina | 10000 | 50 | 12 | tabular | [Paper](https://www.cell.com/fulltext/S0092-8674(15)00549-8) |
32+
| Satimage | 5148 | 36 | 6 | image | [UCI](https://archive.ics.uci.edu/ml/datasets/Statlog+(Landsat+Satellite)) |
33+
| Stanford Dogs | 1384 | 512 | 7 | image | [Stanford University](http://vision.stanford.edu/aditya86/ImageNetDogs/) |
34+
| Texture | 4400 | 40 | 11 | text | [KEEL](https://sci2s.ugr.es/keel/dataset.php?cod=72) |
35+
| USPS | 7440 | 256 | 10 | image | [Kaggle](https://www.kaggle.com/bistaumanga/usps-dataset) |
36+
| Weathers | 900 | 512 | 4 | image | [Kaggle](https://www.kaggle.com/datasets/vijaygiitk/multiclass-weather-dataset) |
37+
| WiFi | 1600 | 7 | 4 | tabular | [UCI](https://archive.ics.uci.edu/ml/datasets/Wireless+Indoor+Localization) |
38+
39+
For image dataset such as Animals, Cifar10, Cats-vs-Dogs, Fish, Food, Stanford Dogs and Weathers, we use [SimCLR](https://github.com/sthalles/SimCLR) to get their 512 dimensional representations.
40+
41+
All the datasets are supported with **H5 format** (e.g. usps.h5), and we need all the dataset to be stored at **`data/H5 Data`.** For image data sets, place all images as `0.jpg,1.jpg,...,n-1.jpg` format and put it in the `static/images/(dataset name)`(e.g. static/images/usps) directory.
42+
43+
## Pre-trained model weights
44+
45+
The pre-training model weights on all the above data sets can be found in [Google Drive](https://drive.google.com/drive/folders/19WYgUcOI6cOYSUPK_w1eICSr0ceRK9Zb?usp=sharing).
46+
47+
## Training
48+
49+
To train the model on USPS with a single GPU, check the configuration file `configs/CDR.yaml`, and try the following command:
50+
51+
```bash
52+
python train.py --configs configs/CDR.yaml
53+
```
54+
55+
## Config File
56+
57+
The configuration files can be found under the folder `./configs`, and we provide two config files with the format `.yaml`. We give the guidance of several key parameters in this paper below.
58+
59+
- **n_neighbors(K):** It determines **the granularity of the local structure** to be maintained in low-dimensional space. A too small value will cause one cluster in the high-dimensional space be projected into two low-dimensional clusters, while too large value will aggravate the problem of clustering overlap. The default setting is **K = 15**.
60+
- **batch_size(B):** It determines the number of negative samples. A larger value is better, but it also depends on the data size. We recommend to use **`B = n/10`**, where `n` is the number of instances.
61+
- **temperature(t):** It determines the ability of the model upon neighborhood preservation. The smaller the value is, the more strict the model is to maintain the neighborhood, but it also keeps more error neighbors. The default setting is **t = 0.15**.
62+
- **separate_upper(μ):** It determines the intensity of cluster separation. The larger the value is, the higher the cluster separation degree is. The default setting is **μ = 0.11**.
63+
64+
## Load pre-trained model for visualization
65+
66+
To use our pre-trained model, try the following command:
67+
68+
```bash
69+
# python vis.py --configs 'configuration file path' --ckpt 'model weights path'
70+
71+
# Example on USPS dataset
72+
python vis.py --configs configs/CDR.yaml --ckpt_path model_weights/usps.pth.tar
73+
```
74+
75+
## Prototype interface
76+
77+
Using our prototype interface for interactive visual clustering analysis, try the following command.
78+
79+
```bash
80+
python app.py --config configs/ICDR.yaml
81+
```
82+
83+
After that, the prototype interface can be found in [http://127.0.0.1:5000](http://127.0.0.1:5000) .
84+
85+
86+
87+
![frontend_07](prototype.png)
88+
89+
[comment]: <> "## Cite"
90+

app.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
#!/usr/bin/env python
2+
# -*- coding:utf-8 -*-
3+
import argparse
4+
import os
5+
from datetime import timedelta
6+
7+
import h5py
8+
from flask import Flask, render_template, request
9+
from experiments.icdr_trainer import ICLPTrainer
10+
from model.cdr import CDRModel
11+
from model.icdr import ICDRModel
12+
from utils.constant_pool import *
13+
from utils.common_utils import get_principle_components, get_config
14+
from utils.link_utils import LinkInfo
15+
import numpy as np
16+
17+
18+
app = Flask(__name__)
19+
experimenter: ICLPTrainer
20+
app.config['SEND_FILE_MAX_AGE_DEFAULT'] = timedelta(seconds=1)
21+
22+
23+
def wrap_results(embeddings, principle_comps=None, attr_names=None):
24+
ret_dict = {}
25+
ret_dict["embeddings"] = embeddings.tolist()
26+
ret_dict["label"] = experimenter.get_label()
27+
if principle_comps is not None:
28+
ret_dict["low_data"] = principle_comps.tolist()
29+
ret_dict["attrs"] = attr_names
30+
return ret_dict
31+
32+
33+
def build_link_info(embeddings, min_dist):
34+
links = request.form.get("links")
35+
link_spreads = request.form.get("link_spreads")
36+
finetune_epochs = request.form.get("finetune_epochs", type=int)
37+
38+
links = np.array(eval(links))
39+
print(links)
40+
link_spreads = np.array(eval(link_spreads))
41+
42+
if links.shape[0] == 0:
43+
experimenter.link_info = None
44+
return experimenter.link_info
45+
46+
if experimenter.link_info is None:
47+
experimenter.link_info = LinkInfo(links, link_spreads, finetune_epochs, embeddings, min_dist)
48+
else:
49+
experimenter.link_info.process_cur_links(links, link_spreads, embeddings)
50+
51+
return experimenter.link_info
52+
53+
54+
def update_config():
55+
global configs
56+
ds_name = request.form.get("dataset", type=str)
57+
configs.exp_params.dataset = ds_name
58+
configs.exp_params.n_neighbors = request.form.get("n_neighbors", type=int)
59+
configs.training_params.epoch_nums = request.form.get("epoch_nums", type=int)
60+
configs.exp_params.input_dims = request.form.get("input_dims", type=int)
61+
configs.exp_params.split_upper = request.form.get("split_upper", type=float)
62+
configs.exp_params.batch_size = int(request.form.get("n_samples", type=int) / 10)
63+
64+
65+
def load_experiment(cfg):
66+
method_name = CDR_METHOD if cfg.exp_params.gradient_redefine else NX_CDR_METHOD
67+
result_save_dir = ConfigInfo.RESULT_SAVE_DIR.format(method_name, cfg.exp_params.n_neighbors)
68+
# 创建CLP模型
69+
clr_model = ICDRModel(cfg, device=device)
70+
global experimenter
71+
experimenter = ICLPTrainer(clr_model, cfg.exp_params.dataset, cfg, result_save_dir, None, device=device)
72+
73+
74+
@app.route("/")
75+
def index():
76+
return render_template("index.html")
77+
78+
79+
@app.route("/load_dataset_list")
80+
def load_dataset_list():
81+
data = []
82+
for item in ConfigInfo.AVAILABLE_DATASETS:
83+
data_obj = {}
84+
for i, k in enumerate(ConfigInfo.DATASETS_META):
85+
data_obj[k] = item[i]
86+
data.append(data_obj)
87+
88+
return {"data": data}
89+
90+
91+
@app.route("/train_for_vis", methods=["POST"])
92+
def train_for_vis():
93+
update_config()
94+
load_experiment(configs)
95+
96+
embeddings = experimenter.train_for_visualize()
97+
principle_comps, attr_names = get_principle_components(experimenter.dataset.data, attr_names=None)
98+
ret_dict = wrap_results(embeddings, principle_comps, attr_names)
99+
return ret_dict
100+
101+
102+
@app.route("/constraint_resume", methods=["POST"])
103+
def constraint_resume():
104+
update_config()
105+
link_info = build_link_info(experimenter.pre_embeddings, experimenter.configs.exp_params.min_dist)
106+
ft_epoch = request.form.get("finetune_epochs", type=int)
107+
108+
ml_strength = request.form.get("ml_strength", type=float)
109+
cl_strength = request.form.get("cl_strength", type=float)
110+
experimenter.update_link_stat(link_info, is_finetune=True, finetune_epoch=ft_epoch)
111+
112+
if link_info is not None:
113+
experimenter.model.link_stat_update(ft_epoch, experimenter.steady_epoch, ml_strength, cl_strength)
114+
115+
embeddings = experimenter.resume_train(ft_epoch)
116+
return wrap_results(embeddings)
117+
118+
119+
def parse_args():
120+
parser = argparse.ArgumentParser()
121+
parser.add_argument("--configs", type=str, default="configs/ICDR.yaml", help="configuration file path")
122+
parser.add_argument("--device", type=str, default="cpu")
123+
return parser.parse_args()
124+
125+
126+
def load_available_data():
127+
for item in os.listdir(ConfigInfo.DATASET_CACHE_DIR):
128+
ds = item.split(".")[0]
129+
n_samples, dims = np.array(h5py.File(os.path.join(ConfigInfo.DATASET_CACHE_DIR, item), "r")['x']).shape
130+
ds_type = "image" if os.path.exists(os.path.join(ConfigInfo.IMAGE_DIR, ds)) else "tabular"
131+
ConfigInfo.AVAILABLE_DATASETS.append([ds, n_samples, dims, ds_type])
132+
133+
134+
if __name__ == '__main__':
135+
app.jinja_env.variable_start_string = '[['
136+
app.jinja_env.variable_end_string = ']]'
137+
138+
args = parse_args()
139+
device = args.device
140+
config_path = args.configs
141+
configs = get_config()
142+
configs.merge_from_file(config_path)
143+
load_available_data()
144+
load_experiment(configs)
145+
app.run(debug=False)

configs/CDR.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
exp_params:
2+
dataset: "usps"
3+
input_dims: 256 # (28, 28, 1)
4+
LR: 0.001
5+
batch_size: 512
6+
n_neighbors: 15
7+
optimizer: "adam" # adam or sgd
8+
scheduler: "multi_step" # cosine or multi_step or on_plateau
9+
temperature: 0.15
10+
gradient_redefine: True
11+
separate_upper: 0.1
12+
separation_begin_ratio: 0.25
13+
steady_begin_ratio: 0.875
14+
15+
training_params:
16+
epoch_nums: 1000
17+
epoch_print_inter_ratio: 0.1
18+
val_inter_ratio: 1
19+
ckp_inter_ratio: 1

configs/ICDR.yaml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
exp_params:
2+
dataset: "wifi"
3+
input_dims: 7 # (28, 28, 1)
4+
LR: 0.001
5+
batch_size: 128
6+
n_neighbors: 15
7+
optimizer: "adam" # adam or sgd
8+
scheduler: "multi_step" # cosine or multi_step or on_plateau
9+
temperature: 0.15
10+
min_dist: 0.1
11+
separate_upper: 0.11
12+
gradient_redefine: True
13+
separation_begin_ratio: 0.25
14+
steady_begin_ratio: 0.875
15+
16+
training_params:
17+
epoch_nums: 1000
18+
epoch_print_inter_ratio: 0.1
19+
val_inter_ratio: 0.5
20+
ckp_inter_ratio: 1

data/H5 Data/texture.h5

1.72 MB
Binary file not shown.

data/H5 Data/usps.h5

14 MB
Binary file not shown.

data/H5 Data/wifi.h5

121 KB
Binary file not shown.

0 commit comments

Comments
 (0)