From 47d3e5c0c5e8900c32165cc82e218198d69df9f9 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Wed, 18 Dec 2024 09:32:20 -0700 Subject: [PATCH] Rename `Cost` to `Usage` (#403) --- docs/api/models/ollama.md | 8 +- docs/api/result.md | 2 +- docs/results.md | 6 +- docs/testing-evals.md | 2 +- pydantic_ai_examples/pydantic_model.py | 2 +- pydantic_ai_examples/stream_markdown.py | 2 +- pydantic_ai_slim/pydantic_ai/agent.py | 22 +-- .../pydantic_ai/models/__init__.py | 18 +- .../pydantic_ai/models/anthropic.py | 10 +- .../pydantic_ai/models/function.py | 30 ++-- pydantic_ai_slim/pydantic_ai/models/gemini.py | 28 ++-- pydantic_ai_slim/pydantic_ai/models/groq.py | 36 ++-- .../pydantic_ai/models/mistral.py | 38 ++--- pydantic_ai_slim/pydantic_ai/models/openai.py | 36 ++-- pydantic_ai_slim/pydantic_ai/models/test.py | 24 +-- pydantic_ai_slim/pydantic_ai/result.py | 44 ++--- tests/models/test_anthropic.py | 24 +-- tests/models/test_gemini.py | 14 +- tests/models/test_groq.py | 8 +- tests/models/test_mistral.py | 156 +++++++++--------- tests/models/test_model_function.py | 6 +- tests/models/test_ollama.py | 6 +- tests/models/test_openai.py | 22 +-- tests/test_agent.py | 8 +- tests/test_live.py | 18 +- tests/test_logfire.py | 4 +- tests/test_streaming.py | 4 +- 27 files changed, 289 insertions(+), 289 deletions(-) diff --git a/docs/api/models/ollama.md b/docs/api/models/ollama.md index 94efe490..8443fa04 100644 --- a/docs/api/models/ollama.md +++ b/docs/api/models/ollama.md @@ -31,8 +31,8 @@ agent = Agent('ollama:llama3.2', result_type=CityLocation) result = agent.run_sync('Where were the olympics held in 2012?') print(result.data) #> city='London' country='United Kingdom' -print(result.cost()) -#> Cost(request_tokens=57, response_tokens=8, total_tokens=65, details=None) +print(result.usage()) +#> Usage(request_tokens=57, response_tokens=8, total_tokens=65, details=None) ``` ## Example using a remote server @@ -59,8 +59,8 @@ agent = Agent(model=ollama_model, result_type=CityLocation) result = agent.run_sync('Where were the olympics held in 2012?') print(result.data) #> city='London' country='United Kingdom' -print(result.cost()) -#> Cost(request_tokens=57, response_tokens=8, total_tokens=65, details=None) +print(result.usage()) +#> Usage(request_tokens=57, response_tokens=8, total_tokens=65, details=None) ``` 1. The name of the model running on the remote server diff --git a/docs/api/result.md b/docs/api/result.md index 83d61af8..d4310881 100644 --- a/docs/api/result.md +++ b/docs/api/result.md @@ -7,4 +7,4 @@ - ResultData - RunResult - StreamedRunResult - - Cost + - Usage diff --git a/docs/results.md b/docs/results.md index 42573fa2..f3f9700f 100644 --- a/docs/results.md +++ b/docs/results.md @@ -1,5 +1,5 @@ Results are the final values returned from [running an agent](agents.md#running-agents). -The result values are wrapped in [`RunResult`][pydantic_ai.result.RunResult] and [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] so you can access other data like [cost][pydantic_ai.result.Cost] of the run and [message history](message-history.md#accessing-messages-from-results) +The result values are wrapped in [`RunResult`][pydantic_ai.result.RunResult] and [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] so you can access other data like [usage][pydantic_ai.result.Usage] of the run and [message history](message-history.md#accessing-messages-from-results) Both `RunResult` and `StreamedRunResult` are generic in the data they wrap, so typing information about the data returned by the agent is preserved. @@ -18,8 +18,8 @@ agent = Agent('gemini-1.5-flash', result_type=CityLocation) result = agent.run_sync('Where were the olympics held in 2012?') print(result.data) #> city='London' country='United Kingdom' -print(result.cost()) -#> Cost(request_tokens=57, response_tokens=8, total_tokens=65, details=None) +print(result.usage()) +#> Usage(request_tokens=57, response_tokens=8, total_tokens=65, details=None) ``` _(This example is complete, it can be run "as is")_ diff --git a/docs/testing-evals.md b/docs/testing-evals.md index b0d23e7f..81cd2efe 100644 --- a/docs/testing-evals.md +++ b/docs/testing-evals.md @@ -18,7 +18,7 @@ Unless you're really sure you know better, you'll probably want to follow roughl * Use [`pytest`](https://docs.pytest.org/en/stable/) as your test harness * If you find yourself typing out long assertions, use [inline-snapshot](https://15r10nk.github.io/inline-snapshot/latest/) * Similarly, [dirty-equals](https://dirty-equals.helpmanual.io/latest/) can be useful for comparing large data structures -* Use [`TestModel`][pydantic_ai.models.test.TestModel] or [`FunctionModel`][pydantic_ai.models.function.FunctionModel] in place of your actual model to avoid the cost, latency and variability of real LLM calls +* Use [`TestModel`][pydantic_ai.models.test.TestModel] or [`FunctionModel`][pydantic_ai.models.function.FunctionModel] in place of your actual model to avoid the usage, latency and variability of real LLM calls * Use [`Agent.override`][pydantic_ai.agent.Agent.override] to replace your model inside your application logic * Set [`ALLOW_MODEL_REQUESTS=False`][pydantic_ai.models.ALLOW_MODEL_REQUESTS] globally to block any requests from being made to non-test models accidentally diff --git a/pydantic_ai_examples/pydantic_model.py b/pydantic_ai_examples/pydantic_model.py index 9c654e9e..5e4dd833 100644 --- a/pydantic_ai_examples/pydantic_model.py +++ b/pydantic_ai_examples/pydantic_model.py @@ -30,4 +30,4 @@ class MyModel(BaseModel): if __name__ == '__main__': result = agent.run_sync('The windy city in the US of A.') print(result.data) - print(result.cost()) + print(result.usage()) diff --git a/pydantic_ai_examples/stream_markdown.py b/pydantic_ai_examples/stream_markdown.py index bba30767..bc2362f7 100644 --- a/pydantic_ai_examples/stream_markdown.py +++ b/pydantic_ai_examples/stream_markdown.py @@ -43,7 +43,7 @@ async def main(): async with agent.run_stream(prompt, model=model) as result: async for message in result.stream(): live.update(Markdown(message)) - console.log(result.cost()) + console.log(result.usage()) else: console.log(f'{model} requires {env_var} to be set.') diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 542086f2..72c76eb7 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -237,7 +237,7 @@ async def run( for tool in self._function_tools.values(): tool.current_retry = 0 - cost = result.Cost() + usage = result.Usage() model_settings = merge_model_settings(self.model_settings, model_settings) @@ -248,12 +248,12 @@ async def run( agent_model = await self._prepare_model(model_used, deps, messages) with _logfire.span('model request', run_step=run_step) as model_req_span: - model_response, request_cost = await agent_model.request(messages, model_settings) + model_response, request_usage = await agent_model.request(messages, model_settings) model_req_span.set_attribute('response', model_response) - model_req_span.set_attribute('cost', request_cost) + model_req_span.set_attribute('usage', request_usage) messages.append(model_response) - cost += request_cost + usage += request_usage with _logfire.span('handle model response', run_step=run_step) as handle_span: final_result, tool_responses = await self._handle_model_response(model_response, deps, messages) @@ -266,10 +266,10 @@ async def run( if final_result is not None: result_data = final_result.data run_span.set_attribute('all_messages', messages) - run_span.set_attribute('cost', cost) + run_span.set_attribute('usage', usage) handle_span.set_attribute('result', result_data) handle_span.message = 'handle model response -> final result' - return result.RunResult(messages, new_message_index, result_data, cost) + return result.RunResult(messages, new_message_index, result_data, usage) else: # continue the conversation handle_span.set_attribute('tool_responses', tool_responses) @@ -385,7 +385,7 @@ async def main(): for tool in self._function_tools.values(): tool.current_retry = 0 - cost = result.Cost() + usage = result.Usage() model_settings = merge_model_settings(self.model_settings, model_settings) run_step = 0 @@ -434,7 +434,7 @@ async def on_complete(): yield result.StreamedRunResult( messages, new_message_index, - cost, + usage, result_stream, self._result_schema, deps, @@ -455,8 +455,8 @@ async def on_complete(): handle_span.set_attribute('tool_responses', tool_responses) tool_responses_str = ' '.join(r.part_kind for r in tool_responses) handle_span.message = f'handle model response -> {tool_responses_str}' - # the model_response should have been fully streamed by now, we can add it's cost - cost += model_response.cost() + # the model_response should have been fully streamed by now, we can add its usage + usage += model_response.usage() @contextmanager def override( @@ -990,7 +990,7 @@ async def _handle_streamed_model_response( response = _messages.RetryPromptPart( content='Plain text responses are not permitted, please call one of the functions instead.', ) - # stream the response, so cost is correct + # stream the response, so usage is correct async for _ in model_response: pass diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 0291e6bc..0c4a4cff 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -20,7 +20,7 @@ from ..settings import ModelSettings if TYPE_CHECKING: - from ..result import Cost + from ..result import Usage from ..tools import ToolDefinition @@ -122,7 +122,7 @@ class AgentModel(ABC): @abstractmethod async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None - ) -> tuple[ModelResponse, Cost]: + ) -> tuple[ModelResponse, Usage]: """Make a request to the model.""" raise NotImplementedError() @@ -164,10 +164,10 @@ def get(self, *, final: bool = False) -> Iterable[str]: raise NotImplementedError() @abstractmethod - def cost(self) -> Cost: - """Return the cost of the request. + def usage(self) -> Usage: + """Return the usage of the request. - NOTE: this won't return the ful cost until the stream is finished. + NOTE: this won't return the full usage until the stream is finished. """ raise NotImplementedError() @@ -205,10 +205,10 @@ def get(self, *, final: bool = False) -> ModelResponse: raise NotImplementedError() @abstractmethod - def cost(self) -> Cost: - """Get the cost of the request. + def usage(self) -> Usage: + """Get the usage of the request. - NOTE: this won't return the full cost until the stream is finished. + NOTE: this won't return the full usage until the stream is finished. """ raise NotImplementedError() @@ -235,7 +235,7 @@ def timestamp(self) -> datetime: def check_allow_model_requests() -> None: """Check if model requests are allowed. - If you're defining your own models that have cost or latency associated with their use, you should call this in + If you're defining your own models that have costs or latency associated with their use, you should call this in [`Model.agent_model`][pydantic_ai.models.Model.agent_model]. Raises: diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 4c4c4384..325de9ff 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -158,9 +158,9 @@ class AnthropicAgentModel(AgentModel): async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None - ) -> tuple[ModelResponse, result.Cost]: + ) -> tuple[ModelResponse, result.Usage]: response = await self._messages_create(messages, False, model_settings) - return self._process_response(response), _map_cost(response) + return self._process_response(response), _map_usage(response) @asynccontextmanager async def request_stream( @@ -315,7 +315,7 @@ def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam: ) -def _map_cost(message: AnthropicMessage | RawMessageStreamEvent) -> result.Cost: +def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> result.Usage: if isinstance(message, AnthropicMessage): usage = message.usage else: @@ -332,11 +332,11 @@ def _map_cost(message: AnthropicMessage | RawMessageStreamEvent) -> result.Cost: usage = None if usage is None: - return result.Cost() + return result.Usage() request_tokens = getattr(usage, 'input_tokens', None) - return result.Cost( + return result.Usage( # Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence this getattr request_tokens=request_tokens, response_tokens=usage.output_tokens, diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index 007538f4..077b29c6 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -144,7 +144,7 @@ class FunctionAgentModel(AgentModel): async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None - ) -> tuple[ModelResponse, result.Cost]: + ) -> tuple[ModelResponse, result.Usage]: agent_info = replace(self.agent_info, model_settings=model_settings) assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests' @@ -155,7 +155,7 @@ async def request( assert isinstance(response_, ModelResponse), response_ response = response_ # TODO is `messages` right here? Should it just be new messages? - return response, _estimate_cost(chain(messages, [response])) + return response, _estimate_usage(chain(messages, [response])) @asynccontextmanager async def request_stream( @@ -198,8 +198,8 @@ def get(self, *, final: bool = False) -> Iterable[str]: yield from self._buffer self._buffer.clear() - def cost(self) -> result.Cost: - return result.Cost() + def usage(self) -> result.Usage: + return result.Usage() def timestamp(self) -> datetime: return self._timestamp @@ -236,15 +236,15 @@ def get(self, *, final: bool = False) -> ModelResponse: return ModelResponse(calls, timestamp=self._timestamp) - def cost(self) -> result.Cost: - return result.Cost() + def usage(self) -> result.Usage: + return result.Usage() def timestamp(self) -> datetime: return self._timestamp -def _estimate_cost(messages: Iterable[ModelMessage]) -> result.Cost: - """Very rough guesstimate of the number of tokens associate with a series of messages. +def _estimate_usage(messages: Iterable[ModelMessage]) -> result.Usage: + """Very rough guesstimate of the token usage associated with a series of messages. This is designed to be used solely to give plausible numbers for testing! """ @@ -255,32 +255,32 @@ def _estimate_cost(messages: Iterable[ModelMessage]) -> result.Cost: if isinstance(message, ModelRequest): for part in message.parts: if isinstance(part, (SystemPromptPart, UserPromptPart)): - request_tokens += _string_cost(part.content) + request_tokens += _string_usage(part.content) elif isinstance(part, ToolReturnPart): - request_tokens += _string_cost(part.model_response_str()) + request_tokens += _string_usage(part.model_response_str()) elif isinstance(part, RetryPromptPart): - request_tokens += _string_cost(part.model_response()) + request_tokens += _string_usage(part.model_response()) else: assert_never(part) elif isinstance(message, ModelResponse): for part in message.parts: if isinstance(part, TextPart): - response_tokens += _string_cost(part.content) + response_tokens += _string_usage(part.content) elif isinstance(part, ToolCallPart): call = part if isinstance(call.args, ArgsJson): args_str = call.args.args_json else: args_str = pydantic_core.to_json(call.args.args_dict).decode() - response_tokens += 1 + _string_cost(args_str) + response_tokens += 1 + _string_usage(args_str) else: assert_never(part) else: assert_never(message) - return result.Cost( + return result.Usage( request_tokens=request_tokens, response_tokens=response_tokens, total_tokens=request_tokens + response_tokens ) -def _string_cost(content: str) -> int: +def _string_usage(content: str) -> int: return len(re.split(r'[\s",.:]+', content)) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 72906428..ff0d4803 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -172,10 +172,10 @@ def __init__( async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None - ) -> tuple[ModelResponse, result.Cost]: + ) -> tuple[ModelResponse, result.Usage]: async with self._make_request(messages, False, model_settings) as http_response: response = _gemini_response_ta.validate_json(await http_response.aread()) - return self._process_response(response), _metadata_as_cost(response) + return self._process_response(response), _metadata_as_usage(response) @asynccontextmanager async def request_stream( @@ -301,7 +301,7 @@ class GeminiStreamTextResponse(StreamTextResponse): _stream: AsyncIterator[bytes] _position: int = 0 _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) - _cost: result.Cost = field(default_factory=result.Cost, init=False) + _usage: result.Usage = field(default_factory=result.Usage, init=False) async def __anext__(self) -> None: chunk = await self._stream.__anext__() @@ -321,7 +321,7 @@ def get(self, *, final: bool = False) -> Iterable[str]: new_items, experimental_allow_partial='trailing-strings' ) for r in new_responses: - self._cost += _metadata_as_cost(r) + self._usage += _metadata_as_usage(r) parts = r['candidates'][0]['content']['parts'] if _all_text_parts(parts): for part in parts: @@ -331,8 +331,8 @@ def get(self, *, final: bool = False) -> Iterable[str]: 'Streamed response with unexpected content, expected all parts to be text' ) - def cost(self) -> result.Cost: - return self._cost + def usage(self) -> result.Usage: + return self._usage def timestamp(self) -> datetime: return self._timestamp @@ -345,7 +345,7 @@ class GeminiStreamStructuredResponse(StreamStructuredResponse): _content: bytearray _stream: AsyncIterator[bytes] _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) - _cost: result.Cost = field(default_factory=result.Cost, init=False) + _usage: result.Usage = field(default_factory=result.Usage, init=False) async def __anext__(self) -> None: chunk = await self._stream.__anext__() @@ -365,15 +365,15 @@ def get(self, *, final: bool = False) -> ModelResponse: experimental_allow_partial='off' if final else 'trailing-strings', ) combined_parts: list[_GeminiPartUnion] = [] - self._cost = result.Cost() + self._usage = result.Usage() for r in responses: - self._cost += _metadata_as_cost(r) + self._usage += _metadata_as_usage(r) candidate = r['candidates'][0] combined_parts.extend(candidate['content']['parts']) return _process_response_from_parts(combined_parts, timestamp=self._timestamp) - def cost(self) -> result.Cost: - return self._cost + def usage(self) -> result.Usage: + return self._usage def timestamp(self) -> datetime: return self._timestamp @@ -640,14 +640,14 @@ class _GeminiUsageMetaData(TypedDict, total=False): cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]] -def _metadata_as_cost(response: _GeminiResponse) -> result.Cost: +def _metadata_as_usage(response: _GeminiResponse) -> result.Usage: metadata = response.get('usage_metadata') if metadata is None: - return result.Cost() + return result.Usage() details: dict[str, int] = {} if cached_content_token_count := metadata.get('cached_content_token_count'): details['cached_content_token_count'] = cached_content_token_count - return result.Cost( + return result.Usage( request_tokens=metadata.get('prompt_token_count', 0), response_tokens=metadata.get('candidates_token_count', 0), total_tokens=metadata.get('total_token_count', 0), diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index 20be8a33..3af295d4 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -25,7 +25,7 @@ ToolReturnPart, UserPromptPart, ) -from ..result import Cost +from ..result import Usage from ..settings import ModelSettings from ..tools import ToolDefinition from . import ( @@ -158,9 +158,9 @@ class GroqAgentModel(AgentModel): async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None - ) -> tuple[ModelResponse, result.Cost]: + ) -> tuple[ModelResponse, result.Usage]: response = await self._completions_create(messages, False, model_settings) - return self._process_response(response), _map_cost(response) + return self._process_response(response), _map_usage(response) @asynccontextmanager async def request_stream( @@ -228,7 +228,7 @@ def _process_response(response: chat.ChatCompletion) -> ModelResponse: async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse: """Process a streamed response, and prepare a streaming response to return.""" timestamp: datetime | None = None - start_cost = Cost() + start_usage = Usage() # the first chunk may contain enough information so we iterate until we get either `tool_calls` or `content` while True: try: @@ -236,19 +236,19 @@ async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) except StopAsyncIteration as e: raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc) - start_cost += _map_cost(chunk) + start_usage += _map_usage(chunk) if chunk.choices: delta = chunk.choices[0].delta if delta.content is not None: - return GroqStreamTextResponse(delta.content, response, timestamp, start_cost) + return GroqStreamTextResponse(delta.content, response, timestamp, start_usage) elif delta.tool_calls is not None: return GroqStreamStructuredResponse( response, {c.index: c for c in delta.tool_calls}, timestamp, - start_cost, + start_usage, ) @classmethod @@ -308,7 +308,7 @@ class GroqStreamTextResponse(StreamTextResponse): _first: str | None _response: AsyncStream[ChatCompletionChunk] _timestamp: datetime - _cost: result.Cost + _usage: result.Usage _buffer: list[str] = field(default_factory=list, init=False) async def __anext__(self) -> None: @@ -318,7 +318,7 @@ async def __anext__(self) -> None: return None chunk = await self._response.__anext__() - self._cost = _map_cost(chunk) + self._usage = _map_usage(chunk) try: choice = chunk.choices[0] @@ -335,8 +335,8 @@ def get(self, *, final: bool = False) -> Iterable[str]: yield from self._buffer self._buffer.clear() - def cost(self) -> Cost: - return self._cost + def usage(self) -> Usage: + return self._usage def timestamp(self) -> datetime: return self._timestamp @@ -349,11 +349,11 @@ class GroqStreamStructuredResponse(StreamStructuredResponse): _response: AsyncStream[ChatCompletionChunk] _delta_tool_calls: dict[int, ChoiceDeltaToolCall] _timestamp: datetime - _cost: result.Cost + _usage: result.Usage async def __anext__(self) -> None: chunk = await self._response.__anext__() - self._cost = _map_cost(chunk) + self._usage = _map_usage(chunk) try: choice = chunk.choices[0] @@ -384,8 +384,8 @@ def get(self, *, final: bool = False) -> ModelResponse: return ModelResponse(items, timestamp=self._timestamp) - def cost(self) -> Cost: - return self._cost + def usage(self) -> Usage: + return self._usage def timestamp(self) -> datetime: return self._timestamp @@ -400,7 +400,7 @@ def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam: ) -def _map_cost(completion: ChatCompletionChunk | ChatCompletion) -> result.Cost: +def _map_usage(completion: ChatCompletionChunk | ChatCompletion) -> result.Usage: usage = None if isinstance(completion, ChatCompletion): usage = completion.usage @@ -408,9 +408,9 @@ def _map_cost(completion: ChatCompletionChunk | ChatCompletion) -> result.Cost: usage = completion.x_groq.usage if usage is None: - return result.Cost() + return result.Usage() - return result.Cost( + return result.Usage( request_tokens=usage.prompt_tokens, response_tokens=usage.completion_tokens, total_tokens=usage.total_tokens, diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 639b9375..a7c956b5 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -26,7 +26,7 @@ ToolReturnPart, UserPromptPart, ) -from ..result import Cost +from ..result import Usage from ..settings import ModelSettings from ..tools import ToolDefinition from . import ( @@ -156,10 +156,10 @@ class MistralAgentModel(AgentModel): async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None - ) -> tuple[ModelResponse, Cost]: + ) -> tuple[ModelResponse, Usage]: """Make a non-streaming request to the model from Pydantic AI call.""" response = await self._completions_create(messages, model_settings) - return self._process_response(response), _map_cost(response) + return self._process_response(response), _map_usage(response) @asynccontextmanager async def request_stream( @@ -297,7 +297,7 @@ async def _process_streamed_response( response: MistralEventStreamAsync[MistralCompletionEvent], ) -> EitherStreamedResponse: """Process a streamed response, and prepare a streaming response to return.""" - start_cost = Cost() + start_usage = Usage() # Iterate until we get either `tool_calls` or `content` from the first chunk. while True: @@ -307,7 +307,7 @@ async def _process_streamed_response( except StopAsyncIteration as e: raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e - start_cost += _map_cost(chunk) + start_usage += _map_usage(chunk) if chunk.created: timestamp = datetime.fromtimestamp(chunk.created, tz=timezone.utc) @@ -329,11 +329,11 @@ async def _process_streamed_response( response, content, timestamp, - start_cost, + start_usage, ) elif content: - return MistralStreamTextResponse(content, response, timestamp, start_cost) + return MistralStreamTextResponse(content, response, timestamp, start_usage) @staticmethod def _map_to_mistral_tool_call(t: ToolCallPart) -> MistralToolCall: @@ -474,7 +474,7 @@ class MistralStreamTextResponse(StreamTextResponse): _first: str | None _response: MistralEventStreamAsync[MistralCompletionEvent] _timestamp: datetime - _cost: Cost + _usage: Usage _buffer: list[str] = field(default_factory=list, init=False) async def __anext__(self) -> None: @@ -484,7 +484,7 @@ async def __anext__(self) -> None: return None chunk = await self._response.__anext__() - self._cost += _map_cost(chunk.data) + self._usage += _map_usage(chunk.data) try: choice = chunk.data.choices[0] @@ -502,8 +502,8 @@ def get(self, *, final: bool = False) -> Iterable[str]: yield from self._buffer self._buffer.clear() - def cost(self) -> Cost: - return self._cost + def usage(self) -> Usage: + return self._usage def timestamp(self) -> datetime: return self._timestamp @@ -518,11 +518,11 @@ class MistralStreamStructuredResponse(StreamStructuredResponse): _response: MistralEventStreamAsync[MistralCompletionEvent] _delta_content: str | None _timestamp: datetime - _cost: Cost + _usage: Usage async def __anext__(self) -> None: chunk = await self._response.__anext__() - self._cost += _map_cost(chunk.data) + self._usage += _map_usage(chunk.data) try: choice = chunk.data.choices[0] @@ -571,8 +571,8 @@ def get(self, *, final: bool = False) -> ModelResponse: return ModelResponse(calls, timestamp=self._timestamp) - def cost(self) -> Cost: - return self._cost + def usage(self) -> Usage: + return self._usage def timestamp(self) -> datetime: return self._timestamp @@ -645,17 +645,17 @@ def _map_mistral_to_pydantic_tool_call(tool_call: MistralToolCall) -> ToolCallPa ) -def _map_cost(response: MistralChatCompletionResponse | MistralCompletionChunk) -> Cost: - """Maps a Mistral Completion Chunk or Chat Completion Response to a Cost.""" +def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) -> Usage: + """Maps a Mistral Completion Chunk or Chat Completion Response to a Usage.""" if response.usage: - return Cost( + return Usage( request_tokens=response.usage.prompt_tokens, response_tokens=response.usage.completion_tokens, total_tokens=response.usage.total_tokens, details=None, ) else: - return Cost() + return Usage() def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None: diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 95759598..3d69cdda 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -25,7 +25,7 @@ ToolReturnPart, UserPromptPart, ) -from ..result import Cost +from ..result import Usage from ..settings import ModelSettings from ..tools import ToolDefinition from . import ( @@ -147,9 +147,9 @@ class OpenAIAgentModel(AgentModel): async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None - ) -> tuple[ModelResponse, result.Cost]: + ) -> tuple[ModelResponse, result.Usage]: response = await self._completions_create(messages, False, model_settings) - return self._process_response(response), _map_cost(response) + return self._process_response(response), _map_usage(response) @asynccontextmanager async def request_stream( @@ -218,7 +218,7 @@ def _process_response(response: chat.ChatCompletion) -> ModelResponse: async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse: """Process a streamed response, and prepare a streaming response to return.""" timestamp: datetime | None = None - start_cost = Cost() + start_usage = Usage() # the first chunk may contain enough information so we iterate until we get either `tool_calls` or `content` while True: try: @@ -227,19 +227,19 @@ async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc) - start_cost += _map_cost(chunk) + start_usage += _map_usage(chunk) if chunk.choices: delta = chunk.choices[0].delta if delta.content is not None: - return OpenAIStreamTextResponse(delta.content, response, timestamp, start_cost) + return OpenAIStreamTextResponse(delta.content, response, timestamp, start_usage) elif delta.tool_calls is not None: return OpenAIStreamStructuredResponse( response, {c.index: c for c in delta.tool_calls}, timestamp, - start_cost, + start_usage, ) # else continue until we get either delta.content or delta.tool_calls @@ -302,7 +302,7 @@ class OpenAIStreamTextResponse(StreamTextResponse): _first: str | None _response: AsyncStream[ChatCompletionChunk] _timestamp: datetime - _cost: result.Cost + _usage: result.Usage _buffer: list[str] = field(default_factory=list, init=False) async def __anext__(self) -> None: @@ -312,7 +312,7 @@ async def __anext__(self) -> None: return None chunk = await self._response.__anext__() - self._cost += _map_cost(chunk) + self._usage += _map_usage(chunk) try: choice = chunk.choices[0] except IndexError: @@ -328,8 +328,8 @@ def get(self, *, final: bool = False) -> Iterable[str]: yield from self._buffer self._buffer.clear() - def cost(self) -> Cost: - return self._cost + def usage(self) -> Usage: + return self._usage def timestamp(self) -> datetime: return self._timestamp @@ -342,11 +342,11 @@ class OpenAIStreamStructuredResponse(StreamStructuredResponse): _response: AsyncStream[ChatCompletionChunk] _delta_tool_calls: dict[int, ChoiceDeltaToolCall] _timestamp: datetime - _cost: result.Cost + _usage: result.Usage async def __anext__(self) -> None: chunk = await self._response.__anext__() - self._cost += _map_cost(chunk) + self._usage += _map_usage(chunk) try: choice = chunk.choices[0] except IndexError: @@ -376,8 +376,8 @@ def get(self, *, final: bool = False) -> ModelResponse: return ModelResponse(items, timestamp=self._timestamp) - def cost(self) -> Cost: - return self._cost + def usage(self) -> Usage: + return self._usage def timestamp(self) -> datetime: return self._timestamp @@ -392,17 +392,17 @@ def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam: ) -def _map_cost(response: chat.ChatCompletion | ChatCompletionChunk) -> result.Cost: +def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk) -> result.Usage: usage = response.usage if usage is None: - return result.Cost() + return result.Usage() else: details: dict[str, int] = {} if usage.completion_tokens_details is not None: details.update(usage.completion_tokens_details.model_dump(exclude_none=True)) if usage.prompt_tokens_details is not None: details.update(usage.prompt_tokens_details.model_dump(exclude_none=True)) - return result.Cost( + return result.Usage( request_tokens=usage.prompt_tokens, response_tokens=usage.completion_tokens, total_tokens=usage.total_tokens, diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index 41670914..f953d8fa 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -21,7 +21,7 @@ ToolCallPart, ToolReturnPart, ) -from ..result import Cost +from ..result import Usage from ..settings import ModelSettings from ..tools import ToolDefinition from . import ( @@ -131,15 +131,15 @@ class TestAgentModel(AgentModel): async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None - ) -> tuple[ModelResponse, Cost]: - return self._request(messages, model_settings), Cost() + ) -> tuple[ModelResponse, Usage]: + return self._request(messages, model_settings), Usage() @asynccontextmanager async def request_stream( self, messages: list[ModelMessage], model_settings: ModelSettings | None ) -> AsyncIterator[EitherStreamedResponse]: msg = self._request(messages, model_settings) - cost = Cost() + usage = Usage() # TODO: Rework this once we make StreamTextResponse more general texts: list[str] = [] @@ -153,9 +153,9 @@ async def request_stream( assert_never(item) if texts: - yield TestStreamTextResponse('\n\n'.join(texts), cost) + yield TestStreamTextResponse('\n\n'.join(texts), usage) else: - yield TestStreamStructuredResponse(msg, cost) + yield TestStreamStructuredResponse(msg, usage) def gen_tool_args(self, tool_def: ToolDefinition) -> Any: return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate() @@ -213,7 +213,7 @@ class TestStreamTextResponse(StreamTextResponse): """A text response that streams test data.""" _text: str - _cost: Cost + _usage: Usage _iter: Iterator[str] = field(init=False) _timestamp: datetime = field(default_factory=_utils.now_utc) _buffer: list[str] = field(default_factory=list, init=False) @@ -234,8 +234,8 @@ def get(self, *, final: bool = False) -> Iterable[str]: yield from self._buffer self._buffer.clear() - def cost(self) -> Cost: - return self._cost + def usage(self) -> Usage: + return self._usage def timestamp(self) -> datetime: return self._timestamp @@ -246,7 +246,7 @@ class TestStreamStructuredResponse(StreamStructuredResponse): """A structured response that streams test data.""" _structured_response: ModelResponse - _cost: Cost + _usage: Usage _iter: Iterator[None] = field(default_factory=lambda: iter([None])) _timestamp: datetime = field(default_factory=_utils.now_utc, init=False) @@ -256,8 +256,8 @@ async def __anext__(self) -> None: def get(self, *, final: bool = False) -> ModelResponse: return self._structured_response - def cost(self) -> Cost: - return self._cost + def usage(self) -> Usage: + return self._usage def timestamp(self) -> datetime: return self._timestamp diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 281b8f5f..26921e41 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -13,7 +13,7 @@ __all__ = ( 'ResultData', - 'Cost', + 'Usage', 'RunResult', 'StreamedRunResult', ) @@ -26,27 +26,27 @@ @dataclass -class Cost: - """Cost of a request or run. +class Usage: + """LLM usage associated to a request or run. - Responsibility for calculating costs is on the model used, PydanticAI simply sums the cost of requests. + Responsibility for calculating usage is on the model; PydanticAI simply sums the usage information across requests. - You'll need to look up the documentation of the model you're using to convent "token count" costs to monetary costs. + You'll need to look up the documentation of the model you're using to convert usage to monetary costs. """ request_tokens: int | None = None - """Tokens used in processing the request.""" + """Tokens used in processing requests.""" response_tokens: int | None = None - """Tokens used in generating the response.""" + """Tokens used in generating responses.""" total_tokens: int | None = None """Total tokens used in the whole run, should generally be equal to `request_tokens + response_tokens`.""" details: dict[str, int] | None = None """Any extra details returned by the model.""" - def __add__(self, other: Cost) -> Cost: - """Add two costs together. + def __add__(self, other: Usage) -> Usage: + """Add two Usages together. - This is provided so it's trivial to sum costs from multiple requests and runs. + This is provided so it's trivial to sum usage information from multiple requests and runs. """ counts: dict[str, int] = {} for f in 'request_tokens', 'response_tokens', 'total_tokens': @@ -61,7 +61,7 @@ def __add__(self, other: Cost) -> Cost: for key, value in other.details.items(): details[key] = details.get(key, 0) + value - return Cost(**counts, details=details or None) + return Usage(**counts, details=details or None) @dataclass @@ -95,7 +95,7 @@ def new_messages_json(self) -> bytes: return _messages.ModelMessagesTypeAdapter.dump_json(self.new_messages()) @abstractmethod - def cost(self) -> Cost: + def usage(self) -> Usage: raise NotImplementedError() @@ -105,19 +105,19 @@ class RunResult(_BaseRunResult[ResultData]): data: ResultData """Data from the final response in the run.""" - _cost: Cost + _usage: Usage - def cost(self) -> Cost: - """Return the cost of the whole run.""" - return self._cost + def usage(self) -> Usage: + """Return the usage of the whole run.""" + return self._usage @dataclass class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultData]): """Result of a streamed run that returns structured data via a tool call.""" - cost_so_far: Cost - """Cost of the run up until the last request.""" + usage_so_far: Usage + """Usage of the run up until the last request.""" _stream_response: models.EitherStreamedResponse _result_schema: _result.ResultSchema[ResultData] | None _deps: AgentDeps @@ -266,13 +266,13 @@ def is_structured(self) -> bool: """Return whether the stream response contains structured data (as opposed to text).""" return isinstance(self._stream_response, models.StreamStructuredResponse) - def cost(self) -> Cost: - """Return the cost of the whole run. + def usage(self) -> Usage: + """Return the usage of the whole run. !!! note - This won't return the full cost until the stream is finished. + This won't return the full usage until the stream is finished. """ - return self.cost_so_far + self._stream_response.cost() + return self.usage_so_far + self._stream_response.usage() def timestamp(self) -> datetime: """Get the timestamp of the response.""" diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index a21427f4..d73da5a7 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -20,7 +20,7 @@ ToolReturnPart, UserPromptPart, ) -from pydantic_ai.result import Cost +from pydantic_ai.result import Usage from ..conftest import IsNow, try_import @@ -31,7 +31,7 @@ Message as AnthropicMessage, TextBlock, ToolUseBlock, - Usage, + Usage as AnthropicUsage, ) from pydantic_ai.models.anthropic import AnthropicModel @@ -71,7 +71,7 @@ async def messages_create(self, *_args: Any, **_kwargs: Any) -> AnthropicMessage return response -def completion_message(content: list[ContentBlock], usage: Usage) -> AnthropicMessage: +def completion_message(content: list[ContentBlock], usage: AnthropicUsage) -> AnthropicMessage: return AnthropicMessage( id='123', content=content, @@ -84,21 +84,21 @@ def completion_message(content: list[ContentBlock], usage: Usage) -> AnthropicMe async def test_sync_request_text_response(allow_model_requests: None): - c = completion_message([TextBlock(text='world', type='text')], Usage(input_tokens=5, output_tokens=10)) + c = completion_message([TextBlock(text='world', type='text')], AnthropicUsage(input_tokens=5, output_tokens=10)) mock_client = MockAnthropic.create_mock(c) m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client) agent = Agent(m) result = await agent.run('hello') assert result.data == 'world' - assert result.cost() == snapshot(Cost(request_tokens=5, response_tokens=10, total_tokens=15)) + assert result.usage() == snapshot(Usage(request_tokens=5, response_tokens=10, total_tokens=15)) # reset the index so we get the same response again mock_client.index = 0 # type: ignore result = await agent.run('hello', message_history=result.new_messages()) assert result.data == 'world' - assert result.cost() == snapshot(Cost(request_tokens=5, response_tokens=10, total_tokens=15)) + assert result.usage() == snapshot(Usage(request_tokens=5, response_tokens=10, total_tokens=15)) assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), @@ -112,7 +112,7 @@ async def test_sync_request_text_response(allow_model_requests: None): async def test_async_request_text_response(allow_model_requests: None): c = completion_message( [TextBlock(text='world', type='text')], - usage=Usage(input_tokens=3, output_tokens=5), + usage=AnthropicUsage(input_tokens=3, output_tokens=5), ) mock_client = MockAnthropic.create_mock(c) m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client) @@ -120,13 +120,13 @@ async def test_async_request_text_response(allow_model_requests: None): result = await agent.run('hello') assert result.data == 'world' - assert result.cost() == snapshot(Cost(request_tokens=3, response_tokens=5, total_tokens=8)) + assert result.usage() == snapshot(Usage(request_tokens=3, response_tokens=5, total_tokens=8)) async def test_request_structured_response(allow_model_requests: None): c = completion_message( [ToolUseBlock(id='123', input={'response': [1, 2, 3]}, name='final_result', type='tool_use')], - usage=Usage(input_tokens=3, output_tokens=5), + usage=AnthropicUsage(input_tokens=3, output_tokens=5), ) mock_client = MockAnthropic.create_mock(c) m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client) @@ -165,15 +165,15 @@ async def test_request_tool_call(allow_model_requests: None): responses = [ completion_message( [ToolUseBlock(id='1', input={'loc_name': 'San Francisco'}, name='get_location', type='tool_use')], - usage=Usage(input_tokens=2, output_tokens=1), + usage=AnthropicUsage(input_tokens=2, output_tokens=1), ), completion_message( [ToolUseBlock(id='2', input={'loc_name': 'London'}, name='get_location', type='tool_use')], - usage=Usage(input_tokens=3, output_tokens=2), + usage=AnthropicUsage(input_tokens=3, output_tokens=2), ), completion_message( [TextBlock(text='final response', type='text')], - usage=Usage(input_tokens=3, output_tokens=5), + usage=AnthropicUsage(input_tokens=3, output_tokens=5), ), ] diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index d08d28ca..ef6f4b26 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -40,7 +40,7 @@ _GeminiTools, _GeminiUsageMetaData, ) -from pydantic_ai.result import Cost +from pydantic_ai.result import Usage from pydantic_ai.tools import ToolDefinition from ..conftest import ClientWithHandler, IsNow, TestEnv @@ -396,7 +396,7 @@ async def test_text_success(get_gemini_client: GetGeminiClient): ModelResponse.from_text(content='Hello world', timestamp=IsNow(tz=timezone.utc)), ] ) - assert result.cost() == snapshot(Cost(request_tokens=1, response_tokens=2, total_tokens=3)) + assert result.usage() == snapshot(Usage(request_tokens=1, response_tokens=2, total_tokens=3)) result = await agent.run('Hello', message_history=result.new_messages()) assert result.data == 'Hello world' @@ -517,7 +517,7 @@ async def get_location(loc_name: str) -> str: ModelResponse.from_text(content='final response', timestamp=IsNow(tz=timezone.utc)), ] ) - assert result.cost() == snapshot(Cost(request_tokens=3, response_tokens=6, total_tokens=9)) + assert result.usage() == snapshot(Usage(request_tokens=3, response_tokens=6, total_tokens=9)) async def test_unexpected_response(client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None): @@ -572,12 +572,12 @@ async def test_stream_text(get_gemini_client: GetGeminiClient): async with agent.run_stream('Hello') as result: chunks = [chunk async for chunk in result.stream(debounce_by=None)] assert chunks == snapshot(['Hello ', 'Hello world']) - assert result.cost() == snapshot(Cost(request_tokens=2, response_tokens=4, total_tokens=6)) + assert result.usage() == snapshot(Usage(request_tokens=2, response_tokens=4, total_tokens=6)) async with agent.run_stream('Hello') as result: chunks = [chunk async for chunk in result.stream_text(delta=True, debounce_by=None)] assert chunks == snapshot(['', 'Hello ', 'world']) - assert result.cost() == snapshot(Cost(request_tokens=2, response_tokens=4, total_tokens=6)) + assert result.usage() == snapshot(Usage(request_tokens=2, response_tokens=4, total_tokens=6)) async def test_stream_text_no_data(get_gemini_client: GetGeminiClient): @@ -609,7 +609,7 @@ async def test_stream_structured(get_gemini_client: GetGeminiClient): async with agent.run_stream('Hello') as result: chunks = [chunk async for chunk in result.stream(debounce_by=None)] assert chunks == snapshot([(1, 2), (1, 2), (1, 2)]) - assert result.cost() == snapshot(Cost(request_tokens=1, response_tokens=2, total_tokens=3)) + assert result.usage() == snapshot(Usage(request_tokens=1, response_tokens=2, total_tokens=3)) async def test_stream_structured_tool_calls(get_gemini_client: GetGeminiClient): @@ -652,7 +652,7 @@ async def bar(y: str) -> str: async with agent.run_stream('Hello') as result: response = await result.get_data() assert response == snapshot((1, 2)) - assert result.cost() == snapshot(Cost(request_tokens=3, response_tokens=6, total_tokens=9)) + assert result.usage() == snapshot(Usage(request_tokens=3, response_tokens=6, total_tokens=9)) assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), diff --git a/tests/models/test_groq.py b/tests/models/test_groq.py index e374c9f5..07983e2e 100644 --- a/tests/models/test_groq.py +++ b/tests/models/test_groq.py @@ -22,7 +22,7 @@ ToolReturnPart, UserPromptPart, ) -from pydantic_ai.result import Cost +from pydantic_ai.result import Usage from ..conftest import IsNow, try_import @@ -128,14 +128,14 @@ async def test_request_simple_success(allow_model_requests: None): result = await agent.run('hello') assert result.data == 'world' - assert result.cost() == snapshot(Cost()) + assert result.usage() == snapshot(Usage()) # reset the index so we get the same response again mock_client.index = 0 # type: ignore result = await agent.run('hello', message_history=result.new_messages()) assert result.data == 'world' - assert result.cost() == snapshot(Cost()) + assert result.usage() == snapshot(Usage()) assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), @@ -411,7 +411,7 @@ async def test_stream_structured(allow_model_requests: None): ) assert result.is_complete - assert result.cost() == snapshot(Cost()) + assert result.usage() == snapshot(Usage()) assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='', timestamp=IsNow(tz=timezone.utc))]), diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 86532dee..12b8e163 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -228,15 +228,15 @@ async def test_multiple_completions(allow_model_requests: None): # Then assert result.data == 'world' - assert result.cost().request_tokens == 1 - assert result.cost().response_tokens == 1 - assert result.cost().total_tokens == 1 + assert result.usage().request_tokens == 1 + assert result.usage().response_tokens == 1 + assert result.usage().total_tokens == 1 result = await agent.run('hello again', message_history=result.new_messages()) assert result.data == 'hello again' - assert result.cost().request_tokens == 1 - assert result.cost().response_tokens == 1 - assert result.cost().total_tokens == 1 + assert result.usage().request_tokens == 1 + assert result.usage().response_tokens == 1 + assert result.usage().total_tokens == 1 assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), @@ -266,21 +266,21 @@ async def test_three_completions(allow_model_requests: None): # Them assert result.data == 'world' - assert result.cost().request_tokens == 1 - assert result.cost().response_tokens == 1 - assert result.cost().total_tokens == 1 + assert result.usage().request_tokens == 1 + assert result.usage().response_tokens == 1 + assert result.usage().total_tokens == 1 result = await agent.run('hello again', message_history=result.all_messages()) assert result.data == 'hello again' - assert result.cost().request_tokens == 1 - assert result.cost().response_tokens == 1 - assert result.cost().total_tokens == 1 + assert result.usage().request_tokens == 1 + assert result.usage().response_tokens == 1 + assert result.usage().total_tokens == 1 result = await agent.run('final message', message_history=result.all_messages()) assert result.data == 'final message' - assert result.cost().request_tokens == 1 - assert result.cost().response_tokens == 1 - assert result.cost().total_tokens == 1 + assert result.usage().request_tokens == 1 + assert result.usage().response_tokens == 1 + assert result.usage().total_tokens == 1 assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), @@ -314,9 +314,9 @@ async def test_stream_text(allow_model_requests: None): ['hello ', 'hello world ', 'hello world welcome ', 'hello world welcome mistral'] ) assert result.is_complete - assert result.cost().request_tokens == 5 - assert result.cost().response_tokens == 5 - assert result.cost().total_tokens == 5 + assert result.usage().request_tokens == 5 + assert result.usage().response_tokens == 5 + assert result.usage().total_tokens == 5 async def test_stream_text_finish_reason(allow_model_requests: None): @@ -349,9 +349,9 @@ async def test_no_delta(allow_model_requests: None): assert not result.is_complete assert [c async for c in result.stream(debounce_by=None)] == snapshot(['hello ', 'hello world']) assert result.is_complete - assert result.cost().request_tokens == 3 - assert result.cost().response_tokens == 3 - assert result.cost().total_tokens == 3 + assert result.usage().request_tokens == 3 + assert result.usage().response_tokens == 3 + assert result.usage().total_tokens == 3 ##################### @@ -388,9 +388,9 @@ class CityLocation(BaseModel): # Then assert result.data == CityLocation(city='paris', country='france') - assert result.cost().request_tokens == 1 - assert result.cost().response_tokens == 2 - assert result.cost().total_tokens == 3 + assert result.usage().request_tokens == 1 + assert result.usage().response_tokens == 2 + assert result.usage().total_tokens == 3 assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='User prompt value', timestamp=IsNow(tz=timezone.utc))]), @@ -448,10 +448,10 @@ class CityLocation(BaseModel): # Then assert result.data == CityLocation(city='paris', country='france') - assert result.cost().request_tokens == 1 - assert result.cost().response_tokens == 1 - assert result.cost().total_tokens == 1 - assert result.cost().details is None + assert result.usage().request_tokens == 1 + assert result.usage().response_tokens == 1 + assert result.usage().total_tokens == 1 + assert result.usage().details is None assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='User prompt value', timestamp=IsNow(tz=timezone.utc))]), @@ -503,10 +503,10 @@ async def test_request_result_type_with_arguments_str_response(allow_model_reque # Then assert result.data == 42 - assert result.cost().request_tokens == 1 - assert result.cost().response_tokens == 1 - assert result.cost().total_tokens == 1 - assert result.cost().details is None + assert result.usage().request_tokens == 1 + assert result.usage().response_tokens == 1 + assert result.usage().total_tokens == 1 + assert result.usage().details is None assert result.all_messages() == snapshot( [ ModelRequest( @@ -642,12 +642,12 @@ class MyTypedDict(TypedDict, total=False): ] ) assert result.is_complete - assert result.cost().request_tokens == 10 - assert result.cost().response_tokens == 10 - assert result.cost().total_tokens == 10 + assert result.usage().request_tokens == 10 + assert result.usage().response_tokens == 10 + assert result.usage().total_tokens == 10 - # double check cost matches stream count - assert result.cost().response_tokens == len(stream) + # double check usage matches stream count + assert result.usage().response_tokens == len(stream) async def test_stream_result_type_primitif_dict(allow_model_requests: None): @@ -733,12 +733,12 @@ class MyTypedDict(TypedDict, total=False): ] ) assert result.is_complete - assert result.cost().request_tokens == 34 - assert result.cost().response_tokens == 34 - assert result.cost().total_tokens == 34 + assert result.usage().request_tokens == 34 + assert result.usage().response_tokens == 34 + assert result.usage().total_tokens == 34 - # double check cost matches stream count - assert result.cost().response_tokens == len(stream) + # double check usage matches stream count + assert result.usage().response_tokens == len(stream) async def test_stream_result_type_primitif_int(allow_model_requests: None): @@ -766,12 +766,12 @@ async def test_stream_result_type_primitif_int(allow_model_requests: None): v = [c async for c in result.stream(debounce_by=None)] assert v == snapshot([1, 1, 1]) assert result.is_complete - assert result.cost().request_tokens == 6 - assert result.cost().response_tokens == 6 - assert result.cost().total_tokens == 6 + assert result.usage().request_tokens == 6 + assert result.usage().response_tokens == 6 + assert result.usage().total_tokens == 6 - # double check cost matches stream count - assert result.cost().response_tokens == len(stream) + # double check usage matches stream count + assert result.usage().response_tokens == len(stream) async def test_stream_result_type_primitif_array(allow_model_requests: None): @@ -861,12 +861,12 @@ async def test_stream_result_type_primitif_array(allow_model_requests: None): ] ) assert result.is_complete - assert result.cost().request_tokens == 35 - assert result.cost().response_tokens == 35 - assert result.cost().total_tokens == 35 + assert result.usage().request_tokens == 35 + assert result.usage().response_tokens == 35 + assert result.usage().total_tokens == 35 - # double check cost matches stream count - assert result.cost().response_tokens == len(stream) + # double check usage matches stream count + assert result.usage().response_tokens == len(stream) async def test_stream_result_type_basemodel(allow_model_requests: None): @@ -950,12 +950,12 @@ class MyTypedBaseModel(BaseModel): ] ) assert result.is_complete - assert result.cost().request_tokens == 34 - assert result.cost().response_tokens == 34 - assert result.cost().total_tokens == 34 + assert result.usage().request_tokens == 34 + assert result.usage().response_tokens == 34 + assert result.usage().total_tokens == 34 - # double check cost matches stream count - assert result.cost().response_tokens == len(stream) + # double check usage matches stream count + assert result.usage().response_tokens == len(stream) ##################### @@ -1020,9 +1020,9 @@ async def get_location(loc_name: str) -> str: # Then assert result.data == 'final response' - assert result.cost().request_tokens == 6 - assert result.cost().response_tokens == 4 - assert result.cost().total_tokens == 10 + assert result.usage().request_tokens == 6 + assert result.usage().response_tokens == 4 + assert result.usage().total_tokens == 10 assert result.all_messages() == snapshot( [ ModelRequest( @@ -1156,9 +1156,9 @@ async def get_location(loc_name: str) -> str: # Then assert result.data == {'lat': 51, 'lng': 0} - assert result.cost().request_tokens == 7 - assert result.cost().response_tokens == 4 - assert result.cost().total_tokens == 12 + assert result.usage().request_tokens == 7 + assert result.usage().response_tokens == 4 + assert result.usage().total_tokens == 12 assert result.all_messages() == snapshot( [ ModelRequest( @@ -1292,12 +1292,12 @@ async def get_location(loc_name: str) -> str: assert v == snapshot([{'won': True}, {'won': True}]) assert result.is_complete assert result.timestamp() == datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc) - assert result.cost().request_tokens == 4 - assert result.cost().response_tokens == 4 - assert result.cost().total_tokens == 4 + assert result.usage().request_tokens == 4 + assert result.usage().response_tokens == 4 + assert result.usage().total_tokens == 4 - # double check cost matches stream count - assert result.cost().response_tokens == 4 + # double check usage matches stream count + assert result.usage().response_tokens == 4 assert result.all_messages() == snapshot( [ @@ -1395,12 +1395,12 @@ async def get_location(loc_name: str) -> str: assert v == snapshot(['final ', 'final response']) assert result.is_complete assert result.timestamp() == datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc) - assert result.cost().request_tokens == 6 - assert result.cost().response_tokens == 6 - assert result.cost().total_tokens == 6 + assert result.usage().request_tokens == 6 + assert result.usage().response_tokens == 6 + assert result.usage().total_tokens == 6 - # double check cost matches stream count - assert result.cost().response_tokens == 6 + # double check usage matches stream count + assert result.usage().response_tokens == 6 assert result.all_messages() == snapshot( [ @@ -1496,12 +1496,12 @@ async def get_location(loc_name: str) -> str: assert v == snapshot(['final ', 'final response']) assert result.is_complete assert result.timestamp() == datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc) - assert result.cost().request_tokens == 7 - assert result.cost().response_tokens == 7 - assert result.cost().total_tokens == 7 + assert result.usage().request_tokens == 7 + assert result.usage().response_tokens == 7 + assert result.usage().total_tokens == 7 - # double check cost matches stream count - assert result.cost().response_tokens == 7 + # double check usage matches stream count + assert result.usage().response_tokens == 7 assert result.all_messages() == snapshot( [ diff --git a/tests/models/test_model_function.py b/tests/models/test_model_function.py index ea11ee79..540d4259 100644 --- a/tests/models/test_model_function.py +++ b/tests/models/test_model_function.py @@ -23,7 +23,7 @@ ) from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel from pydantic_ai.models.test import TestModel -from pydantic_ai.result import Cost +from pydantic_ai.result import Usage from ..conftest import IsNow @@ -399,7 +399,7 @@ async def test_stream_text(): ModelResponse.from_text(content='hello world', timestamp=IsNow(tz=timezone.utc)), ] ) - assert result.cost() == snapshot(Cost()) + assert result.usage() == snapshot(Usage()) class Foo(BaseModel): @@ -420,7 +420,7 @@ async def stream_structured_function( agent = Agent(FunctionModel(stream_function=stream_structured_function), result_type=Foo) async with agent.run_stream('') as result: assert await result.get_data() == snapshot(Foo(x=1)) - assert result.cost() == snapshot(Cost()) + assert result.usage() == snapshot(Usage()) async def test_pass_neither(): diff --git a/tests/models/test_ollama.py b/tests/models/test_ollama.py index d0e5088e..bfebab60 100644 --- a/tests/models/test_ollama.py +++ b/tests/models/test_ollama.py @@ -11,7 +11,7 @@ ModelResponse, UserPromptPart, ) -from pydantic_ai.result import Cost +from pydantic_ai.result import Usage from ..conftest import IsNow, try_import @@ -44,14 +44,14 @@ async def test_request_simple_success(allow_model_requests: None): result = await agent.run('hello') assert result.data == 'world' - assert result.cost() == snapshot(Cost()) + assert result.usage() == snapshot(Usage()) # reset the index so we get the same response again mock_client.index = 0 # type: ignore result = await agent.run('hello', message_history=result.new_messages()) assert result.data == 'world' - assert result.cost() == snapshot(Cost()) + assert result.usage() == snapshot(Usage()) assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 50d0c577..094d3390 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -22,7 +22,7 @@ ToolReturnPart, UserPromptPart, ) -from pydantic_ai.result import Cost +from pydantic_ai.result import Usage from ..conftest import IsNow, try_import @@ -137,14 +137,14 @@ async def test_request_simple_success(allow_model_requests: None): result = await agent.run('hello') assert result.data == 'world' - assert result.cost() == snapshot(Cost()) + assert result.usage() == snapshot(Usage()) # reset the index so we get the same response again mock_client.index = 0 # type: ignore result = await agent.run('hello', message_history=result.new_messages()) assert result.data == 'world' - assert result.cost() == snapshot(Cost()) + assert result.usage() == snapshot(Usage()) assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), @@ -166,7 +166,7 @@ async def test_request_simple_usage(allow_model_requests: None): result = await agent.run('Hello') assert result.data == 'world' - assert result.cost() == snapshot(Cost(request_tokens=2, response_tokens=1, total_tokens=3)) + assert result.usage() == snapshot(Usage(request_tokens=2, response_tokens=1, total_tokens=3)) async def test_request_structured_response(allow_model_requests: None): @@ -322,8 +322,8 @@ async def get_location(loc_name: str) -> str: ModelResponse.from_text(content='final response', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc)), ] ) - assert result.cost() == snapshot( - Cost(request_tokens=5, response_tokens=3, total_tokens=9, details={'cached_tokens': 3}) + assert result.usage() == snapshot( + Usage(request_tokens=5, response_tokens=3, total_tokens=9, details={'cached_tokens': 3}) ) @@ -358,7 +358,7 @@ async def test_stream_text(allow_model_requests: None): assert not result.is_complete assert [c async for c in result.stream(debounce_by=None)] == snapshot(['hello ', 'hello world']) assert result.is_complete - assert result.cost() == snapshot(Cost(request_tokens=6, response_tokens=3, total_tokens=9)) + assert result.usage() == snapshot(Usage(request_tokens=6, response_tokens=3, total_tokens=9)) async def test_stream_text_finish_reason(allow_model_requests: None): @@ -425,9 +425,9 @@ async def test_stream_structured(allow_model_requests: None): ] ) assert result.is_complete - assert result.cost() == snapshot(Cost(request_tokens=20, response_tokens=10, total_tokens=30)) - # double check cost matches stream count - assert result.cost().response_tokens == len(stream) + assert result.usage() == snapshot(Usage(request_tokens=20, response_tokens=10, total_tokens=30)) + # double check usage matches stream count + assert result.usage().response_tokens == len(stream) async def test_stream_structured_finish_reason(allow_model_requests: None): @@ -482,4 +482,4 @@ async def test_no_delta(allow_model_requests: None): assert not result.is_complete assert [c async for c in result.stream(debounce_by=None)] == snapshot(['hello ', 'hello world']) assert result.is_complete - assert result.cost() == snapshot(Cost(request_tokens=6, response_tokens=3, total_tokens=9)) + assert result.usage() == snapshot(Usage(request_tokens=6, response_tokens=3, total_tokens=9)) diff --git a/tests/test_agent.py b/tests/test_agent.py index 047d9fa7..7b724e52 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -26,7 +26,7 @@ from pydantic_ai.models import cached_async_http_client from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.test import TestModel -from pydantic_ai.result import Cost, RunResult +from pydantic_ai.result import RunResult, Usage from pydantic_ai.tools import ToolDefinition from .conftest import IsNow, TestEnv @@ -505,7 +505,7 @@ async def ret_a(x: str) -> str: ], _new_message_index=4, data='{"ret_a":"a-apple"}', - _cost=Cost(), + _usage=Usage(), ) ) new_msg_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result2.all_messages()] @@ -547,7 +547,7 @@ async def ret_a(x: str) -> str: ], _new_message_index=4, data='{"ret_a":"a-apple"}', - _cost=Cost(), + _usage=Usage(), ) ) @@ -646,7 +646,7 @@ async def ret_a(x: str) -> str: ), ], _new_message_index=5, - _cost=Cost(), + _usage=Usage(), ) ) new_msg_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result2.all_messages()] diff --git a/tests/test_live.py b/tests/test_live.py index e1d45ebf..2f0cb774 100644 --- a/tests/test_live.py +++ b/tests/test_live.py @@ -90,9 +90,9 @@ async def test_text(http_client: httpx.AsyncClient, tmp_path: Path, get_model: G result = await agent.run('What is the capital of France?') print('Text response:', result.data) assert 'paris' in result.data.lower() - print('Text cost:', result.cost()) - cost = result.cost() - assert cost.total_tokens is not None and cost.total_tokens > 0 + print('Text usage:', result.usage()) + usage = result.usage() + assert usage.total_tokens is not None and usage.total_tokens > 0 stream_params = [p for p in params if p.id != 'anthropic'] @@ -105,10 +105,10 @@ async def test_stream(http_client: httpx.AsyncClient, tmp_path: Path, get_model: data = await result.get_data() print('Stream response:', data) assert 'paris' in data.lower() - print('Stream cost:', result.cost()) - cost = result.cost() + print('Stream usage:', result.usage()) + usage = result.usage() if get_model.__name__ != 'ollama': - assert cost.total_tokens is not None and cost.total_tokens > 0 + assert usage.total_tokens is not None and usage.total_tokens > 0 class MyModel(BaseModel): @@ -124,6 +124,6 @@ async def test_structured(http_client: httpx.AsyncClient, tmp_path: Path, get_mo result = await agent.run('What is the capital of the UK?') print('Structured response:', result.data) assert result.data.city.lower() == 'london' - print('Structured cost:', result.cost()) - cost = result.cost() - assert cost.total_tokens is not None and cost.total_tokens > 0 + print('Structured usage:', result.usage()) + usage = result.usage() + assert usage.total_tokens is not None and usage.total_tokens > 0 diff --git a/tests/test_logfire.py b/tests/test_logfire.py index 2d1d1cdb..89505d0b 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -161,7 +161,7 @@ async def my_ret(x: int) -> str: }, ] ), - 'cost': IsJson({'request_tokens': None, 'response_tokens': None, 'total_tokens': None, 'details': None}), + 'usage': IsJson({'request_tokens': None, 'response_tokens': None, 'total_tokens': None, 'details': None}), 'logfire.json_schema': IsJson( { 'type': 'object', @@ -254,7 +254,7 @@ async def my_ret(x: int) -> str: }, ], }, - 'cost': {'type': 'object', 'title': 'Cost', 'x-python-datatype': 'dataclass'}, + 'usage': {'type': 'object', 'title': 'Usage', 'x-python-datatype': 'dataclass'}, }, } ), diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 667f8fe8..7e593f9b 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -22,7 +22,7 @@ ) from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel from pydantic_ai.models.test import TestModel -from pydantic_ai.result import Cost +from pydantic_ai.result import Usage from .conftest import IsNow @@ -58,7 +58,7 @@ async def ret_a(x: str) -> str: response = await result.get_data() assert response == snapshot('{"ret_a":"a-apple"}') assert result.is_complete - assert result.cost() == snapshot(Cost()) + assert result.usage() == snapshot(Usage()) assert result.timestamp() == IsNow(tz=timezone.utc) assert result.all_messages() == snapshot( [