|
| 1 | +#!/usr/bin/env python |
| 2 | +# -*- coding:utf-8 -*- |
| 3 | +import argparse |
| 4 | +import os |
| 5 | +from datetime import timedelta |
| 6 | + |
| 7 | +import h5py |
| 8 | +from flask import Flask, render_template, request |
| 9 | +from experiments.icdr_trainer import ICLPTrainer |
| 10 | +from model.cdr import CDRModel |
| 11 | +from model.icdr import ICDRModel |
| 12 | +from utils.constant_pool import * |
| 13 | +from utils.common_utils import get_principle_components, get_config |
| 14 | +from utils.link_utils import LinkInfo |
| 15 | +import numpy as np |
| 16 | + |
| 17 | + |
| 18 | +app = Flask(__name__) |
| 19 | +experimenter: ICLPTrainer |
| 20 | +app.config['SEND_FILE_MAX_AGE_DEFAULT'] = timedelta(seconds=1) |
| 21 | + |
| 22 | + |
| 23 | +def wrap_results(embeddings, principle_comps=None, attr_names=None): |
| 24 | + ret_dict = {} |
| 25 | + ret_dict["embeddings"] = embeddings.tolist() |
| 26 | + ret_dict["label"] = experimenter.get_label() |
| 27 | + if principle_comps is not None: |
| 28 | + ret_dict["low_data"] = principle_comps.tolist() |
| 29 | + ret_dict["attrs"] = attr_names |
| 30 | + return ret_dict |
| 31 | + |
| 32 | + |
| 33 | +def build_link_info(embeddings, min_dist): |
| 34 | + links = request.form.get("links") |
| 35 | + link_spreads = request.form.get("link_spreads") |
| 36 | + finetune_epochs = request.form.get("finetune_epochs", type=int) |
| 37 | + |
| 38 | + links = np.array(eval(links)) |
| 39 | + print(links) |
| 40 | + link_spreads = np.array(eval(link_spreads)) |
| 41 | + |
| 42 | + if links.shape[0] == 0: |
| 43 | + experimenter.link_info = None |
| 44 | + return experimenter.link_info |
| 45 | + |
| 46 | + if experimenter.link_info is None: |
| 47 | + experimenter.link_info = LinkInfo(links, link_spreads, finetune_epochs, embeddings, min_dist) |
| 48 | + else: |
| 49 | + experimenter.link_info.process_cur_links(links, link_spreads, embeddings) |
| 50 | + |
| 51 | + return experimenter.link_info |
| 52 | + |
| 53 | + |
| 54 | +def update_config(): |
| 55 | + global configs |
| 56 | + ds_name = request.form.get("dataset", type=str) |
| 57 | + configs.exp_params.dataset = ds_name |
| 58 | + configs.exp_params.n_neighbors = request.form.get("n_neighbors", type=int) |
| 59 | + configs.training_params.epoch_nums = request.form.get("epoch_nums", type=int) |
| 60 | + configs.exp_params.input_dims = request.form.get("input_dims", type=int) |
| 61 | + configs.exp_params.split_upper = request.form.get("split_upper", type=float) |
| 62 | + configs.exp_params.batch_size = int(request.form.get("n_samples", type=int) / 10) |
| 63 | + |
| 64 | + |
| 65 | +def load_experiment(cfg): |
| 66 | + method_name = CDR_METHOD if cfg.exp_params.gradient_redefine else NX_CDR_METHOD |
| 67 | + result_save_dir = ConfigInfo.RESULT_SAVE_DIR.format(method_name, cfg.exp_params.n_neighbors) |
| 68 | + # 创建CLP模型 |
| 69 | + clr_model = ICDRModel(cfg, device=device) |
| 70 | + global experimenter |
| 71 | + experimenter = ICLPTrainer(clr_model, cfg.exp_params.dataset, cfg, result_save_dir, None, device=device) |
| 72 | + |
| 73 | + |
| 74 | +@app.route("/") |
| 75 | +def index(): |
| 76 | + return render_template("index.html") |
| 77 | + |
| 78 | + |
| 79 | +@app.route("/load_dataset_list") |
| 80 | +def load_dataset_list(): |
| 81 | + data = [] |
| 82 | + for item in ConfigInfo.AVAILABLE_DATASETS: |
| 83 | + data_obj = {} |
| 84 | + for i, k in enumerate(ConfigInfo.DATASETS_META): |
| 85 | + data_obj[k] = item[i] |
| 86 | + data.append(data_obj) |
| 87 | + |
| 88 | + return {"data": data} |
| 89 | + |
| 90 | + |
| 91 | +@app.route("/train_for_vis", methods=["POST"]) |
| 92 | +def train_for_vis(): |
| 93 | + update_config() |
| 94 | + load_experiment(configs) |
| 95 | + |
| 96 | + embeddings = experimenter.train_for_visualize() |
| 97 | + principle_comps, attr_names = get_principle_components(experimenter.dataset.data, attr_names=None) |
| 98 | + ret_dict = wrap_results(embeddings, principle_comps, attr_names) |
| 99 | + return ret_dict |
| 100 | + |
| 101 | + |
| 102 | +@app.route("/constraint_resume", methods=["POST"]) |
| 103 | +def constraint_resume(): |
| 104 | + update_config() |
| 105 | + link_info = build_link_info(experimenter.pre_embeddings, experimenter.configs.exp_params.min_dist) |
| 106 | + ft_epoch = request.form.get("finetune_epochs", type=int) |
| 107 | + |
| 108 | + ml_strength = request.form.get("ml_strength", type=float) |
| 109 | + cl_strength = request.form.get("cl_strength", type=float) |
| 110 | + experimenter.update_link_stat(link_info, is_finetune=True, finetune_epoch=ft_epoch) |
| 111 | + |
| 112 | + if link_info is not None: |
| 113 | + experimenter.model.link_stat_update(ft_epoch, experimenter.steady_epoch, ml_strength, cl_strength) |
| 114 | + |
| 115 | + embeddings = experimenter.resume_train(ft_epoch) |
| 116 | + return wrap_results(embeddings) |
| 117 | + |
| 118 | + |
| 119 | +def parse_args(): |
| 120 | + parser = argparse.ArgumentParser() |
| 121 | + parser.add_argument("--configs", type=str, default="configs/ICDR.yaml", help="configuration file path") |
| 122 | + parser.add_argument("--device", type=str, default="cpu") |
| 123 | + return parser.parse_args() |
| 124 | + |
| 125 | + |
| 126 | +def load_available_data(): |
| 127 | + for item in os.listdir(ConfigInfo.DATASET_CACHE_DIR): |
| 128 | + ds = item.split(".")[0] |
| 129 | + n_samples, dims = np.array(h5py.File(os.path.join(ConfigInfo.DATASET_CACHE_DIR, item), "r")['x']).shape |
| 130 | + ds_type = "image" if os.path.exists(os.path.join(ConfigInfo.IMAGE_DIR, ds)) else "tabular" |
| 131 | + ConfigInfo.AVAILABLE_DATASETS.append([ds, n_samples, dims, ds_type]) |
| 132 | + |
| 133 | + |
| 134 | +if __name__ == '__main__': |
| 135 | + app.jinja_env.variable_start_string = '[[' |
| 136 | + app.jinja_env.variable_end_string = ']]' |
| 137 | + |
| 138 | + args = parse_args() |
| 139 | + device = args.device |
| 140 | + config_path = args.configs |
| 141 | + configs = get_config() |
| 142 | + configs.merge_from_file(config_path) |
| 143 | + load_available_data() |
| 144 | + load_experiment(configs) |
| 145 | + app.run(debug=False) |
0 commit comments