Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support RAG init version. #2

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions examples/rag_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@

from pyspark_ai import SparkAI
from pyspark.sql import DataFrame, SparkSession

from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.llms import VLLM

model_name = "sentence-transformers/all-mpnet-base-v2"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}
hf_embedding = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)

index_name = 'tpch_index'

def load_tpch_schema():
with open("tpch.sql","r") as f:
all_lines = f.readlines()
tpch_texts = "".join(all_lines).replace("\n",' ')
tables = tpch_texts.split(";")
return tables


def store_in_faiss(texts):
db = FAISS.from_texts(texts, hf_embedding)
db.save_local(index_name)


if __name__ == '__main__':
# split TPCH schema and store it in faiss vector store
tpch_texts = load_tpch_schema()
store_in_faiss(tpch_texts)

db = FAISS.load_local(index_name, hf_embedding)
# Initialize the VLLM
# Arguments for vLLM engine: https://github.com/bigPYJ1151/vllm/blob/e394e2b72c0e0d6e57dc818613d1ea3fc8109ace/vllm/engine/arg_utils.py#L12
llm = VLLM(
# model="defog/sqlcoder-7b-2",
# model="deepseek-ai/deepseek-coder-7b-instruct-v1.5",
model="microsoft/Phi-3-mini-4k-instruct",
trust_remote_code=True,
download_dir="/mnt/DP_disk2/models/Huggingface/"
)

# show reference tables
docs = db.similarity_search("What is the customer's name who has placed the most orders in year of 1995? ")
for doc in docs:
print(doc.page_content)

spark_session = SparkSession.builder.appName("text2sql").master("local[*]").enableHiveSupport(). getOrCreate()
spark_session.sql("show databases").show()
spark_session.sql("use tpch;").show()
# # Initialize and activate SparkAI
spark_ai = SparkAI(llm=llm,verbose=True,spark_session=spark_session, vector_db=db)
spark_ai.activate()
spark_ai.transform_rag("What is the customer's name who has placed the most orders in year of 1995? ").show()
92 changes: 92 additions & 0 deletions examples/tpch.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
CREATE TABLE nation
(
n_nationkey INTEGER not null,
n_name CHAR(25) not null,
n_regionkey INTEGER not null,
n_comment VARCHAR(152)
);

CREATE TABLE region
(
r_regionkey INTEGER not null,
r_name CHAR(25) not null,
r_comment VARCHAR(152)
);

CREATE TABLE part
(
p_partkey BIGINT not null,
p_name VARCHAR(55) not null,
p_mfgr CHAR(25) not null,
p_brand CHAR(10) not null,
p_type VARCHAR(25) not null,
p_size INTEGER not null,
p_container CHAR(10) not null,
p_retailprice DOUBLE PRECISION not null,
p_comment VARCHAR(23) not null
);

CREATE TABLE supplier
(
s_suppkey BIGINT not null,
s_name CHAR(25) not null,
s_address VARCHAR(40) not null,
s_nationkey INTEGER not null,
s_phone CHAR(15) not null,
s_acctbal DOUBLE PRECISION not null,
s_comment VARCHAR(101) not null
);

CREATE TABLE partsupp
(
ps_partkey BIGINT not null,
ps_suppkey BIGINT not null,
ps_availqty BIGINT not null,
ps_supplycost DOUBLE PRECISION not null,
ps_comment VARCHAR(199) not null
);

CREATE TABLE customer
(
c_custkey BIGINT not null,
c_name VARCHAR(25) not null,
c_address VARCHAR(40) not null,
c_nationkey INTEGER not null,
c_phone CHAR(15) not null,
c_acctbal DOUBLE PRECISION not null,
c_mktsegment CHAR(10) not null,
c_comment VARCHAR(117) not null
);

CREATE TABLE orders
(
o_orderkey BIGINT not null,
o_custkey BIGINT not null,
o_orderstatus CHAR(1) not null,
o_totalprice DOUBLE PRECISION not null,
o_orderdate DATE not null,
o_orderpriority CHAR(15) not null,
o_clerk CHAR(15) not null,
o_shippriority INTEGER not null,
o_comment VARCHAR(79) not null
);

