Skip to content

Commit

Permalink
feat: add torch-specific test as requested
Browse files Browse the repository at this point in the history
  • Loading branch information
tdejager committed Nov 21, 2024
1 parent 4980718 commit f91bdc8
Showing 1 changed file with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions tests/integration_rust/pypi_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ async fn test_index_strategy() {

#[tokio::test]
#[cfg_attr(not(feature = "slow_integration_tests"), ignore)]
/// This test checks if we can pin a package from a PyPI index, by explicitly specifying the index.
async fn test_pinning_index() {
let pypi_indexes = Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/data/pypi-indexes");
let pypi_indexes_url = Url::from_directory_path(pypi_indexes.clone()).unwrap();
Expand Down Expand Up @@ -201,3 +202,43 @@ async fn test_pinning_index() {
.join("foo-1.0.0-py2.py3-none-any.whl")
);
}

#[tokio::test]
#[cfg_attr(not(feature = "slow_integration_tests"), ignore)]
/// This test checks if we can receive torch correctly from the whl/cu124 index.
async fn pin_torch() {
/// Do some platform magic, as the index does not contain wheels for each platform.

Check failure on line 210 in tests/integration_rust/pypi_tests.rs

View workflow job for this annotation

GitHub Actions / Cargo Lint

unused doc comment
let platform = Platform::current();
let platforms = match platform {
Platform::Linux64 => "\"linux-64\"".to_string(),
_ => format!("\"{platform}\", \"linux-64\"", platform = platform),
};

let pixi = PixiControl::from_manifest(&format!(
r#"
[project]
name = "pypi-pinning-index"
platforms = [{platforms}]
channels = ["conda-forge"]
[dependencies]
python = "~=3.12.0"
[target.linux-64.pypi-dependencies]
torch = {{ version = "*", index = "https://download.pytorch.org/whl/cu124" }}
"#,
platforms = platforms,
));

let lock_file = pixi.unwrap().update_lock_file().await.unwrap();
// So the check is as follows:
// 1. The PyPI index is the main index-url, so normally torch would be taken from there.
// 2. We manually check if it is taken from the whl/cu124 index instead.
assert!(lock_file
.get_pypi_package_url("default", Platform::Linux64, "torch")
.unwrap()
.as_url()
.unwrap()
.path()
.contains("/whl/cu124"));
}

0 comments on commit f91bdc8

Please sign in to comment.