Skip to content

Commit

Permalink
docs&fix: Update the docs for vit regarding OOM; Fix CI bug
Browse files Browse the repository at this point in the history
  • Loading branch information
The-truthh committed Jun 18, 2024
1 parent 19c4a5b commit 7db94c6
Show file tree
Hide file tree
Showing 9 changed files with 12 additions and 117 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
pip install "Pillow==9.1.1"
# MindSpore must be installed following the instruction from official web, but not from pypi.
# That's why we exclude mindspore from requirements.txt. Does this work?
pip install "mindspore>=1.8,<=1.10"
pip install "mindspore>=1.8"
- name: Lint with pre-commit
uses: pre-commit/[email protected]
- name: Test with pytest (UT)
Expand Down
2 changes: 1 addition & 1 deletion configs/vit/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ For detailed illustration of all hyper-parameters, please refer to [config.py](h

**Note:**
1) As the global batch size (batch_size x num_devices) is an important hyper-parameter, it is recommended to keep the global batch size unchanged for reproduction or adjust the learning rate linearly to a new global batch size.
2) The current configuration with a batch_size of 512, was initially set for a machine with 64GB of VRAM. To avoid running out of memory (OOM) on machines with smaller VRAM, consider reducing the batch_size to 256 or lower.
2) The current configuration with a batch_size of 512, was initially set for a machine with 64GB of VRAM. To avoid running out of memory (OOM) on machines with smaller VRAM, consider reducing the batch_size to 256 or lower. Simultaneously, to maintain the consistency of training results, please scale the learning rate down proportionally with decreasing batch_size.

* Standalone Training

Expand Down
2 changes: 1 addition & 1 deletion tests/modules/parallel/test_parallel_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_create_dataset_distribute_imagenet(mode, name, split, shuffle, num_para


@pytest.mark.parametrize("mode", [0, 1])
@pytest.mark.parametrize("name", ["MNIST", "CIFAR10"])
@pytest.mark.parametrize("name", ["CIFAR10"])
@pytest.mark.parametrize("split", ["train", "val"])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("num_parallel_workers", [2, 4, 8, 16])
Expand Down
2 changes: 1 addition & 1 deletion tests/modules/parallel/test_parallel_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_transforms_distribute_imagenet(mode, name, image_resize, is_training):


@pytest.mark.parametrize("mode", [0, 1])
@pytest.mark.parametrize("name", ["MNIST", "CIFAR10"])
@pytest.mark.parametrize("name", ["CIFAR10"])
@pytest.mark.parametrize("image_resize", [224, 256, 320])
@pytest.mark.parametrize("is_training", [True, False])
@pytest.mark.parametrize("download", [True, False])
Expand Down
4 changes: 2 additions & 2 deletions tests/modules/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_checker_invalid():


@pytest.mark.parametrize("mode", [0, 1])
@pytest.mark.parametrize("dataset", ["mnist", "imagenet"])
@pytest.mark.parametrize("dataset", ["imagenet"])
def test_parse_args_without_yaml(mode, dataset):
args = parse_args([f"--mode={mode}", f"--dataset={dataset}"])
assert args.mode == mode
Expand All @@ -46,7 +46,7 @@ def test_parse_args_without_yaml(mode, dataset):

@pytest.mark.parametrize("cfg_yaml", ["configs/resnet/resnet_18_ascend.yaml"])
@pytest.mark.parametrize("mode", [1])
@pytest.mark.parametrize("dataset", ["mnist"])
@pytest.mark.parametrize("dataset", ["imagenet"])
def test_parse_args_with_yaml(cfg_yaml, mode, dataset):
args = parse_args([f"--config={cfg_yaml}", f"--mode={mode}", f"--dataset={dataset}"])
assert args.mode == mode
Expand Down
7 changes: 2 additions & 5 deletions tests/modules/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_create_dataset_standalone_imagenet(mode, name, split, shuffle, num_samp
assert dataset is not None


# test MNIST CIFAR10
# test CIFAR10
@pytest.mark.parametrize("mode", [0, 1])
@pytest.mark.parametrize("name", ["CIFAR10"])
@pytest.mark.parametrize("split", ["train", "test"])
Expand Down Expand Up @@ -95,8 +95,5 @@ def test_create_dataset_standalone_mc(mode, name, split, shuffle, num_samples, n
download=download,
)

assert (
type(dataset) == ms.dataset.engine.datasets_vision.MnistDataset
or type(dataset) == ms.dataset.engine.datasets_vision.Cifar10Dataset
)
assert type(dataset) == ms.dataset.engine.datasets_vision.Cifar10Dataset
assert dataset is not None
2 changes: 1 addition & 1 deletion tests/modules/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_transforms_standalone_imagenet(mode, name, image_resize, is_training, a
assert output_shape[0][0] == 3 * batch_size and output_shape[1][0] == 3 * batch_size, "augment splits error!"


# test mnist cifar10
# test cifar10
@pytest.mark.parametrize("mode", [0, 1])
@pytest.mark.parametrize("name", ["CIFAR10"])
@pytest.mark.parametrize("image_resize", [224, 256])
Expand Down
105 changes: 0 additions & 105 deletions tests/tasks/test_train_mnist.py

This file was deleted.

3 changes: 3 additions & 0 deletions tests/tasks/test_train_val_imagenet_subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def test_train(mode, val_while_train, model="resnet18"):
res = out.decode()
idx = res.find("Accuracy")
acc = res[idx:].split(",")[0].split(":")[1]
# python 3.9 acc will be np.float64(1.0)
if "(" in acc:
acc = acc.split("(")[-1].rstrip(")")
print("Val acc: ", acc)
assert float(acc) > 0.5, "Acc is too low"

Expand Down

0 comments on commit 7db94c6

Please sign in to comment.