Skip to content

Commit

Permalink
add format check
Browse files Browse the repository at this point in the history
  • Loading branch information
phoenixdong committed Jun 11, 2024
1 parent 215bc3a commit 006693f
Show file tree
Hide file tree
Showing 12 changed files with 281 additions and 145 deletions.
32 changes: 32 additions & 0 deletions .github/workflows/format.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: format

on:
pull_request:
branches: [ "main" ]
types: [opened, synchronize, reopened]

jobs:
format:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2
- name: Set up Python 3.10
uses: actions/setup-python@v2
with:
python-version: "3.10"
- name: Install dependencies
run: |
pip install black
- name: Run Black
run: >-
black --check --diff --include
flagscale/auto_tuner/*.py
flagscale/auto_tuner/prune/*.py
flagscale/auto_tuner/record/*.py
flagscale/auto_tuner/search/*.py
flagscale/launcher/*.py
flagscale/logger.py
flagscale/patches_utils.py
flagscale/datasets/sft_dataset.py
./
19 changes: 19 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
repos:
- repo: local
hooks:
- id: black
name: black
entry: black
language: system
types: [python]
files: |
(?x)^(
flagscale/auto_tuner/.*\.py|
flagscale/auto_tuner/prune/\..*\.py|
flagscale/auto_tuner/record/\..*\.py|
flagscale/auto_tuner/search/\..*\.py|
flagscale/launcher/\..*\.py|
flagscale/logger\.py|
flagscale/patches_utils\.py|
flagscale/datasets/sft_dataset\.py
)$
12 changes: 6 additions & 6 deletions flagscale/auto_tuner/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ def __init__(self, config):
"tensor_model_parallel_size": "tensor_model_parallel_size",
"sequence_parallel": "sequence_parallel",
"pipeline_model_parallel_size": "pipeline_model_parallel_size",
"num_layers_per_virtual_pipeline_stage":
"num_layers_per_virtual_pipeline_stage",
"num_layers_per_virtual_pipeline_stage": "num_layers_per_virtual_pipeline_stage",
"recompute_method": "recompute_method",
"recompute_granularity": "recompute_granularity",
"recompute_num_layers": "recompute_num_layers",
Expand Down Expand Up @@ -81,14 +80,15 @@ def gen(self, strategy):
# Set train_iters of each task
if "control" in config.experiment.auto_tuner:
config.train.model.train_iters = config.experiment.auto_tuner.control.get(
"train_iters", 5)
"train_iters", 5
)
else:
config.train.model.train_iters = 5

# log dir
config.experiment.exp_dir = os.path.join(config.experiment.exp_dir,
"auto_tuner",
f"task_{strategy['idx']}")
config.experiment.exp_dir = os.path.join(
config.experiment.exp_dir, "auto_tuner", f"task_{strategy['idx']}"
)

return config

Expand Down
75 changes: 47 additions & 28 deletions flagscale/auto_tuner/prune/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ def prune_by_micro_batch_size(config, strategy, history=[]):
if retrieval:
for item in retrieval:
# performance prune
if item["micro_batch_size"] > micro_batch_size and item[
"performance"]:
if item["micro_batch_size"] > micro_batch_size and item["performance"]:
logger.info(
f"The strategy {strategy} has been pruned by micro_batch_size performance."
)
Expand All @@ -36,8 +35,7 @@ def prune_by_micro_batch_size(config, strategy, history=[]):
strategy["pruned"] = True
return True
# memory prune
if item["micro_batch_size"] < micro_batch_size and item[
"max_mem"] == "OOM":
if item["micro_batch_size"] < micro_batch_size and item["max_mem"] == "OOM":
logger.info(
f"The strategy {strategy} has been pruned by micro_batch_size memory."
)
Expand Down Expand Up @@ -91,10 +89,13 @@ def prune_by_recompute(config, strategy, history=[]):
strategy["pruned"] = True
return True

if (use_recompute and item["use_recompute"]
and recompute_method == "block"
and recompute_method == item["recompute_method"]
and item["performance"]):
if (
use_recompute
and item["use_recompute"]
and recompute_method == "block"
and recompute_method == item["recompute_method"]
and item["performance"]
):
if recompute_num_layers > item["recompute_num_layers"]:
logger.info(
f"The strategy {strategy} has been pruned by block recompute_num_layers performance."
Expand All @@ -104,10 +105,13 @@ def prune_by_recompute(config, strategy, history=[]):
strategy["pruned"] = True
return True

if (use_recompute and item["use_recompute"]
and recompute_method == "uniform"
and recompute_method == item["recompute_method"]
and item["performance"]):
if (
use_recompute
and item["use_recompute"]
and recompute_method == "uniform"
and recompute_method == item["recompute_method"]
and item["performance"]
):
if recompute_num_layers > item["recompute_num_layers"]:
logger.info(
f"The strategy {strategy} has been pruned by uniform recompute_num_layers performance."
Expand All @@ -117,8 +121,7 @@ def prune_by_recompute(config, strategy, history=[]):
strategy["pruned"] = True
return True
# memory prune
if not use_recompute and item["use_recompute"] and item[
"max_mem"] == "OOM":
if not use_recompute and item["use_recompute"] and item["max_mem"] == "OOM":
logger.info(
f"The strategy {strategy} has been pruned by use_recompute memory."
)
Expand All @@ -127,11 +130,16 @@ def prune_by_recompute(config, strategy, history=[]):
strategy["pruned"] = True
return True

if (use_recompute and item["use_recompute"]
and recompute_method == "uniform"
and recompute_method == item["recompute_method"]):
if (recompute_num_layers > item["recompute_num_layers"]
and item["max_mem"] == "OOM"):
if (
use_recompute
and item["use_recompute"]
and recompute_method == "uniform"
and recompute_method == item["recompute_method"]
):
if (
recompute_num_layers > item["recompute_num_layers"]
and item["max_mem"] == "OOM"
):
logger.info(
f"The strategy {strategy} has been pruned by uniform recompute_num_layers memory."
)
Expand All @@ -140,11 +148,16 @@ def prune_by_recompute(config, strategy, history=[]):
strategy["pruned"] = True
return True

if (use_recompute and item["use_recompute"]
and recompute_method == "block"
and recompute_method == item["recompute_method"]):
if (recompute_num_layers < item["recompute_num_layers"]
and item["max_mem"] == "OOM"):
if (
use_recompute
and item["use_recompute"]
and recompute_method == "block"
and recompute_method == item["recompute_method"]
):
if (
recompute_num_layers < item["recompute_num_layers"]
and item["max_mem"] == "OOM"
):
logger.info(
f"The strategy {strategy} has been pruned by block recompute_num_layers memory."
)
Expand All @@ -163,8 +176,11 @@ def prune_by_sequence_parallel(config, strategy, history=[]):
if retrieval:
for item in retrieval:
# performance prune
if item["sequence_parallel"] and item[
"performance"] and not sequence_parallel:
if (
item["sequence_parallel"]
and item["performance"]
and not sequence_parallel
):
logger.info(
f"The strategy {strategy} has been pruned by sequence_parallel performance."
)
Expand All @@ -173,8 +189,11 @@ def prune_by_sequence_parallel(config, strategy, history=[]):
strategy["pruned"] = True
return True
# memory prune
if item["sequence_parallel"] and item[
"max_mem"] == "OOM" and not sequence_parallel:
if (
item["sequence_parallel"]
and item["max_mem"] == "OOM"
and not sequence_parallel
):
logger.info(
f"The strategy {strategy} has been pruned by sequence_parallel memory."
)
Expand Down
1 change: 1 addition & 0 deletions flagscale/auto_tuner/prune/pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def prune(self, strategy, history=[]):
if func(self.config, strategy, history):
not_run = True
break

history.append(strategy)
if not_run:
self.pruned_count += 1
Expand Down
Loading

0 comments on commit 006693f

Please sign in to comment.