Skip to content

Commit

Permalink
feat: expand ResolveOptions and adds ability pass WheelBuilder direct…
Browse files Browse the repository at this point in the history
…ly into resolve(#218)

This makes sure that the WheelBuilder is actually shared for potential resolve and install code.
  • Loading branch information
nichmor authored Feb 16, 2024
1 parent a6ab742 commit 8c301ba
Show file tree
Hide file tree
Showing 14 changed files with 235 additions and 256 deletions.
168 changes: 41 additions & 127 deletions crates/rattler_installs_packages/src/artifacts/sdist.rs

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions crates/rattler_installs_packages/src/index/direct_url/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use url::Url;
pub(crate) async fn get_sdist_from_file_path(
normalized_package_name: &NormalizedPackageName,
path: &PathBuf,
wheel_builder: &WheelBuilder,
wheel_builder: &Arc<WheelBuilder>,
) -> miette::Result<((Vec<u8>, WheelCoreMetadata), SDist)> {
let distribution = PackageName::from(normalized_package_name.clone());

Expand Down Expand Up @@ -71,7 +71,7 @@ pub(crate) async fn get_stree_from_file_path(
normalized_package_name: &NormalizedPackageName,
url: Url,
path: Option<PathBuf>,
wheel_builder: &WheelBuilder,
wheel_builder: &Arc<WheelBuilder>,
) -> miette::Result<((Vec<u8>, WheelCoreMetadata), STree)> {
let distribution = PackageName::from(normalized_package_name.clone());
let path = match path {
Expand Down Expand Up @@ -113,7 +113,7 @@ pub(crate) async fn get_stree_from_file_path(
pub(crate) async fn get_artifacts_and_metadata<P: Into<NormalizedPackageName>>(
p: P,
url: Url,
wheel_builder: &WheelBuilder,
wheel_builder: &Arc<WheelBuilder>,
) -> miette::Result<DirectUrlArtifactResponse> {
let path = if let Ok(path) = url.to_file_path() {
path
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use url::Url;
pub(crate) async fn get_artifacts_and_metadata<P: Into<NormalizedPackageName>>(
p: P,
url: Url,
wheel_builder: &WheelBuilder,
wheel_builder: &Arc<WheelBuilder>,
) -> miette::Result<DirectUrlArtifactResponse> {
let normalized_package_name = p.into();

Expand Down
4 changes: 2 additions & 2 deletions crates/rattler_installs_packages/src/index/direct_url/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub(crate) async fn get_artifacts_and_metadata<P: Into<NormalizedPackageName>>(
http: &Http,
p: P,
url: Url,
wheel_builder: &WheelBuilder,
wheel_builder: &Arc<WheelBuilder>,
) -> miette::Result<crate::index::package_database::DirectUrlArtifactResponse> {
let str_name = url.path();
let url_hash = url.fragment().and_then(parse_hash);
Expand Down Expand Up @@ -122,7 +122,7 @@ async fn get_sdist_from_bytes(
normalized_package_name: &NormalizedPackageName,
url: Url,
bytes: Box<dyn ReadAndSeek + Send>,
wheel_builder: &WheelBuilder,
wheel_builder: &Arc<WheelBuilder>,
) -> miette::Result<((Vec<u8>, WheelCoreMetadata), SDist)> {
// it's probably an sdist
let distribution = PackageName::from(normalized_package_name.clone());
Expand Down
4 changes: 3 additions & 1 deletion crates/rattler_installs_packages/src/index/direct_url/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use crate::index::http::Http;
use crate::index::package_database::DirectUrlArtifactResponse;
use crate::types::NormalizedPackageName;
Expand All @@ -13,7 +15,7 @@ pub(crate) async fn fetch_artifact_and_metadata_by_direct_url<P: Into<Normalized
http: &Http,
p: P,
url: Url,
wheel_builder: &WheelBuilder,
wheel_builder: &Arc<WheelBuilder>,
) -> miette::Result<DirectUrlArtifactResponse> {
let p = p.into();

Expand Down
17 changes: 8 additions & 9 deletions crates/rattler_installs_packages/src/index/package_database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ use std::borrow::Borrow;
use std::path::PathBuf;

use itertools::Itertools;
use std::ops::Deref;
use std::sync::Arc;
use std::{fmt::Display, io::Read, path::Path};

Expand Down Expand Up @@ -174,7 +173,7 @@ impl PackageDb {
url,
wheel_builder,
} => {
self.get_artifact_by_direct_url(name, url, wheel_builder.deref())
self.get_artifact_by_direct_url(name, url, &wheel_builder)
.await
}
}
Expand All @@ -185,7 +184,7 @@ impl PackageDb {
pub async fn get_metadata<'a, A: Borrow<ArtifactInfo>>(
&self,
artifacts: &'a [A],
wheel_builder: Option<&WheelBuilder>,
wheel_builder: Option<&Arc<WheelBuilder>>,
) -> miette::Result<Option<(&'a A, WheelCoreMetadata)>> {
// Check if we already have information about any of the artifacts cached.
// Return if we do
Expand Down Expand Up @@ -239,7 +238,7 @@ impl PackageDb {
pub async fn get_wheel(
&self,
artifact_info: &ArtifactInfo,
builder: Option<&'async_recursion WheelBuilder>,
builder: Option<Arc<WheelBuilder>>,
) -> miette::Result<(Wheel, Option<DirectUrlJson>)> {
// TODO: add support for this currently there are not saved
if artifact_info.is_direct_url {
Expand All @@ -248,7 +247,7 @@ impl PackageDb {
&self.http,
artifact_info.filename.distribution_name(),
artifact_info.url.clone(),
builder,
&builder,
)
.await?;

Expand Down Expand Up @@ -316,7 +315,7 @@ impl PackageDb {
&self,
p: P,
url: Url,
wheel_builder: &WheelBuilder,
wheel_builder: &Arc<WheelBuilder>,
) -> miette::Result<&IndexMap<PypiVersion, Vec<Arc<ArtifactInfo>>>> {
let p = p.into();

Expand Down Expand Up @@ -428,7 +427,7 @@ impl PackageDb {
async fn get_metadata_wheels<'a, A: Borrow<ArtifactInfo>>(
&self,
artifacts: &'a [A],
wheel_builder: Option<&WheelBuilder>,
wheel_builder: Option<&Arc<WheelBuilder>>,
) -> miette::Result<Option<(&'a A, WheelCoreMetadata)>> {
let wheels = artifacts
.iter()
Expand Down Expand Up @@ -495,7 +494,7 @@ impl PackageDb {
async fn get_metadata_sdists<'a, A: Borrow<ArtifactInfo>>(
&self,
artifacts: &'a [A],
wheel_builder: &WheelBuilder,
wheel_builder: &Arc<WheelBuilder>,
) -> miette::Result<Option<(&'a A, WheelCoreMetadata)>> {
let sdists = artifacts
.iter()
Expand Down Expand Up @@ -551,7 +550,7 @@ impl PackageDb {
async fn get_metadata_stree<'a, A: Borrow<ArtifactInfo>>(
&self,
artifacts: &'a [A],
wheel_builder: &WheelBuilder,
wheel_builder: &Arc<WheelBuilder>,
) -> miette::Result<Option<(&'a A, WheelCoreMetadata)>> {
let stree = artifacts
.iter()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::{
pypi_version_types::PypiPackageName,
solve_options::{PreReleaseResolution, ResolveOptions, SDistResolution},
PinnedPackage, PypiVersion, PypiVersionSet,
PypiVersion, PypiVersionSet,
};
use crate::{
artifacts::{SDist, Wheel},
Expand All @@ -14,17 +14,15 @@ use crate::{
};
use elsa::FrozenMap;
use itertools::Itertools;
use miette::{Diagnostic, IntoDiagnostic, MietteDiagnostic};
use miette::{Diagnostic, MietteDiagnostic};
use parking_lot::Mutex;
use pep440_rs::{Operator, VersionSpecifier, VersionSpecifiers};
use pep508_rs::{MarkerEnvironment, Requirement, VersionOrUrl};
use resolvo::{
Candidates, Dependencies, DependencyProvider, KnownDependencies, NameId, Pool, SolvableId,
SolverCache,
};
use std::{
any::Any, borrow::Borrow, cmp::Ordering, collections::HashMap, rc::Rc, str::FromStr, sync::Arc,
};
use std::{any::Any, borrow::Borrow, cmp::Ordering, rc::Rc, str::FromStr, sync::Arc};
use thiserror::Error;
use url::Url;

Expand All @@ -38,48 +36,29 @@ pub(crate) struct PypiDependencyProvider {
markers: Arc<MarkerEnvironment>,
compatible_tags: Option<Arc<WheelTags>>,

favored_packages: HashMap<NormalizedPackageName, PinnedPackage>,
locked_packages: HashMap<NormalizedPackageName, PinnedPackage>,

options: ResolveOptions,
should_cancel_with_value: Mutex<Option<MetadataError>>,
}

impl PypiDependencyProvider {
/// Creates a new PypiDependencyProvider
/// for use with the [`resolvo`] crate
#[allow(clippy::too_many_arguments)]
pub fn new(
pool: Pool<PypiVersionSet, PypiPackageName>,
package_db: Arc<PackageDb>,
markers: Arc<MarkerEnvironment>,
compatible_tags: Option<Arc<WheelTags>>,
locked_packages: HashMap<NormalizedPackageName, PinnedPackage>,
favored_packages: HashMap<NormalizedPackageName, PinnedPackage>,
name_to_url: FrozenMap<NormalizedPackageName, String>,
wheel_builder: Arc<WheelBuilder>,
options: ResolveOptions,
env_variables: HashMap<String, String>,
) -> miette::Result<Self> {
let wheel_builder = Arc::new(
WheelBuilder::new(
package_db.clone(),
markers.clone(),
compatible_tags.clone(),
options.clone(),
env_variables,
)
.into_diagnostic()?,
);

Ok(Self {
pool: Rc::new(pool),
package_db,
wheel_builder,
markers,
compatible_tags,
cached_artifacts: Default::default(),
favored_packages,
locked_packages,
name_to_url,
options,
should_cancel_with_value: Default::default(),
Expand Down Expand Up @@ -338,8 +317,8 @@ impl<'p> DependencyProvider<PypiVersionSet, PypiPackageName> for &'p PypiDepende
}
};
let mut candidates = Candidates::default();
let locked_package = self.locked_packages.get(package_name.base());
let favored_package = self.favored_packages.get(package_name.base());
let locked_package = self.options.locked_packages.get(package_name.base());
let favored_package = self.options.favored_packages.get(package_name.base());

let should_package_allow_prerelease = match &self.options.pre_release_resolution {
PreReleaseResolution::Disallow => false,
Expand Down Expand Up @@ -406,7 +385,7 @@ impl<'p> DependencyProvider<PypiVersionSet, PypiPackageName> for &'p PypiDepende
}

// Add a locked dependency
if let Some(locked) = self.locked_packages.get(package_name.base()) {
if let Some(locked) = self.options.locked_packages.get(package_name.base()) {
let version = if let Some(url) = &locked.url {
PypiVersion::Url(url.clone())
} else {
Expand All @@ -423,7 +402,7 @@ impl<'p> DependencyProvider<PypiVersionSet, PypiPackageName> for &'p PypiDepende
}

// Add a favored dependency
if let Some(favored) = self.favored_packages.get(package_name.base()) {
if let Some(favored) = self.options.favored_packages.get(package_name.base()) {
let version = if let Some(url) = &favored.url {
PypiVersion::Url(url.clone())
} else {
Expand Down Expand Up @@ -490,7 +469,7 @@ impl<'p> DependencyProvider<PypiVersionSet, PypiPackageName> for &'p PypiDepende
// TODO: rework this so it makes more sense from an API perspective later, I think we should add the concept of installed_and_locked or something
// It is locked the package data may be available externally
// So it's fine if there are no artifacts, we can just assume this has been taken care of
let locked_package = self.locked_packages.get(package_name.base());
let locked_package = self.options.locked_packages.get(package_name.base());
match package_version {
PypiVersion::Url(url) => {
if locked_package.map(|p| &p.url) == Some(&Some(url.clone())) {
Expand Down
20 changes: 5 additions & 15 deletions crates/rattler_installs_packages/src/resolve/solve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::python_env::WheelTags;
use crate::resolve::dependency_provider::PypiDependencyProvider;
use crate::resolve::pypi_version_types::PypiVersion;
use crate::types::PackageName;
use crate::wheel_builder::WheelBuilder;
use crate::{types::ArtifactInfo, types::Extra, types::NormalizedPackageName};
use elsa::FrozenMap;
use pep440_rs::Version;
Expand Down Expand Up @@ -51,17 +52,13 @@ pub struct PinnedPackage {
/// If `compatible_tags` is defined then the available artifacts of a distribution are filtered to
/// include only artifacts that are compatible with the specified tags. If `None` is passed, the
/// artifacts are not filtered at all
// TODO: refactor this into an input type of sorts later
#[allow(clippy::too_many_arguments)]
pub async fn resolve(
package_db: Arc<PackageDb>,
requirements: impl IntoIterator<Item = &Requirement>,
env_markers: Arc<MarkerEnvironment>,
compatible_tags: Option<Arc<WheelTags>>,
locked_packages: HashMap<NormalizedPackageName, PinnedPackage>,
favored_packages: HashMap<NormalizedPackageName, PinnedPackage>,
wheel_builder: Arc<WheelBuilder>,
options: ResolveOptions,
env_variables: HashMap<String, String>,
) -> miette::Result<Vec<PinnedPackage>> {
let requirements: Vec<_> = requirements.into_iter().cloned().collect();
tokio::task::spawn_blocking(move || {
Expand All @@ -70,10 +67,8 @@ pub async fn resolve(
&requirements,
env_markers,
compatible_tags,
locked_packages,
favored_packages,
wheel_builder,
options,
env_variables,
)
})
.await
Expand All @@ -86,16 +81,13 @@ pub async fn resolve(
)
}

#[allow(clippy::too_many_arguments)]
fn resolve_inner<'r>(
package_db: Arc<PackageDb>,
requirements: impl IntoIterator<Item = &'r Requirement>,
env_markers: Arc<MarkerEnvironment>,
compatible_tags: Option<Arc<WheelTags>>,
locked_packages: HashMap<NormalizedPackageName, PinnedPackage>,
favored_packages: HashMap<NormalizedPackageName, PinnedPackage>,
wheel_buider: Arc<WheelBuilder>,
options: ResolveOptions,
env_variables: HashMap<String, String>,
) -> miette::Result<Vec<PinnedPackage>> {
// Construct the pool
let pool = Pool::new();
Expand Down Expand Up @@ -147,11 +139,9 @@ fn resolve_inner<'r>(
package_db,
env_markers,
compatible_tags,
locked_packages,
favored_packages,
name_to_url,
wheel_buider,
options,
env_variables,
)?;

// Invoke the solver to get a solution to the requirements
Expand Down
Loading

0 comments on commit 8c301ba

Please sign in to comment.