Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPT] enable historical signals fetch #1089

Merged
merged 7 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 122 additions & 55 deletions Evaluator/TA/ai_evaluator/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,28 +31,31 @@


class GPTEvaluator(evaluators.TAEvaluator):
GLOBAL_VERSION = 1
PREPROMPT = "Predict: {up or down} {confidence%} (no other information)"
PASSED_DATA_LEN = 10
MAX_CONFIDENCE_PERCENT = 100
HIGH_CONFIDENCE_PERCENT = 80
MEDIUM_CONFIDENCE_PERCENT = 50
LOW_CONFIDENCE_PERCENT = 30
INDICATORS = {
"No indicator: the raw value of the selected source": lambda data, period: data,
"No indicator: raw candles price data": lambda data, period: data,
"EMA: Exponential Moving Average": tulipy.ema,
"SMA: Simple Moving Average": tulipy.sma,
"Kaufman Adaptive Moving Average": tulipy.kama,
"Hull Moving Average": tulipy.kama,
"RSI: Relative Strength Index": tulipy.rsi,
"Detrended Price Oscillator": tulipy.dpo,
}
SOURCES = ["Open", "High", "Low", "Close", "Volume"]
SOURCES = ["Open", "High", "Low", "Close", "Volume", "Full candle (For no indicator only)"]
GPT_MODELS = []

def __init__(self, tentacles_setup_config):
super().__init__(tentacles_setup_config)
self.indicator = None
self.source = None
self.period = None
self.min_confidence_threshold = 0
self.gpt_model = gpt_service.GPTService.DEFAULT_MODEL
self.is_backtesting = False
self.min_allowed_timeframe = os.getenv("MIN_GPT_TIMEFRAME", None)
Expand All @@ -65,13 +68,21 @@ def __init__(self, tentacles_setup_config):
except ValueError:
self.logger.error(f"Invalid timeframe configuration: unknown timeframe: '{self.min_allowed_timeframe}'")
self.allow_reevaluations = os_util.parse_boolean_environment_var("ALLOW_GPT_REEVALUATIONS", "True")
self.services_config = None

def enable_reevaluation(self) -> bool:
"""
Override when artificial re-evaluations from the evaluator channel can be disabled
"""
return self.allow_reevaluations

@classmethod
def get_signals_history_type(cls):
"""
Override when this evaluator uses a specific type of signal history
"""
return commons_enums.SignalHistoryTypes.GPT