CREATE TABLE lineitem
(
l_orderkey BIGINT not null,
l_partkey BIGINT not null,
l_suppkey BIGINT not null,
l_linenumber BIGINT not null,
l_quantity DOUBLE PRECISION not null,
l_extendedprice DOUBLE PRECISION not null,
l_discount DOUBLE PRECISION not null,
l_tax DOUBLE PRECISION not null,
l_returnflag CHAR(1) not null,
l_linestatus CHAR(1) not null,
l_shipdate DATE not null,
l_commitdate DATE not null,
l_receiptdate DATE not null,
l_shipinstruct CHAR(25) not null,
l_shipmode CHAR(10) not null,
l_comment VARCHAR(44) not null
);
84 changes: 82 additions & 2 deletions pyspark_ai/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,61 @@

sql_answer2 = "SELECT COUNT(`Student`) FROM `spark_ai_temp_view_12qcl3` WHERE `Birthday` = 'January 1, 2006'"

sql_question3 = """QUESTION: Given some Spark tables metadata or sqls:
```
CREATE TABLE Customers (
customer_id INT,
customer_name STRING,
customer_email STRING
);

CREATE TABLE Orders (
order_id INT,
customer_id INT,
order_date DATE,
order_total DECIMAL(10, 2)
);
```
Write a Spark SQL query to retrieve data based on the provided information: How many orders has each customer placed?
"""

sql_answer3 = """
SELECT c.customer_name, COUNT(o.order_id) AS number_of_orders
FROM Customers c
JOIN Orders o ON c.customer_id = o.customer_id
GROUP BY c.customer_name;
"""


sql_question4 = """QUESTION: Given some Spark tables metadata or sqls:
```
CREATE TABLE Products (
product_id INT,
product_name STRING,
product_price DECIMAL(10, 2),
category STRING
);

CREATE TABLE Sales (
sale_id INT,
product_id INT,
sale_date DATE,
quantity INT,
total_sale_amount DECIMAL(10, 2)
);
```
Write a Spark SQL query to retrieve data based on the provided information: Which product has the highest number of sales in terms of quantity sold?
"""

sql_answer4 = """
SELECT p.product_name, SUM(s.quantity) AS total_quantity_sold
FROM Products p
JOIN Sales s ON p.product_id = s.product_id
GROUP BY p.product_name
ORDER BY total_quantity_sold DESC
LIMIT 1;
"""

spark_sql_shared_example_1_prefix = f"""{sql_question1}
Thought: The column names are non-descriptive, but from the sample values I see that column `a` contains mountains
and column `c` contains countries. So, I will filter on column `c` for 'Japan' and column `a` for the mountain.
Expand Down Expand Up @@ -180,13 +235,21 @@
Answer:
"""

SPARK_SQL_SUFFIX_RAG = """\nQUESTION: Given some Spark tables metadata or sqls:
```
{comment}
```
Write a Spark SQL query to retrieve data based on the provided information: {desc}
Answer:
"""

SPARK_SQL_SUFFIX_FOR_AGENT = SPARK_SQL_SUFFIX + "\n{agent_scratchpad}"

SPARK_SQL_PREFIX = """You are an assistant for writing professional Spark SQL queries.
Given a question, you need to write a Spark SQL query to answer the question.
The rules that you should follow for answering question:
1.The answer only consists of Spark SQL query. No explaination. No
2.SQL statements should be Spark SQL query.
1.The answer only consists of Spark SQL query. No explanation.
2.SQL statements should be Spark SQL query.
3.ONLY use the verbatim column_name in your resulting SQL query; DO NOT include the type.
4.Use the COUNT SQL function when the query asks for total number of some non-countable column.
5.Use the SUM SQL function to accumulate the total number of countable column values."""
Expand Down Expand Up @@ -239,6 +302,23 @@
prefix=SPARK_SQL_PREFIX,
)

SQL_CHAIN_EXAMPLES_RAG = [
sql_question3 + f"\nAnswer:\n```{sql_answer3}```",
sql_question4 + f"\nAnswer:\n```{sql_answer4}```",
]

SQL_CHAIN_PROMPT_RAG = PromptTemplate.from_examples(
examples=SQL_CHAIN_EXAMPLES_RAG,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So examples is the question-answer set that will be stored in vector store and then searched against?

suffix=SPARK_SQL_SUFFIX_RAG,
input_variables=[
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are these input variables used? in which phase?

"view_name",
"sample_vals",
"comment",
"desc",
],
prefix=SPARK_SQL_PREFIX,
)

