From db009a43357f2b36f0b638255d6d9e36325626a3 Mon Sep 17 00:00:00 2001 From: Olivier Lacroix Date: Mon, 12 Aug 2024 23:05:20 +1000 Subject: [PATCH] Refactor upgrade --- src/cli/global/upgrade.rs | 147 +++++++++----------------------------- 1 file changed, 35 insertions(+), 112 deletions(-) diff --git a/src/cli/global/upgrade.rs b/src/cli/global/upgrade.rs index dfbd2c85b..38b7af1ed 100644 --- a/src/cli/global/upgrade.rs +++ b/src/cli/global/upgrade.rs @@ -1,20 +1,22 @@ -use std::{collections::HashMap, sync::Arc, time::Duration}; +use itertools::Itertools; +use std::iter::once; +use std::{sync::Arc, time::Duration}; use clap::Parser; use indexmap::IndexMap; use indicatif::ProgressBar; -use itertools::Itertools; -use miette::{Context, IntoDiagnostic, Report}; +use miette::{IntoDiagnostic, Report}; use pixi_utils::reqwest::build_reqwest_clients; -use rattler_conda_types::{Channel, GenericVirtualPackage, MatchSpec, PackageName, Platform}; -use rattler_solve::{resolvo::Solver, SolverImpl, SolverTask}; -use rattler_virtual_packages::VirtualPackage; +use rattler_conda_types::{Channel, MatchSpec, PackageName, Platform}; + use tokio::task::JoinSet; use super::{common::find_installed_package, install::globally_install_package}; -use crate::cli::{cli_config::ChannelsConfig, has_specs::HasSpecs}; +use crate::cli::{ + cli_config::ChannelsConfig, global::common::solve_package_records, has_specs::HasSpecs, +}; use pixi_config::Config; -use pixi_progress::{global_multi_progress, long_running_progress_style, wrap_in_progress}; +use pixi_progress::{global_multi_progress, long_running_progress_style}; /// Upgrade specific package which is installed globally. #[derive(Parser, Debug)] @@ -51,117 +53,48 @@ pub(super) async fn upgrade_packages( platform: Platform, ) -> miette::Result<()> { let channel_cli = cli_channels.resolve_from_config(&config); - - // Get channels and version of globally installed packages in parallel - let mut channels = HashMap::with_capacity(specs.len()); - let mut versions = HashMap::with_capacity(specs.len()); - let mut set: JoinSet> = JoinSet::new(); - for package_name in specs.keys().cloned() { - let channel_config = config.global_channel_config().clone(); - set.spawn(async move { - let p = find_installed_package(&package_name).await?; - let channel = - Channel::from_str(p.repodata_record.channel, &channel_config).into_diagnostic()?; - let version = p.repodata_record.package_record.version.into_version(); - Ok((package_name, channel, version)) - }); - } - while let Some(data) = set.join_next().await { - let (package_name, channel, version) = data.into_diagnostic()??; - channels.insert(package_name.clone(), channel); - versions.insert(package_name, version); - } - - // Fetch repodata across all channels - - // Start by aggregating all channels that we need to iterate - let all_channels: Vec = channels - .values() - .cloned() - .chain(channel_cli.iter().cloned()) - .unique() - .collect(); - - // Now ask gateway to query repodata for these channels - let (_, authenticated_client) = build_reqwest_clients(Some(&config)); - let gateway = config.gateway(authenticated_client.clone()); - let repodata = gateway - .query( - all_channels, - [platform, Platform::NoArch], - specs.values().cloned().collect_vec(), - ) - .recursive(true) - .await - .into_diagnostic()?; + let (_, client) = build_reqwest_clients(Some(&config)); + let gateway = config.gateway(client.clone()); // Resolve environments in parallel let mut set: JoinSet> = JoinSet::new(); - // Create arcs for these structs // as they later will be captured by closure - let repodata = Arc::new(repodata); - let config = Arc::new(config); + let channel_config = Arc::new(config.global_channel_config().clone()); let channel_cli = Arc::new(channel_cli); - let channels = Arc::new(channels); for (package_name, package_matchspec) in specs { - let repodata = repodata.clone(); - let config = config.clone(); + let channel_config = channel_config.clone(); let channel_cli = channel_cli.clone(); - let channels = channels.clone(); + let gateway = gateway.clone(); // Already an Arc under the hood - set.spawn_blocking(move || { - // Filter repodata based on channels specific to the package (and from the CLI) - let specific_repodata: Vec<_> = repodata - .iter() - .filter_map(|repodata| { - let filtered: Vec<_> = repodata - .iter() - .filter(|item| { - let item_channel = - Channel::from_str(&item.channel, config.global_channel_config()) - .expect("should be parseable"); - channel_cli.contains(&item_channel) - || channels - .get(&package_name) - .map_or(false, |c| c == &item_channel) - }) - .collect(); - - (!filtered.is_empty()).then_some(filtered) - }) - .collect(); - - // Determine virtual packages of the current platform - let virtual_packages = VirtualPackage::current() - .into_diagnostic() - .context("failed to determine virtual packages")? + set.spawn(async move { + let record = find_installed_package(&package_name).await?.repodata_record; + let channel = Channel::from_str(record.channel, &channel_config).into_diagnostic()?; + let version = record.package_record.version.into_version(); + + let channels = channel_cli .iter() .cloned() - .map(GenericVirtualPackage::from) - .collect(); - - // Solve the environment - let solver_matchspec = package_matchspec.clone(); - let solved_records = wrap_in_progress("solving environment", move || { - Solver.solve(SolverTask { - specs: vec![solver_matchspec], - virtual_packages, - ..SolverTask::from_iter(specific_repodata) - }) - }) - .into_diagnostic() - .context("failed to solve environment")?; - - Ok((package_name, package_matchspec.clone(), solved_records)) + .chain(once(channel).into_iter()) + .unique(); + let records = solve_package_records( + &gateway, + platform, + channels, + vec![package_matchspec.clone()], + ) + .await?; + + Ok((package_name, package_matchspec, records, version)) }); } // Upgrade each package when relevant let mut upgraded = false; while let Some(data) = set.join_next().await { - let (package_name, package_matchspec, records) = data.into_diagnostic()??; + let (package_name, package_matchspec, records, installed_version) = + data.into_diagnostic()??; let toinstall_version = records .iter() .find(|r| r.package_record.name == package_name) @@ -172,10 +105,6 @@ pub(super) async fn upgrade_packages( package_name.as_normalized() ) })?; - let installed_version = versions - .get(&package_name) - .expect("should have the installed version") - .to_owned(); // Perform upgrade if a specific version was requested // OR if a more recent version is available @@ -195,13 +124,7 @@ pub(super) async fn upgrade_packages( console::style("Updating").green(), message )); - globally_install_package( - &package_name, - records, - authenticated_client.clone(), - platform, - ) - .await?; + globally_install_package(&package_name, records, client.clone(), platform).await?; pb.finish_with_message(format!("{} {}", console::style("Updated").green(), message)); upgraded = true; }