Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1180,6 +1180,11 @@ class TrainingArguments:
},
)

save_replicas: Optional[bool] = field(
default=False,
metadata={"help": "Whether to save replicas cross files in distributed save load system."},
)

def __post_init__(self):
world_size = paddle.distributed.get_world_size()
if in_auto_parallel_align_mode():
Expand Down
59 changes: 43 additions & 16 deletions paddlenlp/trainer/utils/zero_cost_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
import os
import random
import time
from collections import OrderedDict
from collections import OrderedDict, defaultdict
from dataclasses import replace
from enum import Enum

import numpy as np
Expand All @@ -37,10 +38,7 @@
LocalTensorMetadata,
Metadata,
)
from paddle.distributed.flex_checkpoint.dcp.save_state_dict import (
balanced_dedup_key_in_dict,
dedup_key_in_dict,
)
from paddle.distributed.flex_checkpoint.dcp.save_state_dict import dedup_key_in_dict
from paddle.distributed.flex_checkpoint.dcp.sharded_weight import ShardedWeight
from paddle.distributed.flex_checkpoint.dcp.utils import (
flatten_state_dict,
Expand Down Expand Up @@ -1380,7 +1378,7 @@ def check_same_strategy(self, resume_from_checkpoint=None):
return True, None


def saved_ckptmeta(state_dict, ckpt_file_name, process_group=None):
def saved_ckptmeta(state_dict, ckpt_file_name, process_group=None, save_replicas=False):
with paddle.base.dygraph.guard():
assert isinstance(state_dict, dict), "The state_dict should be a dictionary."
flat_state_dict, mapping = flatten_state_dict(state_dict)
Expand Down Expand Up @@ -1416,19 +1414,16 @@ def saved_ckptmeta(state_dict, ckpt_file_name, process_group=None):
else:
flattened_range = None
local_state_dict_metadata[key] = LocalTensorMetadata(
global_offset,
local_shape,
tuple(global_offset),
tuple(local_shape),
local_tensor_dtype,
global_shape,
tuple(global_shape),
is_flattened,
flattened_range,
)
local_storage_metadata[
LocalTensorIndex(
key,
tuple(global_offset),
is_flattened,
flattened_range,
key, tuple(global_offset), is_flattened, flattened_range, local_shape=tuple(local_shape)
)
] = ckpt_file_name

Expand All @@ -1450,6 +1445,34 @@ def saved_ckptmeta(state_dict, ckpt_file_name, process_group=None):
global_storage_metadata.append(local_storage_metadata)
global_flatten_mapping.append(mapping)

def balanced_dedup_key_in_dict(global_storage_metadata):
lti_to_files = defaultdict(set)
for storage_metadata in global_storage_metadata:
for lti, fname in storage_metadata.items():
lti_to_files[lti].add(fname)

file_load = defaultdict(int)
out = {}
for lti, file_candidates in lti_to_files.items():
candidates = sorted(file_candidates)
selected_main_file = min(candidates, key=lambda f: file_load[f])
file_load[selected_main_file] += 1

if save_replicas:
lti_main = replace(lti, replica_id=0)
out[lti_main] = selected_main_file
replica_id = 1
for fname in candidates:
if fname == selected_main_file:
continue
lti_replica = replace(lti, replica_id=replica_id)
out[lti_replica] = fname
replica_id += 1
else:
out[lti] = selected_main_file

return out

metadata.state_dict_metadata = merge_state_dict_metadata(global_state_dict_metadata)
metadata.storage_metadata = balanced_dedup_key_in_dict(global_storage_metadata)
metadata.flat_mapping = dedup_key_in_dict(global_flatten_mapping)
Expand Down Expand Up @@ -1570,7 +1593,7 @@ def create_ckpt_file_name():
self.ckpt_data_name, self.ckpt_meta_name = create_ckpt_file_name()
# self.model_ckpt_meta, self.model_state_filter = saved_ckptmeta(model.sharded_state_dict(), self.ckpt_data_name)
self.model_ckpt_meta, self.model_state_filter = saved_ckptmeta(
self.manipulated_state_dict, self.ckpt_data_name
self.manipulated_state_dict, self.ckpt_data_name, save_replicas=self.args.save_replicas
)

# opt state dict ckpt meta and filter
Expand All @@ -1584,8 +1607,12 @@ def create_ckpt_file_name():
else:
opt_state_dict[k] = v

self.opt_ckpt_meta, self.opt_state_filter = saved_ckptmeta(opt_state_dict, self.ckpt_data_name)
self.master_weight_ckpt_meta, self.master_weights_filter = saved_ckptmeta(master_weights, self.ckpt_data_name)
self.opt_ckpt_meta, self.opt_state_filter = saved_ckptmeta(
opt_state_dict, self.ckpt_data_name, save_replicas=self.args.save_replicas
)
self.master_weight_ckpt_meta, self.master_weights_filter = saved_ckptmeta(
master_weights, self.ckpt_data_name, save_replicas=self.args.save_replicas
)

# gen unified name mapping for optimzier
self.unified_name_mapping, self.param_slice_info = self._gen_unified_name(
Expand Down
Loading