-
Notifications
You must be signed in to change notification settings - Fork 552
TextPredictor #486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
TextPredictor #486
Changes from 19 commits
0a0a4c6
002683f
60a847c
60a9e27
bc7f38d
f9ca56b
fe0ecbb
4a52ac7
30cc834
14e6720
6b75a73
d10945e
06f64b2
c9ff3d4
e7b6f6d
2307b37
bf3203b
10c93b2
53b5f09
ee3cacb
8096a89
fed989b
c40af7d
d0b3b11
30e9f60
d15dd60
6c42839
301eb16
c59a3b2
f04b69e
2f07223
ea515d2
6cc2f9e
4cc2b4e
4fa136d
c5d9914
25c1baf
f9d3b22
c1568b4
1e4201d
9692d4e
1b2cb28
05941bc
984d000
74f27b5
c8848c7
7be2c5c
1c7f7ad
be60fa6
3a29c5b
543b660
4296129
ca30eab
5bd061f
2b150e7
505c894
cd98daf
98ee138
ff8c078
24a5333
2aeb563
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1965,6 +1965,165 @@ class XGBoostLimitDepth_TS(TS_SKLearn): | |
| base_class = XGBoostLimitDepthEstimator | ||
|
|
||
|
|
||
| class AGTextPredictorEstimator(BaseEstimator): | ||
| """ | ||
| The class for tuning AutoGluon TextPredictor | ||
| """ | ||
| def __init__(self, task="binary", **params,): | ||
| from autogluon.text.text_prediction.mx_predictor import MXTextPredictor | ||
|
|
||
| super().__init__(task, **params) | ||
| self.estimator_class = MXTextPredictor | ||
|
|
||
| @classmethod | ||
| def search_space(cls, **params): | ||
| """ | ||
| Add the possible search space configs here, e.g. 'optimization.lr' | ||
| reference: | ||
| https://auto.gluon.ai/stable/tutorials/text_prediction/customization.html#custom-hyperparameter-values | ||
| """ | ||
| search_space_dict = { | ||
| "model.network.agg_net.mid_units": { | ||
| "domain": tune.choice(list(range(32, 129))), | ||
| "init_value": 128, | ||
| }, | ||
| "optimization.lr": { | ||
| "domain": tune.loguniform(lower=1E-5, upper=1E-4), | ||
| "init_value": 1E-4, | ||
| }, | ||
| "optimization.wd": { | ||
| "domain": tune.choice([1E-4, 1E-3, 1E-2]), | ||
| "init_value":1E-4, | ||
| }, | ||
| "optimization.warmup_portion": { | ||
| "domain": tune.choice([0.1, 0.2]), | ||
| "init_value":0.1, | ||
| }, | ||
| } | ||
| return search_space_dict | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There were only 4 hyperparameters and now there are 9. Which one was the search space used in your original experiment for autogluon?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The original four are "model.network.agg_net.mid_units", "optimization.warmup_portion", "optimization.lr", "optimization.wd". |
||
|
|
||
| def _init_fix_args(self, automl_fit_kwargs: dict=None): | ||
|
||
| """ | ||
| Save the customed fix args here | ||
| this includes: | ||
| "output_dir", | ||
| "text_backbone": "electra_base" | ||
| "multimodal_fusion_strategy":"fuse_late", | ||
| """ | ||
| fix_args = {} | ||
| FIX_ARGS_LIST = ["output_dir", "dataset_name", "label_column", "per_device_batch_size", | ||
| "text_backbone", "multimodal_fusion_strategy", "num_train_epochs", "batch_size"] | ||
| for key, value in automl_fit_kwargs["custom_fix_args"].items(): | ||
| assert ( | ||
| key in FIX_ARGS_LIST | ||
| ), "The specified key {} is not in the argument list: output_dir, label_column, dataset_name, text_backbone,\ | ||
| multimodal_fusion_strategy".format(key) | ||
|
|
||
| fix_args[key] = value | ||
|
|
||
| self.fix_args = fix_args | ||
|
|
||
| def _init_hp_config(self, text_backbone: str, multimodal_fusion_strategy: str): | ||
|
||
|
|
||
| """" | ||
| Ref: | ||
| https://auto.gluon.ai/stable/tutorials/text_prediction/customization.html#custom-hyperparameter-values | ||
| """ | ||
| from autogluon.text.text_prediction.legacy_presets import ag_text_presets | ||
|
|
||
| base_key = f'{text_backbone}_{multimodal_fusion_strategy}' | ||
| cfg = ag_text_presets.create(base_key) | ||
| # NOTE: if the search_space() is modified, add new items or delete here too. | ||
| TUNABLE_HP = set(["model.network.agg_net.mid_units", | ||
| "optimization.batch_size", | ||
| "optimization.layerwise_lr_decay", | ||
| "optimization.lr", | ||
| "optimization.nbest", | ||
| "optimization.num_train_epochs", | ||
| "optimization.per_device_batch_size", | ||
| "optimization.wd", | ||
| "optimization.warmup_portion", | ||
| ]) | ||
| search_space = cfg["models"]["MultimodalTextModel"]["search_space"] | ||
| search_space["optimization.per_device_batch_size"] = self.fix_args.get("per_device_batch_size", 4) | ||
| search_space["optimization.num_train_epochs"] = self.fix_args.get("num_train_epochs", 10) | ||
| search_space["optimization.batch_size"] = self.fix_args.get("batch_size", 128) | ||
| for key, value in self.params.items(): | ||
| if key in TUNABLE_HP: | ||
| # NOTE: FLAML uses np.float64 but AG uses float, need to transform | ||
| if isinstance(value, np.float64): | ||
| search_space[key] = value.item() | ||
| else: | ||
| search_space[key] = value | ||
| return cfg | ||
|
|
||
| def _set_seed(self, seed): | ||
| import random | ||
| import mxnet as mx | ||
| import torch as th | ||
| th.manual_seed(seed) | ||
| mx.random.seed(seed) | ||
| np.random.seed(seed) | ||
| random.seed(seed) | ||
|
||
|
|
||
| def fit(self, X_train=None, y_train=None, budget=None, **kwargs): | ||
| self._kwargs = kwargs | ||
| self._init_fix_args(kwargs) | ||
| # the seed set in the bash script for ag experiment is 123 | ||
| seed = self.params.get("seed", 123) | ||
| self._set_seed(seed) | ||
|
|
||
| # get backbone and fusion strategy | ||
| text_backbone = self.fix_args["text_backbone"] | ||
|
||
| multimodal_fusion_strategy = self.fix_args["multimodal_fusion_strategy"] | ||
|
|
||
| # get & set the save dir, get the dataset info | ||
| save_dir = self.fix_args["output_dir"] | ||
| label_column = self.fix_args["label_column"] | ||
| dataset_name = self.fix_args["dataset_name"] | ||
| ag_model_save_dir = os.path.join(save_dir, f"{dataset_name}_ag_text_multimodal_{text_backbone}\ | ||
|
||
| _{multimodal_fusion_strategy}_no_ensemble") | ||
|
|
||
| # set the hyperparameters | ||
| self.hyperparameters = self._init_hp_config(text_backbone, multimodal_fusion_strategy) | ||
| PROBLEM_TYPE_MAPPING = {"binary": "binary", "multi": "multiclass", "regression": "regression"} | ||
| TASK_METRIC_MAPPING = {"multi": "acc", "binary": "roc_auc", "regression": "r2"} | ||
|
|
||
| # train the model | ||
| start_time = time.time() | ||
|
|
||
| self._model = self.estimator_class(path=ag_model_save_dir, | ||
Qiaochu-Song marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| label=label_column, | ||
| problem_type=PROBLEM_TYPE_MAPPING[self._task], | ||
| eval_metric=TASK_METRIC_MAPPING[self._task]) | ||
|
|
||
| train_data = self._kwargs["train_data"] | ||
|
|
||
| self._model.fit(train_data=train_data, | ||
| hyperparameters=self.hyperparameters, | ||
| time_limit=budget, | ||
| seed=seed) | ||
|
|
||
| training_time = time.time() - start_time | ||
| return training_time | ||
|
|
||
| def predict(self, X): | ||
| output = self._model.predict(self._kwargs["valid_data"], as_pandas=False) | ||
| return output | ||
|
|
||
| def predict_proba(self, X, as_multiclass=True): | ||
| # only works for classification tasks | ||
| assert ( | ||
| self._task in CLASSIFICATION | ||
| ), "predict_proba() only for classification tasks." | ||
|
|
||
| output = self._model.predict_proba(self._kwargs["valid_data"], as_pandas=False) | ||
| if not as_multiclass: | ||
| if self._task == "binary": | ||
| output = output[:, 1] | ||
| return output | ||
|
|
||
|
|
||
| class suppress_stdout_stderr(object): | ||
| def __init__(self): | ||
| # Open a pair of null files | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,135 @@ | ||
| from flaml import AutoML | ||
| import pandas as pd | ||
| import requests | ||
| import gc | ||
| import numpy as np | ||
| import os | ||
| import sys | ||
| import platform | ||
| from sklearn.model_selection import train_test_split | ||
| os.environ["AUTOGLUON_TEXT_TRAIN_WITHOUT_GPU"] = "1" | ||
|
|
||
|
|
||
| def default_holdout_frac(num_train_rows, hyperparameter_tune=False): | ||
| """ | ||
| Returns default holdout_frac used in fit(). | ||
| Between row count 5,000 and 25,000 keep 0.1 holdout_frac, as we want to grow validation set to a stable 2500 examples. | ||
| Ref: https://github.com/awslabs/autogluon/blob/master/core/src/autogluon/core/utils/utils.py#L243 | ||
| """ | ||
| if num_train_rows < 5000: | ||
| holdout_frac = max(0.1, min(0.2, 500.0 / num_train_rows)) | ||
| else: | ||
| holdout_frac = max(0.01, min(0.1, 2500.0 / num_train_rows)) | ||
|
|
||
| if hyperparameter_tune: | ||
| holdout_frac = min(0.2, holdout_frac * 2) # to allocate more validation data for HPO to avoid overfitting | ||
|
|
||
| return holdout_frac | ||
|
|
||
|
|
||
| def test_ag_text_predictor(): | ||
| # # DEBUG | ||
| # return | ||
| # # DEBUG | ||
| if sys.version < "3.7": | ||
| # do not test on python3.6 | ||
| return | ||
| elif platform.system() == "Windows": | ||
| # do not test on windows with py3.8 | ||
| return | ||
|
|
||
| seed = 123 | ||
| metric = "roc_auc" | ||
| train_data = { | ||
| "sentence1": [ | ||
| 'Amrozi accused his brother , whom he called " the witness " , of deliberately distorting his evidence .', | ||
| "Yucaipa owned Dominick 's before selling the chain to Safeway in 1998 for $ 2.5 billion .", | ||
| "They had published an advertisement on the Internet on June 10 , offering the cargo for sale , he added .", | ||
| "Around 0335 GMT , Tab shares were up 19 cents , or 4.4 % , at A $ 4.56 , having earlier set a record high of A $ 4.57 .", | ||
| "The stock rose $ 2.11 , or about 11 percent , to close Friday at $ 21.51 on the New York Stock Exchange .", | ||
| "Revenue in the first quarter of the year dropped 15 percent from the same period a year earlier .", | ||
| "The Nasdaq had a weekly gain of 17.27 , or 1.2 percent , closing at 1,520.15 on Friday .", | ||
| "The DVD-CCA then appealed to the state Supreme Court .", | ||
| ], | ||
| "sentence2": [ | ||
| 'Referring to him as only " the witness " , Amrozi accused his brother of deliberately distorting his evidence .', | ||
| "Yucaipa bought Dominick 's in 1995 for $ 693 million and sold it to Safeway for $ 1.8 billion in 1998 .", | ||
| "On June 10 , the ship 's owners had published an advertisement on the Internet , offering the explosives for sale .", | ||
| "Tab shares jumped 20 cents , or 4.6 % , to set a record closing high at A $ 4.57 .", | ||
| "PG & E Corp. shares jumped $ 1.63 or 8 percent to $ 21.03 on the New York Stock Exchange on Friday .", | ||
| "With the scandal hanging over Stewart 's company , revenue the first quarter of the year dropped 15 percent from the same period a year earlier .", | ||
| "The tech-laced Nasdaq Composite .IXIC rallied 30.46 points , or 2.04 percent , to 1,520.15 .", | ||
| "The DVD CCA appealed that decision to the U.S. Supreme Court .", | ||
| ], | ||
| "numerical1": [1, 2, 3, 4, 5, 6, 7, 8], | ||
| "categorical1": ["a", "b", "a", "a", "a", "b", "a", "a"], | ||
| "label": [1, 0, 1, 0, 1, 1, 0, 1], | ||
| "idx": [0, 1, 2, 3, 4, 5, 6, 7], | ||
| } | ||
| train_dataset = pd.DataFrame(train_data) | ||
|
|
||
| test_data = { | ||
| "sentence1": [ | ||
| "That compared with $ 35.18 million , or 24 cents per share , in the year-ago period .", | ||
| "Shares of Genentech , a much larger company with several products on the market , rose more than 2 percent .", | ||
| "Legislation making it harder for consumers to erase their debts in bankruptcy court won overwhelming House approval in March .", | ||
| "The Nasdaq composite index increased 10.73 , or 0.7 percent , to 1,514.77 .", | ||
| ], | ||
| "sentence2": [ | ||
| "Earnings were affected by a non-recurring $ 8 million tax benefit in the year-ago period .", | ||
| "Shares of Xoma fell 16 percent in early trade , while shares of Genentech , a much larger company with several products on the market , were up 2 percent .", | ||
| "Legislation making it harder for consumers to erase their debts in bankruptcy court won speedy , House approval in March and was endorsed by the White House .", | ||
| "The Nasdaq Composite index , full of technology stocks , was lately up around 18 points .", | ||
| ], | ||
| "numerical1": [3, 4, 5, 6], | ||
| "categorical1": ["b", "a", "a", "b"], | ||
| "label": [0, 1, 1, 0], | ||
| "idx": [8, 10, 11, 12], | ||
| } | ||
| test_dataset = pd.DataFrame(test_data) | ||
|
|
||
| # FORCE THE SAME TRAIN-VALID SPLIT IN & OUT THE PREDICTOR | ||
| holdout_frac = default_holdout_frac(len(train_dataset), False) | ||
|
|
||
| _, valid_dataset = train_test_split(train_dataset, | ||
| test_size=holdout_frac, | ||
| random_state=np.random.RandomState(seed)) | ||
|
|
||
| feature_columns = ["sentence1", "sentence2", "numerical1", "categorical1"] | ||
|
|
||
| automl = AutoML() | ||
| automl_settings = { | ||
| "gpu_per_trial": 0, | ||
| "max_iter": 2, | ||
| "time_budget": 50, | ||
| "task": "binary", | ||
| "metric": "roc_auc", | ||
| } | ||
|
|
||
| automl_settings["custom_fix_args"] = { | ||
|
||
| "output_dir": "test/ag/output/", | ||
| "text_backbone": "electra_base", | ||
| "multimodal_fusion_strategy": "fuse_late", | ||
| "dataset_name": "test_ag", | ||
|
||
| "label_column": "label", | ||
| "per_device_batch_size": 4, | ||
| "num_train_epochs": 2, | ||
| "batch_size": 4, | ||
| } | ||
|
|
||
| automl.fit( | ||
| dataframe=train_dataset[feature_columns+["label"]], | ||
| label="label", | ||
| train_data=train_dataset[feature_columns+["label"]], | ||
| valid_data=valid_dataset[feature_columns+["label"]], | ||
| X_val=valid_dataset[feature_columns], | ||
| y_val=valid_dataset["label"], | ||
| estimator_list=["agtextpredictor"], | ||
| **automl_settings | ||
| ) | ||
|
|
||
| print("Begin to run inference on test set") | ||
| score = automl.model.estimator.evaluate(test_dataset) | ||
| print(f"Inference on test set complete, {metric}: {score}") | ||
| del automl | ||
| gc.collect() | ||
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update all occurrences
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update the commit.