From 3afa8bbf8d24987fb29d8b6a289b0dd224c35c9f Mon Sep 17 00:00:00 2001 From: noooey <20203065@kookmin.ac.kr> Date: Wed, 29 May 2024 22:41:40 +0900 Subject: [PATCH 1/3] =?UTF-8?q?[feat/#505]=20:zap:=20Improve:=20=EA=B2=80?= =?UTF-8?q?=EC=83=89=EB=90=9C=20=EB=AC=B8=EC=84=9C=EC=99=80=20=EC=9A=94?= =?UTF-8?q?=EC=95=BD=EC=9D=98=20=EC=9C=A0=EC=82=AC=EB=8F=84=20=EA=B3=84?= =?UTF-8?q?=EC=82=B0=20=EC=B6=94=EA=B0=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model/app/app.py | 4 ++-- model/app/data/generator.py | 9 ++------- model/app/model/chain.py | 36 +++++++++++++++++++++++++++++++----- 3 files changed, 35 insertions(+), 14 deletions(-) diff --git a/model/app/app.py b/model/app/app.py index 35cb9b8e..b742364e 100644 --- a/model/app/app.py +++ b/model/app/app.py @@ -168,7 +168,7 @@ async def generate_quiz_async(generate_request, id, summary_split_docs, vector_s async for doc in vector_split_docs: vector_docs.append(doc) # Generate quiz - quiz_generator = QuizGenerator(vector_docs) + quiz_generator = QuizGenerator(vector_docs, summary) async def generate_question(idx, i): max_attempts = generate_request.num_of_quiz @@ -268,7 +268,7 @@ async def generate(generate_request: GenerateRequest): # keyword 추출 keywords = await extract_keywords(keyword_docs, top_n=min(generate_request.num_of_quiz * 2, len(sentences) - 1)) # 키워드는 개수를 여유롭게 생성합니다. - log('info', f'[app.py > quiz] Extracted Keywords: {keywords}') + log('info', f'[app.py > quiz] Extracted Keywords for creating {generate_request.num_of_quiz} quizs.: {keywords}') # quiz 생성 (비동기) asyncio.create_task(generate_quiz_async(generate_request, res["id"], summary_docs, vector_docs, keywords)) diff --git a/model/app/data/generator.py b/model/app/data/generator.py index 1e09c0be..ff0507d3 100644 --- a/model/app/data/generator.py +++ b/model/app/data/generator.py @@ -12,13 +12,8 @@ setup_logging() class QuizGenerator(): - def __init__(self, split_docs): - self.embedding = OpenAIEmbeddings(model="text-embedding-3-large") - try: - self.indices = FAISS.from_documents(split_docs, self.embedding) - except Exception as e: - raise e - self.chain = QuizPipeline(self.indices) + def __init__(self, split_docs, summary): + self.chain = QuizPipeline(split_docs, summary) def get_type(self, text): # self.types = ['1. Multiple choice', '2. Short answer type that can be easily answered in one word', '3. yes/no quiz'] diff --git a/model/app/model/chain.py b/model/app/model/chain.py index 6a1576f3..7ea99d2b 100644 --- a/model/app/model/chain.py +++ b/model/app/model/chain.py @@ -7,6 +7,9 @@ from langchain.chains.base import Chain from langchain.schema import BaseRetriever from typing import List, Dict, Any +from langchain_openai import OpenAIEmbeddings +from langchain_community.vectorstores import FAISS +from sklearn.metrics.pairwise import cosine_similarity from model.prompt import CHOICE_PROB_TEMPLATE, SHORT_PROB_TEMPLATE, GENERATE_QUIZ_TEMPLATE, JSON_FORMAT_TEMPLATE from data.settings import * @@ -18,6 +21,8 @@ # Retrieval 체인 class RetrievalChain(Chain): retriever: BaseRetriever + embeddings: OpenAIEmbeddings + summary: str @property def input_keys(self) -> List[str]: @@ -30,7 +35,23 @@ def output_keys(self) -> List[str]: def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]: message = inputs["message"] relevant_docs = self.retriever.invoke(message) - context = " ".join([doc.page_content for doc in relevant_docs]) + log('info', f'[chain.py > RetrievalChain] Retrieved {len(relevant_docs)} Docs for {message}') + + # 검색된 문서들과 summary 임베딩 + doc_embeddings = self.embeddings.embed_documents([doc.page_content for doc in relevant_docs]) + summary_embedding = self.embeddings.embed_query(self.summary) + summary_embedding = [summary_embedding] # 1D 배열을 2D 배열로 변환 + + # 유사도 계산 + similarities = cosine_similarity(summary_embedding, doc_embeddings) + + # 유사도가 높은 문서 K개 선택 + top_k_indices = similarities.argsort()[0][-K:][::-1] + top_k_docs = [relevant_docs[i] for i in top_k_indices] + log('info', f'[chain.py > RetrievalChain] Selected {len(top_k_docs)} Docs for {message} with summary.') + + context = "\n".join([doc.page_content for doc in top_k_docs]) + if len(context) < VECTOR_CHUNK_SIZE * 0.3: log('error', f'[chain.py > RetrievalChain] Retrieved Context is not sufficient. - "{message}", len: {len(context)}') raise ValueError(f'[chain.py > RetrievalChain] Retrieved Context is not sufficient. - "{message}", len: {len(context)}') @@ -90,9 +111,14 @@ def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]: # 문제 생성 클래스 class QuizPipeline: - def __init__(self, indices: object): - self.indices = indices - self.retriever = self.indices.as_retriever(search_kwargs=dict(k=K)) # 인덱스 객체로부터 retriever 초기화 + def __init__(self, split_docs, summary): + self.embeddings = OpenAIEmbeddings(model="text-embedding-3-large") + try: + self.indices = FAISS.from_documents(split_docs, self.embeddings) + self.retriever = self.indices.as_retriever(search_kwargs=dict(k=K+1)) # 인덱스 객체로부터 retriever 초기화 + except Exception as e: + raise e + self.summary = summary self.llm = ChatOpenAI(model="gpt-3.5-turbo-0125") self.question_templates = [CHOICE_PROB_TEMPLATE, SHORT_PROB_TEMPLATE] # self.output_schemas = [ChoiceOutput, ShortOutput] @@ -114,7 +140,7 @@ def generate_quiz(self, message: str) -> dict: partial_variables={"format_instructions": json_output_parser.get_format_instructions()} ) - retrieval_chain = RetrievalChain(retriever=self.retriever) + retrieval_chain = RetrievalChain(retriever=self.retriever, embeddings=self.embeddings, summary=self.summary) quiz_generation_chain = QuizGenerationChain(prompt= quiz_prompt, llm=self.llm) json_formatter_chain = JSONFormatterChain(prompt=json_prompt, llm=self.llm, output_parser=json_output_parser) From 3f01e8b17082765e4a3e7909b6ec39f32d97d8d9 Mon Sep 17 00:00:00 2001 From: noooey <20203065@kookmin.ac.kr> Date: Wed, 29 May 2024 22:42:06 +0900 Subject: [PATCH 2/3] =?UTF-8?q?[feat/#505]=20:wrench:=20Config:=20vector?= =?UTF-8?q?=20chunk=20size=20=EC=A1=B0=EC=A0=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model/app/data/settings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model/app/data/settings.py b/model/app/data/settings.py index d01edeb9..86399fef 100644 --- a/model/app/data/settings.py +++ b/model/app/data/settings.py @@ -20,7 +20,7 @@ SUMMARY_SENTENCE_OVERLAP = 2 # vector chunk -VECTOR_CHUNK_SIZE = 320 +VECTOR_CHUNK_SIZE = 400 VECTOR_SENTENCE_OVERLAP = 0 K = 2 From 3b1609b219ced01c914216912d565982904bcb19 Mon Sep 17 00:00:00 2001 From: noooey <20203065@kookmin.ac.kr> Date: Wed, 29 May 2024 22:47:09 +0900 Subject: [PATCH 3/3] =?UTF-8?q?[feat/#505]=20:pencil2:=20Typo:=20log=20?= =?UTF-8?q?=EC=88=98=EC=A0=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model/app/app.py | 2 +- model/app/data/generator.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/model/app/app.py b/model/app/app.py index b742364e..f3655046 100644 --- a/model/app/app.py +++ b/model/app/app.py @@ -203,7 +203,7 @@ async def generate_question(idx, i): # 5문제가 생성되었거나 마지막 문제인 경우 DynamoDB에 업데이트 if idx % 5 == 0 or idx == generate_request.num_of_quiz: - log('info', f'[app.py > quiz] cur idx: {idx} - quiz batch is ready to push. {questions}') + log('info', f'[app.py > quiz] cur idx: {idx} - quiz batch is ready to push. idx: {idx}') retries = 0 while retries < QUIZ_UPDATE_RETRY: diff --git a/model/app/data/generator.py b/model/app/data/generator.py index ff0507d3..a39513b0 100644 --- a/model/app/data/generator.py +++ b/model/app/data/generator.py @@ -25,7 +25,6 @@ def get_type(self, text): return "OX퀴즈" def adjust_result(self, type, option_list, correct): - log("warning", f"[generator.py > quiz] set type({type}), options({option_list}), correct({correct}) ") if type == "단답형" or type == None: if (correct.strip().lower() in NO_LIST) or (correct.strip().lower() in YES_LIST): type = "OX퀴즈"