Skip to content

Commit

Permalink
Merge branch 'main' into objectstore-registry
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea authored Oct 23, 2024
2 parents f7a6d56 + ef42f54 commit 10db562
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pr-gpu.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: PR GPU tests
on:
pull_request:
pull_request_target:
workflow_dispatch:
# Cancel old runs when a new commit is pushed to the same branch if not on main
# or dev
Expand Down
5 changes: 5 additions & 0 deletions composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,15 @@ def __init__(
num_concurrent_uploads: int = 1,
upload_timeout_in_seconds: int = 3600,
):

backend, _, local_folder = parse_uri(str(folder))
if local_folder == '':
local_folder = '.'

is_remote_folder = backend != ''
if is_remote_folder: # If uploading to a remote path, use a temporary directory to save local checkpoints.
local_folder = os.path.join(tempfile.mkdtemp(), local_folder)

filename = str(filename)
remote_file_name = str(remote_file_name) if remote_file_name is not None else None
latest_filename = str(latest_filename) if latest_filename is not None else None
Expand Down
12 changes: 9 additions & 3 deletions composer/loggers/mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,10 @@ def _start_mlflow_run(self, state):
)
self.monitor_process.start()

def _global_exception_handler(self, exc_type, exc_value, exc_traceback):
def _global_exception_handler(self, original_excepthook, exc_type, exc_value, exc_traceback):
"""Catch global exception."""
self._global_exception_occurred += 1
sys.__excepthook__(exc_type, exc_value, exc_traceback)
original_excepthook(exc_type, exc_value, exc_traceback)

def init(self, state: State, logger: Logger) -> None:
del logger # unused
Expand All @@ -322,7 +322,13 @@ def init(self, state: State, logger: Logger) -> None:
self.run_name += f'-rank{dist.get_global_rank()}'

# Register the global exception handler so that uncaught exception is tracked.
sys.excepthook = self._global_exception_handler
original_excepthook = sys.excepthook
sys.excepthook = lambda exc_type, exc_value, exc_traceback: self._global_exception_handler(
original_excepthook,
exc_type,
exc_value,
exc_traceback,
)
# Start run
if self._enabled:
self._start_mlflow_run(state)
Expand Down
3 changes: 2 additions & 1 deletion tests/trainer/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,8 @@ def _get_tmp_dir(self):

if delete_local:
# delete files locally, forcing trainer to look in object store
shutil.rmtree('first')
assert trainer_1._checkpoint_saver is not None
shutil.rmtree(trainer_1._checkpoint_saver.folder)

trainer_2 = self.get_trainer(
latest_filename=latest_filename,
Expand Down

0 comments on commit 10db562

Please sign in to comment.