Skip to content

Commit

Permalink
filter lowquality context or question
Browse files Browse the repository at this point in the history
  • Loading branch information
gulixin0922 committed Jul 30, 2024
1 parent f40fa8e commit 20bb3b2
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ def __init__(
chunk_size: int = 1024,
seed: int = 42,
prompt: Optional[ChatPromptTemplate] = SEED_QUESTION_CHAT_PROMPT,
filter_lowquality_context: bool = False,
filter_lowquality_question: bool = False
) -> None:
self.generator_llm = generator_llm
self.critic_llm = critic_llm
Expand All @@ -150,6 +152,8 @@ def __init__(
self.threshold = 5.0
self.rng = default_rng(seed)
self.prompt = prompt
self.filter_lowquality_context = filter_lowquality_context
self.filter_lowquality_question = filter_lowquality_question

@classmethod
def from_default(
Expand All @@ -158,6 +162,8 @@ def from_default(
chunk_size: int = 512,
trainset_distribution: dict = DEFAULT_TRAIN_DISTRIBUTION,
prompt: Optional[ChatPromptTemplate] = SEED_QUESTION_CHAT_PROMPT,
filter_lowquality_context: bool = False,
filter_lowquality_question: bool = False
):
generator_llm = llm
critic_llm = llm
Expand All @@ -167,6 +173,8 @@ def from_default(
chunk_size=chunk_size,
trainset_distribution=trainset_distribution,
prompt=prompt,
filter_lowquality_context=filter_lowquality_context,
filter_lowquality_question=filter_lowquality_question
)

def _get_evolve_type(self) -> str:
Expand Down Expand Up @@ -309,14 +317,17 @@ def generate(
)

text_chunk = " ".join([node.get_content() for node in nodes])
score = self._filter_context(text_chunk)
if not score:
continue
if self.filter_lowquality_context:
score = self._filter_context(text_chunk)
if not score:
continue
seed_question = self._seed_question(text_chunk)

question = seed_question
# is_valid_question = self._filter_question(question)
is_valid_question = True
if self.filter_lowquality_question:
is_valid_question = self._filter_question(question)
else:
is_valid_question = True
if is_valid_question:
context = [text_chunk] * len(question.split("\n"))
is_conv = len(context) > 1
Expand Down Expand Up @@ -355,6 +366,8 @@ def from_llm(
k: Optional[int] = None,
chunk_size: int = 512,
prompt: Optional[ChatPromptTemplate] = SEED_QUESTION_CHAT_PROMPT,
filter_lowquality_context: bool = False,
filter_lowquality_question: bool = False,
**kwargs: Any,
) -> QAGenerationChainV2:
"""
Expand All @@ -368,7 +381,13 @@ def from_llm(
Returns:
a QAGenerationChain class
"""
generator = TrainsetGenerator.from_default(llm, chunk_size=chunk_size, prompt=prompt)
generator = TrainsetGenerator.from_default(
llm,
chunk_size=chunk_size,
prompt=prompt,
filter_lowquality_context=filter_lowquality_context,
filter_lowquality_question=filter_lowquality_question
)
return cls(documents=documents, generator=generator, k=k, **kwargs)

@property
Expand Down
14 changes: 12 additions & 2 deletions src/bisheng-langchain/tests/chains/test_qa_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,18 @@ def generator():
unstructured_api_url='https://bisheng.dataelem.com/api/v1/etl4llm/predict',
)
documents = loader.load()
# qa_generator = QAGenerationChain.from_llm(documents, llm, k=5)
qa_generator = QAGenerationChainV2.from_llm(documents, llm)
k = 5
chunk_size = 512
filter_lowquality_context = False
filter_lowquality_question = False
# qa_generator = QAGenerationChain.from_llm(documents, llm, k=k, chunk_size=chunk_size, filter_lowquality_context=filter_lowquality_context, filter_lowquality_question=filter_lowquality_question)
qa_generator = QAGenerationChainV2.from_llm(documents,
llm,
k=k,
chunk_size=chunk_size,
filter_lowquality_context=filter_lowquality_context,
filter_lowquality_question=filter_lowquality_question
)
inputs = {'begin': '开始'}
response = qa_generator(inputs)
question_answers = response['questions']
Expand Down

0 comments on commit 20bb3b2

Please sign in to comment.