forked from pyspark-ai/pyspark-ai
-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support RAG init version. #2
Open
yao531441
wants to merge
2
commits into
oap-project:master
Choose a base branch
from
yao531441:master
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
|
||
from pyspark_ai import SparkAI | ||
from pyspark.sql import DataFrame, SparkSession | ||
|
||
from langchain_community.embeddings import HuggingFaceEmbeddings | ||
from langchain_community.vectorstores import FAISS | ||
from langchain_community.llms import VLLM | ||
|
||
model_name = "sentence-transformers/all-mpnet-base-v2" | ||
model_kwargs = {'device': 'cpu'} | ||
encode_kwargs = {'normalize_embeddings': False} | ||
hf_embedding = HuggingFaceEmbeddings( | ||
model_name=model_name, | ||
model_kwargs=model_kwargs, | ||
encode_kwargs=encode_kwargs | ||
) | ||
|
||
index_name = 'tpch_index' | ||
|
||
def load_tpch_schema(): | ||
with open("tpch.sql","r") as f: | ||
all_lines = f.readlines() | ||
tpch_texts = "".join(all_lines).replace("\n",' ') | ||
tables = tpch_texts.split(";") | ||
return tables | ||
|
||
|
||
def store_in_faiss(texts): | ||
db = FAISS.from_texts(texts, hf_embedding) | ||
db.save_local(index_name) | ||
|
||
|
||
if __name__ == '__main__': | ||
# split TPCH schema and store it in faiss vector store | ||
tpch_texts = load_tpch_schema() | ||
store_in_faiss(tpch_texts) | ||
|
||
db = FAISS.load_local(index_name, hf_embedding) | ||
# Initialize the VLLM | ||
# Arguments for vLLM engine: https://github.com/bigPYJ1151/vllm/blob/e394e2b72c0e0d6e57dc818613d1ea3fc8109ace/vllm/engine/arg_utils.py#L12 | ||
llm = VLLM( | ||
# model="defog/sqlcoder-7b-2", | ||
# model="deepseek-ai/deepseek-coder-7b-instruct-v1.5", | ||
model="microsoft/Phi-3-mini-4k-instruct", | ||
trust_remote_code=True, | ||
download_dir="/mnt/DP_disk2/models/Huggingface/" | ||
) | ||
|
||
# show reference tables | ||
docs = db.similarity_search("What is the customer's name who has placed the most orders in year of 1995? ") | ||
for doc in docs: | ||
print(doc.page_content) | ||
|
||
spark_session = SparkSession.builder.appName("text2sql").master("local[*]").enableHiveSupport(). getOrCreate() | ||
spark_session.sql("show databases").show() | ||
spark_session.sql("use tpch;").show() | ||
# # Initialize and activate SparkAI | ||
spark_ai = SparkAI(llm=llm,verbose=True,spark_session=spark_session, vector_db=db) | ||
spark_ai.activate() | ||
spark_ai.transform_rag("What is the customer's name who has placed the most orders in year of 1995? ").show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
CREATE TABLE nation | ||
( | ||
n_nationkey INTEGER not null, | ||
n_name CHAR(25) not null, | ||
n_regionkey INTEGER not null, | ||
n_comment VARCHAR(152) | ||
); | ||
|
||
CREATE TABLE region | ||
( | ||
r_regionkey INTEGER not null, | ||
r_name CHAR(25) not null, | ||
r_comment VARCHAR(152) | ||
); | ||
|
||
CREATE TABLE part | ||
( | ||
p_partkey BIGINT not null, | ||
p_name VARCHAR(55) not null, | ||
p_mfgr CHAR(25) not null, | ||
p_brand CHAR(10) not null, | ||
p_type VARCHAR(25) not null, | ||
p_size INTEGER not null, | ||
p_container CHAR(10) not null, | ||
p_retailprice DOUBLE PRECISION not null, | ||
p_comment VARCHAR(23) not null | ||
); | ||
|
||
CREATE TABLE supplier | ||
( | ||
s_suppkey BIGINT not null, | ||
s_name CHAR(25) not null, | ||
s_address VARCHAR(40) not null, | ||
s_nationkey INTEGER not null, | ||
s_phone CHAR(15) not null, | ||
s_acctbal DOUBLE PRECISION not null, | ||
s_comment VARCHAR(101) not null | ||
); | ||
|
||
CREATE TABLE partsupp | ||
( | ||
ps_partkey BIGINT not null, | ||
ps_suppkey BIGINT not null, | ||
ps_availqty BIGINT not null, | ||
ps_supplycost DOUBLE PRECISION not null, | ||
ps_comment VARCHAR(199) not null | ||
); | ||
|
||
CREATE TABLE customer | ||
( | ||
c_custkey BIGINT not null, | ||
c_name VARCHAR(25) not null, | ||
c_address VARCHAR(40) not null, | ||
c_nationkey INTEGER not null, | ||
c_phone CHAR(15) not null, | ||
c_acctbal DOUBLE PRECISION not null, | ||
c_mktsegment CHAR(10) not null, | ||
c_comment VARCHAR(117) not null | ||
); | ||
|
||
CREATE TABLE orders | ||
( | ||
o_orderkey BIGINT not null, | ||
o_custkey BIGINT not null, | ||
o_orderstatus CHAR(1) not null, | ||
o_totalprice DOUBLE PRECISION not null, | ||
o_orderdate DATE not null, | ||
o_orderpriority CHAR(15) not null, | ||
o_clerk CHAR(15) not null, | ||
o_shippriority INTEGER not null, | ||
o_comment VARCHAR(79) not null | ||
); | ||
|
||
CREATE TABLE lineitem | ||
( | ||
l_orderkey BIGINT not null, | ||
l_partkey BIGINT not null, | ||
l_suppkey BIGINT not null, | ||
l_linenumber BIGINT not null, | ||
l_quantity DOUBLE PRECISION not null, | ||
l_extendedprice DOUBLE PRECISION not null, | ||
l_discount DOUBLE PRECISION not null, | ||
l_tax DOUBLE PRECISION not null, | ||
l_returnflag CHAR(1) not null, | ||
l_linestatus CHAR(1) not null, | ||
l_shipdate DATE not null, | ||
l_commitdate DATE not null, | ||
l_receiptdate DATE not null, | ||
l_shipinstruct CHAR(25) not null, | ||
l_shipmode CHAR(10) not null, | ||
l_comment VARCHAR(44) not null | ||
); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,6 +62,61 @@ | |
|
||
sql_answer2 = "SELECT COUNT(`Student`) FROM `spark_ai_temp_view_12qcl3` WHERE `Birthday` = 'January 1, 2006'" | ||
|
||
sql_question3 = """QUESTION: Given some Spark tables metadata or sqls: | ||
``` | ||
CREATE TABLE Customers ( | ||
customer_id INT, | ||
customer_name STRING, | ||
customer_email STRING | ||
); | ||
|
||
CREATE TABLE Orders ( | ||
order_id INT, | ||
customer_id INT, | ||
order_date DATE, | ||
order_total DECIMAL(10, 2) | ||
); | ||
``` | ||
Write a Spark SQL query to retrieve data based on the provided information: How many orders has each customer placed? | ||
""" | ||
|
||
sql_answer3 = """ | ||
SELECT c.customer_name, COUNT(o.order_id) AS number_of_orders | ||
FROM Customers c | ||
JOIN Orders o ON c.customer_id = o.customer_id | ||
GROUP BY c.customer_name; | ||
""" | ||
|
||
|
||
sql_question4 = """QUESTION: Given some Spark tables metadata or sqls: | ||
``` | ||
CREATE TABLE Products ( | ||
product_id INT, | ||
product_name STRING, | ||
product_price DECIMAL(10, 2), | ||
category STRING | ||
); | ||
|
||
CREATE TABLE Sales ( | ||
sale_id INT, | ||
product_id INT, | ||
sale_date DATE, | ||
quantity INT, | ||
total_sale_amount DECIMAL(10, 2) | ||
); | ||
``` | ||
Write a Spark SQL query to retrieve data based on the provided information: Which product has the highest number of sales in terms of quantity sold? | ||
""" | ||
|
||
sql_answer4 = """ | ||
SELECT p.product_name, SUM(s.quantity) AS total_quantity_sold | ||
FROM Products p | ||
JOIN Sales s ON p.product_id = s.product_id | ||
GROUP BY p.product_name | ||
ORDER BY total_quantity_sold DESC | ||
LIMIT 1; | ||
""" | ||
|
||
spark_sql_shared_example_1_prefix = f"""{sql_question1} | ||
Thought: The column names are non-descriptive, but from the sample values I see that column `a` contains mountains | ||
and column `c` contains countries. So, I will filter on column `c` for 'Japan' and column `a` for the mountain. | ||
|
@@ -180,13 +235,21 @@ | |
Answer: | ||
""" | ||
|
||
SPARK_SQL_SUFFIX_RAG = """\nQUESTION: Given some Spark tables metadata or sqls: | ||
``` | ||
{comment} | ||
``` | ||
Write a Spark SQL query to retrieve data based on the provided information: {desc} | ||
Answer: | ||
""" | ||
|
||
SPARK_SQL_SUFFIX_FOR_AGENT = SPARK_SQL_SUFFIX + "\n{agent_scratchpad}" | ||
|
||
SPARK_SQL_PREFIX = """You are an assistant for writing professional Spark SQL queries. | ||
Given a question, you need to write a Spark SQL query to answer the question. | ||
The rules that you should follow for answering question: | ||
1.The answer only consists of Spark SQL query. No explaination. No | ||
2.SQL statements should be Spark SQL query. | ||
1.The answer only consists of Spark SQL query. No explanation. | ||
2.SQL statements should be Spark SQL query. | ||
3.ONLY use the verbatim column_name in your resulting SQL query; DO NOT include the type. | ||
4.Use the COUNT SQL function when the query asks for total number of some non-countable column. | ||
5.Use the SUM SQL function to accumulate the total number of countable column values.""" | ||
|
@@ -239,6 +302,23 @@ | |
prefix=SPARK_SQL_PREFIX, | ||
) | ||
|
||
SQL_CHAIN_EXAMPLES_RAG = [ | ||
sql_question3 + f"\nAnswer:\n```{sql_answer3}```", | ||
sql_question4 + f"\nAnswer:\n```{sql_answer4}```", | ||
] | ||
|
||
SQL_CHAIN_PROMPT_RAG = PromptTemplate.from_examples( | ||
examples=SQL_CHAIN_EXAMPLES_RAG, | ||
suffix=SPARK_SQL_SUFFIX_RAG, | ||
input_variables=[ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How are these input variables used? in which phase? |
||
"view_name", | ||
"sample_vals", | ||
"comment", | ||
"desc", | ||
], | ||
prefix=SPARK_SQL_PREFIX, | ||
) | ||
|
||
EXPLAIN_PREFIX = """You are an Apache Spark SQL expert, who can summary what a dataframe retrieves. Given an analyzed | ||
query plan of a dataframe, you will | ||
1. convert the dataframe to SQL query. Note that an explain output contains plan | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So
examples
is the question-answer set that will be stored in vector store and then searched against?