async def load_and_save_user_inputs(self, bot_id: str) -> dict:
"""
instance method API for user inputs
Expand All @@ -98,7 +109,12 @@ def init_user_inputs(self, inputs: dict) -> None:
self.period = self.UI.user_input(
"period", enums.UserInputTypes.INT,
self.period, inputs, min_val=1,
title="Period: length of the indicator period."
title="Period: length of the indicator period or the number of candles to give to ChatGPT."
)
self.min_confidence_threshold = self.UI.user_input(
"min_confidence_threshold", enums.UserInputTypes.INT,
self.min_confidence_threshold, inputs, min_val=0, max_val=100,
title="Minimum confidence threshold: % confidence value starting from which to return 1 or -1."
)
if len(self.GPT_MODELS) > 1 and self.enable_model_selector:
self.gpt_model = self.UI.user_input(
Expand All @@ -112,7 +128,9 @@ async def _init_GPT_models(self):
self.GPT_MODELS = [gpt_service.GPTService.DEFAULT_MODEL]
if self.enable_model_selector and not self.is_backtesting:
try:
service = await services_api.get_service(gpt_service.GPTService, self.is_backtesting)
service = await services_api.get_service(
gpt_service.GPTService, self.is_backtesting, self.services_config
)
self.GPT_MODELS = service.models
except Exception as err:
self.logger.exception(err, True, f"Impossible to fetch GPT models: {err}")
Expand All @@ -128,67 +146,110 @@ async def _init_registered_topics(self, all_symbols_by_crypto_currencies, curren

async def ohlcv_callback(self, exchange: str, exchange_id: str,
cryptocurrency: str, symbol: str, time_frame, candle, inc_in_construction_data):
candle_data = self.get_candles_data_api()(
self.get_exchange_symbol_data(exchange, exchange_id, symbol), time_frame,
include_in_construction=inc_in_construction_data
)
candle_data = self.get_candles_data(exchange, exchange_id, symbol, time_frame, inc_in_construction_data)
await self.evaluate(cryptocurrency, symbol, time_frame, candle_data, candle)

async def evaluate(self, cryptocurrency, symbol, time_frame, candle_data, candle):
self.eval_note = commons_constants.START_PENDING_EVAL_NOTE
if self._check_timeframe(time_frame):
try:
computed_data = self.call_indicator(candle_data)
reduced_data = computed_data[-self.PASSED_DATA_LEN:]
formatted_data = ", ".join(str(datum).replace('[', '').replace(']', '') for datum in reduced_data)
prediction = await self.ask_gpt(self.PREPROMPT, formatted_data, symbol, time_frame)
cleaned_prediction = prediction.strip().replace("\n", "").replace(".", "").lower()
prediction_side = self._parse_prediction_side(cleaned_prediction)
if prediction_side == 0:
self.logger.error(f"Error when reading GPT answer: {cleaned_prediction}")
return
confidence = self._parse_confidence(cleaned_prediction) / 100
self.eval_note = prediction_side * confidence
except services_errors.InvalidRequestError as e:
self.logger.error(f"Invalid GPT request: {e}")
except services_errors.RateLimitError as e:
self.logger.error(f"Too many requests: {e}")
except services_errors.UnavailableInBacktestingError:
# error already logged error for backtesting in use_backtesting_init_timeout
pass
except evaluators_errors.UnavailableEvaluatorError as e:
self.logger.exception(e, True, f"Evaluation error: {e}")
except tulipy.lib.InvalidOptionError as e:
self.logger.warning(
f"Error when computing {self.indicator} on {self.period} period with {len(candle_data)} "
f"candles: {e}"
)
self.logger.exception(e, False)
else:
self.logger.debug(f"Ignored {time_frame} time frame as the shorted allowed time frame is "
f"{self.min_allowed_timeframe}")
await self.evaluation_completed(cryptocurrency, symbol, time_frame,
eval_time=evaluators_util.get_eval_time(full_candle=candle,
time_frame=time_frame))
async with self.async_evaluation():
self.eval_note = commons_constants.START_PENDING_EVAL_NOTE
if self._check_timeframe(time_frame):
try:
candle_time = candle[commons_enums.PriceIndexes.IND_PRICE_TIME.value]
computed_data = self.call_indicator(candle_data)
formatted_data = self.get_formatted_data(computed_data)
prediction = await self.ask_gpt(self.PREPROMPT, formatted_data, symbol, time_frame, candle_time)
cleaned_prediction = prediction.strip().replace("\n", "").replace(".", "").lower()
prediction_side = self._parse_prediction_side(cleaned_prediction)
if prediction_side == 0 and not self.is_backtesting:
self.logger.error(f"Error when reading GPT answer: {cleaned_prediction}")
return
confidence = self._parse_confidence(cleaned_prediction) / 100
self.eval_note = prediction_side * confidence
except services_errors.InvalidRequestError as e:
self.logger.error(f"Invalid GPT request: {e}")
except services_errors.RateLimitError as e:
self.logger.error(f"Too many requests: {e}")
except services_errors.UnavailableInBacktestingError:
# error already logged error for backtesting in use_backtesting_init_timeout
pass
except evaluators_errors.UnavailableEvaluatorError as e:
self.logger.exception(e, True, f"Evaluation error: {e}")
except tulipy.lib.InvalidOptionError as e:
self.logger.warning(
f"Error when computing {self.indicator} on {self.period} period with {len(candle_data)} "
f"candles: {e}"
)
self.logger.exception(e, False)
else:
self.logger.debug(f"Ignored {time_frame} time frame as the shorted allowed time frame is "
f"{self.min_allowed_timeframe}")
await self.evaluation_completed(cryptocurrency, symbol, time_frame,
eval_time=evaluators_util.get_eval_time(full_candle=candle,
time_frame=time_frame))

def get_formatted_data(self, computed_data) -> str:
if self.source in self.get_unformated_sources():
return str(computed_data)
reduced_data = computed_data[-self.PASSED_DATA_LEN:]
return ", ".join(str(datum).replace('[', '').replace(']', '') for datum in reduced_data)

async def ask_gpt(self, preprompt, inputs, symbol, time_frame) -> str:
async def ask_gpt(self, preprompt, inputs, symbol, time_frame, candle_time) -> str:
try:
service = await services_api.get_service(gpt_service.GPTService, self.is_backtesting)
service = await services_api.get_service(
gpt_service.GPTService,
self.is_backtesting,
{} if self.is_backtesting else self.services_config
)
resp = await service.get_chat_completion(
[
service.create_message("system", preprompt),
service.create_message("user", inputs),
],
model=self.gpt_model if self.enable_model_selector else None
model=self.gpt_model if self.enable_model_selector else None,
exchange=self.exchange_name,
symbol=symbol,
time_frame=time_frame,
version=self.get_version(),
candle_open_time=candle_time,
use_stored_signals=self.is_backtesting
)
self.logger.info(f"GPT's answer is '{resp}' for {symbol} on {time_frame} with input: {inputs}")
return resp
except services_errors.CreationError as err:
raise evaluators_errors.UnavailableEvaluatorError(f"Impossible to get ChatGPT prediction: {err}") from err

def get_version(self):
# later on, identify by its specs
# return f"{self.gpt_model}-{self.source}-{self.indicator}-{self.period}-{self.GLOBAL_VERSION}"
return "0.0.0"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why the version should contains all these data and not just a number (here GLOBAL_VERISON)


def call_indicator(self, candle_data):
if self.source in self.get_unformated_sources():
return candle_data
return data_util.drop_nan(self.INDICATORS[self.indicator](candle_data, self.period))

def get_candles_data(self, exchange, exchange_id, symbol, time_frame, inc_in_construction_data):
if self.source in self.get_unformated_sources():
limit = self.period if inc_in_construction_data else self.period + 1
full_candles = trading_api.get_candles_as_list(
trading_api.get_symbol_historical_candles(
self.get_exchange_symbol_data(exchange, exchange_id, symbol), time_frame, limit=limit
)
)
# remove time value
for candle in full_candles:
candle.pop(commons_enums.PriceIndexes.IND_PRICE_TIME.value)
if inc_in_construction_data:
return full_candles
return full_candles[:-1]
return self.get_candles_data_api()(
self.get_exchange_symbol_data(exchange, exchange_id, symbol), time_frame,
include_in_construction=inc_in_construction_data
)

def get_unformated_sources(self):
return (self.SOURCES[5], )

def get_candles_data_api(self):
return {
self.SOURCES[0]: trading_api.get_symbol_open_candles,
Expand Down Expand Up @@ -216,14 +277,20 @@ def _parse_confidence(self, cleaned_prediction):
up with 70% confidence
up with high confidence
"""
value = self.LOW_CONFIDENCE_PERCENT
if "%" in cleaned_prediction:
percent_index = cleaned_prediction.index("%")
return float(cleaned_prediction[:percent_index].split(" ")[-1])
if "high" in cleaned_prediction:
return self.HIGH_CONFIDENCE_PERCENT
if "medium" in cleaned_prediction or "intermediate" in cleaned_prediction:
return self.MEDIUM_CONFIDENCE_PERCENT
if "low" in cleaned_prediction:
return self.LOW_CONFIDENCE_PERCENT
self.logger.warning(f"Impossible to parse confidence in {cleaned_prediction}. Using low confidence")
return self.LOW_CONFIDENCE_PERCENT
value = float(cleaned_prediction[:percent_index].split(" ")[-1])
elif "high" in cleaned_prediction:
value = self.HIGH_CONFIDENCE_PERCENT
elif "medium" in cleaned_prediction or "intermediate" in cleaned_prediction:
value = self.MEDIUM_CONFIDENCE_PERCENT
elif "low" in cleaned_prediction:
value = self.LOW_CONFIDENCE_PERCENT
elif not cleaned_prediction:
value = 0
else:
self.logger.warning(f"Impossible to parse confidence in {cleaned_prediction}. Using low confidence")
if value >= self.min_confidence_threshold:
return self.MAX_CONFIDENCE_PERCENT
return value
6 changes: 4 additions & 2 deletions Evaluator/TA/ai_evaluator/resources/GPTEvaluator.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
Uses [Chat GPT](https://chat.openai.com/) to predict the next moves of the market.

Evaluates between -1 to 1 according to chat GPT's prediction of the selected data and its confidence.
Evaluates between -1 to 1 according to ChatGPT's prediction of the selected data and its confidence.

*This evaluator can't be used in backtesting.*
Note: this evaluator can only be used in backtesting for markets where historical ChatGPT data are available.

Find the full list of supported historical markets on https://www.octobot.cloud/features/chatgpt-trading
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ class DisplayedElements(display.DisplayTranslator):
}

