Skip to content

Commit

Permalink
update setup, add nightly build, fix broken tests (#150)
Browse files Browse the repository at this point in the history
Summary:
Please read through our [contribution guide](https://github.com/pytorch/tnt/blob/main/CONTRIBUTING.md) prior to creating your pull request.

- add necessary packages to requirements. TODO make some of them optional (tensorboard, psutil, etc).
- fix code to make it compatible with python 3.7.
- add nightly build
- update setup.py

Pull Request resolved: #150

Test Plan:
running the action on my fork: https://pypi.org/project/torchtnt-nightly/

Fixes #{issue number}

Reviewed By: daniellepintz

Differential Revision: D38336240

Pulled By: edward-io

fbshipit-source-id: 514675d7a845662c1b696c2d8a8faf9d590d156d
  • Loading branch information
edward-io authored and facebook-github-bot committed Aug 2, 2022
1 parent fe5e0e2 commit b52c028
Show file tree
Hide file tree
Showing 9 changed files with 177 additions and 22 deletions.
77 changes: 77 additions & 0 deletions .github/workflows/nightly_build_cpu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
name: Push CPU Binary Nightly

on:
# run every day at 11:15am
schedule:
- cron: '15 11 * * *'
# or manually trigger it
workflow_dispatch:
inputs:
append_to_version:
description: "Optional value to append to version string"


jobs:
unit_tests:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7, 3.8, 3.9]
steps:
- name: Check out repo
uses: actions/checkout@v2
- name: Setup conda env
uses: conda-incubator/setup-miniconda@v2
with:
miniconda-version: "latest"
activate-environment: test
python-version: ${{ matrix.python-version }}
- name: Install dependencies
shell: bash -l {0}
run: |
set -eux
conda activate test
conda install pytorch cpuonly -c pytorch-nightly
pip install -r requirements.txt
pip install -r dev-requirements.txt
python setup.py sdist bdist_wheel
pip install dist/*.whl
- name: Run unit tests
shell: bash -l {0}
run: |
set -eux
conda activate test
pytest tests -vv
# TODO figure out how to deduplicate steps
upload_to_pypi:
needs: unit_tests
runs-on: ubuntu-latest
steps:
- name: Check out repo
uses: actions/checkout@v2
- name: Setup conda env
uses: conda-incubator/setup-miniconda@v2
with:
miniconda-version: "latest"
activate-environment: test
python-version: 3.7
- name: Install dependencies
shell: bash -l {0}
run: |
set -eux
conda activate test
conda install pytorch cpuonly -c pytorch-nightly
pip install -r requirements.txt
pip install -r dev-requirements.txt
pip install -e ".[dev]"
- name: Upload to PyPI
shell: bash -l {0}
env:
PYPI_USER: ${{ secrets.PYPI_USER }}
PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}
run: |
set -eux
conda activate test
pip install twine
python setup.py --nightly --append-to-version=${{ github.event.inputs.append_to_version }} sdist bdist_wheel
twine upload --username "$PYPI_USER" --password "$PYPI_TOKEN" dist/* --verbose
2 changes: 2 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ jobs:
conda activate test
conda install pytorch cpuonly -c pytorch-nightly
pip install -r requirements.txt
pip install -r dev-requirements.txt
pip install -e .
- name: Run unit tests with coverage
shell: bash -l {0}
run: |
Expand Down
2 changes: 2 additions & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pytest
pytest-cov
6 changes: 4 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
torch
numpy
packaging
fsspec
tensorboard
psutil
typing_extensions
setuptools
visdom
102 changes: 85 additions & 17 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,87 @@
#!/usr/bin/env python
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import os
import sys

from datetime import date
from typing import List

from setuptools import find_packages, setup
from torchtnt import __version__


def current_path(file_name: str) -> str:
return os.path.abspath(os.path.join(__file__, os.path.pardir, file_name))


def read_requirements(file_name: str) -> List[str]:
with open(current_path(file_name), encoding="utf8") as f:
return [r for r in f.read().strip().split() if not r.startswith("-")]


def get_nightly_version() -> str:
return date.today().strftime("%Y.%m.%d")


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="torchtnt setup")
parser.add_argument(
"--nightly",
dest="nightly",
action="store_true",
help="enable settings for nightly package build",
default=False,
)
parser.add_argument(
"--append-to-version",
dest="append_version",
help="append string to end of version number (e.g. a1)",
)
return parser.parse_known_args()


if __name__ == "__main__":
with open(current_path("README.md"), encoding="utf8") as f:
readme: str = f.read()

custom_args, setup_args = parse_args()
package_name = "torchtnt" if not custom_args.nightly else "torchtnt-nightly"
version = __version__ if not custom_args.nightly else get_nightly_version()
if custom_args.append_version:
version = f"{version}{custom_args.append_version}"

print(f"using package_name={package_name}, version={version}")

sys.argv = [sys.argv[0]] + setup_args

VERSION = "0.0.5.1"

setup(
# Metadata
name="torchtnt",
version=VERSION,
author="PyTorch",
author_email="[email protected]",
url="https://github.com/pytorch/tnt/",
description="A lightweight library for PyTorch training tools and utilities",
license="BSD",
# Package info
packages=find_packages(exclude=("test", "docs")),
zip_safe=True,
install_requires=["torch", "six", "future", "visdom"],
)
setup(
name=package_name,
version=version,
author="PyTorch",
author_email="[email protected]",
description="A lightweight library for PyTorch training tools and utilities",
long_description=readme,
long_description_content_type="text/markdown",
url="https://github.com/pytorch/tnt",
license="BSD-3",
keywords=["pytorch", "torch", "training", "tools", "utilities"],
python_requires=">=3.7",
install_requires=read_requirements("requirements.txt"),
packages=find_packages(),
zip_safe=True,
classifiers=[
"Development Status :: 2 - Pre-Alpha",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: BSD License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
extras_require={"dev": read_requirements("dev-requirements.txt")},
)
1 change: 1 addition & 0 deletions torchtnt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = "0.0.1"
3 changes: 2 additions & 1 deletion torchtnt/loggers/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict, Protocol, Union
from typing import Dict, Union

from numpy import ndarray
from torch import Tensor
from typing_extensions import Protocol

Scalar = Union[Tensor, ndarray, int, float]

Expand Down
3 changes: 2 additions & 1 deletion torchtnt/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
import subprocess
from collections import defaultdict
from dataclasses import fields, is_dataclass
from typing import Any, Mapping, Protocol, runtime_checkable, TypedDict, TypeVar
from typing import Any, Mapping, TypeVar

import torch
from torchtnt.utils.version import is_torch_version_geq_1_12
from typing_extensions import Protocol, runtime_checkable, TypedDict


def get_device_from_env() -> torch.device:
Expand Down
3 changes: 2 additions & 1 deletion torchtnt/utils/early_stop_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
# LICENSE file in the root directory of this source tree.

import logging
from typing import Any, Callable, Dict, final, Literal, Optional, Union
from typing import Any, Callable, Dict, Optional, Union

import torch
from typing_extensions import final, Literal

_log: logging.Logger = logging.getLogger(__name__)

Expand Down

0 comments on commit b52c028

Please sign in to comment.