Skip to content

Commit

Permalink
Make some modules private (#1093)
Browse files Browse the repository at this point in the history
This is a PyTorch requirement. Otherwise, the module must be documented
in a .rst file.

Making modules private by prefixing an `_`, as suggested by PyTorch doc
checker.

Privatized modules: 
`_utils`, `_debug`, `_backward`, `_unflatten`
  • Loading branch information
kwen2501 authored Apr 26, 2024
1 parent ac5c483 commit 73e349b
Show file tree
Hide file tree
Showing 11 changed files with 22 additions and 18 deletions.
2 changes: 1 addition & 1 deletion format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
DEFAULT_TARGETS=()
for f in $(git ls-files | grep '\.py$'); do
case "$f" in
'pippy/unflatten.py')
'pippy/_unflatten.py')
# ignore
;;

Expand Down
10 changes: 5 additions & 5 deletions pippy/IR.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@
from torch.fx.node import map_aggregate
from torch.fx.passes.split_module import split_module

from .backward import _null_coalesce_accumulate, stage_backward
from .debug import PIPPY_VERBOSITY
from .microbatch import split_args_kwargs_into_chunks, TensorChunkSpec
from .unflatten import (
from ._backward import _null_coalesce_accumulate, stage_backward
from ._debug import PIPPY_VERBOSITY
from ._unflatten import (
_assign_attr,
_AttrKind,
_outline_submodules,
_sink_params,
)
from .utils import QualnameMapMixin
from ._utils import QualnameMapMixin
from .microbatch import split_args_kwargs_into_chunks, TensorChunkSpec


logger = logging.getLogger(__name__)
Expand Down
6 changes: 3 additions & 3 deletions pippy/PipelineStage.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from torch.fx.node import map_aggregate
from torch.nn.parallel import DistributedDataParallel

from .backward import stage_backward
from .debug import map_debug_info
from ._backward import stage_backward
from ._debug import map_debug_info
from ._utils import flatten_args, modify_graph_op_device
from .IR import Pipe
from .utils import flatten_args, modify_graph_op_device

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion pippy/backward.py → pippy/_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

from .debug import map_debug_info
from ._debug import map_debug_info


def stage_backward(
Expand Down
File renamed without changes.
3 changes: 1 addition & 2 deletions pippy/unflatten.py → pippy/_unflatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
import copy
import operator
from enum import Enum
from typing import Any, cast, Dict, List, Optional, Union
from typing import cast, Dict, List, Optional, Union

import torch
import torch.fx._pytree as fx_pytree
import torch.utils._pytree as pytree
from torch.export.exported_program import (
ConstantArgument,
ExportedProgram,
ModuleCallSignature,
SymIntArgument,
TensorArgument,
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ exclude = [

[tool.ufmt]
excludes = [
"pippy/unflatten.py",
"pippy/_unflatten.py",
]
11 changes: 8 additions & 3 deletions test/test_interleave.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@
import torch
import torch.distributed as dist

from pippy.IR import annotate_split_points, pipeline, SplitPoint
from pippy.PipelineSchedule import ScheduleInterleaved1F1B, ScheduleLoopedBFS
from pippy.PipelineStage import PipelineStage
from pippy import (
annotate_split_points,
pipeline,
PipelineStage,
ScheduleInterleaved1F1B,
ScheduleLoopedBFS,
SplitPoint,
)

# Using same key words as single-stage tests for convenience in CI.
schedule_map = {
Expand Down
2 changes: 1 addition & 1 deletion test/test_pipeline_schedule_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from pippy.PipelineSchedule import (
from pippy import (
ManualPipelineStage,
Schedule1F1B,
ScheduleGPipe,
Expand Down
2 changes: 1 addition & 1 deletion test/test_stage_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from pippy.backward import stage_backward
from pippy._backward import stage_backward


d_hid = 512
Expand Down

0 comments on commit 73e349b

Please sign in to comment.