Skip to content

Commit

Permalink
Merge pull request #670 from LorenzoPaleari/576-exec-info-misses-nest…
Browse files Browse the repository at this point in the history
…ed-graphs

Added CustomOpenaiCallback to ensure exclusive access to nested data.
  • Loading branch information
VinciGit00 authored Sep 14, 2024
2 parents 063dd1a + e657113 commit d7afdb1
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 19 deletions.
40 changes: 21 additions & 19 deletions scrapegraphai/graphs/base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import time
import warnings
from typing import Tuple
from langchain_community.callbacks import get_openai_callback
from ..telemetry import log_graph_execution
from ..utils import CustomOpenAiCallbackManager

class BaseGraph:
"""
Expand Down Expand Up @@ -52,6 +52,7 @@ def __init__(self, nodes: list, edges: list, entry_point: str, use_burr: bool =
self.entry_point = entry_point.node_name
self.graph_name = graph_name
self.initial_state = {}
self.callback_manager = CustomOpenAiCallbackManager()

if nodes[0].node_name != entry_point.node_name:
# raise a warning if the entry point is not the first node in the list
Expand Down Expand Up @@ -154,7 +155,7 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
except Exception as e:
schema = None

with get_openai_callback() as cb:
with self.callback_manager.exclusive_get_openai_callback() as cb:
try:
result = current_node.execute(state)
except Exception as e:
Expand All @@ -176,23 +177,24 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
node_exec_time = time.time() - curr_time
total_exec_time += node_exec_time

cb_data = {
"node_name": current_node.node_name,
"total_tokens": cb.total_tokens,
"prompt_tokens": cb.prompt_tokens,
"completion_tokens": cb.completion_tokens,
"successful_requests": cb.successful_requests,
"total_cost_USD": cb.total_cost,
"exec_time": node_exec_time,
}

exec_info.append(cb_data)

cb_total["total_tokens"] += cb_data["total_tokens"]
cb_total["prompt_tokens"] += cb_data["prompt_tokens"]
cb_total["completion_tokens"] += cb_data["completion_tokens"]
cb_total["successful_requests"] += cb_data["successful_requests"]
cb_total["total_cost_USD"] += cb_data["total_cost_USD"]
if cb is not None:
cb_data = {
"node_name": current_node.node_name,
"total_tokens": cb.total_tokens,
"prompt_tokens": cb.prompt_tokens,
"completion_tokens": cb.completion_tokens,
"successful_requests": cb.successful_requests,
"total_cost_USD": cb.total_cost,
"exec_time": node_exec_time,
}

exec_info.append(cb_data)

cb_total["total_tokens"] += cb_data["total_tokens"]
cb_total["prompt_tokens"] += cb_data["prompt_tokens"]
cb_total["completion_tokens"] += cb_data["completion_tokens"]
cb_total["successful_requests"] += cb_data["successful_requests"]
cb_total["total_cost_USD"] += cb_data["total_cost_USD"]

if current_node.node_type == "conditional_node":
current_node_name = result
Expand Down
1 change: 1 addition & 0 deletions scrapegraphai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
from .screenshot_scraping.text_detection import detect_text
from .tokenizer import num_tokens_calculus
from .split_text_into_chunks import split_text_into_chunks
from .custom_openai_callback import CustomOpenAiCallbackManager
17 changes: 17 additions & 0 deletions scrapegraphai/utils/custom_openai_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import threading
from contextlib import contextmanager
from langchain_community.callbacks import get_openai_callback

class CustomOpenAiCallbackManager:
_lock = threading.Lock()

@contextmanager
def exclusive_get_openai_callback(self):
if CustomOpenAiCallbackManager._lock.acquire(blocking=False):
try:
with get_openai_callback() as cb:
yield cb
finally:
CustomOpenAiCallbackManager._lock.release()
else:
yield None

0 comments on commit d7afdb1

Please sign in to comment.