Skip to content

Commit

Permalink
Scoring and discussion functionality (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
NeonKirill authored Aug 7, 2023
2 parents a8c3a90 + 8663aa1 commit 84b8cae
Show file tree
Hide file tree
Showing 3 changed files with 273 additions and 73 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ MQ:
neon_llm_fastchat:
password: <neon_fastchat user's password>
user: neon_fastchat
FastChat:
model: "fastchat"
LLM_FASTCHAT:
context_depth: 3
max_tokens: 256
num_parallel_processes: 2
Expand Down
169 changes: 139 additions & 30 deletions neon_llm_fastchat/fastchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,57 +24,166 @@
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from typing import List

import ctranslate2
from transformers import T5Tokenizer
from huggingface_hub import snapshot_download
import numpy as np


# TODO: make LLM interface generic
class FastChat:

def __init__(self, config):
self.model = config["model"]
self.context_depth = config["context_depth"]
self.max_tokens = config["max_tokens"]
self.num_parallel_processes = config["num_parallel_processes"]
self.num_threads_per_process = config["num_threads_per_process"]
self.init_model()
self._tokenizer = None
self._model = None

@property
def tokenizer(self) -> T5Tokenizer:
if self._tokenizer is None:
self._tokenizer = T5Tokenizer.from_pretrained(pretrained_model_name_or_path=self.tokenizer_model_name)
return self._tokenizer

@property
def tokenizer_model_name(self) -> str:
return "google/flan-t5-xl"

@property
def model(self) -> ctranslate2.Translator:
if self._model is None:
repo_path = snapshot_download(repo_id=self.llm_model_name)
self._model = ctranslate2.Translator(model_path=repo_path,
intra_threads=self.num_threads_per_process,
inter_threads=self.num_parallel_processes)
return self._model

@property
def llm_model_name(self) -> str:
return "neongeckocom/fastchat-t5-3b-v1.0"

@property
def _system_prompt(self) -> str:
return "A chat between a curious human and an artificial intelligence assistant. " \
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n" \
"### Human: What are the key differences between renewable and non-renewable energy sources?\n" \
"### Assistant: Renewable energy sources are those that can be " \
"replenished naturally in a relatively short amount of time, such as solar, wind, hydro, " \
"geothermal, and biomass. Non-renewable energy sources, on the other hand, " \
"are finite and will eventually be depleted, such as coal, oil, and natural gas.\n"

def ask(self, message: str, chat_history: List[List[str]]) -> str:
""" Generates llm response based on user message and (user, llm) chat history """
prompt = self._assemble_prompt(message, chat_history)
llm_text_output = self._call_model(prompt)
return llm_text_output

def get_sorted_answer_indexes(self, question: str, answers: List[str]) -> List[int]:
"""
Creates sorted list of answer indexes with respect to order provided in :param answers based on PPL score
Answers are sorted from best to worst
:param question: incoming question
:param answers: list of answers to rank
:returns list of indexes
"""
if not answers:
return []
scores = self._ppl(question=question, answers=answers)
sorted_items = sorted(zip(range(len(answers)), scores), key=lambda x: x[1])
sorted_items_indexes = [x[0] for x in sorted_items]
return sorted_items_indexes

def _call_model(self, prompt: str) -> str:
"""
Wrapper for FastChat Model generation logic
:param prompt: Input text sequence
:returns: Output text sequence generated by model
"""
tokens = self._tokenize(prompt)

results = self.model.translate_batch(
[tokens],
beam_size=1,
max_decoding_length=self.max_tokens,
repetition_penalty=1.2,
)

output_tokens = results[0].hypotheses[0]
text = self.tokenizer.decode(self.tokenizer.convert_tokens_to_ids(output_tokens),
spaces_between_special_tokens=False)
return text

def _assemble_prompt(self, message: str, chat_history: List[List[str]]) -> str:
"""
Assembles prompt engineering logic
Setup Guidance:
https://github.com/lm-sys/FastChat/blob/4e2c942b8d785eb5e2aef1d0df2150e756f381ab/fastchat/conversation.py#L279
def init_model(self):
repo_path = snapshot_download(repo_id="neongeckocom/fastchat-t5-3b-v1.0")
self.model = ctranslate2.Translator(repo_path,
intra_threads=self.num_threads_per_process,
inter_threads = self.num_parallel_processes)
self.tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xl")
self.system_message = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n### Human: What are the key differences between renewable and non-renewable energy sources?\n### Assistant: Renewable energy sources are those that can be replenished naturally in a relatively short amount of time, such as solar, wind, hydro, geothermal, and biomass. Non-renewable energy sources, on the other hand, are finite and will eventually be depleted, such as coal, oil, and natural gas.\n"
:param message: Incoming prompt
:param chat_history: History of preceding conversation
:returns: assembled prompt
"""
prompt = self._system_prompt
# Context N messages
for role, content in chat_history[-self.context_depth:]:
role_fastchat = self._convert_role(role)
prompt += f"### {role_fastchat}: {content}\n"
prompt += f"### Human: {message}\n### Assistant:"
return prompt

@staticmethod
def convert_role(role):
def _convert_role(role: str) -> str:
""" Maps MQ role to FastChat internal domain """
if role == "user":
role_fastchat = "Human"
elif role == "llm":
role_fastchat = "Assistant"
else:
raise ValueError(f"role={role} is undefined, supported are: ('user', 'llm')")
return role_fastchat

def ask(self, message, chat_history):
prompt = self.system_message
# Context N messages
for role, content in chat_history[-self.context_depth:]:
role_fastchat = self.convert_role(role)
prompt += f"### {role_fastchat}: {content}\n"
prompt += f"### Human: {message}\n### Assistant:"

bot_message = self.call_model(prompt)
return bot_message
def _call_score(self, prompt: str, targets: List[str]) -> List[List[float]]:
"""
Calculates logarithmic probabilities for the list of provided text sequences
:param prompt: Input text sequence
:param targets: Output text sequences
:returns: List of calculated logarithmic probabilities per output text sequence
"""
tokens = self._tokenize(prompt)
tokens_list = len(targets) * [tokens]

def call_model(self, prompt):
tokens = self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(prompt))
target_tokens_list = [self._tokenize(target) for target in targets]

results = self.model.translate_batch(
[tokens],
beam_size=1,
max_decoding_length = self.max_tokens,
repetition_penalty = 1.2,
results = self.model.score_batch(
source=tokens_list,
target=target_tokens_list,
)

output_tokens = results[0].hypotheses[0]
text = self.tokenizer.decode(self.tokenizer.convert_tokens_to_ids(output_tokens), spaces_between_special_tokens=False)
return text
log_probs_list = [result.log_probs for result in results]
return log_probs_list

def _tokenize(self, prompt: str) -> List[str]:
tokens = self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(prompt))
return tokens

def _ppl(self, question: str, answers: List[str]) -> List[float]:
"""
Computes PPL value for the list of provided answers
:param question: Question for LLM to response to
:param answers: List of provided answers
:returns ppl values for each answer
"""
question_prompt = self._assemble_prompt(question, [])
log_probs_list = self._call_score(question_prompt, answers)
ppl_list = [self._compute_ppl(log_probs) for log_probs in log_probs_list]
return ppl_list

@staticmethod
def _compute_ppl(log_probs: List[float]) -> float:
""" Calculates perplexity value: https://en.wikipedia.org/wiki/Perplexity """
ppl = np.exp(-np.mean(log_probs))
return ppl
Loading

0 comments on commit 84b8cae

Please sign in to comment.