Skip to content

Commit

Permalink
fix: suppress logs
Browse files Browse the repository at this point in the history
AND make SpinningCursorAnimation work!
  • Loading branch information
AgainstEntropy committed Aug 21, 2024
1 parent aca455c commit 047877a
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 31 deletions.
2 changes: 1 addition & 1 deletion nexa/gguf/llama/_utils_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class suppress_stdout_stderr(object):
sys = sys
os = os

def __init__(self, disable: bool = True):
def __init__(self, disable: bool = False):
self.disable = disable

# Oddly enough this works better than the contextlib version
Expand Down
55 changes: 25 additions & 30 deletions nexa/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import itertools
import os
import platform
import sys
import threading
import time
from functools import partial, wraps
from importlib.metadata import PackageNotFoundError, distribution
import platform
from contextlib import redirect_stdout, redirect_stderr

from prompt_toolkit import HTML, prompt
from prompt_toolkit.styles import Style

from nexa.constants import EXIT_COMMANDS, EXIT_REMINDER
from nexa.gguf.llama._utils_transformers import (
suppress_stdout_stderr,
) # re-import, don't comment out


def is_package_installed(package_name: str) -> bool:
Expand All @@ -30,41 +31,26 @@ def is_nexa_gpu_installed() -> bool:

def is_metal_available():
arch = platform.machine().lower()
return sys.platform == "darwin" and ('arm' in arch or 'aarch' in arch) # ARM architecture for Apple Silicon
return sys.platform == "darwin" and (
"arm" in arch or "aarch" in arch
) # ARM architecture for Apple Silicon


def is_x86() -> bool:
"""Check if the architecture is x86."""
return platform.machine().startswith("x86")


def is_arm64() -> bool:
"""Check if the architecture is ARM64."""
return platform.machine().startswith("arm")


class suppress_stdout_stderr:
"""Context manager to suppress stdout and stderr."""
def __enter__(self):
self.null_file = open(os.devnull, "w")
self.old_stdout = sys.stdout
self.old_stderr = sys.stderr
self.stdout_redirect = redirect_stdout(self.null_file)
self.stderr_redirect = redirect_stderr(self.null_file)
self.stdout_redirect.__enter__()
self.stderr_redirect.__enter__()
return self

def __exit__(self, *args):
self.stdout_redirect.__exit__(*args)
self.stderr_redirect.__exit__(*args)
sys.stdout = self.old_stdout
sys.stderr = self.old_stderr
self.null_file.close()


_style = Style.from_dict({
"prompt": "ansiblue",
})
_style = Style.from_dict(
{
"prompt": "ansiblue",
}
)

_prompt = partial(prompt, ">>> ", style=_style)

Expand Down Expand Up @@ -117,27 +103,36 @@ def _load_model(self):
obj = MyClass()
"""
def __init__(self):

def __init__(self, alternate_stream: bool = True):
frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
self.spinner = itertools.cycle(frames)
self.stop_spinning = threading.Event()
self._use_alternate_stream = alternate_stream
self.stream = sys.stdout

def _spin(self):
while not self.stop_spinning.is_set():
print(f"\r{next(self.spinner)} ", flush=True, end="")
self.stream.write(f"\r{next(self.spinner)} ")
self.stream.flush()
time.sleep(0.1)
if self.stop_spinning.is_set():
break

def __enter__(self):
if self._use_alternate_stream:
self.stream = open("/dev/tty", "w")
self.thread = threading.Thread(target=self._spin)
self.thread.start()
return self

def __exit__(self, exc_type, exc_value, traceback):
self.stop_spinning.set()
self.thread.join()
print("\r", flush=True, end="")
self.stream.write("\r")
self.stream.flush()
if self._use_alternate_stream:
self.stream.close()

def __call__(self, func):
@wraps(func)
Expand Down

0 comments on commit 047877a

Please sign in to comment.