Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
290 changes: 253 additions & 37 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,13 @@
from ..utils.batch_sampler import DistributedBatchSampler as NlpDistributedBatchSampler
from ..utils.env import (
DISLORA_WEIGHTS_NAME,
EMA_STATE_DIC,
LOKR_WEIGHTS_NAME,
LORA_WEIGHTS_NAME,
MASTER_WEIGHT_DIC,
MODEL_META_NAME,
MODEL_STATE_DIC,
OPTIMIZER_STATE_DIC,
PADDLE_MASTER_WEIGHTS_INDEX_NAME,
PADDLE_OPTIMIZER_NAME,
PADDLE_PEFT_WEIGHTS_INDEX_NAME,
Expand Down Expand Up @@ -185,6 +189,8 @@
from .unified_checkpoint import UnifiedCheckpointHandler
from .utils import reshard as reshard_util
from .utils.async_save import AsyncSaver
from .utils.reshard import SHARDING_STRATEGY_V1, split_opt_state
from .utils.sharding_io import GroupGetter, to_device

try:
from .utils.zero_cost_checkpoint import (
Expand Down Expand Up @@ -673,6 +679,248 @@ def _load_from_peft_checkpoint(self, resume_from_checkpoint=None):
elif resume_from_checkpoint is not None:
logger.info(f"not loading ckpt :{self.args.dataset_rank}")

def _load_flex_checkpoint(self, resume_from_checkpoint):
def get_metadata_file_name(path):
files = os.listdir(path)
metadata_files = [f for f in files if f.endswith(".metadata")]
assert len(metadata_files) > 0, f"Found no metadata files in {path}"
assert len(metadata_files) == 1, f"Found multiple metadata files in {path}"
return metadata_files[0]

model_sharded_state_dict = self.model.sharded_state_dict()
hf_aoa_config = self.model._gen_aoa_config(self.model.config)
master_weights_path = os.path.join(resume_from_checkpoint, MASTER_WEIGHT_DIC)
opt_states_path = os.path.join(resume_from_checkpoint, OPTIMIZER_STATE_DIC)
model_states_path = os.path.join(resume_from_checkpoint, MODEL_STATE_DIC)

if self.args.load_from_hf:
hcg = dist.fleet.get_hybrid_communicate_group()
assert (
self.args.ignore_load_lr_and_optim
), "Loading from HuggingFace format is only allowed when learning rate and optimizer state are ignored."
try:
moe_sharding_group = hcg.get_moe_sharding_parallel_group()
except Exception:
moe_sharding_group = None

if moe_sharding_group is None or moe_sharding_group.nranks <= 1:
# when moe_sharding_group is None, we use the default process_group
logger.info(f"Loading model weights from '{resume_from_checkpoint}' in safetensors format.")
dist.load_state_dict(
model_sharded_state_dict,
resume_from_checkpoint,
aoa_config=hf_aoa_config,
offload=self.args.load_via_cpu,
safetensors=True,
process_group=None,
comm_method=self.args.comm_method,
)
else:
try:
pp_group = hcg.get_pipe_parallel_group()
if pp_group is None or pp_group.nranks < 1:
raise NotImplementedError("Only support when pp_group is not None.")
except Exception:
raise RuntimeError("Only support when pp_group is not None.")

try:
moe_group = hcg.get_expert_parallel_group()
if moe_group is None or moe_group.nranks < 1:
raise NotImplementedError("Only support when moe_group is not None.")
except Exception:
raise RuntimeError("Only support when moe_group is not None.")
moe_sharding_rank = moe_sharding_group.rank
cur_rank = dist.get_rank()
if moe_sharding_rank == 0:
moe_group_ranks = []
dist.all_gather_object(moe_group_ranks, cur_rank, group=moe_group)
pp_group_ranks = []
dist.all_gather_object(pp_group_ranks, moe_group_ranks, group=pp_group)
process_group_ranks = [rank for ranks in pp_group_ranks for rank in ranks]
else:
process_group_ranks = [0] * (pp_group.nranks * moe_group.nranks)
src_rank = hcg.get_moe_sharding_parallel_group_src_rank()
dist.broadcast_object_list(process_group_ranks, src=src_rank, group=moe_sharding_group)
assert any(process_group_ranks), "process_group_ranks should not be all 0"
logger.info(f"Creating a temporary process group with ranks: {process_group_ranks}")
process_group = dist.new_group(process_group_ranks)

if moe_sharding_rank == 0:
logger.info(f"Loading model weights from '{resume_from_checkpoint}' in safetensors format.")
# Only the first moe_sharding process is allowed to load the model weights.
dist.load_state_dict(
model_sharded_state_dict,
resume_from_checkpoint,
aoa_config=hf_aoa_config,
offload=self.args.load_via_cpu,
safetensors=True,
process_group=process_group,
comm_method=self.args.comm_method,
)

dist.barrier()
logger.info("Destroying the temporary process group.")
dist.destroy_process_group(process_group)
# The first moe_sharding group loads the model weights and then broadcasts them to all other moe_sharding groups.
logger.info(
"First shard (moe_sharding_group) has loaded safetensors weights, starting broadcast on moe_sharding_groups."
)
for param_name, param in self.model.state_dict().items():
dist.broadcast(param, src=src_rank, group=moe_sharding_group)
logger.info("Safetensors format weights have been loaded successfully.")
return

if not self.args.ignore_load_lr_and_optim:
state_dict_metadata = {}
metadata_paths = [
os.path.join(model_states_path, get_metadata_file_name(model_states_path)),
os.path.join(opt_states_path, get_metadata_file_name(opt_states_path)),
os.path.join(master_weights_path, get_metadata_file_name(master_weights_path)),
]

for metadata_file in metadata_paths:
if not os.path.exists(metadata_file):
raise FileNotFoundError(f"Metadata file not found: {metadata_file}")
metadata = paddle.load(metadata_file)
state_dict_metadata.update(metadata.state_dict_metadata)

init_optimizer(self.optimizer, model_sharded_state_dict, state_dict_metadata)

optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)

opt_states = {}
master_weights = {}
for k, v in optimizer_sharded_state_dict.items():
if k.endswith(".w_0"):
master_weights[k] = v
else:
opt_states[k] = v

dist.load_state_dict(
opt_states,
opt_states_path,
aoa_config=self.args.aoa_config,
offload=self.args.load_via_cpu,
comm_method=self.args.comm_method,
)

if not self.args.sharded_model_from_ema:
dist.load_state_dict(
master_weights,
master_weights_path,
aoa_config=self.args.aoa_config,
offload=self.args.load_via_cpu,
)

self._load_scheduler(resume_from_checkpoint)

should_load_stage1 = self.args.should_load_sharding_stage1_model
if should_load_stage1 and self.args.sharded_model_from_ema:
ema_states_path = os.path.join(resume_from_checkpoint, EMA_STATE_DIC, f"{dist.get_rank()}_0.distcp")
ema_state_dict = paddle.load(ema_states_path)
ema_master_weights = ema_state_dict.pop("master_weights", None)
opt_master_weights = self.optimizer.state_dict()["master_weights"]
for k, v in opt_master_weights.items():
assert (
k in ema_master_weights
), f"{k} not in ema_master_weights, emas_master_weight keys {ema_master_weights.keys()}"
paddle.assign(ema_master_weights[k], opt_master_weights[k])

ema_state_dict = reshard_util.all_gather_state_dict(ema_state_dict, lambda x: True, self.sharding_group)
self.model.set_state_dict(ema_state_dict)
else:
dist.load_state_dict(
model_sharded_state_dict,
model_states_path,
aoa_config=self.args.aoa_config,
offload=self.args.load_via_cpu,
)

if self.args.bf16 and (not self.args.ignore_load_lr_and_optim) and should_load_stage1:
opt_state_dict = self.optimizer.state_dict()

def recover_params_from_master_weight(opt_state_dict, group):
master_weights = opt_state_dict["master_weights"]
tmp = OrderedDict()
(master_weights, tmp) = (tmp, master_weights)
# cast to before
for (k, v) in tmp.items():
name = v.name
master_weights[k] = paddle.cast(to_device(v), paddle.bfloat16).cpu()
master_weights[k].name = name

structure_name_map = {k: v.name for (k, v) in self.model.state_dict().items()}
node_model_state = reshard_util.NodeModelState(group=group)
node_model_state_tmp = reshard_util.NodeModelState(group=group)
node_model_state_tmp.add_master_weights(master_weights)
node_model_state_tmp.pack_keys(structure_name_map)
node_model_state.merge_from(node_model_state_tmp, max(group.rank, 0))
del node_model_state_tmp
sharding_strategy = reshard_util.get_sharding_strategy(self.optimizer)
logger.debug(f"sharding_strategy: {sharding_strategy}")
restore_func = (
reshard_util.sharding_v1.restore
if sharding_strategy == SHARDING_STRATEGY_V1
else reshard_util.sharding_v2.restore
)
node_model_state = restore_func(node_model_state, self.model, self.optimizer)
node_model_state.unpack_keys()
master_weights = node_model_state.master_weights

master_weights = reshard_util.all_gather_state_dict(master_weights, lambda x: True, group)

model_state_dict = self.model.state_dict()
for key, param in model_state_dict.items():
if param.name in master_weights:
logger.debug(
f"key {key}, convert master weights {param.name} shape {master_weights[param.name].shape} to param {param.name} shape{param.shape}"
)
assert (
param.shape == master_weights[param.name].shape
), f"got {param.shape} vs {master_weights[param.name].shape}"
master_weight = paddle.reshape(master_weights[param.name], param.shape)
paddle.assign(paddle.cast(to_device(master_weight), paddle.bfloat16), model_state_dict[key])

group_getter = GroupGetter(self.model)
opt_state_dict = split_opt_state(opt_state_dict, group_getter)
for gid in group_getter.get_group_ids():
sub_opt_state_dict = opt_state_dict[gid]
group = group_getter.get_group_by_id(gid)
if self.args.bf16:
recover_params_from_master_weight(sub_opt_state_dict, group)

def _save_flex_model_state(self, output_dir):
model_sharded_state_dict = self.model.sharded_state_dict()
model_state_dict_path = os.path.join(output_dir, MODEL_STATE_DIC)
os.makedirs(model_state_dict_path, exist_ok=True)
dist.save_state_dict(
model_sharded_state_dict,
model_state_dict_path,
)

def _save_flex_optimizer_state(self, output_dir):
optimizer_state_dict_path = os.path.join(output_dir, OPTIMIZER_STATE_DIC)
optimizer_states = {}
master_weights = {}
model_sharded_state_dict = self.model.sharded_state_dict()
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
for k, v in optimizer_sharded_state_dict.items():
if k.endswith(".w_0"):
master_weights[k] = v
else:
optimizer_states[k] = v

dist.save_state_dict(
optimizer_states,
optimizer_state_dict_path,
)

master_weights_path = os.path.join(output_dir, MASTER_WEIGHT_DIC)
dist.save_state_dict(
master_weights,
master_weights_path,
)

def _load_from_checkpoint(self, resume_from_checkpoint=None):
"""load state_dict from_checkpoint, Only load model state dict.

Expand Down Expand Up @@ -1048,27 +1296,7 @@ def train(
if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)

if resume_from_checkpoint is not None:
if not self.args.ignore_load_lr_and_optim:
model_sharded_state_dict = self.model.sharded_state_dict()
accessible_files = os.listdir(resume_from_checkpoint)
metadata_files = [file for file in accessible_files if file.endswith(".metadata")]
assert len(metadata_files) == 1, "Only support one metadata file now."
metadata = paddle.load(os.path.join(resume_from_checkpoint, metadata_files[0]))
state_dict_metadata = metadata.state_dict_metadata
init_optimizer(self.optimizer, model_sharded_state_dict, state_dict_metadata)
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
sharded_state_dict = {**model_sharded_state_dict, **optimizer_sharded_state_dict}
dist.load_state_dict(
sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config
)
self._load_scheduler(resume_from_checkpoint)
else:
model_sharded_state_dict = self.model.sharded_state_dict()
sharded_state_dict = model_sharded_state_dict
dist.load_state_dict(
sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config
)
self._load_flex_checkpoint(resume_from_checkpoint)
else:
model = self._wrap_model(self.model_wrapped)
# for the rest of this function `model` is the outside model, whether it was wrapped or not
Expand Down Expand Up @@ -2867,8 +3095,7 @@ def _save_checkpoint(self, model, metrics=None):
self.save_model(output_dir)

if self.args.save_checkpoint_format == "flex_checkpoint":
model_sharded_state_dict = self.model.sharded_state_dict()
os.makedirs(output_dir, exist_ok=True)
self._save_flex_model_state(output_dir)

# Determine the new best metric / best model checkpoint
if metrics is not None and self.args.metric_for_best_model is not None:
Expand Down Expand Up @@ -2932,11 +3159,7 @@ def _save_checkpoint(self, model, metrics=None):
)
else:
if self.args.save_checkpoint_format == "flex_checkpoint":
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
dist.save_state_dict(
{**model_sharded_state_dict, **optimizer_sharded_state_dict},
output_dir,
)
self._save_flex_optimizer_state(output_dir)
if self.args.should_save:
if self.tokenizer is not None and self.args.save_tokenizer:
self.tokenizer.save_pretrained(output_dir)
Expand Down Expand Up @@ -2992,11 +3215,7 @@ def _save_checkpoint(self, model, metrics=None):
signal_dir,
)
elif self.args.save_checkpoint_format == "flex_checkpoint":
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
dist.save_state_dict(
{**model_sharded_state_dict, **optimizer_sharded_state_dict},
output_dir,
)
self._save_flex_optimizer_state(output_dir)
if self.args.should_save:
if self.tokenizer is not None and self.args.save_tokenizer:
self.tokenizer.save_pretrained(output_dir)
Expand Down Expand Up @@ -3039,10 +3258,7 @@ def _save_checkpoint(self, model, metrics=None):
self._offload_optimizer()
else:
if self.args.save_checkpoint_format == "flex_checkpoint":
dist.save_state_dict(
model_sharded_state_dict,
output_dir,
)
self._save_flex_model_state(output_dir)
if self.args.should_save:
if self.tokenizer is not None and self.args.save_tokenizer:
self.tokenizer.save_pretrained(output_dir)
Expand Down
Loading
Loading