Skip to content

Commit

Permalink
Rename Cost to Usage
Browse files Browse the repository at this point in the history
  • Loading branch information
dmontagu committed Dec 18, 2024
1 parent 4553b3c commit 62bef6a
Show file tree
Hide file tree
Showing 27 changed files with 289 additions and 289 deletions.
8 changes: 4 additions & 4 deletions docs/api/models/ollama.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ agent = Agent('ollama:llama3.2', result_type=CityLocation)
result = agent.run_sync('Where were the olympics held in 2012?')
print(result.data)
#> city='London' country='United Kingdom'
print(result.cost())
#> Cost(request_tokens=57, response_tokens=8, total_tokens=65, details=None)
print(result.usage())
#> Usage(request_tokens=57, response_tokens=8, total_tokens=65, details=None)
```

## Example using a remote server
Expand All @@ -59,8 +59,8 @@ agent = Agent(model=ollama_model, result_type=CityLocation)
result = agent.run_sync('Where were the olympics held in 2012?')
print(result.data)
#> city='London' country='United Kingdom'
print(result.cost())
#> Cost(request_tokens=57, response_tokens=8, total_tokens=65, details=None)
print(result.usage())
#> Usage(request_tokens=57, response_tokens=8, total_tokens=65, details=None)
```

1. The name of the model running on the remote server
Expand Down
2 changes: 1 addition & 1 deletion docs/api/result.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
- ResultData
- RunResult
- StreamedRunResult
- Cost
- Usage
6 changes: 3 additions & 3 deletions docs/results.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Results are the final values returned from [running an agent](agents.md#running-agents).
The result values are wrapped in [`RunResult`][pydantic_ai.result.RunResult] and [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] so you can access other data like [cost][pydantic_ai.result.Cost] of the run and [message history](message-history.md#accessing-messages-from-results)
The result values are wrapped in [`RunResult`][pydantic_ai.result.RunResult] and [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] so you can access other data like [usage][pydantic_ai.result.Usage] of the run and [message history](message-history.md#accessing-messages-from-results)

Both `RunResult` and `StreamedRunResult` are generic in the data they wrap, so typing information about the data returned by the agent is preserved.

Expand All @@ -18,8 +18,8 @@ agent = Agent('gemini-1.5-flash', result_type=CityLocation)
result = agent.run_sync('Where were the olympics held in 2012?')
print(result.data)
#> city='London' country='United Kingdom'
print(result.cost())
#> Cost(request_tokens=57, response_tokens=8, total_tokens=65, details=None)
print(result.usage())
#> Usage(request_tokens=57, response_tokens=8, total_tokens=65, details=None)
```

_(This example is complete, it can be run "as is")_
Expand Down
2 changes: 1 addition & 1 deletion docs/testing-evals.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Unless you're really sure you know better, you'll probably want to follow roughl
* Use [`pytest`](https://docs.pytest.org/en/stable/) as your test harness
* If you find yourself typing out long assertions, use [inline-snapshot](https://15r10nk.github.io/inline-snapshot/latest/)
* Similarly, [dirty-equals](https://dirty-equals.helpmanual.io/latest/) can be useful for comparing large data structures
* Use [`TestModel`][pydantic_ai.models.test.TestModel] or [`FunctionModel`][pydantic_ai.models.function.FunctionModel] in place of your actual model to avoid the cost, latency and variability of real LLM calls
* Use [`TestModel`][pydantic_ai.models.test.TestModel] or [`FunctionModel`][pydantic_ai.models.function.FunctionModel] in place of your actual model to avoid the usage, latency and variability of real LLM calls
* Use [`Agent.override`][pydantic_ai.agent.Agent.override] to replace your model inside your application logic
* Set [`ALLOW_MODEL_REQUESTS=False`][pydantic_ai.models.ALLOW_MODEL_REQUESTS] globally to block any requests from being made to non-test models accidentally

Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_examples/pydantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ class MyModel(BaseModel):
if __name__ == '__main__':
result = agent.run_sync('The windy city in the US of A.')
print(result.data)
print(result.cost())
print(result.usage())
2 changes: 1 addition & 1 deletion pydantic_ai_examples/stream_markdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ async def main():
async with agent.run_stream(prompt, model=model) as result:
async for message in result.stream():
live.update(Markdown(message))
console.log(result.cost())
console.log(result.usage())
else:
console.log(f'{model} requires {env_var} to be set.')

Expand Down
22 changes: 11 additions & 11 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ async def run(
for tool in self._function_tools.values():
tool.current_retry = 0

cost = result.Cost()
usage = result.Usage()

model_settings = merge_model_settings(self.model_settings, model_settings)

Expand All @@ -248,12 +248,12 @@ async def run(
agent_model = await self._prepare_model(model_used, deps, messages)

with _logfire.span('model request', run_step=run_step) as model_req_span:
model_response, request_cost = await agent_model.request(messages, model_settings)
model_response, request_usage = await agent_model.request(messages, model_settings)
model_req_span.set_attribute('response', model_response)
model_req_span.set_attribute('cost', request_cost)
model_req_span.set_attribute('usage', request_usage)

messages.append(model_response)
cost += request_cost
usage += request_usage

with _logfire.span('handle model response', run_step=run_step) as handle_span:
final_result, tool_responses = await self._handle_model_response(model_response, deps, messages)
Expand All @@ -266,10 +266,10 @@ async def run(
if final_result is not None:
result_data = final_result.data
run_span.set_attribute('all_messages', messages)
run_span.set_attribute('cost', cost)
run_span.set_attribute('usage', usage)
handle_span.set_attribute('result', result_data)
handle_span.message = 'handle model response -> final result'
return result.RunResult(messages, new_message_index, result_data, cost)
return result.RunResult(messages, new_message_index, result_data, usage)
else:
# continue the conversation
handle_span.set_attribute('tool_responses', tool_responses)
Expand Down Expand Up @@ -385,7 +385,7 @@ async def main():
for tool in self._function_tools.values():
tool.current_retry = 0

cost = result.Cost()
usage = result.Usage()
model_settings = merge_model_settings(self.model_settings, model_settings)

run_step = 0
Expand Down Expand Up @@ -434,7 +434,7 @@ async def on_complete():
yield result.StreamedRunResult(
messages,
new_message_index,
cost,
usage,
result_stream,
self._result_schema,
deps,
Expand All @@ -455,8 +455,8 @@ async def on_complete():
handle_span.set_attribute('tool_responses', tool_responses)
tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
handle_span.message = f'handle model response -> {tool_responses_str}'
# the model_response should have been fully streamed by now, we can add it's cost
cost += model_response.cost()
# the model_response should have been fully streamed by now, we can add its usage
usage += model_response.usage()

@contextmanager
def override(
Expand Down Expand Up @@ -988,7 +988,7 @@ async def _handle_streamed_model_response(
response = _messages.RetryPromptPart(
content='Plain text responses are not permitted, please call one of the functions instead.',
)
# stream the response, so cost is correct
# stream the response, so usage is correct
async for _ in model_response:
pass

Expand Down
18 changes: 9 additions & 9 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ..settings import ModelSettings

if TYPE_CHECKING:
from ..result import Cost
from ..result import Usage
from ..tools import ToolDefinition


Expand Down Expand Up @@ -122,7 +122,7 @@ class AgentModel(ABC):
@abstractmethod
async def request(
self, messages: list[ModelMessage], model_settings: ModelSettings | None
) -> tuple[ModelResponse, Cost]:
) -> tuple[ModelResponse, Usage]:
"""Make a request to the model."""
raise NotImplementedError()

Expand Down Expand Up @@ -164,10 +164,10 @@ def get(self, *, final: bool = False) -> Iterable[str]:
raise NotImplementedError()

@abstractmethod
def cost(self) -> Cost:
"""Return the cost of the request.
def usage(self) -> Usage:
"""Return the usage of the request.
NOTE: this won't return the ful cost until the stream is finished.
NOTE: this won't return the full usage until the stream is finished.
"""
raise NotImplementedError()

Expand Down Expand Up @@ -205,10 +205,10 @@ def get(self, *, final: bool = False) -> ModelResponse:
raise NotImplementedError()

@abstractmethod
def cost(self) -> Cost:
"""Get the cost of the request.
def usage(self) -> Usage:
"""Get the usage of the request.
NOTE: this won't return the full cost until the stream is finished.
NOTE: this won't return the full usage until the stream is finished.
"""
raise NotImplementedError()

Expand All @@ -235,7 +235,7 @@ def timestamp(self) -> datetime:
def check_allow_model_requests() -> None:
"""Check if model requests are allowed.
If you're defining your own models that have cost or latency associated with their use, you should call this in
If you're defining your own models that have costs or latency associated with their use, you should call this in
[`Model.agent_model`][pydantic_ai.models.Model.agent_model].
Raises:
Expand Down
10 changes: 5 additions & 5 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ class AnthropicAgentModel(AgentModel):

async def request(
self, messages: list[ModelMessage], model_settings: ModelSettings | None
) -> tuple[ModelResponse, result.Cost]:
) -> tuple[ModelResponse, result.Usage]:
response = await self._messages_create(messages, False, model_settings)
return self._process_response(response), _map_cost(response)
return self._process_response(response), _map_usage(response)

@asynccontextmanager
async def request_stream(
Expand Down Expand Up @@ -315,7 +315,7 @@ def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
)


def _map_cost(message: AnthropicMessage | RawMessageStreamEvent) -> result.Cost:
def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> result.Usage:
if isinstance(message, AnthropicMessage):
usage = message.usage
else:
Expand All @@ -332,11 +332,11 @@ def _map_cost(message: AnthropicMessage | RawMessageStreamEvent) -> result.Cost:
usage = None

if usage is None:
return result.Cost()
return result.Usage()

request_tokens = getattr(usage, 'input_tokens', None)

return result.Cost(
return result.Usage(
# Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence this getattr
request_tokens=request_tokens,
response_tokens=usage.output_tokens,
Expand Down
30 changes: 15 additions & 15 deletions pydantic_ai_slim/pydantic_ai/models/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class FunctionAgentModel(AgentModel):

async def request(
self, messages: list[ModelMessage], model_settings: ModelSettings | None
) -> tuple[ModelResponse, result.Cost]:
) -> tuple[ModelResponse, result.Usage]:
agent_info = replace(self.agent_info, model_settings=model_settings)

assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
Expand All @@ -155,7 +155,7 @@ async def request(
assert isinstance(response_, ModelResponse), response_
response = response_
# TODO is `messages` right here? Should it just be new messages?
return response, _estimate_cost(chain(messages, [response]))
return response, _estimate_usage(chain(messages, [response]))

@asynccontextmanager
async def request_stream(
Expand Down Expand Up @@ -198,8 +198,8 @@ def get(self, *, final: bool = False) -> Iterable[str]:
yield from self._buffer
self._buffer.clear()

def cost(self) -> result.Cost:
return result.Cost()
def usage(self) -> result.Usage:
return result.Usage()

def timestamp(self) -> datetime:
return self._timestamp
Expand Down Expand Up @@ -236,15 +236,15 @@ def get(self, *, final: bool = False) -> ModelResponse:

return ModelResponse(calls, timestamp=self._timestamp)

def cost(self) -> result.Cost:
return result.Cost()
def usage(self) -> result.Usage:
return result.Usage()

def timestamp(self) -> datetime:
return self._timestamp


def _estimate_cost(messages: Iterable[ModelMessage]) -> result.Cost:
"""Very rough guesstimate of the number of tokens associate with a series of messages.
def _estimate_usage(messages: Iterable[ModelMessage]) -> result.Usage:
"""Very rough guesstimate of the token usage associated with a series of messages.
This is designed to be used solely to give plausible numbers for testing!
"""
Expand All @@ -255,32 +255,32 @@ def _estimate_cost(messages: Iterable[ModelMessage]) -> result.Cost:
if isinstance(message, ModelRequest):
for part in message.parts:
if isinstance(part, (SystemPromptPart, UserPromptPart)):
request_tokens += _string_cost(part.content)
request_tokens += _string_usage(part.content)
elif isinstance(part, ToolReturnPart):
request_tokens += _string_cost(part.model_response_str())
request_tokens += _string_usage(part.model_response_str())
elif isinstance(part, RetryPromptPart):
request_tokens += _string_cost(part.model_response())
request_tokens += _string_usage(part.model_response())
else:
assert_never(part)
elif isinstance(message, ModelResponse):
for part in message.parts:
if isinstance(part, TextPart):
response_tokens += _string_cost(part.content)
response_tokens += _string_usage(part.content)
elif isinstance(part, ToolCallPart):
call = part
if isinstance(call.args, ArgsJson):
args_str = call.args.args_json
else:
args_str = pydantic_core.to_json(call.args.args_dict).decode()
response_tokens += 1 + _string_cost(args_str)
response_tokens += 1 + _string_usage(args_str)
else:
assert_never(part)
else:
assert_never(message)
return result.Cost(
return result.Usage(
request_tokens=request_tokens, response_tokens=response_tokens, total_tokens=request_tokens + response_tokens
)


def _string_cost(content: str) -> int:
def _string_usage(content: str) -> int:
return len(re.split(r'[\s",.:]+', content))
Loading

0 comments on commit 62bef6a

Please sign in to comment.