async def fill_from_database(self, trading_mode, database_manager, exchange_name, symbol, time_frame, exchange_id,
with_inputs=True):

with_inputs=True, symbols=None, time_frames=None):
async with databases.MetaDatabase.database(database_manager) as meta_db:
graphs_by_parts = {}
inputs = []
Expand All @@ -52,6 +51,10 @@ async def fill_from_database(self, trading_mode, database_manager, exchange_name
run_db = meta_db.get_run_db()
metadata_rows = await run_db.all(commons_enums.DBTables.METADATA.value)
metadata = metadata_rows[0] if metadata_rows else None
if symbols is not None:
symbols.extend(metadata[commons_enums.BacktestingMetadata.SYMBOLS.value])
if time_frames is not None:
time_frames.extend(metadata[commons_enums.BacktestingMetadata.TIME_FRAMES.value])
account_type = trading_api.get_account_type_from_run_metadata(metadata) \
if database_manager.is_backtesting() \
else trading_api.get_account_type_from_exchange_manager(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def fake_backtesting(backtesting_config):
config=backtesting_config,
exchange_ids=[],
matrix_id="",
backtesting_files=[]
backtesting_files=[],
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,8 @@ def profiles_management(action):
return util.get_rest_reply(flask.jsonify(data))
if action == "duplicate":
profile_id = flask.request.args.get("profile_id")
new_profile = models.duplicate_profile(profile_id)
models.select_profile(new_profile.profile_id)
flask.flash(f"New profile successfully created and selected.", "success")
models.duplicate_profile(profile_id)
flask.flash(f"New profile successfully created.", "success")
return util.get_rest_reply(flask.jsonify("Profile created"))
if action == "use_as_live":
profile_id = flask.request.args.get("profile_id")
Expand Down
Loading
Loading