diff --git a/source/tests/pt/test_multitask.py b/source/tests/pt/test_multitask.py index 984c9c6079..3c78484e1f 100644 --- a/source/tests/pt/test_multitask.py +++ b/source/tests/pt/test_multitask.py @@ -138,7 +138,11 @@ def test_multitask_train(self): multi_state_dict[state_key.replace("model_3", "model_2")], multi_state_dict_finetuned[state_key], ) - elif "model_4" in state_key and "fitting_net" not in state_key and "out_bias" not in state_key: + elif ( + "model_4" in state_key + and "fitting_net" not in state_key + and "out_bias" not in state_key + ): torch.testing.assert_close( multi_state_dict[state_key.replace("model_4", "model_2")], multi_state_dict_finetuned[state_key],