diff --git a/benchmark/bert/run_glue.py b/benchmark/bert/run_glue.py index 241d36b..9d938ab 100644 --- a/benchmark/bert/run_glue.py +++ b/benchmark/bert/run_glue.py @@ -27,6 +27,8 @@ from paddlenlp.data.batchify import Stack, Tuple, Pad from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer +import paddle.fluid as fluid + FORMAT = '%(asctime)s-%(levelname)s: %(message)s' logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) @@ -131,6 +133,21 @@ def parse_args(): help="Save checkpoint every X updates steps.") parser.add_argument( "--seed", type=int, default=42, help="Random seed for initialization") + parser.add_argument( + "--use_fp16", + type=bool, + default=False, + help="Whether to enable half precision training with fp16.") + parser.add_argument( + "--scale_loss", + type=float, + default=1.0, + help="The value of scale_loss for fp16.") + parser.add_argument( + "--use_dynamic_loss_scaling", + type=bool, + default=True, + help="Whether to use dynamic loss scaling.") args = parser.parse_args() return args @@ -347,10 +364,16 @@ def do_train(args): p.name for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "norm"]) ]) + if args.use_fp16: + optimizer = paddle.fluid.contrib.mixed_precision.decorate( + optimizer, + init_loss_scaling=args.scale_loss, + use_dynamic_loss_scaling=args.use_dynamic_loss_scaling) optimizer.minimize(loss) # Create the metric pass for the validation with paddle.static.program_guard(dev_program, startup_program): + logits = paddle.fluid.layers.cast(logits, 'float32') metric = metric_class() correct = metric.compute(logits, labels) @@ -364,6 +387,17 @@ def do_train(args): pretrained_state_dict) paddle.static.set_program_state(main_program, reset_state_dict) + exec_strategy = fluid.ExecutionStrategy() + exec_strategy.num_threads = 1 + exec_strategy.num_iteration_per_drop_scope = 10000 + + build_strategy = fluid.BuildStrategy() + + main_program = fluid.CompiledProgram(main_program).with_data_parallel( + loss_name=loss.name, + exec_strategy=exec_strategy, + build_strategy=build_strategy) + global_step = 0 tic_train = time.time() for epoch in range(args.num_train_epochs): diff --git a/benchmark/bert/run_glue_amp.sh b/benchmark/bert/run_glue_amp.sh new file mode 100755 index 0000000..1c70334 --- /dev/null +++ b/benchmark/bert/run_glue_amp.sh @@ -0,0 +1,17 @@ +export CUDA_VISIBLE_DEVICES=0 +export TASK_NAME=SST-2 + +python -u ./run_glue.py \ + --model_type bert \ + --model_name_or_path bert-base-uncased \ + --task_name $TASK_NAME \ + --max_seq_length 128 \ + --batch_size 64 \ + --learning_rate 2e-5 \ + --num_train_epochs 3 \ + --logging_steps 20 \ + --save_steps 500 \ + --output_dir ./tmp/$TASK_NAME/ \ + --use_fp16=true \ + --scale_loss=128.0 \ + --use_dynamic_loss_scaling=true \ diff --git a/paddlenlp/data/batchify.py b/paddlenlp/data/batchify.py index 32dd3dd..b51ef48 100644 --- a/paddlenlp/data/batchify.py +++ b/paddlenlp/data/batchify.py @@ -44,6 +44,7 @@ class Stack(object): [8 9 1 2]] ''' """ + def __init__(self, axis=0, dtype=None): self._axis = axis self._dtype = dtype @@ -56,8 +57,10 @@ def __call__(self, data): Returns: numpy.ndarray: Stacked batch data. """ - data = np.stack(data, axis=self._axis).astype( - self._dtype) if self._dtype else np.stack(data, axis=self._axis) + data = np.stack( + data, + axis=self._axis).astype(self._dtype) if self._dtype else np.stack( + data, axis=self._axis) return data @@ -92,6 +95,7 @@ class Pad(object): [8. 2. 0. 0.]] ''' """ + def __init__(self, pad_val=0, axis=0, ret_length=None, dtype=None): self._pad_val = pad_val self._axis = axis @@ -116,6 +120,8 @@ def __call__(self, data): arrs = [np.asarray(ele) for ele in data] original_length = [ele.shape[self._axis] for ele in arrs] max_size = max(original_length) + if max_size % 8 != 0: + max_size = (int(max_size / 8) + 1) * 8 ret_shape = list(arrs[0].shape) ret_shape[self._axis] = max_size ret_shape = (len(arrs), ) + tuple(ret_shape) @@ -160,6 +166,7 @@ class Tuple(object): from paddle.incubate.hapi.text.data_utils import Tuple, Pad, Stack batchify_fn = Tuple(Pad(axis=0, pad_val=0), Stack()) """ + def __init__(self, fn, *args): if isinstance(fn, (list, tuple)): assert len(args) == 0, 'Input pattern not understood. The input of Tuple can be ' \