Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 54 additions & 3 deletions benchmark/bert/run_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from paddlenlp.data.batchify import Stack, Tuple, Pad
from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer

import paddle.fluid as fluid

FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -131,6 +133,21 @@ def parse_args():
help="Save checkpoint every X updates steps.")
parser.add_argument(
"--seed", type=int, default=42, help="Random seed for initialization")
parser.add_argument(
"--use_fp16",
type=bool,
default=False,
help="Whether to enable half precision training with fp16.")
parser.add_argument(
"--scale_loss",
type=float,
default=1.0,
help="The value of scale_loss for fp16.")
parser.add_argument(
"--use_dynamic_loss_scaling",
type=bool,
default=True,
help="Whether to use dynamic loss scaling.")
args = parser.parse_args()
return args

Expand Down Expand Up @@ -321,6 +338,7 @@ def do_train(args):
num_classes=len(train_dataset.get_labels()))
loss_fct = paddle.nn.loss.CrossEntropyLoss(
) if train_dataset.get_labels() else paddle.nn.loss.MSELoss()
seq_len = paddle.shape(input_ids)[1]
logits = model(input_ids, segment_ids)
loss = loss_fct(logits, labels)
dev_program = main_program.clone(for_test=True)
Expand All @@ -347,10 +365,18 @@ def do_train(args):
p.name for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
])
if args.use_fp16:
amp_list = fluid.contrib.mixed_precision.AutoMixedPrecisionLists(custom_white_list=['layer_norm', 'softmax'])
optimizer = paddle.fluid.contrib.mixed_precision.decorate(
optimizer,
amp_lists=amp_list,
init_loss_scaling=args.scale_loss,
use_dynamic_loss_scaling=args.use_dynamic_loss_scaling)
optimizer.minimize(loss)

# Create the metric pass for the validation
with paddle.static.program_guard(dev_program, startup_program):
logits = paddle.fluid.layers.cast(logits, 'float32')
metric = metric_class()
correct = metric.compute(logits, labels)

Expand All @@ -364,18 +390,43 @@ def do_train(args):
pretrained_state_dict)
paddle.static.set_program_state(main_program, reset_state_dict)

exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_threads = 1
exec_strategy.num_iteration_per_drop_scope = 10000

build_strategy = fluid.BuildStrategy()

main_program = fluid.CompiledProgram(main_program).with_data_parallel(
loss_name=loss.name,
exec_strategy=exec_strategy,
build_strategy=build_strategy)

global_step = 0
processed_words_len = 0.0
tic_train = time.time()
for epoch in range(args.num_train_epochs):
for step, batch in enumerate(train_data_loader):
global_step += 1
loss_return = exe.run(main_program, feed=batch, fetch_list=[loss])
"""
if step == 200:
# profiler.start_profiler("All")
fluid.core.nvprof_start()
if step == 210:
fluid.core.nvprof_stop()
# profiler.stop_profiler("total", "./profile")
return
"""
loss_return = exe.run(main_program, feed=batch, fetch_list=[loss, seq_len])
processed_words_len += loss_return[1]
if global_step % args.logging_steps == 0:
log_time = time.time() - tic_train
logger.info(
"global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s"
"global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s, thoughput: %.2f words/s"
% (global_step, epoch, step, loss_return[0],
args.logging_steps / (time.time() - tic_train)))
args.logging_steps / log_time,
processed_words_len / log_time))
tic_train = time.time()
processed_words_len = 0.0
lr_scheduler.step()
if global_step % args.save_steps == 0:
# Validation pass, record the loss and metric
Expand Down
17 changes: 17 additions & 0 deletions benchmark/bert/run_glue_amp.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
export CUDA_VISIBLE_DEVICES=0
export TASK_NAME=SST-2

python -u ./run_glue.py \
--model_type bert \
--model_name_or_path bert-base-uncased \
--task_name $TASK_NAME \
--max_seq_length 128 \
--batch_size 64 \
--learning_rate 2e-5 \
--num_train_epochs 3 \
--logging_steps 20 \
--save_steps 500 \
--output_dir ./tmp/$TASK_NAME/ \
--use_fp16=true \
--scale_loss=128.0 \
--use_dynamic_loss_scaling=true \
11 changes: 9 additions & 2 deletions paddlenlp/data/batchify.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class Stack(object):
[8 9 1 2]]
'''
"""

def __init__(self, axis=0, dtype=None):
self._axis = axis
self._dtype = dtype
Expand All @@ -56,8 +57,10 @@ def __call__(self, data):
Returns:
numpy.ndarray: Stacked batch data.
"""
data = np.stack(data, axis=self._axis).astype(
self._dtype) if self._dtype else np.stack(data, axis=self._axis)
data = np.stack(
data,
axis=self._axis).astype(self._dtype) if self._dtype else np.stack(
data, axis=self._axis)
return data


Expand Down Expand Up @@ -92,6 +95,7 @@ class Pad(object):
[8. 2. 0. 0.]]
'''
"""

def __init__(self, pad_val=0, axis=0, ret_length=None, dtype=None):
self._pad_val = pad_val
self._axis = axis
Expand All @@ -116,6 +120,8 @@ def __call__(self, data):
arrs = [np.asarray(ele) for ele in data]
original_length = [ele.shape[self._axis] for ele in arrs]
max_size = max(original_length)
if max_size % 8 != 0:
max_size = (int(max_size / 8) + 1) * 8
ret_shape = list(arrs[0].shape)
ret_shape[self._axis] = max_size
ret_shape = (len(arrs), ) + tuple(ret_shape)
Expand Down Expand Up @@ -160,6 +166,7 @@ class Tuple(object):
from paddle.incubate.hapi.text.data_utils import Tuple, Pad, Stack
batchify_fn = Tuple(Pad(axis=0, pad_val=0), Stack())
"""

def __init__(self, fn, *args):
if isinstance(fn, (list, tuple)):
assert len(args) == 0, 'Input pattern not understood. The input of Tuple can be ' \
Expand Down