Utility model for quickly testing apps built with PydanticAI.
Here's a minimal example:
test_model_usage.py
frompydantic_aiimportAgentfrompydantic_ai.models.testimportTestModelmy_agent=Agent('openai:gpt-4o',system_prompt='...')asyncdeftest_my_agent():"""Unit test for my_agent, to be run by pytest."""m=TestModel()withmy_agent.override(model=m):result=awaitmy_agent.run('Testing my agent...')assertresult.data=='success (no tool calls)'assertm.last_model_request_parameters.function_tools==[]
@dataclassclassTestModel(Model):"""A model specifically for testing purposes. This will (by default) call all tools in the agent, then return a tool response if possible, otherwise a plain response. How useful this model is will vary significantly. Apart from `__init__` derived by the `dataclass` decorator, all methods are private or match those of the base class. """# NOTE: Avoid test discovery by pytest.__test__=Falsecall_tools:list[str]|Literal['all']='all'"""List of tools to call. If `'all'`, all tools will be called."""custom_result_text:str|None=None"""If set, this text is returned as the final result."""custom_result_args:Any|None=None"""If set, these args will be passed to the result tool."""seed:int=0"""Seed for generating random data."""last_model_request_parameters:ModelRequestParameters|None=field(default=None,init=False)"""The last ModelRequestParameters passed to the model in a request. The ModelRequestParameters contains information about the function and result tools available during request handling. This is set when a request is made, so will reflect the function tools from the last step of the last run. """_model_name:str=field(default='test',repr=False)_system:str=field(default='test',repr=False)asyncdefrequest(self,messages:list[ModelMessage],model_settings:ModelSettings|None,model_request_parameters:ModelRequestParameters,)->tuple[ModelResponse,Usage]:self.last_model_request_parameters=model_request_parametersmodel_response=self._request(messages,model_settings,model_request_parameters)usage=_estimate_usage([*messages,model_response])returnmodel_response,usage@asynccontextmanagerasyncdefrequest_stream(self,messages:list[ModelMessage],model_settings:ModelSettings|None,model_request_parameters:ModelRequestParameters,)->AsyncIterator[StreamedResponse]:self.last_model_request_parameters=model_request_parametersmodel_response=self._request(messages,model_settings,model_request_parameters)yieldTestStreamedResponse(_model_name=self._model_name,_structured_response=model_response,_messages=messages)@propertydefmodel_name(self)->str:"""The model name."""returnself._model_name@propertydefsystem(self)->str:"""The system / model provider."""returnself._systemdefgen_tool_args(self,tool_def:ToolDefinition)->Any:return_JsonSchemaTestData(tool_def.parameters_json_schema,self.seed).generate()def_get_tool_calls(self,model_request_parameters:ModelRequestParameters)->list[tuple[str,ToolDefinition]]:ifself.call_tools=='all':return[(r.name,r)forrinmodel_request_parameters.function_tools]else:function_tools_lookup={t.name:tfortinmodel_request_parameters.function_tools}tools_to_call=(function_tools_lookup[name]fornameinself.call_tools)return[(r.name,r)forrintools_to_call]def_get_result(self,model_request_parameters:ModelRequestParameters)->_TextResult|_FunctionToolResult:ifself.custom_result_textisnotNone:assertmodel_request_parameters.allow_text_result,('Plain response not allowed, but `custom_result_text` is set.')assertself.custom_result_argsisNone,'Cannot set both `custom_result_text` and `custom_result_args`.'return_TextResult(self.custom_result_text)elifself.custom_result_argsisnotNone:assertmodel_request_parameters.result_toolsisnotNone,('No result tools provided, but `custom_result_args` is set.')result_tool=model_request_parameters.result_tools[0]ifk:=result_tool.outer_typed_dict_key:return_FunctionToolResult({k:self.custom_result_args})else:return_FunctionToolResult(self.custom_result_args)elifmodel_request_parameters.allow_text_result:return_TextResult(None)elifmodel_request_parameters.result_tools:return_FunctionToolResult(None)else:return_TextResult(None)def_request(self,messages:list[ModelMessage],model_settings:ModelSettings|None,model_request_parameters:ModelRequestParameters,)->ModelResponse:tool_calls=self._get_tool_calls(model_request_parameters)result=self._get_result(model_request_parameters)result_tools=model_request_parameters.result_tools# if there are tools, the first thing we want to do is call all of themiftool_callsandnotany(isinstance(m,ModelResponse)forminmessages):returnModelResponse(parts=[ToolCallPart(name,self.gen_tool_args(args))forname,argsintool_calls],model_name=self._model_name,)ifmessages:last_message=messages[-1]assertisinstance(last_message,ModelRequest),'Expected last message to be a `ModelRequest`.'# check if there are any retry prompts, if so retry themnew_retry_names={p.tool_nameforpinlast_message.partsifisinstance(p,RetryPromptPart)}ifnew_retry_names:# Handle retries for both function tools and result tools# Check function tools firstretry_parts:list[ModelResponsePart]=[ToolCallPart(name,self.gen_tool_args(args))forname,argsintool_callsifnameinnew_retry_names]# Check result toolsifresult_tools:retry_parts.extend([ToolCallPart(tool.name,result.valueifisinstance(result,_FunctionToolResult)andresult.valueisnotNoneelseself.gen_tool_args(tool),)fortoolinresult_toolsiftool.nameinnew_retry_names])returnModelResponse(parts=retry_parts,model_name=self._model_name)ifisinstance(result,_TextResult):if(response_text:=result.value)isNone:# build up details of tool responsesoutput:dict[str,Any]={}formessageinmessages:ifisinstance(message,ModelRequest):forpartinmessage.parts:ifisinstance(part,ToolReturnPart):output[part.tool_name]=part.contentifoutput:returnModelResponse(parts=[TextPart(pydantic_core.to_json(output).decode())],model_name=self._model_name)else:returnModelResponse(parts=[TextPart('success (no tool calls)')],model_name=self._model_name)else:returnModelResponse(parts=[TextPart(response_text)],model_name=self._model_name)else:assertresult_tools,'No result tools provided'custom_result_args=result.valueresult_tool=result_tools[self.seed%len(result_tools)]ifcustom_result_argsisnotNone:returnModelResponse(parts=[ToolCallPart(result_tool.name,custom_result_args)],model_name=self._model_name)else:response_args=self.gen_tool_args(result_tool)returnModelResponse(parts=[ToolCallPart(result_tool.name,response_args)],model_name=self._model_name)
@dataclassclassTestStreamedResponse(StreamedResponse):"""A structured response that streams test data."""_model_name:str_structured_response:ModelResponse_messages:InitVar[Iterable[ModelMessage]]_timestamp:datetime=field(default_factory=_utils.now_utc,init=False)def__post_init__(self,_messages:Iterable[ModelMessage]):self._usage=_estimate_usage(_messages)asyncdef_get_event_iterator(self)->AsyncIterator[ModelResponseStreamEvent]:fori,partinenumerate(self._structured_response.parts):ifisinstance(part,TextPart):text=part.content*words,last_word=text.split(' ')words=[f'{word} 'forwordinwords]words.append(last_word)iflen(words)==1andlen(text)>2:mid=len(text)//2words=[text[:mid],text[mid:]]self._usage+=_get_string_usage('')yieldself._parts_manager.handle_text_delta(vendor_part_id=i,content='')forwordinwords:self._usage+=_get_string_usage(word)yieldself._parts_manager.handle_text_delta(vendor_part_id=i,content=word)else:yieldself._parts_manager.handle_tool_call_part(vendor_part_id=i,tool_name=part.tool_name,args=part.args,tool_call_id=part.tool_call_id)@propertydefmodel_name(self)->str:"""Get the model name of the response."""returnself._model_name@propertydeftimestamp(self)->datetime:"""Get the timestamp of the response."""returnself._timestamp