diff --git a/nexa/gguf/llama/_utils_transformers.py b/nexa/gguf/llama/_utils_transformers.py index 945c1478..0049e9cc 100644 --- a/nexa/gguf/llama/_utils_transformers.py +++ b/nexa/gguf/llama/_utils_transformers.py @@ -17,7 +17,7 @@ class suppress_stdout_stderr(object): sys = sys os = os - def __init__(self, disable: bool = True): + def __init__(self, disable: bool = False): self.disable = disable # Oddly enough this works better than the contextlib version diff --git a/nexa/utils.py b/nexa/utils.py index e59793fd..ea187d8a 100644 --- a/nexa/utils.py +++ b/nexa/utils.py @@ -1,17 +1,18 @@ import itertools -import os +import platform import sys import threading import time from functools import partial, wraps from importlib.metadata import PackageNotFoundError, distribution -import platform -from contextlib import redirect_stdout, redirect_stderr from prompt_toolkit import HTML, prompt from prompt_toolkit.styles import Style from nexa.constants import EXIT_COMMANDS, EXIT_REMINDER +from nexa.gguf.llama._utils_transformers import ( + suppress_stdout_stderr, +) # re-import, don't comment out def is_package_installed(package_name: str) -> bool: @@ -30,41 +31,26 @@ def is_nexa_gpu_installed() -> bool: def is_metal_available(): arch = platform.machine().lower() - return sys.platform == "darwin" and ('arm' in arch or 'aarch' in arch) # ARM architecture for Apple Silicon + return sys.platform == "darwin" and ( + "arm" in arch or "aarch" in arch + ) # ARM architecture for Apple Silicon def is_x86() -> bool: """Check if the architecture is x86.""" return platform.machine().startswith("x86") + def is_arm64() -> bool: """Check if the architecture is ARM64.""" return platform.machine().startswith("arm") -class suppress_stdout_stderr: - """Context manager to suppress stdout and stderr.""" - def __enter__(self): - self.null_file = open(os.devnull, "w") - self.old_stdout = sys.stdout - self.old_stderr = sys.stderr - self.stdout_redirect = redirect_stdout(self.null_file) - self.stderr_redirect = redirect_stderr(self.null_file) - self.stdout_redirect.__enter__() - self.stderr_redirect.__enter__() - return self - - def __exit__(self, *args): - self.stdout_redirect.__exit__(*args) - self.stderr_redirect.__exit__(*args) - sys.stdout = self.old_stdout - sys.stderr = self.old_stderr - self.null_file.close() - - -_style = Style.from_dict({ - "prompt": "ansiblue", -}) +_style = Style.from_dict( + { + "prompt": "ansiblue", + } +) _prompt = partial(prompt, ">>> ", style=_style) @@ -117,19 +103,25 @@ def _load_model(self): obj = MyClass() """ - def __init__(self): + + def __init__(self, alternate_stream: bool = True): frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"] self.spinner = itertools.cycle(frames) self.stop_spinning = threading.Event() + self._use_alternate_stream = alternate_stream + self.stream = sys.stdout def _spin(self): while not self.stop_spinning.is_set(): - print(f"\r{next(self.spinner)} ", flush=True, end="") + self.stream.write(f"\r{next(self.spinner)} ") + self.stream.flush() time.sleep(0.1) if self.stop_spinning.is_set(): break def __enter__(self): + if self._use_alternate_stream: + self.stream = open("/dev/tty", "w") self.thread = threading.Thread(target=self._spin) self.thread.start() return self @@ -137,7 +129,10 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): self.stop_spinning.set() self.thread.join() - print("\r", flush=True, end="") + self.stream.write("\r") + self.stream.flush() + if self._use_alternate_stream: + self.stream.close() def __call__(self, func): @wraps(func)