Skip to content

Commit

Permalink
feat: add support for shortcuts
Browse files Browse the repository at this point in the history
  • Loading branch information
gventuri committed Jun 18, 2023
1 parent 9036c59 commit 88ef8db
Show file tree
Hide file tree
Showing 6 changed files with 368 additions and 3 deletions.
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,26 @@ Oh, Olivia gets paid the most.

You can find more examples in the [examples](examples) directory.

### ⚡️ Shortcuts

PandasAI also provides a number of shortcuts (beta) to make it easier to ask questions to your data. For example, you can ask PandasAI to `clean_data`, `impute_missing_values`, `generate_features`, `plot_histogram`, and many many more.

```python
# Clean data
pandas_ai.clean_data(df)

# Impute missing values
pandas_ai.impute_missing_values(df)

# Generate features
pandas_ai.generate_features(df)

# Plot histogram
pandas_ai.plot_histogram(df, column="gdp")
```

Learn more about the shortcuts [here](https://pandas-ai.readthedocs.io/en/latest/shortcuts/).

## 🔒 Privacy & Security

In order to generate the Python code to run, we take the dataframe head, we randomize it (using random generation for sensitive data and shuffling for non-sensitive data) and send just the head.
Expand Down
149 changes: 149 additions & 0 deletions docs/shortcuts.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Shortcuts

Shortcuts are a way to quickly access the most common queries. At the moment, shortcuts are in beta, and only a few are available. More will be added in the future.

## Available shortcuts

### clean_data

```python
df = pd.read_csv('data.csv')
pandas_ai.clean_data(df)
```

This shortcut will do data cleaning on the data frame.

### impute_missing_values

```python
df = pd.read_csv('data.csv')
pandas_ai.impute_missing_values(df)
```

This shortcut will impute missing values in the data frame.

### generate_features

```python
df = pd.read_csv('data.csv')
pandas_ai.generate_features(df)
```

This shortcut will generate features in the data frame.

### plot_pie_chart

```python
df = pd.read_csv('data.csv')
pandas_ai.plot_pie_chart(df, labels = ['a', 'b', 'c'], values = [1, 2, 3])
```

This shortcut will plot a pie chart of the data frame.

### plot_bar_chart

```python
df = pd.read_csv('data.csv')
pandas_ai.plot_bar_chart(df, x = ['a', 'b', 'c'], y = [1, 2, 3])
```

This shortcut will plot a bar chart of the data frame.

### plot_bar_chart

```python
df = pd.read_csv('data.csv')
pandas_ai.plot_bar_chart(df, x = ['a', 'b', 'c'])
```

This shortcut will plot a bar chart of the data frame.

### plot_histogram

```python
df = pd.read_csv('data.csv')
pandas_ai.plot_histogram(df, column = 'a')
```

This shortcut will plot a histogram of the data frame.

### plot_line_chart

```python
df = pd.read_csv('data.csv')
pandas_ai.plot_line_chart(df, x = ['a', 'b', 'c'], y = [1, 2, 3])
```

This shortcut will plot a line chart of the data frame.

### plot_scatter_chart

```python
df = pd.read_csv('data.csv')
pandas_ai.plot_scatter_chart(df, x = ['a', 'b', 'c'], y = [1, 2, 3])
```

This shortcut will plot a scatter chart of the data frame.

### plot_correlation_heatmap

```python
df = pd.read_csv('data.csv')
pandas_ai.plot_correlation_heatmap(df)
```

This shortcut will plot a correlation heatmap of the data frame.

### plot_confusion_matrix

```python
df = pd.read_csv('data.csv')
pandas_ai.plot_confusion_matrix(df, y_true = [1, 2, 3], y_pred = [1, 2, 3])
```

This shortcut will plot a confusion matrix of the data frame.

### plot_roc_curve

```python
df = pd.read_csv('data.csv')
pandas_ai.plot_roc_curve(df, y_true = [1, 2, 3], y_pred = [1, 2, 3])
```

This shortcut will plot a ROC curve of the data frame.

### rolling_mean

```python
df = pd.read_csv('data.csv')
pandas_ai.rolling_mean(df, column = 'a', window = 5)
```

This shortcut will calculate the rolling mean of the data frame.

### rolling_median

```python
df = pd.read_csv('data.csv')
pandas_ai.rolling_median(df, column = 'a', window = 5)
```

This shortcut will calculate the rolling median of the data frame.

### rolling_std

```python
df = pd.read_csv('data.csv')
pandas_ai.rolling_std(df, column = 'a', window = 5)
```

This shortcut will calculate the rolling standard deviation of the data frame.

### segment_customers

```python
df = pd.read_csv('data.csv')
pandas_ai.segment_customers(df, features = ['a', 'b', 'c'], n_clusters = 5)
```

This shortcut will segment customers in the data frame.
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ nav:
- Home: index.md
- Getting Started: getting-started.md
- Cache: cache.md
- Shortcuts: shortcuts.md
- Middlewares: middlewares.md
- Custom Optional Arguments: custom_optional_arguments.md
- Command Line Tool: pai_cli.md
Expand Down
3 changes: 2 additions & 1 deletion pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from .helpers.cache import Cache
from .helpers.notebook import Notebook
from .helpers.save_chart import add_save_chart
from .helpers.shortcuts import Shortcuts
from .llm.base import LLM
from .llm.langchain import LangchainLLM
from .middlewares.base import Middleware
Expand All @@ -66,7 +67,7 @@
from .prompts.multiple_dataframes import MultipleDataframesPrompt


class PandasAI:
class PandasAI(Shortcuts):
"""
PandasAI is a wrapper around a LLM to make dataframes conversational.
Expand Down
174 changes: 174 additions & 0 deletions pandasai/helpers/shortcuts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
from typing import Union
import pandas as pd
from abc import ABC, abstractmethod


class Shortcuts(ABC):
@abstractmethod
def run(self, df: pd.DataFrame, prompt: str) -> Union[str, pd.DataFrame]:
"""Run method from PandasAI class."""

pass

def clean_data(self, df: pd.DataFrame) -> pd.DataFrame:
"""Do data cleaning and return the dataframe."""

return self.run(
df,
"""
1. Copy the dataframe to a new variable named df_cleaned.
2. Do data cleaning.
3. Return df_cleaned.
""",
)

def impute_missing_values(self, df: pd.DataFrame) -> pd.DataFrame:
"""Do missing value imputation and return the dataframe."""

return self.run(
df,
"""
1. Copy the dataframe to a new variable named df_imputed.
2. Do the imputation of missing values.
3. Return df_imputed.
""",
)

def generate_features(self, df: pd.DataFrame) -> pd.DataFrame:
"""Do feature generation and return the dataframe."""

return self.run(
df,
"""
1. Copy the dataframe to a new variable named df_features.
2. Do feature generation.
3. Return df_features.
""",
)

def plot_pie_chart(self, df: pd.DataFrame, labels: list, values: list) -> None:
"""Plot a pie chart."""

self.run(
df,
f"""
Plot a pie chart with the following labels and values:
labels = {labels}
values = {values}
""",
)

def plot_bar_chart(self, df: pd.DataFrame, x: list, y: list) -> None:
"""Plot a bar chart."""

self.run(
df,
f"""
Plot a bar chart with the following x and y:
x = {x}
y = {y}
""",
)

def plot_histogram(self, df: pd.DataFrame, column: str) -> None:
"""Plot a histogram."""

self.run(df, f"Plot a histogram of the column {column}.")

def plot_line_chart(self, df: pd.DataFrame, x: list, y: list) -> None:
"""Plot a line chart."""

self.run(
df,
f"""
Plot a line chart with the following x and y:
x = {x}
y = {y}
""",
)

def plot_scatter_chart(self, df: pd.DataFrame, x: list, y: list) -> None:
"""Plot a scatter chart."""

self.run(
df,
f"""
Plot a scatter chart with the following x and y:
x = {x}
y = {y}
""",
)

def plot_correlation_heatmap(self, df: pd.DataFrame) -> None:
"""Plot a correlation heatmap."""

self.run(df, "Plot a correlation heatmap.")

def plot_confusion_matrix(
self, df: pd.DataFrame, y_true: list, y_pred: list
) -> None:
"""Plot a confusion matrix."""

self.run(
df,
f"""
Plot a confusion matrix with the following y_true and y_pred:
y_true = {y_true}
y_pred = {y_pred}
""",
)

def plot_roc_curve(self, df: pd.DataFrame, y_true: list, y_pred: list) -> None:
"""Plot a ROC curve."""

self.run(
df,
f"""
Plot a ROC curve with the following y_true and y_pred:
y_true = {y_true}
y_pred = {y_pred}
""",
)

def rolling_mean(self, df: pd.DataFrame, column: str, window: int) -> pd.DataFrame:
"""Calculate the rolling mean."""

return self.run(
df,
f"Calculate the rolling mean of the column {column} with a window"
" of {window}.",
)

def rolling_median(
self, df: pd.DataFrame, column: str, window: int
) -> pd.DataFrame:
"""Calculate the rolling median."""

return self.run(
df,
f"Calculate the rolling median of the column {column} with a window"
" of {window}.",
)

def rolling_std(self, df: pd.DataFrame, column: str, window: int) -> pd.DataFrame:
"""Calculate the rolling standard deviation."""

return self.run(
df,
f"Calculate the rolling standard deviation of the column {column} with a"
"window of {window}.",
)

def segment_customers(
self, df: pd.DataFrame, features: list, n_clusters: int
) -> pd.DataFrame:
"""Segment customers."""

return self.run(
df,
f"""
Segment customers with the following features and number of clusters:
features = {features}
n_clusters = {n_clusters}
""",
)
Loading

0 comments on commit 88ef8db

Please sign in to comment.