-
Notifications
You must be signed in to change notification settings - Fork 121
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #54 from ag2ai/captain1
CaptainAgent PR part 3
- Loading branch information
Showing
42 changed files
with
1,218 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Introduction | ||
|
||
This directory contains a library of manually created python tools. These tools have three categories: math, data_analysis and information_retrieval. | ||
|
||
# Directory Layout | ||
``` | ||
tools | ||
├── README.md | ||
├── data_analysis | ||
│ ├── calculate_correlation.py | ||
│ └── ... | ||
├── information_retrieval | ||
│ ├── arxiv_download.py | ||
│ ├── arxiv_search.py | ||
│ └── ... | ||
├── math | ||
│ ├── calculate_circle_area_from_diameter.py | ||
│ └── ... | ||
└── tool_description.tsv | ||
``` | ||
|
||
Tools can be imported from `tools/{category}/{tool_name}.py` with exactly the same function name. | ||
|
||
`tool_description.tsv` contains descriptions of tools for retrieval. | ||
|
||
# How to use | ||
Some tools require Bing Search API key and RapidAPI key. For Bing API, you can read more about how to get an API on the [Bing Web Search API](https://www.microsoft.com/en-us/bing/apis/bing-web-search-api) page. For RapidAPI, you can [sign up](https://rapidapi.com/auth/sign-up) and subscribe to these two links([link1](https://rapidapi.com/solid-api-solid-api-default/api/youtube-transcript3), [link2](https://rapidapi.com/420vijay47/api/youtube-mp3-downloader2)). These apis have free billing options and there is no need to worry about extra costs. | ||
|
||
To install the requirements for running tools, use pip install. | ||
```bash | ||
pip install -r autogen/agentchat/contrib/captainagent/tools/requirements.txt | ||
``` | ||
|
||
Whenever you run the tool-related code, remember to export the api keys to system variables. | ||
```bash | ||
export BING_API_KEY="" | ||
export RAPID_API_KEY="" | ||
``` | ||
or | ||
```python | ||
import os | ||
os.environ["BING_API_KEY"] = "" | ||
os.environ["RAPID_API_KEY"] = "" | ||
``` |
38 changes: 38 additions & 0 deletions
38
autogen/agentchat/contrib/captainagent/tools/data_analysis/calculate_correlation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
def calculate_correlation(csv_path: str, column1: str, column2: str, method: str = "pearson") -> float: | ||
""" | ||
Calculate the correlation between two columns in a CSV file. | ||
Args: | ||
csv_path (str): The path to the CSV file. | ||
column1 (str): The name of the first column. | ||
column2 (str): The name of the second column. | ||
method (str or callable, optional): The method used to calculate the correlation. | ||
- 'pearson' (default): Pearson correlation coefficient. | ||
- 'kendall': Kendall Tau correlation coefficient. | ||
- 'spearman': Spearman rank correlation coefficient. | ||
- callable: A custom correlation function that takes two arrays and returns a scalar. | ||
Returns: | ||
float: The correlation coefficient between the two columns. | ||
""" | ||
import pandas as pd | ||
|
||
# Read the CSV file into a pandas DataFrame | ||
df = pd.read_csv(csv_path) | ||
|
||
# Select the specified columns | ||
selected_columns = df[[column1, column2]] | ||
|
||
# Calculate the correlation based on the specified method | ||
if method == "pearson": | ||
correlation = selected_columns.corr().iloc[0, 1] | ||
elif method == "kendall": | ||
correlation = selected_columns.corr(method="kendall").iloc[0, 1] | ||
elif method == "spearman": | ||
correlation = selected_columns.corr(method="spearman").iloc[0, 1] | ||
elif callable(method): | ||
correlation = selected_columns.corr(method=method).iloc[0, 1] | ||
else: | ||
raise ValueError("Invalid correlation method. Please choose 'pearson', 'kendall', 'spearman', or a callable.") | ||
|
||
return correlation |
26 changes: 26 additions & 0 deletions
26
...gen/agentchat/contrib/captainagent/tools/data_analysis/calculate_skewness_and_kurtosis.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
def calculate_skewness_and_kurtosis(csv_file: str, column_name: str) -> tuple: | ||
""" | ||
Calculate the skewness and kurtosis of a specified column in a CSV file. The kurtosis is calculated using the Fisher definition. | ||
The two metrics are computed using scipy.stats functions. | ||
Args: | ||
csv_file (str): The path to the CSV file. | ||
column_name (str): The name of the column to calculate skewness and kurtosis for. | ||
Returns: | ||
tuple: (skewness, kurtosis) | ||
""" | ||
import pandas as pd | ||
from scipy.stats import kurtosis, skew | ||
|
||
# Read the CSV file into a pandas DataFrame | ||
df = pd.read_csv(csv_file) | ||
|
||
# Extract the specified column | ||
column = df[column_name] | ||
|
||
# Calculate the skewness and kurtosis | ||
skewness = skew(column) | ||
kurt = kurtosis(column) | ||
|
||
return skewness, kurt |
26 changes: 26 additions & 0 deletions
26
autogen/agentchat/contrib/captainagent/tools/data_analysis/detect_outlier_iqr.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
def detect_outlier_iqr(csv_file: str, column_name: str): | ||
""" | ||
Detect outliers in a specified column of a CSV file using the IQR method. | ||
Args: | ||
csv_file (str): The path to the CSV file. | ||
column_name (str): The name of the column to detect outliers in. | ||
Returns: | ||
list: A list of row indices that correspond to the outliers. | ||
""" | ||
import pandas as pd | ||
|
||
# Read the CSV file into a pandas DataFrame | ||
df = pd.read_csv(csv_file) | ||
|
||
# Calculate the quartiles and IQR for the specified column | ||
q1 = df[column_name].quantile(0.25) | ||
q3 = df[column_name].quantile(0.75) | ||
iqr = q3 - q1 | ||
|
||
# Find the outliers based on the defined criteria | ||
outliers = df[(df[column_name] < q1 - 1.5 * iqr) | (df[column_name] > q3 + 1.5 * iqr)] | ||
|
||
# Return the row indices of the outliers | ||
return outliers.index.tolist() |
26 changes: 26 additions & 0 deletions
26
autogen/agentchat/contrib/captainagent/tools/data_analysis/detect_outlier_zscore.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
def detect_outlier_zscore(csv_file, column_name, threshold=3): | ||
""" | ||
Detect outliers in a CSV file based on a specified column. The outliers are determined by calculating the z-score of the data points in the column. | ||
Args: | ||
csv_file (str): The path to the CSV file. | ||
column_name (str): The name of the column to calculate z-scores for. | ||
threshold (float, optional): The threshold value for determining outliers. By default set to 3. | ||
Returns: | ||
list: A list of row indices where the z-score is above the threshold. | ||
""" | ||
import numpy as np | ||
import pandas as pd | ||
|
||
# Read the CSV file into a pandas DataFrame | ||
df = pd.read_csv(csv_file) | ||
|
||
# Calculate the z-score for the specified column | ||
z_scores = np.abs((df[column_name] - df[column_name].mean()) / df[column_name].std()) | ||
|
||
# Find the row indices where the z-score is above the threshold | ||
outlier_indices = np.where(z_scores > threshold)[0] | ||
|
||
# Return the row indices of the outliers | ||
return outlier_indices |
19 changes: 19 additions & 0 deletions
19
autogen/agentchat/contrib/captainagent/tools/data_analysis/explore_csv.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
def explore_csv(file_path, num_lines=5): | ||
""" | ||
Reads a CSV file and prints the column names, shape, data types, and the first few lines of data. | ||
Args: | ||
file_path (str): The path to the CSV file. | ||
num_lines (int, optional): The number of lines to print. Defaults to 5. | ||
""" | ||
import pandas as pd | ||
|
||
df = pd.read_csv(file_path) | ||
header = df.columns | ||
print("Columns:") | ||
print(", ".join(header)) | ||
print("Shape:", df.shape) | ||
print("Data Types:") | ||
print(df.dtypes) | ||
print("First", num_lines, "lines:") | ||
print(df.head(num_lines)) |
28 changes: 28 additions & 0 deletions
28
autogen/agentchat/contrib/captainagent/tools/data_analysis/shapiro_wilk_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from autogen.coding.func_with_reqs import with_requirements | ||
|
||
|
||
@with_requirements(["pandas", "scipy"]) | ||
def shapiro_wilk_test(csv_file, column_name): | ||
""" | ||
Perform the Shapiro-Wilk test on a specified column of a CSV file. | ||
Args: | ||
csv_file (str): The path to the CSV file. | ||
column_name (str): The name of the column to perform the test on. | ||
Returns: | ||
float: The p-value resulting from the Shapiro-Wilk test. | ||
""" | ||
import pandas as pd | ||
from scipy.stats import shapiro | ||
|
||
# Read the CSV file into a pandas DataFrame | ||
df = pd.read_csv(csv_file) | ||
|
||
# Extract the specified column as a numpy array | ||
column_data = df[column_name].values | ||
|
||
# Perform the Shapiro-Wilk test | ||
_, p_value = shapiro(column_data) | ||
|
||
return p_value |
23 changes: 23 additions & 0 deletions
23
autogen/agentchat/contrib/captainagent/tools/information_retrieval/arxiv_download.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import arxiv | ||
|
||
from autogen.coding.func_with_reqs import with_requirements | ||
|
||
|
||
@with_requirements(["arxiv"], ["arxiv"]) | ||
def arxiv_download(id_list: list, download_dir="./"): | ||
""" | ||
Downloads PDF files from ArXiv based on a list of arxiv paper IDs. | ||
Args: | ||
id_list (list): A list of paper IDs to download. e.g. [2302.00006v1] | ||
download_dir (str, optional): The directory to save the downloaded PDF files. Defaults to './'. | ||
Returns: | ||
list: A list of paths to the downloaded PDF files. | ||
""" | ||
paths = [] | ||
for paper in arxiv.Client().results(arxiv.Search(id_list=id_list)): | ||
path = paper.download_pdf(download_dir, filename=paper.get_short_id() + ".pdf") | ||
paths.append(path) | ||
print("Paper id:", paper.get_short_id(), "Downloaded to:", path) | ||
return paths |
52 changes: 52 additions & 0 deletions
52
autogen/agentchat/contrib/captainagent/tools/information_retrieval/arxiv_search.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import arxiv | ||
|
||
from autogen.coding.func_with_reqs import with_requirements | ||
|
||
|
||
@with_requirements(["arxiv"], ["arxiv"]) | ||
def arxiv_search(query, max_results=10, sortby="relevance"): | ||
""" | ||
Search for articles on arXiv based on the given query. | ||
Args: | ||
query (str): The search query. | ||
max_results (int, optional): The maximum number of results to retrieve. Defaults to 10. | ||
sortby (str, optional): The sorting criterion for the search results. Can be 'relevance' or 'submittedDate'. Defaults to 'relevance'. | ||
Returns: | ||
list: A list of dictionaries containing information about the search results. Each dictionary contains the following keys: | ||
- 'title': The title of the article. | ||
- 'authors': The authors of the article. | ||
- 'summary': The summary of the article. | ||
- 'entry_id': The entry ID of the article. | ||
- 'doi': The DOI of the article (If applicable). | ||
- 'published': The publication date of the article in the format 'Y-M'. | ||
""" | ||
|
||
def get_author(r): | ||
return ", ".join(a.name for a in r.authors) | ||
|
||
criterion = {"relevance": arxiv.SortCriterion.Relevance, "submittedDate": arxiv.SortCriterion.SubmittedDate}[sortby] | ||
|
||
client = arxiv.Client() | ||
search = arxiv.Search(query=query, max_results=max_results, sort_by=criterion) | ||
res = [] | ||
results = client.results(search) | ||
for r in results: | ||
print("Entry id:", r.entry_id) | ||
print("Title:", r.title) | ||
print("Authors:", get_author(r)) | ||
print("DOI:", r.doi) | ||
print("Published:", r.published.strftime("%Y-%m")) | ||
# print("Summary:", r.summary) | ||
res.append( | ||
{ | ||
"title": r.title, | ||
"authors": get_author(r), | ||
"summary": r.summary, | ||
"entry_id": r.entry_id, | ||
"doi": r.doi, | ||
"published": r.published.strftime("%Y-%m"), | ||
} | ||
) | ||
return res |
Oops, something went wrong.