Skip to content

传入Trainer中训练的为何是model而不是delta_model呢? #12

@zxcvbnmkj

Description

@zxcvbnmkj
lora_config = json.load(open("config/lora_config.json"))
lora_config["lora_r"] = sparse_args.lora_r
lora_config = LoraConfig.from_dict(lora_config)
**delta_model** = LoraModel.from_config(lora_config, backbone_model=model)
**delta_model**.freeze_module(set_state_dict = True)
**delta_mode**l.log(delta_ratio=True, trainable_ratio=True, visualization=False)

if training_args.train_sparse:
    print("building sparse optimizer and scheduler")
    from src.trainer import GATE_PARAM_NAME
    valid_param_name = []
    for n, p in **delta_model**.named_parameters():
        print(f"Parameter name: {n}, requires_grad: {p.requires_grad}")
        if GATE_PARAM_NAME in n:
            valid_param_name.append(n)
    print("valid param name:", valid_param_name)
    sparse_optimizer = SparseAdamW(sparse_lambda=sparse_args.sparse_lambda_2, lambda_schedule=sparse_args.lambda_schedule, max_lambda=sparse_args.max_lambda, lambda_num=sparse_args.lambda_num, params=[p for n, p in model.named_parameters() if GATE_PARAM_NAME in n and p.requires_grad], lr=sparse_args.sparse_lr)
    sparse_scheduler = get_linear_schedule_with_warmup(sparse_optimizer, 
    num_warmup_steps=int(training_args.num_train_epochs*(len(train_dataset) / training_args.train_batch_size)*training_args.warmup_ratio), 
    num_training_steps=int(training_args.num_train_epochs*(len(train_dataset) / training_args.train_batch_size)))

trainer = SparseTrainer(
    model=**model**,
    args=training_args,
    train_dataset=train_dataset if training_args.do_train else None,
    eval_dataset=eval_dataset if training_args.do_eval else None,
    compute_metrics=compute_metrics,
    tokenizer=tokenizer,
    data_collator=data_collator,
    optimizers = (optimizer, lr_scheduler),
    sparse_lambda = sparse_args.sparse_lambda,
    sparse_optimizer = (sparse_optimizer, sparse_scheduler)
)

我比较困惑为什么delta_model在后续代码中没有出现过了,按照我浅显的理解,在封装好了delta_model之后,在Trainer中进行微调的应该是delta_model呀。目前delta_model定义好之后好像没有起到什么实际作用。
我对opendelta的了解还不够深入,希望大家不吝赐教。

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions