diff --git a/deepmd/train/trainer.py b/deepmd/train/trainer.py index f483bcc696..afa66935d0 100644 --- a/deepmd/train/trainer.py +++ b/deepmd/train/trainer.py @@ -80,6 +80,7 @@ class DPTrainer (object): def __init__(self, jdata, run_opt): + paddle.set_device("cpu") self.run_opt = run_opt self._init_param(jdata) @@ -304,7 +305,7 @@ def build (self, def train (self, data, stop_batch) : - paddle.set_device("gpu") + self.stop_batch = stop_batch self.print_head()