Skip to content

Commit

Permalink
use torchx for manual many replica (20+) tests (#75)
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k authored Jan 18, 2025
1 parent 3ee2360 commit 39a40b2
Show file tree
Hide file tree
Showing 10 changed files with 129 additions and 19 deletions.
4 changes: 4 additions & 0 deletions .pyre_configuration
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,9 @@
"import_root": ".",
"source": "torchft"
}
],
"search_path": [
{"site-package": "torchx"},
{"site-package": "parameterized"}
]
}
3 changes: 3 additions & 0 deletions .torchxconfig
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[cli:run]
component=torchft/torchx.py:hsdp
scheduler=local_cwd
15 changes: 15 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,21 @@ make livehtml
The docs will be built in the `docs/build/html` directory and served at http://localhost:8000.
The page will be automatically re-built as long as the process is kept running.

### Running Multiple Replica Local Job

We use torchx to run multiple worker local test jobs. You need to run the
lighthouse first and then you can use torchx to launch as many replica groups as
you want. This uses the [train_ddp.py](./train_ddp.py) script.

```sh
$ torchft_lighthouse --min_replicas 2 --join_timeout_ms 10000 &
$ torchx run -- --replicas 10
```

Once the Lighthouse has started you can view the status of all the workers at the Lighthouse dashboard.

Default address is: http://localhost:29510

## Contributor License Agreement ("CLA")

In order to accept your pull request, we need you to submit a CLA. You only need to do this once to work on any of
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ dev = [
"pyre-check",
"parameterized",
"expecttest",
"numpy"
"numpy",
"torchx"
]

[tool.maturin]
Expand Down
13 changes: 9 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,11 +302,16 @@ impl From<Status> for StatusError {
#[pymodule]
fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
// setup logging on import
stderrlog::new()
.verbosity(2)
let mut log = stderrlog::new();
log.verbosity(2)
.show_module_names(true)
.timestamp(stderrlog::Timestamp::Millisecond)
.init()
.timestamp(stderrlog::Timestamp::Millisecond);

if env::var("CLICOLOR_FORCE").is_ok() {
log.color(stderrlog::ColorChoice::AlwaysAnsi);
}

log.init()
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;

m.add_class::<Manager>()?;
Expand Down
28 changes: 17 additions & 11 deletions src/lighthouse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,21 +377,26 @@ impl Lighthouse {

let (_, quorum_status) = quorum_compute(Instant::now(), &state, &self.opt);

let max_step = {
if let Some(quorum) = state.prev_quorum.clone() {
quorum
.participants
.iter()
.map(|p| p.step)
.max()
.unwrap_or(-1)
} else {
-1
}
let max_step = if let Some(quorum) = &state.prev_quorum {
quorum
.participants
.iter()
.map(|p| p.step)
.max()
.unwrap_or(-1)
} else {
-1
};

let num_participants = if let Some(quorum) = &state.prev_quorum {
quorum.participants.len() as i64
} else {
-1
};

StatusTemplate {
quorum_id: state.quorum_id,
num_participants: num_participants,
prev_quorum: state.prev_quorum.clone(),
quorum_status: quorum_status,
max_step: max_step,
Expand Down Expand Up @@ -527,6 +532,7 @@ struct StatusTemplate {
prev_quorum: Option<Quorum>,
quorum_id: i64,
quorum_status: String,
num_participants: i64,
max_step: i64,
heartbeats: HashMap<String, Instant>,

Expand Down
1 change: 1 addition & 0 deletions templates/status.html
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ <h3>Previous Quorum</h3>
{% if let Some(prev_quorum) = prev_quorum %}

Previous quorum id: {{prev_quorum.quorum_id}} <br>
Num participants: {{num_participants}} <br>
Quorum age:
{{SystemTime::try_from(prev_quorum.created.unwrap()).unwrap().elapsed().unwrap().as_secs_f64()}}s

Expand Down
3 changes: 0 additions & 3 deletions torchft/manager_integ_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@

import torch
import torch.distributed as dist

# pyre-fixme[21]: missing module
from parameterized import parameterized
from torch import nn, optim

Expand Down Expand Up @@ -292,7 +290,6 @@ def test_ddp_healthy(self) -> None:
for state_dict in state_dicts:
torch.testing.assert_close(state_dict, state_dicts[0])

# pyre-fixme[56]: couldn't infer type of decorator
@parameterized.expand(
[
(
Expand Down
76 changes: 76 additions & 0 deletions torchft/torchx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""
This is a file for TorchX components used for testing torchft.
"""

import os
from typing import Dict, Optional

import torchx.specs as specs


def hsdp(
*script_args: str,
replicas: int = 2,
workers_per_replica: int = 1,
max_restarts: int = 10,
script: str = "train_ddp.py",
env: Optional[Dict[str, str]] = None,
image: str = "",
h: Optional[str] = None,
cpu: int = 2,
gpu: int = 0,
memMB: int = 1024,
) -> specs.AppDef:
assert replicas > 0, "replicas must be > 0"
assert workers_per_replica > 0, "workers_per_replica must be > 0"

env = env or {}

# Enable logging for PyTorch, torchelastic and Rust.
env.setdefault("TORCH_CPP_LOG_LEVEL", "INFO")
env.setdefault("LOGLEVEL", "INFO")
env.setdefault("RUST_BACKTRACE", "1")

# Enable colored logging for torchft Rust logger.
env.setdefault("CLICOLOR_FORCE", "1")

# Set lighthouse address for replicas
# This must be run externally
env.setdefault(
"TORCHFT_LIGHTHOUSE",
os.environ.get("TORCHFT_LIGHTHOUSE", f"http://localhost:29510"),
)

# Disable CUDA for CPU-only jobs
env.setdefault("CUDA_VISIBLE_DEVICES", "")

roles = []
for replica_id in range(replicas):
cmd = [
f"--master_port={29600+replica_id}",
"--nnodes=1",
f"--nproc_per_node={workers_per_replica}",
f"--max_restarts={max_restarts}",
]
if script:
cmd += [script]
cmd += list(script_args)

roles.append(
specs.Role(
name=f"replica_{replica_id}",
image=image,
min_replicas=workers_per_replica,
num_replicas=workers_per_replica,
resource=specs.resource(cpu=cpu, gpu=gpu, memMB=memMB, h=h),
max_retries=0,
env=env,
entrypoint="torchrun",
args=cmd,
)
)

return specs.AppDef(
name="torchft",
roles=roles,
)
2 changes: 2 additions & 0 deletions train_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torchvision
import torchvision.transforms as transforms
from torch import nn, optim
from torch.distributed.elastic.multiprocessing.errors import record
from torchdata.stateful_dataloader import StatefulDataLoader

from torchft import (
Expand All @@ -27,6 +28,7 @@
logging.basicConfig(level=logging.INFO)


@record
def main() -> None:
REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0))
NUM_REPLICA_GROUPS = int(os.environ.get("NUM_REPLICA_GROUPS", 2))
Expand Down

0 comments on commit 39a40b2

Please sign in to comment.