diff --git a/agentstack/__init__.py b/agentstack/__init__.py index e432836..da5ea50 100644 --- a/agentstack/__init__.py +++ b/agentstack/__init__.py @@ -4,11 +4,12 @@ Methods that have been imported into this file are expected to be used by the end user inside of their project. """ -from agentstack.exceptions import ValidationError +from pathlib import Path +from agentstack import conf from agentstack.inputs import get_inputs ___all___ = [ - "ValidationError", + "conf", "get_inputs", ] diff --git a/agentstack/agents.py b/agentstack/agents.py index a83c91a..1c5ab29 100644 --- a/agentstack/agents.py +++ b/agentstack/agents.py @@ -4,7 +4,8 @@ import pydantic from ruamel.yaml import YAML, YAMLError from ruamel.yaml.scalarstring import FoldedScalarString -from agentstack import ValidationError +from agentstack import conf +from agentstack.exceptions import ValidationError AGENTS_FILENAME: Path = Path("src/config/agents.yaml") @@ -46,11 +47,8 @@ class AgentConfig(pydantic.BaseModel): backstory: str = "" llm: str = "" - def __init__(self, name: str, path: Optional[Path] = None): - if not path: - path = Path() - - filename = path / AGENTS_FILENAME + def __init__(self, name: str): + filename = conf.PATH / AGENTS_FILENAME if not os.path.exists(filename): os.makedirs(filename.parent, exist_ok=True) filename.touch() @@ -69,9 +67,6 @@ def __init__(self, name: str, path: Optional[Path] = None): error_str += f"{' '.join([str(loc) for loc in error['loc']])}: {error['msg']}\n" raise ValidationError(f"Error loading agent {name} from {filename}.\n{error_str}") - # store the path *after* loading data - self._path = path - def model_dump(self, *args, **kwargs) -> dict: dump = super().model_dump(*args, **kwargs) dump.pop('name') # name is the key, so keep it out of the data @@ -81,7 +76,7 @@ def model_dump(self, *args, **kwargs) -> dict: return {self.name: dump} def write(self): - filename = self._path / AGENTS_FILENAME + filename = conf.PATH / AGENTS_FILENAME with open(filename, 'r') as f: data = yaml.load(f) or {} @@ -98,10 +93,8 @@ def __exit__(self, *args): self.write() -def get_all_agent_names(path: Optional[Path] = None) -> list[str]: - if not path: - path = Path() - filename = path / AGENTS_FILENAME +def get_all_agent_names() -> list[str]: + filename = conf.PATH / AGENTS_FILENAME if not os.path.exists(filename): return [] with open(filename, 'r') as f: @@ -109,5 +102,5 @@ def get_all_agent_names(path: Optional[Path] = None) -> list[str]: return list(data.keys()) -def get_all_agents(path: Optional[Path] = None) -> list[AgentConfig]: - return [AgentConfig(name, path) for name in get_all_agent_names(path)] +def get_all_agents() -> list[AgentConfig]: + return [AgentConfig(name) for name in get_all_agent_names()] diff --git a/agentstack/cli/cli.py b/agentstack/cli/cli.py index 1ee4843..99a6a39 100644 --- a/agentstack/cli/cli.py +++ b/agentstack/cli/cli.py @@ -19,9 +19,11 @@ CookiecutterData, ) from agentstack.logger import log +from agentstack import conf +from agentstack.conf import ConfigFile from agentstack.utils import get_package_path from agentstack.tools import get_all_tools -from agentstack.generation.files import ConfigFile, ProjectFile +from agentstack.generation.files import ProjectFile from agentstack import frameworks from agentstack import generation from agentstack import inputs @@ -117,9 +119,12 @@ def init_project_builder( log.debug(f"project_details: {project_details}" f"framework: {framework}" f"design: {design}") insert_template(project_details, framework, design, template_data) - path = Path(project_details['name']) + + # we have an agentstack.json file in the directory now + conf.set_path(project_details['name']) + for tool_data in tools: - generation.add_tool(tool_data['name'], agents=tool_data['agents'], path=path) + generation.add_tool(tool_data['name'], agents=tool_data['agents']) def welcome_message(): @@ -135,9 +140,9 @@ def welcome_message(): print(border) -def configure_default_model(path: Optional[str] = None): +def configure_default_model(): """Set the default model""" - agentstack_config = ConfigFile(path) + agentstack_config = ConfigFile() if agentstack_config.default_model: return # Default model already set @@ -152,7 +157,7 @@ def configure_default_model(path: Optional[str] = None): print('A list of available models is available at: "https://docs.litellm.ai/docs/providers"') model = inquirer.text(message="Enter the model name") - with ConfigFile(path) as agentstack_config: + with ConfigFile() as agentstack_config: agentstack_config.default_model = model @@ -385,6 +390,7 @@ def insert_template( template_path = get_package_path() / f'templates/{framework.name}' with open(f"{template_path}/cookiecutter.json", "w") as json_file: json.dump(cookiecutter_data.to_dict(), json_file) + # TODO this should not be written to the package directory # copy .env.example to .env shutil.copy( @@ -453,22 +459,19 @@ def list_tools(): print(" https://docs.agentstack.sh/tools/core") -def export_template(output_filename: str, path: str = ''): +def export_template(output_filename: str): """ Export the current project as a template. """ - _path = Path(path) - framework = get_framework(_path) - try: - metadata = ProjectFile(_path) + metadata = ProjectFile() except Exception as e: print(term_color(f"Failed to load project metadata: {e}", 'red')) sys.exit(1) # Read all the agents from the project's agents.yaml file agents: list[TemplateConfig.Agent] = [] - for agent in get_all_agents(_path): + for agent in get_all_agents(): agents.append( TemplateConfig.Agent( name=agent.name, @@ -481,7 +484,7 @@ def export_template(output_filename: str, path: str = ''): # Read all the tasks from the project's tasks.yaml file tasks: list[TemplateConfig.Task] = [] - for task in get_all_tasks(_path): + for task in get_all_tasks(): tasks.append( TemplateConfig.Task( name=task.name, @@ -493,8 +496,8 @@ def export_template(output_filename: str, path: str = ''): # Export all of the configured tools from the project tools_agents: dict[str, list[str]] = {} - for agent_name in frameworks.get_agent_names(framework, _path): - for tool_name in frameworks.get_agent_tool_names(framework, agent_name, _path): + for agent_name in frameworks.get_agent_names(): + for tool_name in frameworks.get_agent_tool_names(agent_name): if not tool_name: continue if tool_name not in tools_agents: @@ -514,7 +517,7 @@ def export_template(output_filename: str, path: str = ''): template_version=2, name=metadata.project_name, description=metadata.project_description, - framework=framework, + framework=get_framework(), method="sequential", # TODO this needs to be stored in the project somewhere agents=agents, tasks=tasks, @@ -523,8 +526,8 @@ def export_template(output_filename: str, path: str = ''): ) try: - template.write_to_file(_path / output_filename) - print(term_color(f"Template saved to: {_path / output_filename}", 'green')) + template.write_to_file(conf.PATH / output_filename) + print(term_color(f"Template saved to: {conf.PATH / output_filename}", 'green')) except Exception as e: print(term_color(f"Failed to write template to file: {e}", 'red')) sys.exit(1) diff --git a/agentstack/cli/run.py b/agentstack/cli/run.py index 5368674..b5bb848 100644 --- a/agentstack/cli/run.py +++ b/agentstack/cli/run.py @@ -4,7 +4,8 @@ import importlib.util from dotenv import load_dotenv -from agentstack import ValidationError +from agentstack import conf +from agentstack.exceptions import ValidationError from agentstack import inputs from agentstack import frameworks from agentstack.utils import term_color, get_framework @@ -31,17 +32,14 @@ def _import_project_module(path: Path): return project_module -def run_project(command: str = 'run', path: Optional[str] = None, cli_args: Optional[str] = None): +def run_project(command: str = 'run', cli_args: Optional[str] = None): """Validate that the project is ready to run and then run it.""" - _path = Path(path) if path else Path.cwd() - framework = get_framework(_path) - - if framework not in frameworks.SUPPORTED_FRAMEWORKS: - print(term_color(f"Framework {framework} is not supported by agentstack.", 'red')) + if conf.get_framework() not in frameworks.SUPPORTED_FRAMEWORKS: + print(term_color(f"Framework {conf.get_framework()} is not supported by agentstack.", 'red')) sys.exit(1) try: - frameworks.validate_project(framework, _path) + frameworks.validate_project() except ValidationError as e: print(term_color(f"Project validation failed:\n{e}", 'red')) sys.exit(1) @@ -55,11 +53,11 @@ def run_project(command: str = 'run', path: Optional[str] = None, cli_args: Opti inputs.add_input_for_run(key, value) load_dotenv(Path.home() / '.env') # load the user's .env file - load_dotenv(_path / '.env', override=True) # load the project's .env file + load_dotenv(conf.PATH / '.env', override=True) # load the project's .env file # import src/main.py from the project path try: - project_main = _import_project_module(_path) + project_main = _import_project_module(conf.PATH) except ImportError as e: print(term_color(f"Failed to import project. Does '{MAIN_FILENAME}' exist?:\n{e}", 'red')) sys.exit(1) diff --git a/agentstack/conf.py b/agentstack/conf.py new file mode 100644 index 0000000..22003ca --- /dev/null +++ b/agentstack/conf.py @@ -0,0 +1,89 @@ +from typing import Optional, Union +import os, sys +import json +from pathlib import Path +from pydantic import BaseModel +from agentstack.utils import get_version + + +DEFAULT_FRAMEWORK = "crewai" +CONFIG_FILENAME = "agentstack.json" + +PATH: Path = Path() + + +def set_path(path: Union[str, Path, None]): + """Set the path to the project directory.""" + global PATH + PATH = Path(path) if path else Path() + + +def get_framework() -> Optional[str]: + """The framework used in the project. Will be available after PATH has been set + and if we are inside a project directory. + """ + try: + config = ConfigFile() + return config.framework + except FileNotFoundError: + return None # not in a project directory; that's okay + + +class ConfigFile(BaseModel): + """ + Interface for interacting with the agentstack.json file inside a project directory. + Handles both data validation and file I/O. + + Use it as a context manager to make and save edits: + ```python + with ConfigFile() as config: + config.tools.append('tool_name') + ``` + + Config Schema + ------------- + framework: str + The framework used in the project. Defaults to 'crewai'. + tools: list[str] + A list of tools that are currently installed in the project. + telemetry_opt_out: Optional[bool] + Whether the user has opted out of telemetry. + default_model: Optional[str] + The default model to use when generating agent configurations. + agentstack_version: Optional[str] + The version of agentstack used to generate the project. + template: Optional[str] + The template used to generate the project. + template_version: Optional[str] + The version of the template system used to generate the project. + """ + + framework: str = DEFAULT_FRAMEWORK # TODO this should probably default to None + tools: list[str] = [] + telemetry_opt_out: Optional[bool] = None + default_model: Optional[str] = None + agentstack_version: Optional[str] = get_version() + template: Optional[str] = None + template_version: Optional[str] = None + + def __init__(self): + if os.path.exists(PATH / CONFIG_FILENAME): + with open(PATH / CONFIG_FILENAME, 'r') as f: + super().__init__(**json.loads(f.read())) + else: + raise FileNotFoundError(f"File {PATH / CONFIG_FILENAME} does not exist.") + + def model_dump(self, *args, **kwargs) -> dict: + # Ignore None values + dump = super().model_dump(*args, **kwargs) + return {key: value for key, value in dump.items() if value is not None} + + def write(self): + with open(PATH / CONFIG_FILENAME, 'w') as f: + f.write(json.dumps(self.model_dump(), indent=4)) + + def __enter__(self) -> 'ConfigFile': + return self + + def __exit__(self, *args): + self.write() diff --git a/agentstack/frameworks/__init__.py b/agentstack/frameworks/__init__.py index 9c828d9..1a38e5e 100644 --- a/agentstack/frameworks/__init__.py +++ b/agentstack/frameworks/__init__.py @@ -2,7 +2,9 @@ from types import ModuleType from importlib import import_module from pathlib import Path -from agentstack import ValidationError +from agentstack import conf +from agentstack.exceptions import ValidationError +from agentstack.utils import get_framework from agentstack.tools import ToolConfig from agentstack.agents import AgentConfig from agentstack.tasks import TaskConfig @@ -21,50 +23,50 @@ class FrameworkModule(Protocol): ie. `src/crewai.py` """ - def validate_project(self, path: Optional[Path] = None) -> None: + def validate_project(self) -> None: """ Validate that a user's project is ready to run. Raises a `ValidationError` if the project is not valid. """ ... - def get_tool_names(self, path: Optional[Path] = None) -> list[str]: + def get_tool_names(self) -> list[str]: """ Get a list of tool names in the user's project. """ ... - def add_tool(self, tool: ToolConfig, agent_name: str, path: Optional[Path] = None) -> None: + def add_tool(self, tool: ToolConfig, agent_name: str) -> None: """ Add a tool to an agent in the user's project. """ ... - def remove_tool(self, tool: ToolConfig, agent_name: str, path: Optional[Path] = None) -> None: + def remove_tool(self, tool: ToolConfig, agent_name: str) -> None: """ Remove a tool from an agent in user's project. """ ... - def get_agent_names(self, path: Optional[Path] = None) -> list[str]: + def get_agent_names(self) -> list[str]: """ Get a list of agent names in the user's project. """ ... - def get_agent_tool_names(self, agent_name: str, path: Optional[Path] = None) -> list[str]: + def get_agent_tool_names(self, agent_name: str) -> list[str]: """ Get a list of tool names in an agent in the user's project. """ ... - def add_agent(self, agent: AgentConfig, path: Optional[Path] = None) -> None: + def add_agent(self, agent: AgentConfig) -> None: """ Add an agent to the user's project. """ ... - def add_task(self, task: TaskConfig, path: Optional[Path] = None) -> None: + def add_task(self, task: TaskConfig) -> None: """ Add a task to the user's project. """ @@ -80,55 +82,53 @@ def get_framework_module(framework: str) -> FrameworkModule: except ImportError: raise Exception(f"Framework {framework} could not be imported.") -def get_entrypoint_path(framework: str, path: Optional[Path] = None) -> Path: +def get_entrypoint_path(framework: str) -> Path: """ Get the path to the entrypoint file for a framework. """ - if path is None: - path = Path() - return path / get_framework_module(framework).ENTRYPOINT + return conf.PATH / get_framework_module(framework).ENTRYPOINT -def validate_project(framework: str, path: Optional[Path] = None): +def validate_project(): """ Validate that the user's project is ready to run. """ - return get_framework_module(framework).validate_project(path) + return get_framework_module(get_framework()).validate_project() -def add_tool(framework: str, tool: ToolConfig, agent_name: str, path: Optional[Path] = None): +def add_tool(tool: ToolConfig, agent_name: str): """ Add a tool to the user's project. The tool will have aready been installed in the user's application and have all dependencies installed. We're just handling code generation here. """ - return get_framework_module(framework).add_tool(tool, agent_name, path) + return get_framework_module(get_framework()).add_tool(tool, agent_name) -def remove_tool(framework: str, tool: ToolConfig, agent_name: str, path: Optional[Path] = None): +def remove_tool(tool: ToolConfig, agent_name: str): """ Remove a tool from the user's project. """ - return get_framework_module(framework).remove_tool(tool, agent_name, path) + return get_framework_module(get_framework()).remove_tool(tool, agent_name) -def get_agent_names(framework: str, path: Optional[Path] = None) -> list[str]: +def get_agent_names() -> list[str]: """ Get a list of agent names in the user's project. """ - return get_framework_module(framework).get_agent_names(path) + return get_framework_module(get_framework()).get_agent_names() -def get_agent_tool_names(framework: str, agent_name: str, path: Optional[Path] = None) -> list[str]: +def get_agent_tool_names(agent_name: str) -> list[str]: """ Get a list of tool names in the user's project. """ - return get_framework_module(framework).get_agent_tool_names(agent_name, path) + return get_framework_module(get_framework()).get_agent_tool_names(agent_name) -def add_agent(framework: str, agent: AgentConfig, path: Optional[Path] = None): +def add_agent(agent: AgentConfig): """ Add an agent to the user's project. """ - return get_framework_module(framework).add_agent(agent, path) + return get_framework_module(get_framework()).add_agent(agent) -def add_task(framework: str, task: TaskConfig, path: Optional[Path] = None): +def add_task(task: TaskConfig): """ Add a task to the user's project. """ - return get_framework_module(framework).add_task(task, path) + return get_framework_module(get_framework()).add_task(task) diff --git a/agentstack/frameworks/crewai.py b/agentstack/frameworks/crewai.py index 2a3e0c3..a88d4cf 100644 --- a/agentstack/frameworks/crewai.py +++ b/agentstack/frameworks/crewai.py @@ -1,7 +1,8 @@ from typing import Optional, Any from pathlib import Path import ast -from agentstack import ValidationError +from agentstack import conf +from agentstack.exceptions import ValidationError from agentstack.tools import ToolConfig from agentstack.tasks import TaskConfig from agentstack.agents import AgentConfig @@ -190,15 +191,13 @@ def remove_agent_tools(self, agent_name: str, tool: ToolConfig): self.edit_node_range(start, end, existing_node) -def validate_project(path: Optional[Path] = None) -> None: +def validate_project() -> None: """ Validate that a CrewAI project is ready to run. Raises an `agentstack.VaidationError` if the project is not valid. """ - if path is None: - path = Path() try: - crew_file = CrewFile(path / ENTRYPOINT) + crew_file = CrewFile(conf.PATH / ENTRYPOINT) except ValidationError as e: raise e @@ -229,74 +228,59 @@ def validate_project(path: Optional[Path] = None) -> None: ) -def get_task_names(path: Optional[Path] = None) -> list[str]: +def get_task_names() -> list[str]: """ Get a list of task names (methods with an @task decorator). """ - if path is None: - path = Path() - crew_file = CrewFile(path / ENTRYPOINT) + crew_file = CrewFile(conf.PATH / ENTRYPOINT) return [method.name for method in crew_file.get_task_methods()] -def add_task(task: TaskConfig, path: Optional[Path] = None) -> None: +def add_task(task: TaskConfig) -> None: """ Add a task method to the CrewAI entrypoint. """ - if path is None: - path = Path() - with CrewFile(path / ENTRYPOINT) as crew_file: + with CrewFile(conf.PATH / ENTRYPOINT) as crew_file: crew_file.add_task_method(task) -def get_agent_names(path: Optional[Path] = None) -> list[str]: +def get_agent_names() -> list[str]: """ Get a list of agent names (methods with an @agent decorator). """ - if path is None: - path = Path() - crew_file = CrewFile(path / ENTRYPOINT) + crew_file = CrewFile(conf.PATH / ENTRYPOINT) return [method.name for method in crew_file.get_agent_methods()] -def get_agent_tool_names(agent_name: str, path: Optional[Path] = None) -> list[Any]: +def get_agent_tool_names(agent_name: str) -> list[Any]: """ Get a list of tools used by an agent. """ - if path is None: - path = Path() - with CrewFile(path / ENTRYPOINT) as crew_file: + with CrewFile(conf.PATH / ENTRYPOINT) as crew_file: tools = crew_file.get_agent_tools(agent_name) - print([node for node in tools.elts]) return [asttools.get_node_value(node) for node in tools.elts] -def add_agent(agent: AgentConfig, path: Optional[Path] = None) -> None: +def add_agent(agent: AgentConfig) -> None: """ Add an agent method to the CrewAI entrypoint. """ - if path is None: - path = Path() - with CrewFile(path / ENTRYPOINT) as crew_file: + with CrewFile(conf.PATH / ENTRYPOINT) as crew_file: crew_file.add_agent_method(agent) -def add_tool(tool: ToolConfig, agent_name: str, path: Optional[Path] = None): +def add_tool(tool: ToolConfig, agent_name: str): """ Add a tool to the CrewAI entrypoint for the specified agent. The agent should already exist in the crew class and have a keyword argument `tools`. """ - if path is None: - path = Path() - with CrewFile(path / ENTRYPOINT) as crew_file: + with CrewFile(conf.PATH / ENTRYPOINT) as crew_file: crew_file.add_agent_tools(agent_name, tool) -def remove_tool(tool: ToolConfig, agent_name: str, path: Optional[Path] = None): +def remove_tool(tool: ToolConfig, agent_name: str): """ Remove a tool from the CrewAI framework for the specified agent. """ - if path is None: - path = Path() - with CrewFile(path / ENTRYPOINT) as crew_file: + with CrewFile(conf.PATH / ENTRYPOINT) as crew_file: crew_file.remove_agent_tools(agent_name, tool) diff --git a/agentstack/generation/__init__.py b/agentstack/generation/__init__.py index 82e2eb5..477b899 100644 --- a/agentstack/generation/__init__.py +++ b/agentstack/generation/__init__.py @@ -1,4 +1,4 @@ from .agent_generation import add_agent from .task_generation import add_task from .tool_generation import add_tool, remove_tool -from .files import ConfigFile, EnvFile, CONFIG_FILENAME \ No newline at end of file +from .files import EnvFile, ProjectFile \ No newline at end of file diff --git a/agentstack/generation/agent_generation.py b/agentstack/generation/agent_generation.py index 502b7d8..31bbd63 100644 --- a/agentstack/generation/agent_generation.py +++ b/agentstack/generation/agent_generation.py @@ -1,11 +1,11 @@ import sys from typing import Optional from pathlib import Path -from agentstack import ValidationError +from agentstack.exceptions import ValidationError +from agentstack.conf import ConfigFile from agentstack import frameworks from agentstack.utils import verify_agentstack_project from agentstack.agents import AgentConfig, AGENTS_FILENAME -from agentstack.generation.files import ConfigFile def add_agent( @@ -14,15 +14,11 @@ def add_agent( goal: Optional[str] = None, backstory: Optional[str] = None, llm: Optional[str] = None, - path: Optional[Path] = None, ): - if path is None: - path = Path() - verify_agentstack_project(path) - agentstack_config = ConfigFile(path) - framework = agentstack_config.framework + agentstack_config = ConfigFile() + verify_agentstack_project() - agent = AgentConfig(agent_name, path) + agent = AgentConfig(agent_name) with agent as config: config.role = role or "Add your role here" config.goal = goal or "Add your goal here" @@ -30,7 +26,7 @@ def add_agent( config.llm = llm or agentstack_config.default_model or "" try: - frameworks.add_agent(framework, agent, path) + frameworks.add_agent(agent) print(f" > Added to {AGENTS_FILENAME}") except ValidationError as e: print(f"Error adding agent to project:\n{e}") diff --git a/agentstack/generation/asttools.py b/agentstack/generation/asttools.py index 9ab2f04..7d25f1a 100644 --- a/agentstack/generation/asttools.py +++ b/agentstack/generation/asttools.py @@ -14,7 +14,7 @@ import ast import astor import asttokens -from agentstack import ValidationError +from agentstack.exceptions import ValidationError FileT = TypeVar('FileT', bound='File') diff --git a/agentstack/generation/files.py b/agentstack/generation/files.py index 38f1ca8..f2ad90a 100644 --- a/agentstack/generation/files.py +++ b/agentstack/generation/files.py @@ -1,87 +1,18 @@ from typing import Optional, Union import os, sys -import json from pathlib import Path -from pydantic import BaseModel if sys.version_info >= (3, 11): import tomllib else: import tomli as tomllib -from agentstack.utils import get_version +from agentstack import conf -DEFAULT_FRAMEWORK = "crewai" -CONFIG_FILENAME = "agentstack.json" ENV_FILEMANE = ".env" PYPROJECT_FILENAME = "pyproject.toml" -class ConfigFile(BaseModel): - """ - Interface for interacting with the agentstack.json file inside a project directory. - Handles both data validation and file I/O. - - `path` is the directory where the agentstack.json file is located. Defaults - to the current working directory. - - Use it as a context manager to make and save edits: - ```python - with ConfigFile() as config: - config.tools.append('tool_name') - ``` - - Config Schema - ------------- - framework: str - The framework used in the project. Defaults to 'crewai'. - tools: list[str] - A list of tools that are currently installed in the project. - telemetry_opt_out: Optional[bool] - Whether the user has opted out of telemetry. - default_model: Optional[str] - The default model to use when generating agent configurations. - agentstack_version: Optional[str] - The version of agentstack used to generate the project. - template: Optional[str] - The template used to generate the project. - template_version: Optional[str] - The version of the template system used to generate the project. - """ - - framework: str = DEFAULT_FRAMEWORK - tools: list[str] = [] - telemetry_opt_out: Optional[bool] = None - default_model: Optional[str] = None - agentstack_version: Optional[str] = get_version() - template: Optional[str] = None - template_version: Optional[str] = None - - def __init__(self, path: Union[str, Path, None] = None): - path = Path(path) if path else Path.cwd() - if os.path.exists(path / CONFIG_FILENAME): - with open(path / CONFIG_FILENAME, 'r') as f: - super().__init__(**json.loads(f.read())) - else: - raise FileNotFoundError(f"File {path / CONFIG_FILENAME} does not exist.") - self._path = path # attribute needs to be set after init - - def model_dump(self, *args, **kwargs) -> dict: - # Ignore None values - dump = super().model_dump(*args, **kwargs) - return {key: value for key, value in dump.items() if value is not None} - - def write(self): - with open(self._path / CONFIG_FILENAME, 'w') as f: - f.write(json.dumps(self.model_dump(), indent=4)) - - def __enter__(self) -> 'ConfigFile': - return self - - def __exit__(self, *args): - self.write() - - class EnvFile: """ Interface for interacting with the .env file inside a project directory. @@ -103,8 +34,7 @@ class EnvFile: variables: dict[str, str] - def __init__(self, path: Union[str, Path, None] = None, filename: str = ENV_FILEMANE): - self._path = Path(path) if path else Path.cwd() + def __init__(self, filename: str = ENV_FILEMANE): self._filename = filename self.read() @@ -129,15 +59,15 @@ def parse_line(line): key, value = line.split('=') return key.strip(), value.strip() - if os.path.exists(self._path / self._filename): - with open(self._path / self._filename, 'r') as f: + if os.path.exists(conf.PATH / self._filename): + with open(conf.PATH / self._filename, 'r') as f: self.variables = dict([parse_line(line) for line in f.readlines() if '=' in line]) else: self.variables = {} self._new_variables = {} def write(self): - with open(self._path / self._filename, 'a') as f: + with open(conf.PATH / self._filename, 'a') as f: for key, value in self._new_variables.items(): f.write(f"\n{key}={value}") @@ -158,8 +88,7 @@ class ProjectFile: _data: dict - def __init__(self, path: Union[str, Path, None] = None, filename: str = PYPROJECT_FILENAME): - self._path = Path(path) if path else Path.cwd() + def __init__(self, filename: str = PYPROJECT_FILENAME): self._filename = filename self.read() @@ -183,8 +112,8 @@ def project_description(self) -> str: return self.project_metadata.get('description', '') def read(self): - if os.path.exists(self._path / self._filename): - with open(self._path / self._filename, 'rb') as f: + if os.path.exists(conf.PATH / self._filename): + with open(conf.PATH / self._filename, 'rb') as f: self._data = tomllib.load(f) else: - raise FileNotFoundError(f"File {self._path / self._filename} does not exist.") + raise FileNotFoundError(f"File {conf.PATH / self._filename} does not exist.") diff --git a/agentstack/generation/task_generation.py b/agentstack/generation/task_generation.py index a6e1d66..f15f7e5 100644 --- a/agentstack/generation/task_generation.py +++ b/agentstack/generation/task_generation.py @@ -1,11 +1,10 @@ import sys from typing import Optional from pathlib import Path -from agentstack import ValidationError +from agentstack.exceptions import ValidationError from agentstack import frameworks from agentstack.utils import verify_agentstack_project from agentstack.tasks import TaskConfig, TASKS_FILENAME -from agentstack.generation.files import ConfigFile def add_task( @@ -13,22 +12,17 @@ def add_task( description: Optional[str] = None, expected_output: Optional[str] = None, agent: Optional[str] = None, - path: Optional[Path] = None, ): - if path is None: - path = Path() - verify_agentstack_project(path) - agentstack_config = ConfigFile(path) - framework = agentstack_config.framework + verify_agentstack_project() - task = TaskConfig(task_name, path) + task = TaskConfig(task_name) with task as config: config.description = description or "Add your description here" config.expected_output = expected_output or "Add your expected_output here" config.agent = agent or "agent_name" try: - frameworks.add_task(framework, task, path) + frameworks.add_task(task) print(f" > Added to {TASKS_FILENAME}") except ValidationError as e: print(f"Error adding task to project:\n{e}") diff --git a/agentstack/generation/tool_generation.py b/agentstack/generation/tool_generation.py index 4ecb2b2..0490e84 100644 --- a/agentstack/generation/tool_generation.py +++ b/agentstack/generation/tool_generation.py @@ -5,13 +5,15 @@ import shutil import ast +from agentstack import conf +from agentstack.conf import ConfigFile +from agentstack.exceptions import ValidationError from agentstack import frameworks from agentstack import packaging -from agentstack import ValidationError from agentstack.utils import term_color from agentstack.tools import ToolConfig from agentstack.generation import asttools -from agentstack.generation.files import ConfigFile, EnvFile +from agentstack.generation.files import EnvFile # This is the filename of the location of tool imports in the user's project. @@ -45,7 +47,7 @@ def get_import_for_tool(self, tool: ToolConfig) -> Optional[ast.ImportFrom]: except IndexError: return None - def add_import_for_tool(self, framework: str, tool: ToolConfig): + def add_import_for_tool(self, tool: ToolConfig, framework: str): """ Add an import for a tool. raises a ValidationError if the tool is already imported. @@ -63,7 +65,7 @@ def add_import_for_tool(self, framework: str, tool: ToolConfig): import_statement = tool.get_import_statement(framework) self.edit_node_range(end, end, f"\n{import_statement}") - def remove_import_for_tool(self, framework: str, tool: ToolConfig): + def remove_import_for_tool(self, tool: ToolConfig, framework: str): """ Remove an import for a tool. raises a ValidationError if the tool is not imported. @@ -76,42 +78,39 @@ def remove_import_for_tool(self, framework: str, tool: ToolConfig): self.edit_node_range(start, end, "") -def add_tool(tool_name: str, agents: Optional[list[str]] = [], path: Optional[Path] = None): - if path is None: - path = Path() - agentstack_config = ConfigFile(path) - framework = agentstack_config.framework +def add_tool(tool_name: str, agents: Optional[list[str]] = []): + agentstack_config = ConfigFile() if tool_name in agentstack_config.tools: print(term_color(f'Tool {tool_name} is already installed', 'red')) sys.exit(1) tool = ToolConfig.from_tool_name(tool_name) - tool_file_path = tool.get_impl_file_path(framework) + tool_file_path = tool.get_impl_file_path(agentstack_config.framework) if tool.packages: packaging.install(' '.join(tool.packages)) # Move tool from package to project - shutil.copy(tool_file_path, path / f'src/tools/{tool.module_name}.py') + shutil.copy(tool_file_path, conf.PATH / f'src/tools/{tool.module_name}.py') try: # Edit the user's project tool init file to include the tool - with ToolsInitFile(path / TOOLS_INIT_FILENAME) as tools_init: - tools_init.add_import_for_tool(framework, tool) + with ToolsInitFile(conf.PATH / TOOLS_INIT_FILENAME) as tools_init: + tools_init.add_import_for_tool(tool, agentstack_config.framework) except ValidationError as e: print(term_color(f"Error adding tool:\n{e}", 'red')) # Edit the framework entrypoint file to include the tool in the agent definition if not agents: # If no agents are specified, add the tool to all agents - agents = frameworks.get_agent_names(framework, path) + agents = frameworks.get_agent_names() for agent_name in agents: - frameworks.add_tool(framework, tool, agent_name, path) + frameworks.add_tool(tool, agent_name) if tool.env: # add environment variables which don't exist - with EnvFile(path) as env: + with EnvFile() as env: for var, value in tool.env.items(): env.append_if_new(var, value) - with EnvFile(path, filename=".env.example") as env: + with EnvFile(".env.example") as env: for var, value in tool.env.items(): env.append_if_new(var, value) @@ -126,11 +125,8 @@ def add_tool(tool_name: str, agents: Optional[list[str]] = [], path: Optional[Pa print(term_color(f'🪩 {tool.cta}', 'blue')) -def remove_tool(tool_name: str, agents: Optional[list[str]] = [], path: Optional[Path] = None): - if path is None: - path = Path() - agentstack_config = ConfigFile(path) - framework = agentstack_config.framework +def remove_tool(tool_name: str, agents: Optional[list[str]] = []): + agentstack_config = ConfigFile() if tool_name not in agentstack_config.tools: print(term_color(f'Tool {tool_name} is not installed', 'red')) @@ -142,21 +138,21 @@ def remove_tool(tool_name: str, agents: Optional[list[str]] = [], path: Optional # TODO ensure that other agents in the project are not using the tool. try: - os.remove(path / f'src/tools/{tool.module_name}.py') + os.remove(conf.PATH / f'src/tools/{tool.module_name}.py') except FileNotFoundError: print(f'"src/tools/{tool.module_name}.py" not found') try: # Edit the user's project tool init file to exclude the tool - with ToolsInitFile(path / TOOLS_INIT_FILENAME) as tools_init: - tools_init.remove_import_for_tool(framework, tool) + with ToolsInitFile(conf.PATH / TOOLS_INIT_FILENAME) as tools_init: + tools_init.remove_import_for_tool(tool, agentstack_config.framework) except ValidationError as e: print(term_color(f"Error removing tool:\n{e}", 'red')) # Edit the framework entrypoint file to exclude the tool in the agent definition if not agents: # If no agents are specified, remove the tool from all agents - agents = frameworks.get_agent_names(framework, path) + agents = frameworks.get_agent_names() for agent_name in agents: - frameworks.remove_tool(framework, tool, agent_name, path) + frameworks.remove_tool(tool, agent_name) if tool.post_remove: os.system(tool.post_remove) diff --git a/agentstack/inputs.py b/agentstack/inputs.py index 209d5a5..248e0d7 100644 --- a/agentstack/inputs.py +++ b/agentstack/inputs.py @@ -3,7 +3,8 @@ from pathlib import Path from ruamel.yaml import YAML, YAMLError from ruamel.yaml.scalarstring import FoldedScalarString -from agentstack import ValidationError +from agentstack import conf +from agentstack.exceptions import ValidationError INPUTS_FILENAME: Path = Path("src/config/inputs.yaml") @@ -28,9 +29,8 @@ class InputsConfig: _attributes: dict[str, str] - def __init__(self, path: Optional[Path] = None): - self.path = path if path else Path() - filename = self.path / INPUTS_FILENAME + def __init__(self): + filename = conf.PATH / INPUTS_FILENAME if not os.path.exists(filename): os.makedirs(filename.parent, exist_ok=True) @@ -62,7 +62,7 @@ def model_dump(self) -> dict: return dump def write(self): - with open(self.path / INPUTS_FILENAME, 'w') as f: + with open(conf.PATH / INPUTS_FILENAME, 'w') as f: yaml.dump(self.model_dump(), f) def __enter__(self) -> 'InputsConfig': @@ -72,12 +72,11 @@ def __exit__(self, *args): self.write() -def get_inputs(path: Optional[Path] = None) -> dict: +def get_inputs() -> dict: """ Get the inputs configuration file and override with run_inputs. """ - path = path if path else Path() - config = InputsConfig(path).to_dict() + config = InputsConfig().to_dict() # run_inputs override saved inputs for key, value in run_inputs.items(): config[key] = value diff --git a/agentstack/main.py b/agentstack/main.py index e5e004f..1ac0457 100644 --- a/agentstack/main.py +++ b/agentstack/main.py @@ -1,6 +1,8 @@ -import argparse import sys +import argparse +import webbrowser +from agentstack import conf from agentstack.cli import ( init_project_builder, list_tools, @@ -13,12 +15,18 @@ from agentstack import generation from agentstack.update import check_for_updates -import webbrowser - def main(): + global_parser = argparse.ArgumentParser(add_help=False) + global_parser.add_argument( + "--path", + "-p", + help="Path to the project directory, defaults to current working directory", + dest="project_path", + ) + parser = argparse.ArgumentParser( - description="AgentStack CLI - The easiest way to build an agent application" + parents=[global_parser], description="AgentStack CLI - The easiest way to build an agent application" ) parser.add_argument("-v", "--version", action="store_true", help="Show the version") @@ -36,7 +44,9 @@ def main(): subparsers.add_parser("templates", help="View Agentstack templates") # 'init' command - init_parser = subparsers.add_parser("init", aliases=["i"], help="Initialize a directory for the project") + init_parser = subparsers.add_parser( + "init", aliases=["i"], help="Initialize a directory for the project", parents=[global_parser] + ) init_parser.add_argument("slug_name", nargs="?", help="The directory name to place the project in") init_parser.add_argument("--wizard", "-w", action="store_true", help="Use the setup wizard") init_parser.add_argument("--template", "-t", help="Agent template to use") @@ -46,6 +56,7 @@ def main(): "run", aliases=["r"], help="Run your agent", + parents=[global_parser], formatter_class=argparse.RawDescriptionHelpFormatter, epilog=''' --input-=VALUE Specify inputs to be passed to the run. @@ -60,15 +71,11 @@ def main(): default="run", dest="function", ) - run_parser.add_argument( - "--path", - "-p", - help="Path to the project directory, defaults to current working directory", - dest="path", - ) # 'generate' command - generate_parser = subparsers.add_parser("generate", aliases=["g"], help="Generate agents or tasks") + generate_parser = subparsers.add_parser( + "generate", aliases=["g"], help="Generate agents or tasks", parents=[global_parser] + ) # Subparsers under 'generate' generate_subparsers = generate_parser.add_subparsers( @@ -76,7 +83,9 @@ def main(): ) # 'agent' command under 'generate' - agent_parser = generate_subparsers.add_parser("agent", aliases=["a"], help="Generate an agent") + agent_parser = generate_subparsers.add_parser( + "agent", aliases=["a"], help="Generate an agent", parents=[global_parser] + ) agent_parser.add_argument("name", help="Name of the agent") agent_parser.add_argument("--role", "-r", help="Role of the agent") agent_parser.add_argument("--goal", "-g", help="Goal of the agent") @@ -84,7 +93,9 @@ def main(): agent_parser.add_argument("--llm", "-l", help="Language model to use") # 'task' command under 'generate' - task_parser = generate_subparsers.add_parser("task", aliases=["t"], help="Generate a task") + task_parser = generate_subparsers.add_parser( + "task", aliases=["t"], help="Generate a task", parents=[global_parser] + ) task_parser.add_argument("name", help="Name of the task") task_parser.add_argument("--description", "-d", help="Description of the task") task_parser.add_argument("--expected_output", "-e", help="Expected output of the task") @@ -100,7 +111,9 @@ def main(): _ = tools_subparsers.add_parser("list", aliases=["l"], help="List tools") # 'add' command under 'tools' - tools_add_parser = tools_subparsers.add_parser("add", aliases=["a"], help="Add a new tool") + tools_add_parser = tools_subparsers.add_parser( + "add", aliases=["a"], help="Add a new tool", parents=[global_parser] + ) tools_add_parser.add_argument("name", help="Name of the tool to add") tools_add_parser.add_argument( "--agents", "-a", help="Name of agents to add this tool to, comma separated" @@ -108,21 +121,28 @@ def main(): tools_add_parser.add_argument("--agent", help="Name of agent to add this tool to") # 'remove' command under 'tools' - tools_remove_parser = tools_subparsers.add_parser("remove", aliases=["r"], help="Remove a tool") + tools_remove_parser = tools_subparsers.add_parser( + "remove", aliases=["r"], help="Remove a tool", parents=[global_parser] + ) tools_remove_parser.add_argument("name", help="Name of the tool to remove") - export_parser = subparsers.add_parser('export', aliases=['e'], help='Export your agent as a template') + export_parser = subparsers.add_parser( + 'export', aliases=['e'], help='Export your agent as a template', parents=[global_parser] + ) export_parser.add_argument('filename', help='The name of the file to export to') - update = subparsers.add_parser('update', aliases=['u'], help='Check for updates') + update = subparsers.add_parser('update', aliases=['u'], help='Check for updates', parents=[global_parser]) # Parse known args and store unknown args in extras; some commands use them later on args, extra_args = parser.parse_known_args() + # Set the project path from --path if it is provided in the global_parser + conf.set_path(args.project_path) + # Handle version if args.version: print(f"AgentStack CLI version: {get_version()}") - return + sys.exit(0) track_cli_command(args.command) check_for_updates(update_requested=args.command in ('update', 'u')) @@ -137,7 +157,7 @@ def main(): elif args.command in ["init", "i"]: init_project_builder(args.slug_name, args.template, args.wizard) elif args.command in ["run", "r"]: - run_project(command=args.function, path=args.path, cli_args=extra_args) + run_project(command=args.function, cli_args=extra_args) elif args.command in ['generate', 'g']: if args.generate_command in ['agent', 'a']: if not args.llm: diff --git a/agentstack/proj_templates.py b/agentstack/proj_templates.py index a3cc112..e90aaca 100644 --- a/agentstack/proj_templates.py +++ b/agentstack/proj_templates.py @@ -4,7 +4,7 @@ import pydantic import requests import json -from agentstack import ValidationError +from agentstack.exceptions import ValidationError from agentstack.utils import get_package_path diff --git a/agentstack/tasks.py b/agentstack/tasks.py index 4600fd1..f5e7984 100644 --- a/agentstack/tasks.py +++ b/agentstack/tasks.py @@ -4,7 +4,8 @@ import pydantic from ruamel.yaml import YAML, YAMLError from ruamel.yaml.scalarstring import FoldedScalarString -from agentstack import ValidationError +from agentstack import conf +from agentstack.exceptions import ValidationError TASKS_FILENAME: Path = Path("src/config/tasks.yaml") @@ -42,11 +43,8 @@ class TaskConfig(pydantic.BaseModel): expected_output: str = "" agent: str = "" - def __init__(self, name: str, path: Optional[Path] = None): - if not path: - path = Path() - - filename = path / TASKS_FILENAME + def __init__(self, name: str): + filename = conf.PATH / TASKS_FILENAME if not os.path.exists(filename): os.makedirs(filename.parent, exist_ok=True) filename.touch() @@ -65,9 +63,6 @@ def __init__(self, name: str, path: Optional[Path] = None): error_str += f"{' '.join([str(loc) for loc in error['loc']])}: {error['msg']}\n" raise ValidationError(f"Error loading task {name} from {filename}.\n{error_str}") - # store the path *after* loading data - self._path = path - def model_dump(self, *args, **kwargs) -> dict: dump = super().model_dump(*args, **kwargs) dump.pop('name') # name is the key, so keep it out of the data @@ -77,7 +72,7 @@ def model_dump(self, *args, **kwargs) -> dict: return {self.name: dump} def write(self): - filename = self._path / TASKS_FILENAME + filename = conf.PATH / TASKS_FILENAME with open(filename, 'r') as f: data = yaml.load(f) or {} @@ -94,10 +89,8 @@ def __exit__(self, *args): self.write() -def get_all_task_names(path: Optional[Path] = None) -> list[str]: - if not path: - path = Path() - filename = path / TASKS_FILENAME +def get_all_task_names() -> list[str]: + filename = conf.PATH / TASKS_FILENAME if not os.path.exists(filename): return [] with open(filename, 'r') as f: @@ -105,5 +98,5 @@ def get_all_task_names(path: Optional[Path] = None) -> list[str]: return list(data.keys()) -def get_all_tasks(path: Optional[Path] = None) -> list[TaskConfig]: - return [TaskConfig(name, path) for name in get_all_task_names(path)] +def get_all_tasks() -> list[TaskConfig]: + return [TaskConfig(name) for name in get_all_task_names()] diff --git a/agentstack/telemetry.py b/agentstack/telemetry.py index 1e8c3e4..6fc0165 100644 --- a/agentstack/telemetry.py +++ b/agentstack/telemetry.py @@ -28,6 +28,7 @@ import socket import psutil import requests +from agentstack import conf from agentstack.utils import get_telemetry_opt_out, get_framework, get_version TELEMETRY_URL = 'https://api.agentstack.sh/telemetry' @@ -48,7 +49,7 @@ def collect_machine_telemetry(command: str): } if command != "init": - telemetry_data['framework'] = get_framework() + telemetry_data['framework'] = conf.get_framework() else: telemetry_data['framework'] = "n/a" diff --git a/agentstack/utils.py b/agentstack/utils.py index 647e4ee..e1564dc 100644 --- a/agentstack/utils.py +++ b/agentstack/utils.py @@ -6,6 +6,7 @@ from importlib.metadata import version from pathlib import Path import importlib.resources +from agentstack import conf def get_version(package: str = 'agentstack'): @@ -16,11 +17,9 @@ def get_version(package: str = 'agentstack'): return "Unknown version" -def verify_agentstack_project(path: Optional[Path] = None): - from agentstack.generation import ConfigFile - +def verify_agentstack_project(): try: - agentstack_config = ConfigFile(path) + agentstack_config = conf.ConfigFile() except FileNotFoundError: print( "\033[31mAgentStack Error: This does not appear to be an AgentStack project." @@ -37,35 +36,25 @@ def get_package_path() -> Path: return importlib.resources.files('agentstack') # type: ignore[return-value] -def get_framework(path: Union[str, Path, None] = None) -> str: - from agentstack.generation import ConfigFile - - try: - agentstack_config = ConfigFile(path) - framework = agentstack_config.framework - - if framework.lower() not in ['crewai', 'autogen', 'litellm']: - print(term_color("agentstack.json contains an invalid framework", "red")) - - return framework - except FileNotFoundError: - print("\033[31mFile agentstack.json does not exist. Are you in the right directory?\033[0m") - sys.exit(1) +def get_framework() -> str: + """Assert that we're inside a valid project and return the framework name.""" + verify_agentstack_project() + framework = conf.get_framework() + assert framework # verify_agentstack_project should catch this + return framework -def get_telemetry_opt_out(path: Optional[str] = None) -> bool: +def get_telemetry_opt_out() -> bool: """ Gets the telemetry opt out setting. First checks the environment variable AGENTSTACK_TELEMETRY_OPT_OUT. If that is not set, it checks the agentstack.json file. Otherwise we can assume the user has not opted out. """ - from agentstack.generation import ConfigFile - try: return bool(os.environ['AGENTSTACK_TELEMETRY_OPT_OUT']) except KeyError: - agentstack_config = ConfigFile(path) + agentstack_config = conf.ConfigFile() return bool(agentstack_config.telemetry_opt_out) except FileNotFoundError: return False diff --git a/tests/test_agents_config.py b/tests/test_agents_config.py index 657d931..5b11019 100644 --- a/tests/test_agents_config.py +++ b/tests/test_agents_config.py @@ -4,6 +4,7 @@ import unittest import importlib.resources from pathlib import Path +from agentstack import conf from agentstack.agents import AgentConfig, AGENTS_FILENAME BASE_PATH = Path(__file__).parent @@ -12,13 +13,14 @@ class AgentConfigTest(unittest.TestCase): def setUp(self): self.project_dir = BASE_PATH / 'tmp/agent_config' + conf.set_path(self.project_dir) os.makedirs(self.project_dir / 'src/config') def tearDown(self): shutil.rmtree(self.project_dir) def test_empty_file(self): - config = AgentConfig("agent_name", self.project_dir) + config = AgentConfig("agent_name") assert config.name == "agent_name" assert config.role is "" assert config.goal is "" @@ -27,7 +29,7 @@ def test_empty_file(self): def test_read_minimal_yaml(self): shutil.copy(BASE_PATH / "fixtures/agents_min.yaml", self.project_dir / AGENTS_FILENAME) - config = AgentConfig("agent_name", self.project_dir) + config = AgentConfig("agent_name") assert config.name == "agent_name" assert config.role == "" assert config.goal == "" @@ -36,7 +38,7 @@ def test_read_minimal_yaml(self): def test_read_maximal_yaml(self): shutil.copy(BASE_PATH / "fixtures/agents_max.yaml", self.project_dir / AGENTS_FILENAME) - config = AgentConfig("agent_name", self.project_dir) + config = AgentConfig("agent_name") assert config.name == "agent_name" assert config.role == "role" assert config.goal == "this is a goal" @@ -44,7 +46,7 @@ def test_read_maximal_yaml(self): assert config.llm == "provider/model" def test_write_yaml(self): - with AgentConfig("agent_name", self.project_dir) as config: + with AgentConfig("agent_name") as config: config.role = "role" config.goal = "this is a goal" config.backstory = "backstory" @@ -65,7 +67,7 @@ def test_write_yaml(self): ) def test_write_none_values(self): - with AgentConfig("agent_name", self.project_dir) as config: + with AgentConfig("agent_name") as config: config.role = None config.goal = None config.backstory = None diff --git a/tests/test_cli_loads.py b/tests/test_cli_loads.py index 819983a..6ac8fca 100644 --- a/tests/test_cli_loads.py +++ b/tests/test_cli_loads.py @@ -22,6 +22,9 @@ def run_cli(self, *args): def test_version(self): """Test the --version command.""" result = self.run_cli("--version") + print(result.stdout) + print(result.stderr) + print(result.returncode) self.assertEqual(result.returncode, 0) self.assertIn("AgentStack CLI version:", result.stdout) diff --git a/tests/test_frameworks.py b/tests/test_frameworks.py index 8b27b84..4b8e3cf 100644 --- a/tests/test_frameworks.py +++ b/tests/test_frameworks.py @@ -4,7 +4,8 @@ import unittest from parameterized import parameterized_class -from agentstack import ValidationError +from agentstack.conf import ConfigFile, set_path +from agentstack.exceptions import ValidationError from agentstack import frameworks from agentstack.tools import ToolConfig @@ -23,17 +24,22 @@ def setUp(self): (self.project_dir / 'src' / '__init__.py').touch() (self.project_dir / 'src' / 'tools' / '__init__.py').touch() + shutil.copy(BASE_PATH / 'fixtures' / 'agentstack.json', self.project_dir / 'agentstack.json') + set_path(self.project_dir) + with ConfigFile() as config: + config.framework = self.framework + def tearDown(self): shutil.rmtree(self.project_dir) def _populate_min_entrypoint(self): """This entrypoint does not have any tools or agents.""" - entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) + entrypoint_path = frameworks.get_entrypoint_path(self.framework) shutil.copy(BASE_PATH / f"fixtures/frameworks/{self.framework}/entrypoint_min.py", entrypoint_path) def _populate_max_entrypoint(self): """This entrypoint has tools and agents.""" - entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) + entrypoint_path = frameworks.get_entrypoint_path(self.framework) shutil.copy(BASE_PATH / f"fixtures/frameworks/{self.framework}/entrypoint_max.py", entrypoint_path) def _get_test_tool(self) -> ToolConfig: @@ -54,55 +60,55 @@ def test_get_framework_module_invalid(self): def test_validate_project(self): self._populate_max_entrypoint() - frameworks.validate_project(self.framework, self.project_dir) + frameworks.validate_project() def test_validate_project_invalid(self): self._populate_min_entrypoint() with self.assertRaises(ValidationError) as context: - frameworks.validate_project(self.framework, self.project_dir) + frameworks.validate_project() def test_add_tool(self): self._populate_max_entrypoint() - frameworks.add_tool(self.framework, self._get_test_tool(), 'test_agent', self.project_dir) + frameworks.add_tool(self._get_test_tool(), 'test_agent') - entrypoint_src = open(frameworks.get_entrypoint_path(self.framework, self.project_dir)).read() + entrypoint_src = open(frameworks.get_entrypoint_path(self.framework)).read() # TODO these asserts are not framework agnostic assert 'tools=[tools.test_tool' in entrypoint_src def test_add_tool_starred(self): self._populate_max_entrypoint() - frameworks.add_tool(self.framework, self._get_test_tool_starred(), 'test_agent', self.project_dir) + frameworks.add_tool(self._get_test_tool_starred(), 'test_agent') - entrypoint_src = open(frameworks.get_entrypoint_path(self.framework, self.project_dir)).read() + entrypoint_src = open(frameworks.get_entrypoint_path(self.framework)).read() assert 'tools=[*tools.test_tool_star' in entrypoint_src def test_add_tool_invalid(self): self._populate_min_entrypoint() with self.assertRaises(ValidationError) as context: - frameworks.add_tool(self.framework, self._get_test_tool(), 'test_agent', self.project_dir) + frameworks.add_tool(self._get_test_tool(), 'test_agent') def test_remove_tool(self): self._populate_max_entrypoint() - frameworks.add_tool(self.framework, self._get_test_tool(), 'test_agent', self.project_dir) - frameworks.remove_tool(self.framework, self._get_test_tool(), 'test_agent', self.project_dir) + frameworks.add_tool(self._get_test_tool(), 'test_agent') + frameworks.remove_tool(self._get_test_tool(), 'test_agent') - entrypoint_src = open(frameworks.get_entrypoint_path(self.framework, self.project_dir)).read() + entrypoint_src = open(frameworks.get_entrypoint_path(self.framework)).read() assert 'tools=[tools.test_tool' not in entrypoint_src def test_remove_tool_starred(self): self._populate_max_entrypoint() - frameworks.add_tool(self.framework, self._get_test_tool_starred(), 'test_agent', self.project_dir) - frameworks.remove_tool(self.framework, self._get_test_tool_starred(), 'test_agent', self.project_dir) + frameworks.add_tool(self._get_test_tool_starred(), 'test_agent') + frameworks.remove_tool(self._get_test_tool_starred(), 'test_agent') - entrypoint_src = open(frameworks.get_entrypoint_path(self.framework, self.project_dir)).read() + entrypoint_src = open(frameworks.get_entrypoint_path(self.framework)).read() assert 'tools=[*tools.test_tool_star' not in entrypoint_src def test_add_multiple_tools(self): self._populate_max_entrypoint() - frameworks.add_tool(self.framework, self._get_test_tool(), 'test_agent', self.project_dir) - frameworks.add_tool(self.framework, self._get_test_tool_starred(), 'test_agent', self.project_dir) + frameworks.add_tool(self._get_test_tool(), 'test_agent') + frameworks.add_tool(self._get_test_tool_starred(), 'test_agent') - entrypoint_src = open(frameworks.get_entrypoint_path(self.framework, self.project_dir)).read() + entrypoint_src = open(frameworks.get_entrypoint_path(self.framework)).read() assert ( # ordering is not guaranteed 'tools=[tools.test_tool, *tools.test_tool_star' in entrypoint_src or 'tools=[*tools.test_tool_star, tools.test_tool' in entrypoint_src @@ -110,10 +116,10 @@ def test_add_multiple_tools(self): def test_remove_one_tool_of_multiple(self): self._populate_max_entrypoint() - frameworks.add_tool(self.framework, self._get_test_tool(), 'test_agent', self.project_dir) - frameworks.add_tool(self.framework, self._get_test_tool_starred(), 'test_agent', self.project_dir) - frameworks.remove_tool(self.framework, self._get_test_tool(), 'test_agent', self.project_dir) + frameworks.add_tool(self._get_test_tool(), 'test_agent') + frameworks.add_tool(self._get_test_tool_starred(), 'test_agent') + frameworks.remove_tool(self._get_test_tool(), 'test_agent') - entrypoint_src = open(frameworks.get_entrypoint_path(self.framework, self.project_dir)).read() + entrypoint_src = open(frameworks.get_entrypoint_path(self.framework)).read() assert 'tools=[tools.test_tool' not in entrypoint_src assert 'tools=[*tools.test_tool_star' in entrypoint_src diff --git a/tests/test_generation_agent.py b/tests/test_generation_agent.py index 2f836e5..f2b39f5 100644 --- a/tests/test_generation_agent.py +++ b/tests/test_generation_agent.py @@ -5,8 +5,9 @@ from parameterized import parameterized_class import ast -from agentstack import frameworks, ValidationError -from agentstack.generation.files import ConfigFile +from agentstack.conf import ConfigFile, set_path +from agentstack import frameworks +from agentstack.exceptions import ValidationError from agentstack.generation.agent_generation import add_agent BASE_PATH = Path(__file__).parent @@ -22,15 +23,16 @@ def setUp(self): os.makedirs(self.project_dir / 'src' / 'config') (self.project_dir / 'src' / '__init__.py').touch() - # populate the entrypoint - entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) - shutil.copy(BASE_PATH / f"fixtures/frameworks/{self.framework}/entrypoint_max.py", entrypoint_path) - # set the framework in agentstack.json shutil.copy(BASE_PATH / 'fixtures' / 'agentstack.json', self.project_dir / 'agentstack.json') - with ConfigFile(self.project_dir) as config: + set_path(self.project_dir) + with ConfigFile() as config: config.framework = self.framework + # populate the entrypoint + entrypoint_path = frameworks.get_entrypoint_path(self.framework) + shutil.copy(BASE_PATH / f"fixtures/frameworks/{self.framework}/entrypoint_max.py", entrypoint_path) + def tearDown(self): shutil.rmtree(self.project_dir) @@ -41,10 +43,9 @@ def test_add_agent(self): goal='goal', backstory='backstory', llm='llm', - path=self.project_dir, ) - entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) + entrypoint_path = frameworks.get_entrypoint_path(self.framework) entrypoint_src = open(entrypoint_path).read() # agents.yaml is covered in test_agents_config.py # TODO framework-specific validation for code structure @@ -60,5 +61,4 @@ def test_add_agent_exists(self): goal='goal', backstory='backstory', llm='llm', - path=self.project_dir, ) diff --git a/tests/test_generation_files.py b/tests/test_generation_files.py index 3a8f0a0..900efdf 100644 --- a/tests/test_generation_files.py +++ b/tests/test_generation_files.py @@ -2,7 +2,9 @@ import unittest from pathlib import Path import shutil -from agentstack.generation.files import ConfigFile, EnvFile +from agentstack import conf +from agentstack.conf import ConfigFile +from agentstack.generation.files import EnvFile from agentstack.utils import ( verify_agentstack_project, get_framework, @@ -13,9 +15,20 @@ BASE_PATH = Path(__file__).parent +# TODO copy files to working directory class GenerationFilesTest(unittest.TestCase): + def setUp(self): + self.project_dir = BASE_PATH / "tmp" / "generation_files" + os.makedirs(self.project_dir) + + shutil.copy(BASE_PATH / "fixtures/agentstack.json", self.project_dir / "agentstack.json") + conf.set_path(self.project_dir) + + def tearDown(self): + shutil.rmtree(self.project_dir) + def test_read_config(self): - config = ConfigFile(BASE_PATH / "fixtures") # + agentstack.json + config = ConfigFile() # + agentstack.json assert config.framework == "crewai" assert config.tools == [] assert config.telemetry_opt_out is None @@ -25,23 +38,19 @@ def test_read_config(self): assert config.template_version is None def test_write_config(self): - try: - os.makedirs(BASE_PATH / "tmp", exist_ok=True) - shutil.copy(BASE_PATH / "fixtures/agentstack.json", BASE_PATH / "tmp/agentstack.json") - - with ConfigFile(BASE_PATH / "tmp") as config: - config.framework = "crewai" - config.tools = ["tool1", "tool2"] - config.telemetry_opt_out = True - config.default_model = "openai/gpt-4o" - config.agentstack_version = "0.2.1" - config.template = "default" - config.template_version = "1" - - tmp_data = open(BASE_PATH / "tmp/agentstack.json").read() - assert ( - tmp_data - == """{ + with ConfigFile() as config: + config.framework = "crewai" + config.tools = ["tool1", "tool2"] + config.telemetry_opt_out = True + config.default_model = "openai/gpt-4o" + config.agentstack_version = "0.2.1" + config.template = "default" + config.template_version = "1" + + tmp_data = open(self.project_dir / "agentstack.json").read() + assert ( + tmp_data + == """{ "framework": "crewai", "tools": [ "tool1", @@ -53,31 +62,33 @@ def test_write_config(self): "template": "default", "template_version": "1" }""" - ) - except Exception as e: - raise e - finally: - os.remove(BASE_PATH / "tmp/agentstack.json") - # os.rmdir(BASE_PATH / "tmp") + ) def test_read_missing_config(self): + conf.set_path(BASE_PATH / "missing") with self.assertRaises(FileNotFoundError) as _: - _ = ConfigFile(BASE_PATH / "missing") + _ = ConfigFile() def test_verify_agentstack_project_valid(self): - verify_agentstack_project(BASE_PATH / "fixtures") + verify_agentstack_project() def test_verify_agentstack_project_invalid(self): + conf.set_path(BASE_PATH / "missing") with self.assertRaises(SystemExit) as _: - verify_agentstack_project(BASE_PATH / "missing") + verify_agentstack_project() def test_get_framework(self): - assert get_framework(BASE_PATH / "fixtures") == "crewai" + assert get_framework() == "crewai" + + def test_get_framework_missing(self): + conf.set_path(BASE_PATH / "missing") with self.assertRaises(SystemExit) as _: - get_framework(BASE_PATH / "missing") + get_framework() def test_read_env(self): - env = EnvFile(BASE_PATH / "fixtures") + shutil.copy(BASE_PATH / "fixtures/.env", self.project_dir / ".env") + + env = EnvFile() assert env.variables == {"ENV_VAR1": "value1", "ENV_VAR2": "value2"} assert env["ENV_VAR1"] == "value1" assert env["ENV_VAR2"] == "value2" @@ -85,18 +96,11 @@ def test_read_env(self): env["ENV_VAR3"] def test_write_env(self): - try: - os.makedirs(BASE_PATH / "tmp", exist_ok=True) - shutil.copy(BASE_PATH / "fixtures/.env", BASE_PATH / "tmp/.env") - - with EnvFile(BASE_PATH / "tmp") as env: - env.append_if_new("ENV_VAR1", "value100") # Should not be updated - env.append_if_new("ENV_VAR100", "value2") # Should be added - - tmp_data = open(BASE_PATH / "tmp/.env").read() - assert tmp_data == """\nENV_VAR1=value1\nENV_VAR2=value2\nENV_VAR100=value2""" - except Exception as e: - raise e - finally: - os.remove(BASE_PATH / "tmp/.env") - # os.rmdir(BASE_PATH / "tmp") + shutil.copy(BASE_PATH / "fixtures/.env", self.project_dir / ".env") + + with EnvFile() as env: + env.append_if_new("ENV_VAR1", "value100") # Should not be updated + env.append_if_new("ENV_VAR100", "value2") # Should be added + + tmp_data = open(self.project_dir / ".env").read() + assert tmp_data == """\nENV_VAR1=value1\nENV_VAR2=value2\nENV_VAR100=value2""" diff --git a/tests/test_generation_tasks.py b/tests/test_generation_tasks.py index 430a369..106ec12 100644 --- a/tests/test_generation_tasks.py +++ b/tests/test_generation_tasks.py @@ -5,8 +5,9 @@ from parameterized import parameterized_class import ast -from agentstack import frameworks, ValidationError -from agentstack.generation.files import ConfigFile +from agentstack.conf import ConfigFile, set_path +from agentstack.exceptions import ValidationError +from agentstack import frameworks from agentstack.generation.task_generation import add_task BASE_PATH = Path(__file__).parent @@ -22,15 +23,16 @@ def setUp(self): os.makedirs(self.project_dir / 'src' / 'config') (self.project_dir / 'src' / '__init__.py').touch() - # populate the entrypoint - entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) - shutil.copy(BASE_PATH / f"fixtures/frameworks/{self.framework}/entrypoint_max.py", entrypoint_path) - # set the framework in agentstack.json shutil.copy(BASE_PATH / 'fixtures' / 'agentstack.json', self.project_dir / 'agentstack.json') - with ConfigFile(self.project_dir) as config: + set_path(self.project_dir) + with ConfigFile() as config: config.framework = self.framework + # populate the entrypoint + entrypoint_path = frameworks.get_entrypoint_path(self.framework) + shutil.copy(BASE_PATH / f"fixtures/frameworks/{self.framework}/entrypoint_max.py", entrypoint_path) + def tearDown(self): shutil.rmtree(self.project_dir) @@ -40,10 +42,9 @@ def test_add_task(self): description='description', expected_output='expected_output', agent='agent', - path=self.project_dir, ) - entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) + entrypoint_path = frameworks.get_entrypoint_path(self.framework) entrypoint_src = open(entrypoint_path).read() # agents.yaml is covered in test_agents_config.py # TODO framework-specific validation for code structure @@ -58,5 +59,4 @@ def test_add_agent_exists(self): description='description', expected_output='expected_output', agent='agent', - path=self.project_dir, ) diff --git a/tests/test_generation_tool.py b/tests/test_generation_tool.py index d212268..779f7f9 100644 --- a/tests/test_generation_tool.py +++ b/tests/test_generation_tool.py @@ -5,9 +5,9 @@ from parameterized import parameterized_class import ast +from agentstack.conf import ConfigFile, set_path from agentstack import frameworks from agentstack.tools import get_all_tools, ToolConfig -from agentstack.generation.files import ConfigFile from agentstack.generation.tool_generation import add_tool, remove_tool, TOOLS_INIT_FILENAME @@ -26,23 +26,24 @@ def setUp(self): (self.project_dir / 'src' / '__init__.py').touch() (self.project_dir / TOOLS_INIT_FILENAME).touch() - # populate the entrypoint - entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) - shutil.copy(BASE_PATH / f"fixtures/frameworks/{self.framework}/entrypoint_max.py", entrypoint_path) - # set the framework in agentstack.json shutil.copy(BASE_PATH / 'fixtures' / 'agentstack.json', self.project_dir / 'agentstack.json') - with ConfigFile(self.project_dir) as config: + set_path(self.project_dir) + with ConfigFile() as config: config.framework = self.framework + # populate the entrypoint + entrypoint_path = frameworks.get_entrypoint_path(self.framework) + shutil.copy(BASE_PATH / f"fixtures/frameworks/{self.framework}/entrypoint_max.py", entrypoint_path) + def tearDown(self): shutil.rmtree(self.project_dir) def test_add_tool(self): tool_conf = ToolConfig.from_tool_name('agent_connect') - add_tool('agent_connect', path=self.project_dir) + add_tool('agent_connect') - entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) + entrypoint_path = frameworks.get_entrypoint_path(self.framework) entrypoint_src = open(entrypoint_path).read() ast.parse(entrypoint_src) tools_init_src = open(self.project_dir / TOOLS_INIT_FILENAME).read() @@ -55,10 +56,10 @@ def test_add_tool(self): def test_remove_tool(self): tool_conf = ToolConfig.from_tool_name('agent_connect') - add_tool('agent_connect', path=self.project_dir) - remove_tool('agent_connect', path=self.project_dir) + add_tool('agent_connect') + remove_tool('agent_connect') - entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) + entrypoint_path = frameworks.get_entrypoint_path(self.framework) entrypoint_src = open(entrypoint_path).read() ast.parse(entrypoint_src) tools_init_src = open(self.project_dir / TOOLS_INIT_FILENAME).read() diff --git a/tests/test_inputs_config.py b/tests/test_inputs_config.py index 1f20ace..3ca883b 100644 --- a/tests/test_inputs_config.py +++ b/tests/test_inputs_config.py @@ -2,6 +2,7 @@ import shutil import unittest from pathlib import Path +from agentstack import conf from agentstack.inputs import InputsConfig BASE_PATH = Path(__file__).parent @@ -13,17 +14,19 @@ def setUp(self): os.makedirs(self.project_dir) os.makedirs(self.project_dir / "src/config") + conf.set_path(self.project_dir) + def tearDown(self): shutil.rmtree(self.project_dir) def test_minimal_input_config(self): shutil.copy(BASE_PATH / "fixtures/inputs_min.yaml", self.project_dir / "src/config/inputs.yaml") - config = InputsConfig(self.project_dir) + config = InputsConfig() assert config.to_dict() == {} def test_maximal_input_config(self): shutil.copy(BASE_PATH / "fixtures/inputs_max.yaml", self.project_dir / "src/config/inputs.yaml") - config = InputsConfig(self.project_dir) + config = InputsConfig() assert config['input_name'] == "This in an input" assert config['input_name_2'] == "This is another input" assert config.to_dict() == {'input_name': "This in an input", 'input_name_2': "This is another input"} diff --git a/tests/test_project_run.py b/tests/test_project_run.py index 0740aad..cb30f13 100644 --- a/tests/test_project_run.py +++ b/tests/test_project_run.py @@ -4,9 +4,10 @@ import unittest from parameterized import parameterized_class +from agentstack import conf +from agentstack.conf import ConfigFile from agentstack import frameworks from agentstack.cli import run_project -from agentstack.generation.files import ConfigFile BASE_PATH = Path(__file__).parent @@ -14,7 +15,7 @@ @parameterized_class([{"framework": framework} for framework in frameworks.SUPPORTED_FRAMEWORKS]) class ProjectRunTest(unittest.TestCase): def setUp(self): - self.project_dir = BASE_PATH / 'tmp' / self.framework + self.project_dir = BASE_PATH / 'tmp/project_run' / self.framework os.makedirs(self.project_dir) os.makedirs(self.project_dir / 'src') @@ -25,11 +26,12 @@ def setUp(self): # set the framework in agentstack.json shutil.copy(BASE_PATH / 'fixtures' / 'agentstack.json', self.project_dir / 'agentstack.json') - with ConfigFile(self.project_dir) as config: + conf.set_path(self.project_dir) + with ConfigFile() as config: config.framework = self.framework # populate the entrypoint - entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) + entrypoint_path = frameworks.get_entrypoint_path(self.framework) shutil.copy(BASE_PATH / f"fixtures/frameworks/{self.framework}/entrypoint_max.py", entrypoint_path) # write a basic .env file @@ -39,11 +41,11 @@ def tearDown(self): shutil.rmtree(self.project_dir) def test_run_project(self): - run_project(path=self.project_dir) + run_project() def test_env_is_set(self): """ After running a project, the environment variables should be set from project_dir/.env. """ - run_project(path=self.project_dir) + run_project() assert os.getenv('ENV_VAR1') == 'value1' diff --git a/tests/test_tasks_config.py b/tests/test_tasks_config.py index c95665b..498ea9c 100644 --- a/tests/test_tasks_config.py +++ b/tests/test_tasks_config.py @@ -4,6 +4,7 @@ import unittest import importlib.resources from pathlib import Path +from agentstack import conf from agentstack.tasks import TaskConfig, TASKS_FILENAME BASE_PATH = Path(__file__).parent @@ -13,12 +14,13 @@ class AgentConfigTest(unittest.TestCase): def setUp(self): self.project_dir = BASE_PATH / 'tmp/task_config' os.makedirs(self.project_dir / 'src/config') + conf.set_path(self.project_dir) def tearDown(self): shutil.rmtree(self.project_dir) def test_empty_file(self): - config = TaskConfig("task_name", self.project_dir) + config = TaskConfig("task_name") assert config.name == "task_name" assert config.description is "" assert config.expected_output is "" @@ -26,7 +28,7 @@ def test_empty_file(self): def test_read_minimal_yaml(self): shutil.copy(BASE_PATH / "fixtures/tasks_min.yaml", self.project_dir / TASKS_FILENAME) - config = TaskConfig("task_name", self.project_dir) + config = TaskConfig("task_name") assert config.name == "task_name" assert config.description is "" assert config.expected_output is "" @@ -34,14 +36,14 @@ def test_read_minimal_yaml(self): def test_read_maximal_yaml(self): shutil.copy(BASE_PATH / "fixtures/tasks_max.yaml", self.project_dir / TASKS_FILENAME) - config = TaskConfig("task_name", self.project_dir) + config = TaskConfig("task_name") assert config.name == "task_name" assert config.description == "Add your description here" assert config.expected_output == "Add your expected output here" assert config.agent == "default_agent" def test_write_yaml(self): - with TaskConfig("task_name", self.project_dir) as config: + with TaskConfig("task_name") as config: config.description = "Add your description here" config.expected_output = "Add your expected output here" config.agent = "default_agent" @@ -60,7 +62,7 @@ def test_write_yaml(self): ) def test_write_none_values(self): - with TaskConfig("task_name", self.project_dir) as config: + with TaskConfig("task_name") as config: config.description = None config.expected_output = None config.agent = None diff --git a/tests/test_templates_config.py b/tests/test_templates_config.py index 226e0c2..47926c0 100644 --- a/tests/test_templates_config.py +++ b/tests/test_templates_config.py @@ -2,7 +2,7 @@ import json import unittest from parameterized import parameterized -from agentstack import ValidationError +from agentstack.exceptions import ValidationError from agentstack.proj_templates import TemplateConfig, get_all_template_names, get_all_template_paths BASE_PATH = Path(__file__).parent diff --git a/tests/test_tool_generation_init.py b/tests/test_tool_generation_init.py index 7bb7958..8eb0b65 100644 --- a/tests/test_tool_generation_init.py +++ b/tests/test_tool_generation_init.py @@ -4,10 +4,11 @@ import unittest from parameterized import parameterized_class -from agentstack import ValidationError +from agentstack import conf +from agentstack.conf import ConfigFile +from agentstack.exceptions import ValidationError from agentstack import frameworks from agentstack.tools import ToolConfig -from agentstack.generation.files import ConfigFile from agentstack.generation.tool_generation import ToolsInitFile, TOOLS_INIT_FILENAME @@ -24,8 +25,9 @@ def setUp(self): (self.project_dir / 'src' / '__init__.py').touch() (self.project_dir / 'src' / 'tools' / '__init__.py').touch() shutil.copy(BASE_PATH / 'fixtures' / 'agentstack.json', self.project_dir / 'agentstack.json') - # set the framework in agentstack.json - with ConfigFile(self.project_dir) as config: + + conf.set_path(self.project_dir) + with ConfigFile() as config: config.framework = self.framework def tearDown(self): @@ -38,18 +40,18 @@ def _get_test_tool_alt(self) -> ToolConfig: return ToolConfig(name='test_tool_alt', category='test', tools=['test_tool_alt']) def test_tools_init_file(self): - tools_init = ToolsInitFile(self.project_dir / TOOLS_INIT_FILENAME) + tools_init = ToolsInitFile(conf.PATH / TOOLS_INIT_FILENAME) # file is empty assert tools_init.get_import_for_tool(self._get_test_tool()) == None def test_tools_init_file_missing(self): with self.assertRaises(ValidationError) as context: - tools_init = ToolsInitFile(self.project_dir / 'missing') + tools_init = ToolsInitFile(conf.PATH / 'missing') def test_tools_init_file_add_import(self): tool = self._get_test_tool() - with ToolsInitFile(self.project_dir / TOOLS_INIT_FILENAME) as tools_init: - tools_init.add_import_for_tool(self.framework, tool) + with ToolsInitFile(conf.PATH / TOOLS_INIT_FILENAME) as tools_init: + tools_init.add_import_for_tool(tool, self.framework) tool_init_src = open(self.project_dir / TOOLS_INIT_FILENAME).read() assert tool.get_import_statement(self.framework) in tool_init_src @@ -57,16 +59,16 @@ def test_tools_init_file_add_import(self): def test_tools_init_file_add_import_multiple(self): tool = self._get_test_tool() tool_alt = self._get_test_tool_alt() - with ToolsInitFile(self.project_dir / TOOLS_INIT_FILENAME) as tools_init: - tools_init.add_import_for_tool(self.framework, tool) + with ToolsInitFile(conf.PATH / TOOLS_INIT_FILENAME) as tools_init: + tools_init.add_import_for_tool(tool, self.framework) - with ToolsInitFile(self.project_dir / TOOLS_INIT_FILENAME) as tools_init: - tools_init.add_import_for_tool(self.framework, tool_alt) + with ToolsInitFile(conf.PATH / TOOLS_INIT_FILENAME) as tools_init: + tools_init.add_import_for_tool(tool_alt, self.framework) # Should not be able to re-add a tool import with self.assertRaises(ValidationError) as context: - with ToolsInitFile(self.project_dir / TOOLS_INIT_FILENAME) as tools_init: - tools_init.add_import_for_tool(self.framework, tool) + with ToolsInitFile(conf.PATH / TOOLS_INIT_FILENAME) as tools_init: + tools_init.add_import_for_tool(tool, self.framework) tool_init_src = open(self.project_dir / TOOLS_INIT_FILENAME).read() assert tool.get_import_statement(self.framework) in tool_init_src