Skip to content

Commit

Permalink
types
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenfreund committed Apr 8, 2024
1 parent 8eae6b9 commit bc4fc30
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions src/chatdbg/util/config.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import argparse
import os
from textwrap import TextWrapper

from traitlets import Bool, Int, Unicode
from traitlets.config import Configurable

from assistant.listeners import BaseAssistantListener
from chatdbg.util.markdown import ChatDBGMarkdownPrinter
from chatdbg.util.printer import ChatDBGPrinter

from io import StringIO
from io import StringIO, TextIOWrapper
from types import *
from typing import *

Expand Down Expand Up @@ -107,7 +109,7 @@ def _parser(self):

return parser

def to_json(self):
def to_json(self) -> Dict[str, Union[int, str, bool]]:
"""Serialize the object to a JSON string."""
return {
"model": self.model,
Expand All @@ -125,7 +127,7 @@ def to_json(self):
"instructions": self.instructions,
}

def parse_user_flags(self, argv):
def parse_user_flags(self, argv: List[str]) -> None:

args, unknown_args = self._parser().parse_known_args(argv)

Expand All @@ -134,23 +136,23 @@ def parse_user_flags(self, argv):

return unknown_args

def user_flags_help(self):
def user_flags_help(self) -> str:
return "\n".join(
[
self.class_get_trait_help(x, self).replace("ChatDBGConfig.", "")
for x in self._user_configurable
]
)

def user_flags(self):
def user_flags(self) -> str:
return "\n".join(
[
f" --{x.name:10}{self._trait_values[x.name]}"
for x in self._user_configurable
]
)

def parse_only_user_flags(self, args):
def parse_only_user_flags(self, args: List[str]) -> str:
try:
unknown = chatdbg_config.parse_user_flags(args)
if unknown:
Expand All @@ -162,7 +164,9 @@ def parse_only_user_flags(self, args):
except Exception as e:
return str(e) + f"\nChatDBG arguments:\n\n{self.user_flags_help()}"

def make_printer(self, stdout, prompt, prefix, width):
def make_printer(
self, stdout: TextIOWrapper, prompt: str, prefix: str, width: int
) -> BaseAssistantListener:
format = chatdbg_config.format
split = format.split(":")
if split[0] == "md":
Expand Down

0 comments on commit bc4fc30

Please sign in to comment.