Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

integrations compatible with agent #57

Merged
merged 3 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 34 additions & 139 deletions examples/uni_v3_lp/action_agent.py
Original file line number Diff line number Diff line change
@@ -1,170 +1,65 @@
import os
import logging
import pprint

import numpy as np
from addresses import ADDRESSES
from dotenv import find_dotenv, load_dotenv
from lp_tools import get_tick_range
from mint_position import close_position, get_all_user_positions, get_mint_params
from prefect import get_run_logger

from giza.agents.action import action
from giza.agents import AgentResult, GizaAgent
from giza.agents.task import task
from giza.agents import GizaAgent

load_dotenv(find_dotenv())

# Here we load a custom sepolia rpc url from the environment
sepolia_rpc_url = os.environ.get("SEPOLIA_RPC_URL")

MODEL_ID = ... # Update with your model ID
VERSION_ID = ... # Update with your version ID


@task
def process_data(realized_vol, dec_price_change):
pct_change_sq = (100 * dec_price_change) ** 2
X = np.array([[realized_vol, pct_change_sq]])
return X


# Get image
@task
def get_data():
# TODO: implement fetching onchain or from some other source
realized_vol = 4.20
dec_price_change = 0.1
return realized_vol, dec_price_change


@task
def create_agent(
model_id: int, version_id: int, chain: str, contracts: dict, account: str
):
"""
Create a Giza agent for the volatility prediction model
"""
def transmission():
logger = logging.getLogger(__name__)
id = ...
version = ...
account = ...
realized_vol, dec_price_change = get_data()
input_data = process_data(realized_vol, dec_price_change)

agent = GizaAgent(
contracts=contracts,
id=model_id,
version_id=version_id,
chain=chain,
integrations=["UniswapV3"],
id=id,
chain="ethereum:sepolia:geth",
version_id=version,
account=account,
)
return agent


@task
def predict(agent: GizaAgent, X: np.ndarray):
"""
Predict the digit in an image.

Args:
image (np.ndarray): Image to predict.

Returns:
int: Predicted digit.
"""
prediction = agent.predict(input_feed={"val": X}, verifiable=True, job_size="XL")
return prediction


@task
def get_pred_val(prediction: AgentResult):
"""
Get the value from the prediction.

Args:
prediction (dict): Prediction from the model.

Returns:
int: Predicted value.
"""
# This will block the executon until the prediction has generated the proof and the proof has been verified
return prediction.value[0][0]


# Create Action
@action
def transmission(
pred_model_id,
pred_version_id,
account="dev",
chain=f"ethereum:sepolia:{sepolia_rpc_url}",
):
logger = get_run_logger()

nft_manager_address = ADDRESSES["NonfungiblePositionManager"][11155111]
tokenA_address = ADDRESSES["UNI"][11155111]
tokenB_address = ADDRESSES["WETH"][11155111]
pool_address = "0x287B0e934ed0439E2a7b1d5F0FC25eA2c24b64f7"
user_address = "0xCBB090699E0664f0F6A4EFbC616f402233718152"

pool_fee = 3000
tokenA_amount = 1000
tokenB_amount = 1000

logger.info("Fetching input data")
realized_vol, dec_price_change = get_data()

logger.info(f"Input data: {realized_vol}, {dec_price_change}")
X = process_data(realized_vol, dec_price_change)

