Skip to content

Commit

Permalink
init pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
rlsu9 committed Jan 3, 2025
1 parent dd75ee8 commit af3c5c3
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 1 deletion.
28 changes: 28 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: Run Tests

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Check out repository
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12' # or any version you need

- name: Install dependencies
run: |
pip install --upgrade pip
pip install -e .
pip install pytest
- name: Run Pytest
run: |
pytest
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ wandb/
*.pt
cache_dir/
wandb/
test*
sample_video*
sample_image*
512*
Expand Down
35 changes: 35 additions & 0 deletions tests/test_save_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import os
import shutil
import pytest

import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from fastvideo.utils.checkpoint import save_checkpoint
from fastvideo.models.mochi_hf.modeling_mochi import MochiTransformer3DModel
from fastvideo.utils.fsdp_util import get_dit_fsdp_kwargs

@pytest.fixture(scope="module", autouse=True)
def setup_distributed():
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_RANK"] = "0"
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12345"

dist.init_process_group("nccl")
yield
dist.destroy_process_group()

def test_save_and_remove_checkpoint():
transformer = MochiTransformer3DModel(num_layers=0)
fsdp_kwargs, _ = get_dit_fsdp_kwargs(transformer, "none")
transformer = FSDP(transformer, **fsdp_kwargs)

test_folder = "./test_checkpoint"
save_checkpoint(transformer, 0, test_folder, 0)

assert os.path.exists(test_folder), "Checkpoint folder was not created."

shutil.rmtree(test_folder)
assert not os.path.exists(test_folder), "Checkpoint folder still exists."

0 comments on commit af3c5c3

Please sign in to comment.