diff --git a/boa/cli/boa.py b/boa/cli/boa.py index b10a1bf9..1a26ec8e 100644 --- a/boa/cli/boa.py +++ b/boa/cli/boa.py @@ -12,7 +12,7 @@ from boa.core.config import init_global_config from boa._version import __version__ -from mamba.utils import init_api_context +from boa.core.utils import init_api_context from conda_build.conda_interface import cc_conda_build diff --git a/boa/cli/mambabuild.py b/boa/cli/mambabuild.py index 3c777874..68b7d3fb 100644 --- a/boa/cli/mambabuild.py +++ b/boa/cli/mambabuild.py @@ -20,7 +20,7 @@ from boa.core.solver import MambaSolver from boa.core.utils import normalize_subdir -from mamba.utils import init_api_context +from boa.core.utils import init_api_context from boa.core.config import boa_config only_dot_or_digit_re = re.compile(r"^[\d\.]+$") diff --git a/boa/core/solver.py b/boa/core/solver.py index ef7cf9d1..d8aa2efc 100644 --- a/boa/core/solver.py +++ b/boa/core/solver.py @@ -20,17 +20,14 @@ from conda.core.package_cache_data import PackageCacheData import libmambapy -import mamba -from mamba.utils import get_index, load_channels, to_package_record_from_subjson +from boa.core.utils import get_index, load_channels, to_package_record_from_subjson from boa.core.config import boa_config console = boa_config.console solver_cache = {} -MAMBA_17_UP = mamba.version_info >= (0, 17, 0) - def refresh_solvers(): for _, v in solver_cache.items(): diff --git a/boa/core/utils.py b/boa/core/utils.py index aa169e59..3b632693 100644 --- a/boa/core/utils.py +++ b/boa/core/utils.py @@ -1,18 +1,30 @@ # Copyright (C) 2021, QuantStack # SPDX-License-Identifier: BSD-3-Clause +from __future__ import absolute_import, division, print_function, unicode_literals + import collections import sys import os import typing +import json +import urllib.parse from conda.base.context import context from conda_build import utils from conda_build.config import get_or_merge_config from conda_build.variants import find_config_files, parse_config_file, combine_specs from conda_build import __version__ as cb_version +from conda.base.constants import ChannelPriority +from conda.gateways.connection.session import CondaHttpAuth +from conda.core.index import check_allowlist +from conda.models.channel import Channel as CondaChannel +from conda.models.records import PackageRecord +from conda.common.url import join_url from boa.core.config import boa_config +import libmambapy as api + if typing.TYPE_CHECKING: from typing import Any @@ -149,3 +161,255 @@ def get_sys_vars_stubs(target_platform): "BUILD", ] return res + + +def get_index( + channel_urls=(), + prepend=True, + platform=None, + use_local=False, + use_cache=False, + unknown=None, + prefix=None, + repodata_fn="repodata.json", +): + if isinstance(platform, str): + platform = [platform, "noarch"] + + all_channels = [] + if use_local: + all_channels.append("local") + all_channels.extend(channel_urls) + if prepend: + all_channels.extend(context.channels) + check_allowlist(all_channels) + + # Remove duplicates but retain order + all_channels = list(collections.OrderedDict.fromkeys(all_channels)) + + dlist = api.DownloadTargetList() + + index = [] + + def fixup_channel_spec(spec): + at_count = spec.count("@") + if at_count > 1: + first_at = spec.find("@") + spec = ( + spec[:first_at] + + urllib.parse.quote(spec[first_at]) + + spec[first_at + 1 :] + ) + if platform: + spec = spec + "[" + ",".join(platform) + "]" + return spec + + all_channels = list(map(fixup_channel_spec, all_channels)) + pkgs_dirs = api.MultiPackageCache(context.pkgs_dirs) + api.create_cache_dir(str(pkgs_dirs.first_writable_path)) + + for channel in api.get_channels(all_channels): + for channel_platform, url in channel.platform_urls(with_credentials=True): + full_url = CondaHttpAuth.add_binstar_token(url) + + sd = api.SubdirData( + channel, channel_platform, full_url, pkgs_dirs, repodata_fn + ) + + index.append( + (sd, {"platform": channel_platform, "url": url, "channel": channel}) + ) + dlist.add(sd) + + is_downloaded = dlist.download(api.MAMBA_DOWNLOAD_FAILFAST) + + if not is_downloaded: + raise RuntimeError("Error downloading repodata.") + + return index + + +def load_channels( + pool, + channels, + repos, + has_priority=None, + prepend=True, + platform=None, + use_local=False, + use_cache=True, + repodata_fn="repodata.json", +): + index = get_index( + channel_urls=channels, + prepend=prepend, + platform=platform, + use_local=use_local, + repodata_fn=repodata_fn, + use_cache=use_cache, + ) + + if has_priority is None: + has_priority = context.channel_priority in [ + ChannelPriority.STRICT, + ChannelPriority.FLEXIBLE, + ] + + subprio_index = len(index) + if has_priority: + # first, count unique channels + n_channels = len(set([entry["channel"].canonical_name for _, entry in index])) + current_channel = index[0][1]["channel"].canonical_name + channel_prio = n_channels + + for subdir, entry in index: + # add priority here + if has_priority: + if entry["channel"].canonical_name != current_channel: + channel_prio -= 1 + current_channel = entry["channel"].canonical_name + priority = channel_prio + else: + priority = 0 + if has_priority: + subpriority = 0 + else: + subpriority = subprio_index + subprio_index -= 1 + + if not subdir.loaded() and entry["platform"] != "noarch": + # ignore non-loaded subdir if channel is != noarch + continue + + if context.verbosity != 0 and not context.json: + print( + "Channel: {}, platform: {}, prio: {} : {}".format( + entry["channel"], entry["platform"], priority, subpriority + ) + ) + print("Cache path: ", subdir.cache_path()) + + repo = subdir.create_repo(pool) + repo.set_priority(priority, subpriority) + repos.append(repo) + + return index + + +def init_api_context(use_mamba_experimental=False): + api_ctx = api.Context() + + api_ctx.json = context.json + api_ctx.dry_run = context.dry_run + if context.json: + context.always_yes = True + context.quiet = True + if use_mamba_experimental: + context.json = False + + api_ctx.verbosity = context.verbosity + api_ctx.set_verbosity(context.verbosity) + api_ctx.quiet = context.quiet + api_ctx.offline = context.offline + api_ctx.local_repodata_ttl = context.local_repodata_ttl + api_ctx.use_index_cache = context.use_index_cache + api_ctx.always_yes = context.always_yes + api_ctx.channels = context.channels + api_ctx.platform = context.subdir + # Conda uses a frozendict here + api_ctx.proxy_servers = dict(context.proxy_servers) + + if "MAMBA_EXTRACT_THREADS" in os.environ: + try: + max_threads = int(os.environ["MAMBA_EXTRACT_THREADS"]) + api_ctx.extract_threads = max_threads + except ValueError: + v = os.environ["MAMBA_EXTRACT_THREADS"] + raise ValueError( + f"Invalid conversion of env variable 'MAMBA_EXTRACT_THREADS' from value '{v}'" + ) + + def get_base_url(url, name=None): + tmp = url.rsplit("/", 1)[0] + if name: + if tmp.endswith(name): + return tmp.rsplit("/", 1)[0] + return tmp + + api_ctx.channel_alias = str( + get_base_url(context.channel_alias.url(with_credentials=True)) + ) + + additional_custom_channels = {} + for el in context.custom_channels: + if context.custom_channels[el].canonical_name not in ["local", "defaults"]: + additional_custom_channels[el] = get_base_url( + context.custom_channels[el].url(with_credentials=True), el + ) + api_ctx.custom_channels = additional_custom_channels + + additional_custom_multichannels = {} + for el in context.custom_multichannels: + if el not in ["defaults", "local"]: + additional_custom_multichannels[el] = [] + for c in context.custom_multichannels[el]: + additional_custom_multichannels[el].append( + get_base_url(c.url(with_credentials=True)) + ) + api_ctx.custom_multichannels = additional_custom_multichannels + + api_ctx.default_channels = [ + get_base_url(x.url(with_credentials=True)) for x in context.default_channels + ] + + if context.ssl_verify is False: + api_ctx.ssl_verify = "" + elif context.ssl_verify is not True: + api_ctx.ssl_verify = context.ssl_verify + api_ctx.target_prefix = context.target_prefix + api_ctx.root_prefix = context.root_prefix + api_ctx.conda_prefix = context.conda_prefix + api_ctx.pkgs_dirs = context.pkgs_dirs + api_ctx.envs_dirs = context.envs_dirs + + api_ctx.connect_timeout_secs = int(round(context.remote_connect_timeout_secs)) + api_ctx.max_retries = context.remote_max_retries + api_ctx.retry_backoff = context.remote_backoff_factor + api_ctx.add_pip_as_python_dependency = context.add_pip_as_python_dependency + api_ctx.use_only_tar_bz2 = context.use_only_tar_bz2 + + if context.channel_priority is ChannelPriority.STRICT: + api_ctx.channel_priority = api.ChannelPriority.kStrict + elif context.channel_priority is ChannelPriority.FLEXIBLE: + api_ctx.channel_priority = api.ChannelPriority.kFlexible + elif context.channel_priority is ChannelPriority.DISABLED: + api_ctx.channel_priority = api.ChannelPriority.kDisabled + + +def to_conda_channel(channel, platform): + if channel.scheme == "file": + return CondaChannel.from_value( + channel.platform_url(platform, with_credentials=False) + ) + + return CondaChannel( + channel.scheme, + channel.auth, + channel.location, + channel.token, + channel.name, + platform, + channel.package_filename, + ) + + +def to_package_record_from_subjson(entry, pkg, jsn_string): + channel_url = entry["url"] + info = json.loads(jsn_string) + info["fn"] = pkg + info["channel"] = to_conda_channel(entry["channel"], entry["platform"]) + info["url"] = join_url(channel_url, pkg) + if not info.get("subdir"): + info["subdir"] = entry["platform"] + package_record = PackageRecord(**info) + return package_record diff --git a/tests/env.yml b/tests/env.yml index c71c80d2..2090fae7 100644 --- a/tests/env.yml +++ b/tests/env.yml @@ -4,7 +4,7 @@ channels: dependencies: - python>=3.7 - pip - - mamba + - conda - libmambapy <=1.4.2 - pytest - "conda-build>=3.20"