nft_manager_abi_path = "nft_manager_abi.json"
contracts = {
"nft_manager": [nft_manager_address, nft_manager_abi_path],
"tokenA": [tokenA_address],
"tokenB": tokenB_address,
"pool": pool_address,
}
agent = create_agent(
model_id=pred_model_id,
version_id=pred_version_id,
chain=chain,
contracts=contracts,
account=account,
result = agent.predict(
input_feed={"val": input_data}, verifiable=True, dry_run=True
)
result = predict(agent, X)
predicted_value = get_pred_val(result)
logger.info(f"Result: {result}")

logger.info(f"Result: {result}")
with agent.execute() as contracts:
logger.info("Executing contract")
# TODO: fix below
positions = get_all_user_positions(contracts.nft_manager, user_address)
logger.info(f"Found the following positions: {positions}")
# step 4: close all positions
logger.info("Closing all open positions...")
for nft_id in positions:
close_position(user_address, contracts.nft_manager, nft_id)
# step 4: calculate mint params
logger.info("Calculating mint params...")
_, curr_tick, _, _, _, _, _ = contracts.pool.slot0()
tokenA_decimals = contracts.tokenA.decimals()
tokenB_decimals = contracts.tokenB.decimals()
# TODO: confirm input should be result and not result * 100
lower_tick, upper_tick = get_tick_range(
curr_tick, predicted_value, tokenA_decimals, tokenB_decimals, pool_fee
UNI_address = "0x1f9840a85d5aF5bf1D1762F925BDADdC4201F984"
WETH_address = "0xfFf9976782d46CC05630D1f6eBAb18b2324d6B14"
uni = contracts.UniswapV3
volatility_prediction = result.value[0]
pool = uni.get_pool(UNI_address, WETH_address, fee=500)
curr_price = pool.get_pool_price()
lower_price = curr_price * (1 - volatility_prediction)
upper_price = curr_price * (1 + volatility_prediction)
amount0 = 100
amount1 = 100
agent_result = uni.mint_position(
pool, lower_price, upper_price, amount0, amount1
)
mint_params = get_mint_params(
tokenA_address,
tokenB_address,
user_address,
tokenA_amount,
tokenB_amount,
pool_fee,
lower_tick,
upper_tick,
logger.info(
f"Current price: {curr_price}, new bounds: {lower_price}, {upper_price}"
)
# step 5: mint new position
logger.info("Minting new position...")
contract_result = contracts.nft_manager.mint(mint_params)
logger.info("SUCCESSFULLY MINTED A POSITION")
logger.info("Contract executed")
logger.info(f"Minted position: {agent_result}")

logger.info(f"Contract result: {contract_result}")
pprint.pprint(contract_result.__dict__)
logger.info(f"Contract result: {agent_result}")
logger.info("Finished")


transmission(MODEL_ID, VERSION_ID)
transmission()
82 changes: 62 additions & 20 deletions giza/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, Callable, Dict, List, Optional, Self, Tuple, Union

from ape import Contract, accounts, networks
from ape.api import AccountAPI
from ape.contracts import ContractInstance
from ape.exceptions import NetworkError
from ape_accounts.accounts import InvalidPasswordError
Expand All @@ -18,6 +19,8 @@
from giza.cli.utils.enums import JobKind, JobStatus
from requests import HTTPError

from giza.agents.exceptions import DuplicateIntegrationError
from giza.agents.integration import IntegrationFactory
from giza.agents.model import GizaModel
from giza.agents.utils import read_json

Expand All @@ -34,7 +37,8 @@ def __init__(
self,
id: int,
version_id: int,
contracts: Dict[str, Union[str, List[str]]],
contracts: Optional[Dict[str, Union[str, List[str]]]] = None,
integrations: Optional[List[str]] = None,
chain: Optional[str] = None,
account: Optional[str] = None,
**kwargs: Any,
Expand All @@ -44,6 +48,7 @@ def __init__(
model_id (int): The ID of the model.
version_id (int): The version of the model.
contracts (Dict[str, str]): The contracts to handle, must be a dictionary with the contract name as the key and the contract address as the value.
integrations (List[str]): The integrations to use.
chain_id (int): The ID of the blockchain network.
**kwargs: Additional keyword arguments.
"""
Expand All @@ -63,11 +68,11 @@ def __init__(
logger.error("Agent is missing required parameters")
raise ValueError(f"Agent is missing required parameters: {e}")

self.contract_handler = ContractHandler(contracts)
self.chain = chain
self.account = account
self._check_passphrase_in_env()
self._check_or_create_account()
self.contract_handler = ContractHandler(contracts, integrations)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an if statement so if contracts and integrations are both None a error is raised


# Useful for testing
network_parser: Callable = kwargs.get(
Expand Down Expand Up @@ -240,8 +245,8 @@ def execute(self) -> Any:
f"Invalid passphrase for account {self.account}. Could not decrypt account."
) from e
logger.debug("Autosign enabled")
with accounts.use_sender(self._account):
yield self.contract_handler.handle()
with accounts.use_sender(self._account) as sender:
yield self.contract_handler.handle(account=sender)

def predict(
self,
Expand Down Expand Up @@ -452,15 +457,37 @@ class ContractHandler:
which means that it should be done insede the GizaAgent's execute context.
"""

def __init__(self, contracts: Dict[str, Union[str, List[str]]]) -> None:
def __init__(
self,
contracts: Optional[Dict[str, Union[str, List[str]]]] = None,
integrations: Optional[List[str]] = None,
) -> None:
if contracts is None and integrations is None:
raise ValueError("Contracts or integrations must be specified.")
if contracts is None:
contracts = {}
if integrations is None:
integrations = []
contract_names = list(contracts.keys())
duplicates = set(contract_names) & set(integrations)
if duplicates:
duplicate_names = ", ".join(duplicates)
raise DuplicateIntegrationError(
f"Integrations of these names already exist: {duplicate_names}. Choose different contract names."
)
self._contracts = contracts
self._integrations = integrations
self._contracts_instances: Dict[str, ContractInstance] = {}
self._integrations_instances: Dict[str, IntegrationFactory] = {}

def __getattr__(self, name: str) -> ContractInstance:
def __getattr__(self, name: str) -> Union[ContractInstance, IntegrationFactory]:
"""
Get the contract by name.
"""
return self._contracts_instances[name]
if name in self._contracts_instances.keys():
return self._contracts_instances[name]
if name in self._integrations_instances.keys():
return self._integrations_instances[name]

def _initiate_contract(
self, address: str, abi: Optional[str] = None
Expand All @@ -472,26 +499,41 @@ def _initiate_contract(
return Contract(address=address)
return Contract(address=address, abi=abi)

def handle(self) -> Self:
def _initiate_integration(
self, name: str, account: AccountAPI
) -> IntegrationFactory:
"""
Initiate the integration.
"""
return IntegrationFactory.from_name(name, sender=account)

def handle(self, account: Optional[AccountAPI] = None) -> Self:
"""
Handle the contracts.
"""
try:
for name, contract_data in self._contracts.items():
if isinstance(contract_data, str):
address = contract_data
self._contracts_instances[name] = self._initiate_contract(address)
elif isinstance(contract_data, list):
if len(contract_data) == 1:
address = contract_data[0]
if self._contracts:
for name, contract_data in self._contracts.items():
if isinstance(contract_data, str):
address = contract_data
self._contracts_instances[name] = self._initiate_contract(
address
)
else:
address, abi = contract_data
self._contracts_instances[name] = self._initiate_contract(
address, abi
)
elif isinstance(contract_data, list):
if len(contract_data) == 1:
address = contract_data[0]
self._contracts_instances[name] = self._initiate_contract(
address
)
else:
address, abi = contract_data
self._contracts_instances[name] = self._initiate_contract(
address, abi
)
for name in self._integrations:
self._integrations_instances[name] = self._initiate_integration(
name, account
)
except NetworkError as e:
logger.error(f"Failed to initiate contract: {e}")
raise ValueError(
Expand Down
4 changes: 4 additions & 0 deletions giza/agents/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
class DuplicateIntegrationError(Exception):
"""Exception raised when there is a duplicate in integration names."""

pass
13 changes: 13 additions & 0 deletions giza/agents/integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from ape.api import AccountAPI

from giza.agents.integrations import Uniswap


class IntegrationFactory:
@staticmethod
def from_name(name: str, sender: AccountAPI) -> Uniswap:
match name:
case "UniswapV3":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that it would be better to have an enum for this, maybe open an issue and will deal with this later

return Uniswap(sender, version=3)
case _:
raise ValueError(f"Integration {name} not found")
3 changes: 3 additions & 0 deletions giza/agents/integrations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from giza.agents.integrations.uniswap.uniswap import Uniswap

__all__ = ["Uniswap"]
Loading
Loading