EXPLAIN_PREFIX = """You are an Apache Spark SQL expert, who can summary what a dataframe retrieves. Given an analyzed
query plan of a dataframe, you will
1. convert the dataframe to SQL query. Note that an explain output contains plan
Expand Down
32 changes: 31 additions & 1 deletion pyspark_ai/pyspark_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from langchain.prompts.base import BasePromptTemplate
from langchain.utilities.google_search import GoogleSearchAPIWrapper
from langchain_community.chat_models import ChatOpenAI
from langchain_core.vectorstores import VectorStore
from pyspark.sql import DataFrame, SparkSession

from pyspark_ai.ai_utils import AIUtils
Expand All @@ -23,6 +24,7 @@
PLOT_PROMPT,
SEARCH_PROMPT,
SQL_CHAIN_PROMPT,
SQL_CHAIN_PROMPT_RAG,
SQL_PROMPT,
UDF_PROMPT,
VERIFY_PROMPT,
Expand Down Expand Up @@ -61,6 +63,7 @@ def __init__(
enable_cache: bool = True,
cache_file_format: str = "json",
cache_file_location: Optional[str] = None,
vector_db: VectorStore = None,
vector_store_dir: Optional[str] = None,
vector_store_max_gb: Optional[float] = 16,
max_tokens_of_web_content: int = 3000,
Expand Down Expand Up @@ -110,6 +113,7 @@ def __init__(
).search
else:
self._cache = None
self._vector_db = vector_db
self._vector_store_dir = vector_store_dir
self._vector_store_max_gb = vector_store_max_gb
self._max_tokens_of_web_content = max_tokens_of_web_content
Expand All @@ -136,8 +140,12 @@ def _create_llm_chain(self, prompt: BasePromptTemplate):
@property
def sql_chain(self):
if self._sql_chain is None:
if self._vector_db:
prompt_temp = SQL_CHAIN_PROMPT_RAG
else:
prompt_temp = SQL_CHAIN_PROMPT
self._sql_chain = SparkSQLChain(
prompt=SQL_CHAIN_PROMPT,
prompt=prompt_temp,
llm=self._llm,
logger=self._logger,
spark=self._spark,
Expand Down Expand Up @@ -576,6 +584,16 @@ def _get_transform_sql_query_tpch(self, desc: str, table: str, cache: bool) -> s
#print(f"-------------------------Current table comment is-------------------------\n\n {comment}\n")
return self._get_sql_query(table, sample_vals_str, comment, desc)

def _get_transform_sql_query_rag(self, desc: str):
docs = self._vector_db.similarity_search(desc)
reference_contents = []
for doc in docs:
reference_contents.append(doc.page_content)
reference_str = "\n".join([str(val) for val in reference_contents])
print(f"-------------------------Current reference contents are:-------------------------\n\n {reference_str}\n")
return self._get_sql_query('', '', reference_str, desc)


def transform_df_tpch(self, desc: str, table: str, cache: bool = False) -> DataFrame:
print(f"---------------------TPCH Table {table}------------------------------\n\n")
start_time = time.time()
Expand All @@ -586,6 +604,18 @@ def transform_df_tpch(self, desc: str, table: str, cache: bool = False) -> DataF
print(f"-------------------------Received query:-------------------------\n\n {sql_query}\n")
return self._spark.sql(sql_query)

def transform_rag(self, desc: str, cache: bool = False) -> DataFrame:
print(f"---------------------Start get_transform_sql_query with rag------------------------------\n\n")
start_time = time.time()
sql_query = self._get_transform_sql_query_rag(desc)
end_time = time.time()
get_transform_sql_query_time = end_time - start_time
print(f"-------------------------End get_transform_sql_query-------------------------\n\n get_transform_sql_query_time: {get_transform_sql_query_time} seconds\n")
print(f"-------------------------Received query:-------------------------\n\n {sql_query}\n")
return self._spark.sql(sql_query)



def transform_df(self, df: DataFrame, desc: str, cache: bool = True) -> DataFrame:
"""
This method applies a transformation to a provided Spark DataFrame,
Expand Down