Skip to content

Commit

Permalink
Add PyTorch 2.5 support (#360)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Nov 12, 2024
1 parent d959724 commit 06a3d5d
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .github/actions/setup/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ inputs:
default: '3.9'
torch-version:
required: false
default: '2.4.0'
default: '2.5.0'
cuda-version:
required: false
default: cpu
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/aws/upload_nightly_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
wheels_dict[torch_version.replace('2.2.0', '2.2.2')].append(wheel)
if '2.3.0' in torch_version:
wheels_dict[torch_version.replace('2.3.0', '2.3.1')].append(wheel)
if '2.4.0' in torch_version:
wheels_dict[torch_version.replace('2.4.0', '2.4.1')].append(wheel)
if '2.5.0' in torch_version:
wheels_dict[torch_version.replace('2.5.0', '2.5.1')].append(wheel)

index_html = html.format('\n'.join([
href.format(f'{version}.html'.replace('+', '%2B'), version)
Expand Down
10 changes: 8 additions & 2 deletions .github/workflows/building.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ jobs:
matrix:
os: [ubuntu-20.04, macos-14, windows-2019]
python-version: ['3.9', '3.10', '3.11', '3.12']
# torch-version: [1.12.0, 1.13.0, 2.0.0, 2.1.0, 2.2.0, 2.3.0, 2.4.0]
torch-version: [2.4.0]
# torch-version: [1.12.0, 1.13.0, 2.0.0, 2.1.0, 2.2.0, 2.3.0, 2.4.0, 2.5.0]
torch-version: [2.5.0]
cuda-version: ['cpu', 'cu113', 'cu116', 'cu117', 'cu118', 'cu121', 'cu124']
exclude:
- torch-version: 1.12.0
Expand Down Expand Up @@ -80,6 +80,12 @@ jobs:
cuda-version: 'cu116'
- torch-version: 2.4.0
cuda-version: 'cu117'
- torch-version: 2.5.0
cuda-version: 'cu113'
- torch-version: 2.5.0
cuda-version: 'cu116'
- torch-version: 2.5.0
cuda-version: 'cu117'
- os: macos-14
cuda-version: 'cu113'
- os: macos-14
Expand Down
8 changes: 7 additions & 1 deletion .github/workflows/nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
matrix:
os: [ubuntu-20.04, macos-14, windows-2019]
python-version: ['3.9', '3.10', '3.11', '3.12']
torch-version: [1.12.0, 1.13.0, 2.0.0, 2.1.0, 2.2.0, 2.3.0, 2.4.0]
torch-version: [1.12.0, 1.13.0, 2.0.0, 2.1.0, 2.2.0, 2.3.0, 2.4.0, 2.5.0]
cuda-version: ['cpu', 'cu113', 'cu116', 'cu117', 'cu118', 'cu121', 'cu124']
exclude:
- torch-version: 1.12.0
Expand Down Expand Up @@ -83,6 +83,12 @@ jobs:
cuda-version: 'cu116'
- torch-version: 2.4.0
cuda-version: 'cu117'
- torch-version: 2.5.0
cuda-version: 'cu113'
- torch-version: 2.5.0
cuda-version: 'cu116'
- torch-version: 2.5.0
cuda-version: 'cu117'
- os: macos-14
cuda-version: 'cu113'
- os: macos-14
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [0.5.0] - 2023-MM-DD
### Added
- Added PyTorch 2.5 support ([#360](https://github.com/pyg-team/pyg-lib/pull/338))
- Added PyTorch 2.4 support ([#338](https://github.com/pyg-team/pyg-lib/pull/338))
- Added PyTorch 2.3 support ([#322](https://github.com/pyg-team/pyg-lib/pull/322))
- Added Windows support ([#315](https://github.com/pyg-team/pyg-lib/pull/315))
Expand Down
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,17 @@ pip install pyg-lib -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html

where

* `${TORCH}` should be replaced by either `1.12.0`, `1.13.0`, `2.0.0`, `2.1.0`, `2.2.0`, `2.3.0`, or `2.4.0`
* `${CUDA}` should be replaced by either `cpu`, `cu102`, `cu113`, `cu116`, `cu117`, `cu118`, or `cu121`
* `${TORCH}` should be replaced by either `1.12.0`, `1.13.0`, `2.0.0`, `2.1.0`, `2.2.0`, `2.3.0`, `2.4.0` or `2.5.0`
* `${CUDA}` should be replaced by either `cpu`, `cu102`, `cu113`, `cu116`, `cu117`, `cu118`, `cu121`, or `cu124`

The following combinations are supported:

| PyTorch 2.5 | `cpu` | `cu113` | `cu116` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|---------|---------|
| **Linux** || | | ||||
| **Windows** || | | ||||
| **macOS** || | | | | | |

| PyTorch 2.4 | `cpu` | `cu113` | `cu116` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|---------|---------|
| **Linux** || | | ||||
Expand Down
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp39-cp39-linux_x86_64.whl
https://download.pytorch.org/whl/cpu/torch-2.5.0%2Bcpu-cp39-cp39-linux_x86_64.whl
git+https://github.com/pyg-team/pyg_sphinx_theme.git
25 changes: 13 additions & 12 deletions pyg_lib/csrc/random/cpu/rand_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,19 +134,20 @@ class RandintEngine {
for (auto& val : dst)
val = static_cast<int>((*this)(beg, end));
};
#if WITH_MKL_BLAS()
const bool use_fallback_func = count > std::numeric_limits<MKL_INT>::max();
if (use_fallback_func) {
fallback_func(beg, end, result);
} else {
const auto b = static_cast<int>(beg);
const auto e = static_cast<int>(end);
const auto c = static_cast<MKL_INT>(count);
viRngUniform(VSL_RNG_METHOD_UNIFORM_STD, stream_, c, result.data(), b, e);
}
#else
// #if WITH_MKL_BLAS()
// const bool use_fallback_func = count >
// std::numeric_limits<MKL_INT>::max(); if (use_fallback_func) {
// fallback_func(beg, end, result);
// } else {
// const auto b = static_cast<int>(beg);
// const auto e = static_cast<int>(end);
// const auto c = static_cast<MKL_INT>(count);
// viRngUniform(VSL_RNG_METHOD_UNIFORM_STD, stream_, c, result.data(),
// b, e);
// }
// #else
fallback_func(beg, end, result);
#endif
// #endif
return result;
}

Expand Down

0 comments on commit 06a3d5d

Please sign in to comment.