From 1650d6a9ae4478ac7c312923579ecdba7de1824a Mon Sep 17 00:00:00 2001 From: nichmor Date: Fri, 22 Nov 2024 16:45:53 +0200 Subject: [PATCH] feat: add pypi index (#2416) Co-authored-by: Tim de Jager Co-authored-by: Tim de Jager Co-authored-by: Ruben Arts --- .../src/pypi/pypi_requirement.rs | 53 +++++++++++-- ...quirement__tests__deserialize_failing.snap | 2 +- ...rement__tests__deserialize_succeeding.snap | 1 + .../pixi_uv_conversions/src/requirements.rs | 4 +- docs/reference/project_configuration.md | 13 +++ schema/examples/valid/full.toml | 1 + schema/model.py | 4 + schema/schema.json | 6 ++ src/cli/project/export/conda_environment.rs | 4 +- tests/integration_rust/add_tests.rs | 1 + tests/integration_rust/pypi_tests.rs | 79 +++++++++++++++++++ 11 files changed, 156 insertions(+), 12 deletions(-) diff --git a/crates/pixi_manifest/src/pypi/pypi_requirement.rs b/crates/pixi_manifest/src/pypi/pypi_requirement.rs index abf069e2d..3e621450a 100644 --- a/crates/pixi_manifest/src/pypi/pypi_requirement.rs +++ b/crates/pixi_manifest/src/pypi/pypi_requirement.rs @@ -85,6 +85,8 @@ pub enum PyPiRequirement { version: VersionOrStar, #[serde(default)] extras: Vec, + #[serde(default)] + index: Option, }, RawVersion(VersionOrStar), } @@ -138,6 +140,10 @@ struct RawPyPiRequirement { // Git and Url only pub subdirectory: Option, + + // Pinned index + #[serde(default)] + pub index: Option, } impl<'de> Deserialize<'de> for PyPiRequirement { @@ -186,18 +192,24 @@ impl<'de> Deserialize<'de> for PyPiRequirement { ))); } - let req = match (raw_req.url, raw_req.path, raw_req.git, raw_req.extras) { - (Some(url), None, None, extras) => PyPiRequirement::Url { + let req = match ( + raw_req.url, + raw_req.path, + raw_req.git, + raw_req.extras, + raw_req.index, + ) { + (Some(url), None, None, extras, None) => PyPiRequirement::Url { url, extras, subdirectory: raw_req.subdirectory, }, - (None, Some(path), None, extras) => PyPiRequirement::Path { + (None, Some(path), None, extras, None) => PyPiRequirement::Path { path, editable: raw_req.editable, extras, }, - (None, None, Some(git), extras) => PyPiRequirement::Git { + (None, None, Some(git), extras, None) => PyPiRequirement::Git { url: ParsedGitUrl { git, branch: raw_req.branch, @@ -207,13 +219,15 @@ impl<'de> Deserialize<'de> for PyPiRequirement { }, extras, }, - (None, None, None, extras) => PyPiRequirement::Version { + (None, None, None, extras, index) => PyPiRequirement::Version { version: raw_req.version.unwrap_or(VersionOrStar::Star), extras, + index, }, - (_, _, _, extras) if !extras.is_empty() => PyPiRequirement::Version { + (_, _, _, extras, index) if !extras.is_empty() => PyPiRequirement::Version { version: raw_req.version.unwrap_or(VersionOrStar::Star), extras, + index, }, _ => { return Err(serde_untagged::de::Error::custom( @@ -278,17 +292,35 @@ impl From for toml_edit::Value { } } + fn insert_index(table: &mut toml_edit::InlineTable, index: &Option) { + if let Some(index) = index { + table.insert( + "index", + toml_edit::Value::String(toml_edit::Formatted::new(index.to_string())), + ); + } + } + match &val { - PyPiRequirement::Version { version, extras } if extras.is_empty() => { + PyPiRequirement::Version { + version, + extras, + index, + } if extras.is_empty() && index.is_none() => { toml_edit::Value::from(version.to_string()) } - PyPiRequirement::Version { version, extras } => { + PyPiRequirement::Version { + version, + extras, + index, + } => { let mut table = toml_edit::Table::new().into_inline_table(); table.insert( "version", toml_edit::Value::String(toml_edit::Formatted::new(version.to_string())), ); insert_extras(&mut table, extras); + insert_index(&mut table, index); toml_edit::Value::InlineTable(table.to_owned()) } PyPiRequirement::Git { @@ -423,6 +455,7 @@ impl TryFrom for PyPiRequirement { pep508_rs::VersionOrUrl::VersionSpecifier(v) => PyPiRequirement::Version { version: v.into(), extras: req.extras, + index: None, }, pep508_rs::VersionOrUrl::Url(u) => { let url = u.to_url(); @@ -494,6 +527,7 @@ impl TryFrom for PyPiRequirement { PyPiRequirement::Version { version: VersionOrStar::Star, extras: req.extras, + index: None, } } else { PyPiRequirement::RawVersion(VersionOrStar::Star) @@ -616,6 +650,7 @@ mod tests { &PyPiRequirement::Version { version: ">=3.12".parse().unwrap(), extras: vec![ExtraName::from_str("bar").unwrap()], + index: None, } ); @@ -636,6 +671,7 @@ mod tests { ExtraName::from_str("bar").unwrap(), ExtraName::from_str("foo").unwrap(), ], + index: None, } ); } @@ -659,6 +695,7 @@ mod tests { ExtraName::from_str("feature1").unwrap(), ExtraName::from_str("feature2").unwrap() ], + index: None, } ); } diff --git a/crates/pixi_manifest/src/pypi/snapshots/pixi_manifest__pypi__pypi_requirement__tests__deserialize_failing.snap b/crates/pixi_manifest/src/pypi/snapshots/pixi_manifest__pypi__pypi_requirement__tests__deserialize_failing.snap index 139b6c4d9..fff924ca7 100644 --- a/crates/pixi_manifest/src/pypi/snapshots/pixi_manifest__pypi__pypi_requirement__tests__deserialize_failing.snap +++ b/crates/pixi_manifest/src/pypi/snapshots/pixi_manifest__pypi__pypi_requirement__tests__deserialize_failing.snap @@ -5,7 +5,7 @@ expression: snapshot - input: ver: 1.2.3 result: - error: "ERROR: unknown field `ver`, expected one of `version`, `extras`, `path`, `editable`, `git`, `branch`, `tag`, `rev`, `url`, `subdirectory`" + error: "ERROR: unknown field `ver`, expected one of `version`, `extras`, `path`, `editable`, `git`, `branch`, `tag`, `rev`, `url`, `subdirectory`, `index`" - input: path: foobar version: "==1.2.3" diff --git a/crates/pixi_manifest/src/pypi/snapshots/pixi_manifest__pypi__pypi_requirement__tests__deserialize_succeeding.snap b/crates/pixi_manifest/src/pypi/snapshots/pixi_manifest__pypi__pypi_requirement__tests__deserialize_succeeding.snap index 431b7e49d..00e440f7b 100644 --- a/crates/pixi_manifest/src/pypi/snapshots/pixi_manifest__pypi__pypi_requirement__tests__deserialize_succeeding.snap +++ b/crates/pixi_manifest/src/pypi/snapshots/pixi_manifest__pypi__pypi_requirement__tests__deserialize_succeeding.snap @@ -9,6 +9,7 @@ expression: snapshot result: version: "==1.2.3" extras: [] + index: ~ - input: "*" result: "*" - input: diff --git a/crates/pixi_uv_conversions/src/requirements.rs b/crates/pixi_uv_conversions/src/requirements.rs index 1e9977868..f2ba973a6 100644 --- a/crates/pixi_uv_conversions/src/requirements.rs +++ b/crates/pixi_uv_conversions/src/requirements.rs @@ -82,11 +82,11 @@ pub fn as_uv_req( ) -> Result { let name = PackageName::new(name.to_owned())?; let source = match req { - PyPiRequirement::Version { version, .. } => { + PyPiRequirement::Version { version, index, .. } => { // TODO: implement index later RequirementSource::Registry { specifier: to_version_specificers(version)?, - index: None, + index: index.clone(), } } PyPiRequirement::Git { diff --git a/docs/reference/project_configuration.md b/docs/reference/project_configuration.md index 9897e4bec..4c6aa17a3 100644 --- a/docs/reference/project_configuration.md +++ b/docs/reference/project_configuration.md @@ -469,6 +469,19 @@ ruff = "~=1.0.0" pytest = {version = "*", extras = ["dev"]} ``` +##### `index` + +The index parameter allows you to specify the URL of a custom package index for the installation of a specific package. +This feature is useful when you want to ensure that a package is retrieved from a particular source, rather than from the default index. + +For example, to use some other than the official Python Package Index (PyPI) at https://pypi.org/simple, you can use the `index` parameter: + +```toml +torch = { version = "*", index = "https://download.pytorch.org/whl/cu118" } +``` + +This is useful for PyTorch specifically, as the registries are pinned to different CUDA versions. + ##### `git` A git repository to install from. diff --git a/schema/examples/valid/full.toml b/schema/examples/valid/full.toml index 6ad5e0c4c..6bdf2ac50 100644 --- a/schema/examples/valid/full.toml +++ b/schema/examples/valid/full.toml @@ -49,6 +49,7 @@ requests = { version = ">= 2.8.1, ==2.8.*", extras = [ "security", "tests", ] } # Using the map allows the user to add `extras` +test-pinning-index = { version = "*", index = "https://example.com/test" } testpypi = "*" testpypi1 = "*" diff --git a/schema/model.py b/schema/model.py index 5d5dc9586..928f621f9 100644 --- a/schema/model.py +++ b/schema/model.py @@ -248,6 +248,10 @@ class PyPIVersion(_PyPIRequirement): None, description="The version of the package in [PEP 440](https://www.python.org/dev/peps/pep-0440/) format", ) + index: NonEmptyStr | None = Field( + None, + description="The index to fetch the package from", + ) PyPIRequirement = ( diff --git a/schema/schema.json b/schema/schema.json index 5016438c2..4ca4e331c 100644 --- a/schema/schema.json +++ b/schema/schema.json @@ -1081,6 +1081,12 @@ "minLength": 1 } }, + "index": { + "title": "Index", + "description": "The index to fetch the package from", + "type": "string", + "minLength": 1 + }, "version": { "title": "Version", "description": "The version of the package in [PEP 440](https://www.python.org/dev/peps/pep-0440/) format", diff --git a/src/cli/project/export/conda_environment.rs b/src/cli/project/export/conda_environment.rs index a4103b650..442a57e38 100644 --- a/src/cli/project/export/conda_environment.rs +++ b/src/cli/project/export/conda_environment.rs @@ -106,7 +106,9 @@ fn format_pip_dependency(name: &PyPiPackageName, requirement: &PyPiRequirement) url_string } - PyPiRequirement::Version { version, extras } => { + PyPiRequirement::Version { + version, extras, .. + } => { format!( "{name}{extras}{version}", name = name.as_normalized(), diff --git a/tests/integration_rust/add_tests.rs b/tests/integration_rust/add_tests.rs index 618e043c1..2be8ee37f 100644 --- a/tests/integration_rust/add_tests.rs +++ b/tests/integration_rust/add_tests.rs @@ -440,6 +440,7 @@ async fn add_pypi_extra_functionality() { PyPiRequirement::Version { version: VersionOrStar::from_str("==24.8.0").unwrap(), extras: vec![pep508_rs::ExtraName::from_str("cli").unwrap()], + index: None } ); } diff --git a/tests/integration_rust/pypi_tests.rs b/tests/integration_rust/pypi_tests.rs index 8be455080..e4f923c6a 100644 --- a/tests/integration_rust/pypi_tests.rs +++ b/tests/integration_rust/pypi_tests.rs @@ -163,3 +163,82 @@ async fn test_index_strategy() { Some("3.0.0".into()) ); } + +#[tokio::test] +#[cfg_attr(not(feature = "slow_integration_tests"), ignore)] +/// This test checks if we can pin a package from a PyPI index, by explicitly specifying the index. +async fn test_pinning_index() { + let pypi_indexes = Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/data/pypi-indexes"); + let pypi_indexes_url = Url::from_directory_path(pypi_indexes.clone()).unwrap(); + + let pixi = PixiControl::from_manifest(&format!( + r#" + [project] + name = "pypi-pinning-index" + platforms = ["{platform}"] + channels = ["conda-forge"] + + [dependencies] + python = "~=3.12.0" + + [pypi-dependencies] + foo = {{ version = "*", index = "{pypi_indexes}multiple-indexes-a/index" }} + + "#, + platform = Platform::current(), + pypi_indexes = pypi_indexes_url, + )); + + let lock_file = pixi.unwrap().update_lock_file().await.unwrap(); + + assert_eq!( + lock_file + .get_pypi_package_url("default", Platform::current(), "foo") + .unwrap() + .as_path() + .unwrap(), + pypi_indexes + .join("multiple-indexes-a/index/foo") + .join("foo-1.0.0-py2.py3-none-any.whl") + ); +} + +#[tokio::test] +#[cfg_attr(not(feature = "slow_integration_tests"), ignore)] +/// This test checks if we can receive torch correctly from the whl/cu124 index. +async fn pin_torch() { + // Do some platform magic, as the index does not contain wheels for each platform. + let platform = Platform::current(); + let platforms = match platform { + Platform::Linux64 => "\"linux-64\"".to_string(), + _ => format!("\"{platform}\", \"linux-64\"", platform = platform), + }; + + let pixi = PixiControl::from_manifest(&format!( + r#" + [project] + name = "pypi-pinning-index" + platforms = [{platforms}] + channels = ["conda-forge"] + + [dependencies] + python = "~=3.12.0" + + [target.linux-64.pypi-dependencies] + torch = {{ version = "*", index = "https://download.pytorch.org/whl/cu124" }} + "#, + platforms = platforms, + )); + + let lock_file = pixi.unwrap().update_lock_file().await.unwrap(); + // So the check is as follows: + // 1. The PyPI index is the main index-url, so normally torch would be taken from there. + // 2. We manually check if it is taken from the whl/cu124 index instead. + assert!(lock_file + .get_pypi_package_url("default", Platform::Linux64, "torch") + .unwrap() + .as_url() + .unwrap() + .path() + .contains("/whl/cu124")); +}