from__future__importannotationsas_annotationsfromdataclassesimportdataclass,fieldfrompathlibimportPathimportlogfirefromgroqimportBaseModelfrompydantic_graphimport(BaseNode,End,Graph,GraphRunContext,)frompydantic_graph.persistence.fileimportFileStatePersistencefrompydantic_aiimportAgentfrompydantic_ai.format_as_xmlimportformat_as_xmlfrompydantic_ai.messagesimportModelMessage# 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configuredlogfire.configure(send_to_logfire='if-token-present')ask_agent=Agent('openai:gpt-4o',result_type=str,instrument=True)@dataclassclassQuestionState:question:str|None=Noneask_agent_messages:list[ModelMessage]=field(default_factory=list)evaluate_agent_messages:list[ModelMessage]=field(default_factory=list)@dataclassclassAsk(BaseNode[QuestionState]):asyncdefrun(self,ctx:GraphRunContext[QuestionState])->Answer:result=awaitask_agent.run('Ask a simple question with a single correct answer.',message_history=ctx.state.ask_agent_messages,)ctx.state.ask_agent_messages+=result.all_messages()ctx.state.question=result.datareturnAnswer(result.data)@dataclassclassAnswer(BaseNode[QuestionState]):question:strasyncdefrun(self,ctx:GraphRunContext[QuestionState])->Evaluate:answer=input(f'{self.question}: ')returnEvaluate(answer)classEvaluationResult(BaseModel,use_attribute_docstrings=True):correct:bool"""Whether the answer is correct."""comment:str"""Comment on the answer, reprimand the user if the answer is wrong."""evaluate_agent=Agent('openai:gpt-4o',result_type=EvaluationResult,system_prompt='Given a question and answer, evaluate if the answer is correct.',)@dataclassclassEvaluate(BaseNode[QuestionState,None,str]):answer:strasyncdefrun(self,ctx:GraphRunContext[QuestionState],)->End[str]|Reprimand:assertctx.state.questionisnotNoneresult=awaitevaluate_agent.run(format_as_xml({'question':ctx.state.question,'answer':self.answer}),message_history=ctx.state.evaluate_agent_messages,)ctx.state.evaluate_agent_messages+=result.all_messages()ifresult.data.correct:returnEnd(result.data.comment)else:returnReprimand(result.data.comment)@dataclassclassReprimand(BaseNode[QuestionState]):comment:strasyncdefrun(self,ctx:GraphRunContext[QuestionState])->Ask:print(f'Comment: {self.comment}')ctx.state.question=NonereturnAsk()question_graph=Graph(nodes=(Ask,Answer,Evaluate,Reprimand),state_type=QuestionState)asyncdefrun_as_continuous():state=QuestionState()node=Ask()end=awaitquestion_graph.run(node,state=state)print('END:',end.output)asyncdefrun_as_cli(answer:str|None):persistence=FileStatePersistence(Path('question_graph.json'))persistence.set_graph_types(question_graph)ifsnapshot:=awaitpersistence.load_next():state=snapshot.stateassertanswerisnotNone,('answer required, usage "uv run -m pydantic_ai_examples.question_graph cli <answer>"')node=Evaluate(answer)else:state=QuestionState()node=Ask()# debug(state, node)asyncwithquestion_graph.iter(node,state=state,persistence=persistence)asrun:whileTrue:node=awaitrun.next()ifisinstance(node,End):print('END:',node.data)history=awaitpersistence.load_all()print('history:','\n'.join(str(e.node)foreinhistory),sep='\n')print('Finished!')breakelifisinstance(node,Answer):print(node.question)break# otherwise just continueif__name__=='__main__':importasyncioimportsystry:sub_command=sys.argv[1]assertsub_commandin('continuous','cli','mermaid')except(IndexError,AssertionError):print('Usage:\n'' uv run -m pydantic_ai_examples.question_graph mermaid\n''or:\n'' uv run -m pydantic_ai_examples.question_graph continuous\n''or:\n'' uv run -m pydantic_ai_examples.question_graph cli [answer]',file=sys.stderr,)sys.exit(1)ifsub_command=='mermaid':print(question_graph.mermaid_code(start_node=Ask))elifsub_command=='continuous':asyncio.run(run_as_continuous())else:a=sys.argv[2]iflen(sys.argv)>2elseNoneasyncio.run(run_as_cli(a))
The mermaid diagram generated in this example looks like this: