Skip to content

Commit

Permalink
feat(text_to_sql): anthropic working with tool use. other providers s…
Browse files Browse the repository at this point in the history
…tubbed out
  • Loading branch information
jgpruitt committed Jan 24, 2025
1 parent 2288058 commit 028b7ac
Show file tree
Hide file tree
Showing 8 changed files with 294 additions and 49 deletions.
12 changes: 11 additions & 1 deletion projects/extension/sql/idempotent/900-semantic-catalog-init.sql
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@ create or replace function ai.create_semantic_catalog
, scheduling pg_catalog.jsonb default ai.scheduling_default()
, processing pg_catalog.jsonb default ai.processing_default()
, grant_to pg_catalog.name[] default ai.grant_to()
-- TODO: need to specify text search config https://www.postgresql.org/docs/current/textsearch-configuration.html
, text_to_sql pg_catalog.jsonb default null
, catalog_name pg_catalog.name default 'default'
) returns pg_catalog.int4
as $func$
declare
_catalog_name pg_catalog.name = catalog_name;
_text_to_sql pg_catalog.jsonb = text_to_sql;
_catalog_id pg_catalog.int4;
_obj_vec_id pg_catalog.int4;
_sql_vec_id pg_catalog.int4;
Expand Down Expand Up @@ -39,6 +43,8 @@ begin
) into strict _obj_vec_id
;

-- TODO: create text search index on vectorizer target table

select ai.create_vectorizer
( 'ai.semantic_catalog_sql'::pg_catalog.regclass
, destination=>pg_catalog.format('semantic_catalog_sql_%s', _catalog_id)
Expand All @@ -52,17 +58,21 @@ begin
) into strict _sql_vec_id
;

-- TODO: create text search index on vectorizer target table

insert into ai.semantic_catalog
( id
, catalog_name
, obj_vectorizer_id
, sql_vectorizer_id
, text_to_sql
)
values
( _catalog_id
, create_semantic_catalog.catalog_name
, _catalog_name
, _obj_vec_id
, _sql_vec_id
, _text_to_sql
)
returning id
into strict _catalog_id
Expand Down
22 changes: 16 additions & 6 deletions projects/extension/sql/idempotent/904-semantic-catalog-search.sql
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
-------------------------------------------------------------------------------
-- _search_semantic_catalog_obj
create function ai._search_semantic_catalog_obj
( question text
( question text -- TODO: provide a version where we pass the question already embedded as a vector
, catalog_name text default 'default'
, max_results bigint default 5
, max_vector_dist float8 default null
Expand All @@ -22,6 +22,7 @@ begin

-- embed the user's question using the embedding settings of the vectorizer on this catalog
-- also grab that table in which the vectorizer is storing the embeddings
raise debug 'embedding question';
select
ai.vectorizer_embed(v.config->'embedding', question, 'question')
, v.target_schema
Expand Down Expand Up @@ -92,8 +93,10 @@ begin
return query select * from ai.semantic_catalog_obj where false;
end if;

-- TODO: look up the text search config on the semantic_catalog so we can use the index

-- construct a tsquery by ORing all the keywords
select to_tsquery(string_agg(keyword, ' | '))
select to_tsquery(string_agg(format($$'%s'$$, keyword), ' | '))
into strict _tsquery
from unnest(keywords) keyword
;
Expand Down Expand Up @@ -159,6 +162,8 @@ begin
raise exception 'question and keywords must not both be null';
end if;

-- TODO: do a real reranking?

return query
select *
from
Expand All @@ -179,7 +184,7 @@ begin
, min_ts_rank
)
) x
limit max_results
-- limit max_results -- TODO: consider whether to do an outer limit or not
;
end;
$func$ language plpgsql stable security invoker
Expand All @@ -189,7 +194,7 @@ set search_path to pg_catalog, pg_temp
-------------------------------------------------------------------------------
-- _search_semantic_catalog_sql
create function ai._search_semantic_catalog_sql
( question text
( question text -- TODO: provide a version where we pass the question already embedded as a vector
, catalog_name text default 'default'
, max_results bigint default 5
, max_vector_dist float8 default null
Expand All @@ -208,6 +213,7 @@ begin

-- embed the user's question using the embedding settings of the vectorizer on this catalog
-- also grab that table in which the vectorizer is storing the embeddings
raise debug 'embedding question';
select
ai.vectorizer_embed(v.config->'embedding', question, 'question')
, v.target_schema
Expand Down Expand Up @@ -274,11 +280,13 @@ begin
end if;

-- construct a tsquery by ORing all the keywords
select to_tsquery(string_agg(keyword, ' | '))
select to_tsquery(string_agg(format($$'%s'$$, keyword), ' | '))
into strict _tsquery
from unnest(keywords) keyword
;

-- TODO: look up the text search config on the semantic_catalog so we can use the index

select format
( $sql$
select q.*
Expand Down Expand Up @@ -334,6 +342,8 @@ begin
raise exception 'question and keywords must not both be null';
end if;

-- TODO: do a real reranking?

return query
select *
from
Expand All @@ -354,7 +364,7 @@ begin
, min_ts_rank
)
) x
-- limit max_results
-- limit max_results -- TODO: consider whether to do an outer limit or not
;
end;
$func$ language plpgsql stable security invoker
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
--FEATURE-FLAG: text_to_sql

-------------------------------------------------------------------------------
-- _text_to_sql
create function ai.text_to_sql
-- _text_to_sql_anthropic
create function ai._text_to_sql_anthropic
( question text
, catalog_name text default 'default'
, max_results bigint default 5
Expand All @@ -11,8 +11,8 @@ create function ai.text_to_sql
, max_iter int2 default 3
, obj_renderer regprocedure default ('ai.render_semantic_catalog_obj(bigint, oid, oid)'::pg_catalog.regprocedure)
, sql_renderer regprocedure default ('ai.render_semantic_catalog_sql(bigint, text, text)'::pg_catalog.regprocedure)
, config jsonb default null -- LLM configuration
) returns text
, config jsonb default null -- TODO: use this for LLM configuration
) returns jsonb
as $func$
declare
_iter_remaining int2 = max_iter;
Expand All @@ -26,7 +26,6 @@ declare
_prompt text;
_tools jsonb;
_response jsonb;
_message_history jsonb = jsonb_build_array();
_message record;
_answer text;
begin
Expand All @@ -38,7 +37,7 @@ begin
-- search -------------------------------------------------------------

