diff --git a/scrapegraphai/graphs/base_graph.py b/scrapegraphai/graphs/base_graph.py index 71e42760..0b9f5517 100644 --- a/scrapegraphai/graphs/base_graph.py +++ b/scrapegraphai/graphs/base_graph.py @@ -4,8 +4,8 @@ import time import warnings from typing import Tuple -from langchain_community.callbacks import get_openai_callback from ..telemetry import log_graph_execution +from ..utils import CustomOpenAiCallbackManager class BaseGraph: """ @@ -52,6 +52,7 @@ def __init__(self, nodes: list, edges: list, entry_point: str, use_burr: bool = self.entry_point = entry_point.node_name self.graph_name = graph_name self.initial_state = {} + self.callback_manager = CustomOpenAiCallbackManager() if nodes[0].node_name != entry_point.node_name: # raise a warning if the entry point is not the first node in the list @@ -154,7 +155,7 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]: except Exception as e: schema = None - with get_openai_callback() as cb: + with self.callback_manager.exclusive_get_openai_callback() as cb: try: result = current_node.execute(state) except Exception as e: @@ -176,23 +177,24 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]: node_exec_time = time.time() - curr_time total_exec_time += node_exec_time - cb_data = { - "node_name": current_node.node_name, - "total_tokens": cb.total_tokens, - "prompt_tokens": cb.prompt_tokens, - "completion_tokens": cb.completion_tokens, - "successful_requests": cb.successful_requests, - "total_cost_USD": cb.total_cost, - "exec_time": node_exec_time, - } - - exec_info.append(cb_data) - - cb_total["total_tokens"] += cb_data["total_tokens"] - cb_total["prompt_tokens"] += cb_data["prompt_tokens"] - cb_total["completion_tokens"] += cb_data["completion_tokens"] - cb_total["successful_requests"] += cb_data["successful_requests"] - cb_total["total_cost_USD"] += cb_data["total_cost_USD"] + if cb is not None: + cb_data = { + "node_name": current_node.node_name, + "total_tokens": cb.total_tokens, + "prompt_tokens": cb.prompt_tokens, + "completion_tokens": cb.completion_tokens, + "successful_requests": cb.successful_requests, + "total_cost_USD": cb.total_cost, + "exec_time": node_exec_time, + } + + exec_info.append(cb_data) + + cb_total["total_tokens"] += cb_data["total_tokens"] + cb_total["prompt_tokens"] += cb_data["prompt_tokens"] + cb_total["completion_tokens"] += cb_data["completion_tokens"] + cb_total["successful_requests"] += cb_data["successful_requests"] + cb_total["total_cost_USD"] += cb_data["total_cost_USD"] if current_node.node_type == "conditional_node": current_node_name = result diff --git a/scrapegraphai/utils/__init__.py b/scrapegraphai/utils/__init__.py index fbd03800..0132c775 100644 --- a/scrapegraphai/utils/__init__.py +++ b/scrapegraphai/utils/__init__.py @@ -17,3 +17,4 @@ from .screenshot_scraping.text_detection import detect_text from .tokenizer import num_tokens_calculus from .split_text_into_chunks import split_text_into_chunks +from .custom_openai_callback import CustomOpenAiCallbackManager diff --git a/scrapegraphai/utils/custom_openai_callback.py b/scrapegraphai/utils/custom_openai_callback.py new file mode 100644 index 00000000..e0efa723 --- /dev/null +++ b/scrapegraphai/utils/custom_openai_callback.py @@ -0,0 +1,17 @@ +import threading +from contextlib import contextmanager +from langchain_community.callbacks import get_openai_callback + +class CustomOpenAiCallbackManager: + _lock = threading.Lock() + + @contextmanager + def exclusive_get_openai_callback(self): + if CustomOpenAiCallbackManager._lock.acquire(blocking=False): + try: + with get_openai_callback() as cb: + yield cb + finally: + CustomOpenAiCallbackManager._lock.release() + else: + yield None \ No newline at end of file