diff --git a/src/cli/exec.rs b/src/cli/exec.rs index 438418877..57345ea31 100644 --- a/src/cli/exec.rs +++ b/src/cli/exec.rs @@ -1,6 +1,5 @@ use std::{ hash::{DefaultHasher, Hash, Hasher}, - path::Path, str::FromStr, }; @@ -10,14 +9,12 @@ use rattler::{ install::{IndicatifReporter, Installer}, package_cache::PackageCache, }; -use rattler_conda_types::{GenericVirtualPackage, MatchSpec, PackageName, Platform}; -use rattler_solve::{resolvo::Solver, SolverImpl, SolverTask}; -use rattler_virtual_packages::VirtualPackage; +use rattler_conda_types::{MatchSpec, PackageName, Platform}; use reqwest_middleware::ClientWithMiddleware; -use crate::prefix::Prefix; +use crate::{cli::global::common::solve_package_records, prefix::Prefix}; use pixi_config::{self, Config, ConfigCli}; -use pixi_progress::{await_in_progress, global_multi_progress, wrap_in_progress}; +use pixi_progress::{global_multi_progress, wrap_in_progress}; use pixi_utils::{reqwest::build_reqwest_clients, PrefixGuard}; use super::cli_config::ChannelsConfig; @@ -84,14 +81,13 @@ impl EnvironmentHash { /// CLI entry point for `pixi runx` pub async fn execute(args: Args) -> miette::Result<()> { let config = Config::with_cli_config(&args.config); - let cache_dir = pixi_config::get_cache_dir().context("failed to determine cache directory")?; + let (_, client) = build_reqwest_clients(Some(&config)); let mut command_args = args.command.iter(); let command = command_args.next().ok_or_else(|| miette::miette!(help ="i.e when specifying specs explicitly use a command at the end: `pixi exec -s python==3.12 python`", "missing required command to execute",))?; - let (_, client) = build_reqwest_clients(Some(&config)); // Create the environment to run the command in. - let prefix = create_exec_prefix(&args, &cache_dir, &config, &client).await?; + let prefix = create_exec_prefix(&args, &config, &client).await?; // Get environment variables from the activation let activation_env = run_activation(&prefix).await?; @@ -114,11 +110,11 @@ pub async fn execute(args: Args) -> miette::Result<()> { /// Creates a prefix for the `pixi exec` command. pub async fn create_exec_prefix( args: &Args, - cache_dir: &Path, config: &Config, client: &ClientWithMiddleware, ) -> miette::Result { let environment_name = EnvironmentHash::from_args(args, config).name(); + let cache_dir = pixi_config::get_cache_dir().context("failed to determine cache directory")?; let prefix = Prefix::new(cache_dir.join("cached-envs-v0").join(environment_name)); let mut guard = PrefixGuard::new(prefix.root()) @@ -164,32 +160,12 @@ pub async fn create_exec_prefix( args.specs.clone() }; - // Get the repodata for the specs - let repodata = await_in_progress("fetching repodata for environment", |_| async { - gateway - .query( - args.channels.resolve_from_config(config), - [Platform::current(), Platform::NoArch], - specs.clone(), - ) - .recursive(true) - .execute() - .await - }) - .await - .into_diagnostic() - .context("failed to get repodata")?; - - // Determine virtual packages of the current platform - let virtual_packages = VirtualPackage::current() - .into_diagnostic() - .context("failed to determine virtual packages")? - .iter() - .cloned() - .map(GenericVirtualPackage::from) - .collect(); - // Solve the environment + let channels = args.channels.resolve_from_config(config); + let solved_records = + solve_package_records(&gateway, Platform::current(), channels, specs.clone()).await?; + + // Install the environment tracing::info!( "creating environment in {}", dunce::canonicalize(prefix.root()) @@ -197,17 +173,6 @@ pub async fn create_exec_prefix( .unwrap_or(prefix.root()) .display() ); - let solved_records = wrap_in_progress("solving environment", move || { - Solver.solve(SolverTask { - specs, - virtual_packages, - ..SolverTask::from_iter(&repodata) - }) - }) - .into_diagnostic() - .context("failed to solve environment")?; - - // Install the environment Installer::new() .with_download_client(client.clone()) .with_reporter( diff --git a/src/cli/global/common.rs b/src/cli/global/common.rs index 8f08c210f..6c567a541 100644 --- a/src/cli/global/common.rs +++ b/src/cli/global/common.rs @@ -1,7 +1,14 @@ use std::path::PathBuf; -use miette::IntoDiagnostic; -use rattler_conda_types::{Channel, ChannelConfig, PackageName, PrefixRecord}; +use miette::{Context, IntoDiagnostic}; +use pixi_progress::{await_in_progress, wrap_in_progress}; +use rattler_conda_types::{ + Channel, ChannelConfig, GenericVirtualPackage, MatchSpec, PackageName, Platform, PrefixRecord, + RepoDataRecord, +}; +use rattler_repodata_gateway::Gateway; +use rattler_solve::{resolvo::Solver, SolverImpl, SolverTask}; +use rattler_virtual_packages::VirtualPackage; use crate::{prefix::Prefix, repodata}; use pixi_config::home_path; @@ -110,6 +117,58 @@ pub(super) fn channel_name_from_prefix( .unwrap_or_else(|_| prefix_package.repodata_record.channel.clone()) } +/// Solve package records from [`Gateway`] for the given package MatchSpec +/// +/// # Returns +/// +/// The package records (with dependencies records) for the given package +/// MatchSpec +pub async fn solve_package_records( + gateway: &Gateway, + platform: Platform, + channels: ChannelIter, + specs: Vec, +) -> miette::Result> +where + AsChannel: Into, + ChannelIter: IntoIterator, +{ + // Get the repodata for the specs + let repodata = await_in_progress("fetching repodata for environment", |_| async { + gateway + .query(channels, [platform, Platform::NoArch], specs.clone()) + .recursive(true) + .execute() + .await + }) + .await + .into_diagnostic() + .context("failed to get repodata")?; + + // Determine virtual packages of the current platform + // We cannot infer virtual_packages for another platform + let virtual_packages = VirtualPackage::current() + .into_diagnostic() + .context("failed to determine virtual packages")? + .iter() + .cloned() + .map(GenericVirtualPackage::from) + .collect(); + + // Solve the environment + let solved_records = wrap_in_progress("solving environment", move || { + Solver.solve(SolverTask { + specs, + virtual_packages, + ..SolverTask::from_iter(&repodata) + }) + }) + .into_diagnostic() + .context("failed to solve environment")?; + + Ok(solved_records) +} + /// Find the globally installed package with the given [`PackageName`] /// /// # Returns @@ -128,7 +187,7 @@ pub(super) async fn find_installed_package( find_designated_package(&prefix, package_name).await } -/// Find the designated package in the given [`Prefix`] +/// Find the designated package in the given prefix /// /// # Returns /// diff --git a/src/cli/global/install.rs b/src/cli/global/install.rs index f96dabc05..944b7c75a 100644 --- a/src/cli/global/install.rs +++ b/src/cli/global/install.rs @@ -7,30 +7,28 @@ use std::{ use clap::Parser; use indexmap::IndexMap; use itertools::Itertools; -use miette::{Context, IntoDiagnostic}; +use miette::IntoDiagnostic; + use pixi_utils::reqwest::build_reqwest_clients; use rattler::{ install::{DefaultProgressFormatter, IndicatifReporter, Installer}, package_cache::PackageCache, }; -use rattler_conda_types::{ - GenericVirtualPackage, MatchSpec, PackageName, Platform, PrefixRecord, RepoDataRecord, -}; +use rattler_conda_types::{MatchSpec, PackageName, Platform, PrefixRecord, RepoDataRecord}; use rattler_shell::{ activation::{ActivationVariables, Activator, PathModificationBehavior}, shell::{Shell, ShellEnum}, }; -use rattler_solve::{resolvo::Solver, SolverImpl, SolverTask}; -use rattler_virtual_packages::VirtualPackage; use reqwest_middleware::ClientWithMiddleware; use super::common::{channel_name_from_prefix, find_designated_package, BinDir, BinEnvDir}; use crate::{ - cli::cli_config::ChannelsConfig, cli::has_specs::HasSpecs, prefix::Prefix, + cli::{cli_config::ChannelsConfig, global::common::solve_package_records, has_specs::HasSpecs}, + prefix::Prefix, rlimit::try_increase_rlimit_to_sensible, }; use pixi_config::{self, Config, ConfigCli}; -use pixi_progress::{await_in_progress, global_multi_progress, wrap_in_progress}; +use pixi_progress::{await_in_progress, global_multi_progress}; /// Installs the defined package in a global accessible location. #[derive(Parser, Debug)] @@ -286,8 +284,8 @@ pub fn prompt_user_to_continue( /// Install a global command pub async fn execute(args: Args) -> miette::Result<()> { - // Figure out what channels we are using let config = Config::with_cli_config(&args.config); + let (_, client) = build_reqwest_clients(Some(&config)); let channels = args.channels.resolve_from_config(&config); let specs = args.specs()?; @@ -297,52 +295,22 @@ pub async fn execute(args: Args) -> miette::Result<()> { return Ok(()); } - // Fetch the repodata - let (_, auth_client) = build_reqwest_clients(Some(&config)); - - let gateway = config.gateway(auth_client.clone()); - - let repodata = gateway - .query( - channels, - [args.platform, Platform::NoArch], - specs.values().cloned().collect_vec(), - ) - .recursive(true) - .await - .into_diagnostic()?; - - // Determine virtual packages of the current platform - let virtual_packages = VirtualPackage::current() - .into_diagnostic() - .context("failed to determine virtual packages")? - .iter() - .cloned() - .map(GenericVirtualPackage::from) - .collect(); - - // Solve the environment - let solver_specs = specs.clone(); - let solved_records = wrap_in_progress("solving environment", move || { - Solver.solve(SolverTask { - specs: solver_specs.values().cloned().collect_vec(), - virtual_packages, - ..SolverTask::from_iter(&repodata) - }) - }) - .into_diagnostic() - .context("failed to solve environment")?; + // Construct a gateway to get repodata. + let gateway = config.gateway(client.clone()); // Install the package(s) let mut executables = vec![]; - for (package_name, _) in specs { - let (prefix_package, scripts, _) = globally_install_package( - &package_name, - solved_records.clone(), - auth_client.clone(), + for (package_name, package_matchspec) in specs { + let records = solve_package_records( + &gateway, args.platform, + channels.clone(), + vec![package_matchspec], ) .await?; + + let (prefix_package, scripts, _) = + globally_install_package(&package_name, records, client.clone(), args.platform).await?; let channel_name = channel_name_from_prefix(&prefix_package, config.global_channel_config()); let record = &prefix_package.repodata_record.package_record; diff --git a/src/cli/global/mod.rs b/src/cli/global/mod.rs index 0bc94513f..c1c2c3977 100644 --- a/src/cli/global/mod.rs +++ b/src/cli/global/mod.rs @@ -1,6 +1,6 @@ use clap::Parser; -mod common; +pub mod common; mod install; mod list; mod remove; diff --git a/src/cli/global/upgrade.rs b/src/cli/global/upgrade.rs index dfbd2c85b..58dcb1ee6 100644 --- a/src/cli/global/upgrade.rs +++ b/src/cli/global/upgrade.rs @@ -1,20 +1,22 @@ -use std::{collections::HashMap, sync::Arc, time::Duration}; +use itertools::Itertools; +use std::iter::once; +use std::{sync::Arc, time::Duration}; use clap::Parser; use indexmap::IndexMap; use indicatif::ProgressBar; -use itertools::Itertools; -use miette::{Context, IntoDiagnostic, Report}; +use miette::{IntoDiagnostic, Report}; use pixi_utils::reqwest::build_reqwest_clients; -use rattler_conda_types::{Channel, GenericVirtualPackage, MatchSpec, PackageName, Platform}; -use rattler_solve::{resolvo::Solver, SolverImpl, SolverTask}; -use rattler_virtual_packages::VirtualPackage; +use rattler_conda_types::{Channel, MatchSpec, PackageName, Platform}; + use tokio::task::JoinSet; use super::{common::find_installed_package, install::globally_install_package}; -use crate::cli::{cli_config::ChannelsConfig, has_specs::HasSpecs}; +use crate::cli::{ + cli_config::ChannelsConfig, global::common::solve_package_records, has_specs::HasSpecs, +}; use pixi_config::Config; -use pixi_progress::{global_multi_progress, long_running_progress_style, wrap_in_progress}; +use pixi_progress::{global_multi_progress, long_running_progress_style}; /// Upgrade specific package which is installed globally. #[derive(Parser, Debug)] @@ -51,117 +53,44 @@ pub(super) async fn upgrade_packages( platform: Platform, ) -> miette::Result<()> { let channel_cli = cli_channels.resolve_from_config(&config); - - // Get channels and version of globally installed packages in parallel - let mut channels = HashMap::with_capacity(specs.len()); - let mut versions = HashMap::with_capacity(specs.len()); - let mut set: JoinSet> = JoinSet::new(); - for package_name in specs.keys().cloned() { - let channel_config = config.global_channel_config().clone(); - set.spawn(async move { - let p = find_installed_package(&package_name).await?; - let channel = - Channel::from_str(p.repodata_record.channel, &channel_config).into_diagnostic()?; - let version = p.repodata_record.package_record.version.into_version(); - Ok((package_name, channel, version)) - }); - } - while let Some(data) = set.join_next().await { - let (package_name, channel, version) = data.into_diagnostic()??; - channels.insert(package_name.clone(), channel); - versions.insert(package_name, version); - } - - // Fetch repodata across all channels - - // Start by aggregating all channels that we need to iterate - let all_channels: Vec = channels - .values() - .cloned() - .chain(channel_cli.iter().cloned()) - .unique() - .collect(); - - // Now ask gateway to query repodata for these channels - let (_, authenticated_client) = build_reqwest_clients(Some(&config)); - let gateway = config.gateway(authenticated_client.clone()); - let repodata = gateway - .query( - all_channels, - [platform, Platform::NoArch], - specs.values().cloned().collect_vec(), - ) - .recursive(true) - .await - .into_diagnostic()?; + let (_, client) = build_reqwest_clients(Some(&config)); + let gateway = config.gateway(client.clone()); // Resolve environments in parallel let mut set: JoinSet> = JoinSet::new(); - // Create arcs for these structs // as they later will be captured by closure - let repodata = Arc::new(repodata); - let config = Arc::new(config); + let channel_config = Arc::new(config.global_channel_config().clone()); let channel_cli = Arc::new(channel_cli); - let channels = Arc::new(channels); for (package_name, package_matchspec) in specs { - let repodata = repodata.clone(); - let config = config.clone(); + let channel_config = channel_config.clone(); let channel_cli = channel_cli.clone(); - let channels = channels.clone(); - - set.spawn_blocking(move || { - // Filter repodata based on channels specific to the package (and from the CLI) - let specific_repodata: Vec<_> = repodata - .iter() - .filter_map(|repodata| { - let filtered: Vec<_> = repodata - .iter() - .filter(|item| { - let item_channel = - Channel::from_str(&item.channel, config.global_channel_config()) - .expect("should be parseable"); - channel_cli.contains(&item_channel) - || channels - .get(&package_name) - .map_or(false, |c| c == &item_channel) - }) - .collect(); - - (!filtered.is_empty()).then_some(filtered) - }) - .collect(); - - // Determine virtual packages of the current platform - let virtual_packages = VirtualPackage::current() - .into_diagnostic() - .context("failed to determine virtual packages")? - .iter() - .cloned() - .map(GenericVirtualPackage::from) - .collect(); - - // Solve the environment - let solver_matchspec = package_matchspec.clone(); - let solved_records = wrap_in_progress("solving environment", move || { - Solver.solve(SolverTask { - specs: vec![solver_matchspec], - virtual_packages, - ..SolverTask::from_iter(specific_repodata) - }) - }) - .into_diagnostic() - .context("failed to solve environment")?; - - Ok((package_name, package_matchspec.clone(), solved_records)) + let gateway = gateway.clone(); // Already an Arc under the hood + + set.spawn(async move { + let record = find_installed_package(&package_name).await?.repodata_record; + let channel = Channel::from_str(record.channel, &channel_config).into_diagnostic()?; + let version = record.package_record.version.into_version(); + + let channels = channel_cli.iter().cloned().chain(once(channel)).unique(); + let records = solve_package_records( + &gateway, + platform, + channels, + vec![package_matchspec.clone()], + ) + .await?; + + Ok((package_name, package_matchspec, records, version)) }); } // Upgrade each package when relevant let mut upgraded = false; while let Some(data) = set.join_next().await { - let (package_name, package_matchspec, records) = data.into_diagnostic()??; + let (package_name, package_matchspec, records, installed_version) = + data.into_diagnostic()??; let toinstall_version = records .iter() .find(|r| r.package_record.name == package_name) @@ -172,10 +101,6 @@ pub(super) async fn upgrade_packages( package_name.as_normalized() ) })?; - let installed_version = versions - .get(&package_name) - .expect("should have the installed version") - .to_owned(); // Perform upgrade if a specific version was requested // OR if a more recent version is available @@ -195,13 +120,7 @@ pub(super) async fn upgrade_packages( console::style("Updating").green(), message )); - globally_install_package( - &package_name, - records, - authenticated_client.clone(), - platform, - ) - .await?; + globally_install_package(&package_name, records, client.clone(), platform).await?; pb.finish_with_message(format!("{} {}", console::style("Updated").green(), message)); upgraded = true; } diff --git a/src/project/repodata.rs b/src/project/repodata.rs index 405d824bb..e071b20d6 100644 --- a/src/project/repodata.rs +++ b/src/project/repodata.rs @@ -1,24 +1,11 @@ use crate::project::Project; use rattler_repodata_gateway::Gateway; -use std::path::PathBuf; impl Project { /// Returns the [`Gateway`] used by this project. pub fn repodata_gateway(&self) -> &Gateway { - self.repodata_gateway.get_or_init(|| { - // Determine the cache directory and fall back to sane defaults otherwise. - let cache_dir = pixi_config::get_cache_dir().unwrap_or_else(|e| { - tracing::error!("failed to determine repodata cache directory: {e}"); - std::env::current_dir().unwrap_or_else(|_| PathBuf::from("./")) - }); - - // Construct the gateway - Gateway::builder() - .with_client(self.authenticated_client().clone()) - .with_cache_dir(cache_dir.join("repodata")) - .with_channel_config(self.config().into()) - .finish() - }) + self.repodata_gateway + .get_or_init(|| self.config.gateway(self.authenticated_client().clone())) } }