-- search obj
if jsonb_array_length(_questions) > 0 or jsonb_array_length(_keywords) > 0 then
if jsonb_array_length(_questions) > 0 /*or jsonb_array_length(_keywords) > 0*/ then
raise debug 'searching for database objects';
select jsonb_agg(x.obj)
into _ctx_obj
Expand Down Expand Up @@ -83,7 +82,7 @@ begin
end if;

-- search sql
if jsonb_array_length(_questions) > 0 or jsonb_array_length(_keywords) > 0 then
if jsonb_array_length(_questions) > 0 /*or jsonb_array_length(_keywords) > 0*/ then
raise debug 'searching for sql examples';
select jsonb_agg(x)
into _ctx_sql
Expand Down Expand Up @@ -206,37 +205,36 @@ begin
"required": ["relevant_ids"]
}
},
*/
_tools = $json$
[

{
"name": "request_more_context_by_question",
"description": "If you do not have enough context to confidently answer the user's question, use this tool to ask for more context by providing a question to be used for semantic search.",
"name": "request_more_context_by_keywords",
"description": "If you do not have enough context to confidently answer the user's question, use this tool to ask for more context by providing a list of keywords to use in performing a full-text search.",
"input_schema": {
"type": "object",
"properties" : {
"question": {
"type": "string",
"description": "A new question relevant to the user's question that will be used to perform a semantic search to gather more context"
"keywords": {
"type": "array",
"items": {"type": "string"},
"description": "A list of keywords relevant to the user's question that will be used to perform a full-text search to gather more context. Each item must be a single word with no whitespace."
}
},
"required": ["question"]
"required": ["keywords"]
}
},
*/
_tools = $json$
[
{
"name": "request_more_context_by_keywords",
"description": "If you do not have enough context to confidently answer the user's question, use this tool to ask for more context by providing a list of keywords to use in performing a full-text search.",
"name": "request_more_context_by_question",
"description": "If you do not have enough context to confidently answer the user's question, use this tool to ask for more context by providing a question to be used for semantic search.",
"input_schema": {
"type": "object",
"properties" : {
"keywords": {
"type": "array",
"items": {"type": "string"},
"description": "A list of keywords relevant to the user's question that will be used to perform a full-text search to gather more context"
"question": {
"type": "string",
"description": "A new natural language question relevant to the user's question that will be used to perform a semantic search to gather more context"
}
},
"required": ["keywords"]
"required": ["question"]
}
},
{
Expand All @@ -248,6 +246,16 @@ begin
"sql_statement": {
"type": "string",
"description": "A valid SQL statement that addresses the user's question."
},
"relevant_database_object_ids": {
"type": "array",
"items": {"type": "integer"},
"description": "Provide a list of the ids of the database examples which were relevant to the user's question and useful in providing the answer."
},
"relevant_sql_example_ids": {
"type": "array",
"items": {"type": "integer"},
"description": "Provide a list of the ids of the SQL examples which were relevant to the user's question and useful in providing the answer."
}
},
"required": ["sql_statement"]
Expand All @@ -259,7 +267,7 @@ begin

raise debug 'calling llm';
select ai.anthropic_generate
( 'claude-3-5-sonnet-latest'
( 'claude-3-5-sonnet-latest' -- TODO: use config argument for this and other options
, jsonb_build_array(jsonb_build_object('role', 'user', 'content', _prompt))
, system_prompt=>
concat_ws
Expand All @@ -268,6 +276,7 @@ begin
, 'You have access to tools, but only use them when necessary. You may use multiple tools on each interaction.'
)
, tools=>_tools
, tool_choice=>jsonb_build_object('type', 'any')
) into strict _response
;

Expand All @@ -291,22 +300,22 @@ begin
raise info '%', _message.text;
when 'tool_use' then
case _message.name
when 'identify_relevant_obj' then
raise debug 'tool use: identify_relevant_obj';
-- throw out any obj that the LLM did NOT mark as relevant
select jsonb_agg(r) into _ctx_obj
from jsonb_array_elements_text(_message.input->'relevant_ids') i
inner join jsonb_to_recordset(_ctx_obj) r(id bigint, objtype text, objid oid)
on (i::bigint = r.id)
;
when 'identify_relevant_sql' then
raise debug 'tool use: identify_relevant_sql';
-- throw out any sql that the LLM did NOT mark as relevant
select jsonb_agg(r) into _ctx_sql
from jsonb_array_elements_text(_message.input->'relevant_ids') i
inner join jsonb_to_recordset(_ctx_sql) r(id bigint, sql text, description text)
on (i::int = r.id)
;
--when 'identify_relevant_obj' then
-- raise debug 'tool use: identify_relevant_obj';
-- -- throw out any obj that the LLM did NOT mark as relevant
-- select jsonb_agg(r) into _ctx_obj
-- from jsonb_array_elements_text(_message.input->'relevant_ids') i
-- inner join jsonb_to_recordset(_ctx_obj) r(id bigint, objtype text, objid oid)
-- on (i::bigint = r.id)
-- ;
--when 'identify_relevant_sql' then
-- raise debug 'tool use: identify_relevant_sql';
-- -- throw out any sql that the LLM did NOT mark as relevant
-- select jsonb_agg(r) into _ctx_sql
-- from jsonb_array_elements_text(_message.input->'relevant_ids') i
-- inner join jsonb_to_recordset(_ctx_sql) r(id bigint, sql text, description text)
-- on (i::int = r.id)
-- ;
when 'request_more_context_by_question' then
raise debug 'tool use: request_more_context_by_question: %', _message.input->'question';
-- append the question to the list of questions to use on the next iteration
Expand All @@ -322,13 +331,36 @@ begin
when 'answer_user_question_with_sql_statement' then
raise debug 'tool use: answer_user_question_with_sql_statement';
select _message.input->>'sql_statement' into strict _answer;
if _message.input->'relevant_database_object_ids' is not null and jsonb_array_length(_message.input->'relevant_database_object_ids') > 0 then
-- throw out any obj that the LLM did NOT mark as relevant
select jsonb_agg(r) into _ctx_obj
from jsonb_array_elements_text(_message.input->'relevant_database_object_ids') i
inner join jsonb_to_recordset(_ctx_obj) r(id bigint, classid oid, objid oid)
on (i::bigint = r.id)
;
end if;
if _message.input->'relevant_sql_example_ids' is not null and jsonb_array_length(_message.input->'relevant_sql_example_ids') > 0 then
-- throw out any sql that the LLM did NOT mark as relevant
select jsonb_agg(r) into _ctx_sql
from jsonb_array_elements_text(_message.input->'relevant_sql_example_ids') i
inner join jsonb_to_recordset(_ctx_sql) r(id bigint, sql text, description text)
on (i::int = r.id)
;
end if;
end case
;
end case
;
-- if we got our answer, return
if _answer is not null then
return _answer;
raise debug 'relevant database objects %', jsonb_pretty(_ctx_obj);
raise debug 'relevant sql examples %', jsonb_pretty(_ctx_sql);
return jsonb_build_object
( 'sql_statement', _answer
, 'relevant_database_objects', _ctx_obj
, 'relevant_sql_examples', _ctx_sql
, 'iterations', (max_iter - _iter_remaining)
);
end if;
end loop;
_iter_remaining = _iter_remaining - 1;
Expand Down
23 changes: 23 additions & 0 deletions projects/extension/sql/idempotent/907-text-to-sql-ollama.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
--FEATURE-FLAG: text_to_sql

-------------------------------------------------------------------------------
-- _text_to_sql_ollama
create function ai._text_to_sql_ollama
( question text
, catalog_name text default 'default'
, max_results bigint default 5
, max_vector_dist float8 default null
, min_ts_rank real default null
, max_iter int2 default 3
, obj_renderer regprocedure default ('ai.render_semantic_catalog_obj(bigint, oid, oid)'::pg_catalog.regprocedure)
, sql_renderer regprocedure default ('ai.render_semantic_catalog_sql(bigint, text, text)'::pg_catalog.regprocedure)
, config jsonb default null -- TODO: use this for LLM configuration
) returns jsonb
as $func$
declare
begin
raise exception 'not implemented yet';
end
$func$ language plpgsql stable security invoker
set search_path to pg_catalog, pg_temp
;
Loading

0 comments on commit 028b7ac

Please sign in to comment.