Skip to content

Commit

Permalink
Adding script to run a training job and validate the training_loss (#104
Browse files Browse the repository at this point in the history
)
  • Loading branch information
parambole authored Sep 17, 2024
1 parent e68bbc0 commit edbf599
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 1 deletion.
59 changes: 59 additions & 0 deletions end_to_end/tpu/eval_assert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""
Copyright 2024 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

# pylint: skip-file
"""Reads and asserts over target values"""
from absl import app
from typing import Sequence
import json

def get_last_n_data(metrics_file, target, n=10):
last_n_data = []
with open(metrics_file, 'r', encoding='utf8') as file:
lines = file.readlines()
for line in lines[::-1]:
metrics = json.loads(line)
if target in metrics:
last_n_data.append(metrics[target])
if len(last_n_data) >= n:
break
return last_n_data


def test_final_loss(metrics_file, target_loss):
target_loss = float(target_loss)
with open(metrics_file, 'r', encoding='utf8') as metrics:
use_last_n_data = 10
last_n_data = get_last_n_data(metrics_file, 'learning/loss', use_last_n_data)
avg_last_n_data = sum(last_n_data) / len(last_n_data)
print(f"Mean of last {len(last_n_data)} losses is {avg_last_n_data}")
print(f"Target loss is {target_loss}")
assert avg_last_n_data < target_loss
print('Final loss test passed.')


def main(argv: Sequence[str]) -> None:

_, test_scenario, *test_vars = argv

if test_scenario == 'final_loss':
test_final_loss(*test_vars)
else:
raise ValueError(f"Unrecognized test_scenario {test_scenario}")


if __name__ == "__main__":
app.run(main)
25 changes: 25 additions & 0 deletions end_to_end/tpu/test_sdxl_training_loss.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/bin/bash
set -ex

echo "Running test_sdxl_training_loss.sh"

# Set environment variables
for ARGUMENT in "$@"; do
IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
export "$KEY"="$VALUE"
done

TRAIN_CMD="python src/maxdiffusion/train_sdxl.py src/maxdiffusion/configs/base_xl.yml \
pretrained_model_name_or_path=gs://maxdiffusion-github-runner-test-assets/checkpoints/models--stabilityai--stable-diffusion-xl-base-1.0 \
revision=refs/pr/95 activations_dtype=bfloat16 weights_dtype=bfloat16 metrics_file=metrics.txt write_metrics=True \
dataset_name=gs://jfacevedo-maxdiffusion-v5p/pokemon-datasets/pokemon-gpt4-captions_xl resolution=1024 per_device_batch_size=1 \
jax_cache_dir=gs://jfacevedo-maxdiffusion/cache_dir/ max_train_steps=$STEPS attention=flash run_name=sdxl-fsdp-v5p-64-ddp enable_profiler=True \
run_name=$RUN_NAME \
output_dir=$OUTPUT_DIR "

# Train
export LIBTPU_INIT_ARGS=""
$TRAIN_CMD

# Assert training loss is smaller than input LOSS_THRESHOLD
python3 end_to_end/tpu/eval_assert.py final_loss metrics.txt $LOSS_THRESHOLD
2 changes: 1 addition & 1 deletion src/maxdiffusion/trainers/sdxl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera
num_model_parameters = max_utils.calculate_num_params_from_pytree(unet_state.params)

max_utils.add_text_to_summary_writer("number_model_parameters", str(num_model_parameters), writer)
max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ["LIBTPU_INIT_ARGS"], writer)
max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ.get("LIBTPU_INIT_ARGS",""), writer)
max_utils.add_config_to_summary_writer(self.config, writer)

if jax.process_index() == 0:
Expand Down

0 comments on commit edbf599

Please sign in to comment.