python-mpydantic_ai_examples.sql_gen"find me errors"
uvrun-mpydantic_ai_examples.sql_gen"find me errors"
This model uses gemini-1.5-flash by default since Gemini is good at single shot queries of this kind.
Example Code
sql_gen.py
importasyncioimportsysfromcollections.abcimportAsyncGeneratorfromcontextlibimportasynccontextmanagerfromdataclassesimportdataclassfromdatetimeimportdatefromtypingimportAnnotated,Any,Unionimportasyncpgimportlogfirefromannotated_typesimportMinLenfromdevtoolsimportdebugfrompydanticimportBaseModel,Fieldfromtyping_extensionsimportTypeAliasfrompydantic_aiimportAgent,ModelRetry,RunContextfrompydantic_ai.format_as_xmlimportformat_as_xml# 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configuredlogfire.configure(send_to_logfire='if-token-present')logfire.instrument_asyncpg()DB_SCHEMA="""CREATE TABLE records ( created_at timestamptz, start_timestamp timestamptz, end_timestamp timestamptz, trace_id text, span_id text, parent_span_id text, level log_level, span_name text, message text, attributes_json_schema text, attributes jsonb, tags text[], is_exception boolean, otel_status_message text, service_name text);"""SQL_EXAMPLES=[{'request':'show me records where foobar is false','response':"SELECT * FROM records WHERE attributes->>'foobar' = false",},{'request':'show me records where attributes include the key "foobar"','response':"SELECT * FROM records WHERE attributes ? 'foobar'",},{'request':'show me records from yesterday','response':"SELECT * FROM records WHERE start_timestamp::date > CURRENT_TIMESTAMP - INTERVAL '1 day'",},{'request':'show me error records with the tag "foobar"','response':"SELECT * FROM records WHERE level = 'error' and 'foobar' = ANY(tags)",},]@dataclassclassDeps:conn:asyncpg.ConnectionclassSuccess(BaseModel):"""Response when SQL could be successfully generated."""sql_query:Annotated[str,MinLen(1)]explanation:str=Field('',description='Explanation of the SQL query, as markdown')classInvalidRequest(BaseModel):"""Response the user input didn't include enough information to generate SQL."""error_message:strResponse:TypeAlias=Union[Success,InvalidRequest]agent:Agent[Deps,Response]=Agent('google-gla:gemini-1.5-flash',# Type ignore while we wait for PEP-0747, nonetheless unions will work fine everywhere elseresult_type=Response,# type: ignoredeps_type=Deps,instrument=True,)@agent.system_promptasyncdefsystem_prompt()->str:returnf"""\Given the following PostgreSQL table of records, your job is towrite a SQL query that suits the user's request.Database schema:{DB_SCHEMA}today's date = {date.today()}{format_as_xml(SQL_EXAMPLES)}"""@agent.result_validatorasyncdefvalidate_result(ctx:RunContext[Deps],result:Response)->Response:ifisinstance(result,InvalidRequest):returnresult# gemini often adds extraneous backslashes to SQLresult.sql_query=result.sql_query.replace('\\','')ifnotresult.sql_query.upper().startswith('SELECT'):raiseModelRetry('Please create a SELECT query')try:awaitctx.deps.conn.execute(f'EXPLAIN {result.sql_query}')exceptasyncpg.exceptions.PostgresErrorase:raiseModelRetry(f'Invalid query: {e}')fromeelse:returnresultasyncdefmain():iflen(sys.argv)==1:prompt='show me logs from yesterday, with level "error"'else:prompt=sys.argv[1]asyncwithdatabase_connect('postgresql://postgres:postgres@localhost:54320','pydantic_ai_sql_gen')asconn:deps=Deps(conn)result=awaitagent.run(prompt,deps=deps)debug(result.data)# pyright: reportUnknownMemberType=false# pyright: reportUnknownVariableType=false@asynccontextmanagerasyncdefdatabase_connect(server_dsn:str,database:str)->AsyncGenerator[Any,None]:withlogfire.span('check and create DB'):conn=awaitasyncpg.connect(server_dsn)try:db_exists=awaitconn.fetchval('SELECT 1 FROM pg_database WHERE datname = $1',database)ifnotdb_exists:awaitconn.execute(f'CREATE DATABASE {database}')finally:awaitconn.close()conn=awaitasyncpg.connect(f'{server_dsn}/{database}')try:withlogfire.span('create schema'):asyncwithconn.transaction():ifnotdb_exists:awaitconn.execute("CREATE TYPE log_level AS ENUM ('debug', 'info', 'warning', 'error', 'critical')")awaitconn.execute(DB_SCHEMA)yieldconnfinally:awaitconn.close()if__name__=='__main__':asyncio.run(main())