-
Notifications
You must be signed in to change notification settings - Fork 126
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c5b2bd7
commit 6823744
Showing
3 changed files
with
380 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .opennessScoreBERT import * | ||
from .opennessScoreSimilarity import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
import convokit | ||
from convokit import Corpus, download, FightingWords | ||
from convokit.transformer import Transformer | ||
from inspect import signature | ||
from collections import defaultdict | ||
from itertools import permutations | ||
from nltk.tokenize import word_tokenize | ||
from convokit import Corpus, download | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import random | ||
from transformers import AutoModelForMaskedLM, AutoTokenizer | ||
import torch | ||
import language_tool_python | ||
import os | ||
|
||
|
||
class OpennessScoreBERT(Transformer): | ||
""" | ||
A transformer to calculate openness score for all utterance | ||
:param obj_type: type of Corpus object to calculate: 'conversation', 'speaker', or 'utterance', default to be 'utterance' | ||
:param input_field: Input fields from every utterance object. Will default to reading 'utt.text'. If a string is provided, than consider metadata with field name input_field. | ||
:param output_field: field for writing the computed output in metadata. Will default to write to utterance metadata with name 'capitalization'. | ||
:param input_filter: a boolean function of signature `input_filter(utterance, aux_input)`. attributes will only be computed for utterances where `input_filter` returns `True`. By default, will always return `True`, meaning that attributes will be computed for all utterances. | ||
:param verbosity: frequency at which to print status messages when computing attributes. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
obj_type="utterance", | ||
output_field="openness_score", | ||
input_field=None, | ||
input_filter=None, | ||
model_name="bert-base-cased", | ||
verbosity=1000, | ||
): | ||
if input_filter: | ||
if len(signature(input_filter).parameters) == 1: | ||
self.input_filter = lambda utt: input_filter(utt) | ||
else: | ||
self.input_filter = input_filter | ||
else: | ||
self.input_filter = lambda utt: True | ||
self.obj_type = obj_type | ||
self.input_field = input_field | ||
self.output_field = output_field | ||
self.verbosity = verbosity | ||
self.grammar_tool = language_tool_python.LanguageToolPublicAPI("en") | ||
self.answer_sample = ["Mhm", "Okay", "I see", "Yup"] | ||
self.model = AutoModelForMaskedLM.from_pretrained(model_name) | ||
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
|
||
def _print_output(self, i): | ||
return (self.verbosity > 0) and (i > 0) and (i % self.verbosity == 0) | ||
|
||
def bert_score(self, question, answer): | ||
""" | ||
Outputs the perplexitty score for predicting the answer, given the question | ||
:param question: str | ||
:param answer: str | ||
:return: perplexity | ||
""" | ||
sentence = question + " " + answer | ||
tensor_input = self.tokenizer.encode(sentence, return_tensors="pt") | ||
question_tok_len = len(self.tokenizer.encode(question)) - 2 | ||
repeat_input = tensor_input.repeat(tensor_input.size(-1) - 2 - question_tok_len, 1) | ||
mask = torch.ones(tensor_input.size(-1) - 1).diag(1)[:-2][question_tok_len:] | ||
masked_input = repeat_input.masked_fill(mask == 1, self.tokenizer.mask_token_id) | ||
labels = repeat_input.masked_fill(masked_input != self.tokenizer.mask_token_id, -100) | ||
with torch.inference_mode(): | ||
loss = self.model(masked_input, labels=labels).loss | ||
return np.exp(loss.item()) | ||
|
||
def find_last_question(self, text): | ||
""" | ||
Finds the last sentence that ended with a question mark | ||
:param text: str | ||
:return: text | ||
""" | ||
end_sent = set([".", "?", "!"]) | ||
last_q = text.rfind("?") | ||
for i in range(last_q - 1, -1, -1): | ||
if text[i] in end_sent: | ||
return text[i + 1 : last_q + 1].strip() | ||
return text[: last_q + 1].strip() | ||
|
||
def bert_opennes_score(self, question): | ||
scores = [] | ||
question = self.find_last_question(question) | ||
question = self.grammar_tool.correct(question) | ||
|
||
for ans in self.answer_sample: | ||
ans_text = ans | ||
perp = self.bert_score(question, ans_text) | ||
scores.append(perp) | ||
return np.mean(scores) | ||
|
||
def transform(self, corpus: Corpus) -> Corpus: | ||
""" | ||
Score the given utterance on their openness and store it to the corresponding object metadata fields. | ||
:param corpus: Corpus | ||
:return: the corpus | ||
""" | ||
if self.obj_type == "utterance": | ||
total = len(list(corpus.iter_utterances())) | ||
|
||
for idx, utterance in enumerate(corpus.iter_utterances()): | ||
if self._print_output(idx): | ||
print(f"%03d/%03d {self.obj_type} processed" % (idx, total)) | ||
|
||
if not self.input_filter(utterance): | ||
continue | ||
|
||
if self.input_field is None: | ||
text_entry = utterance.text | ||
elif isinstance(self.input_field, str): | ||
text_entry = utterance.meta(self.input_field) | ||
if text_entry is None: | ||
continue | ||
|
||
# do the catching and add to output_field | ||
catch = self.bert_opennes_score(text_entry) | ||
|
||
utterance.add_meta(self.output_field, catch) | ||
|
||
elif self.obj_type == "conversation": | ||
total = len(list(corpus.iter_conversations())) | ||
for idx, convo in enumerate(corpus.iter_conversations()): | ||
if self._print_output(idx): | ||
print(f"%03d/%03d {self.obj_type} processed" % (idx, total)) | ||
|
||
if not self.input_filter(convo): | ||
continue | ||
|
||
if self.input_field is None: | ||
utt_lst = convo.get_utterance_ids() | ||
text_entry = " ".join([corpus.get_utterance(x).text for x in utt_lst]) | ||
elif isinstance(self.input_field, str): | ||
text_entry = convo.meta(self.input_field) | ||
if text_entry is None: | ||
continue | ||
|
||
# do the catching and add to output_field | ||
catch = self.bert_opennes_score(text_entry) | ||
|
||
convo.add_meta(self.output_field, catch) | ||
|
||
elif self.obj_type == "speaker": | ||
total = len(list(corpus.iter_speakers())) | ||
for idx, sp in enumerate(corpus.iter_speakers()): | ||
if self._print_output(idx): | ||
print(f"%03d/%03d {self.obj_type} processed" % (idx, total)) | ||
|
||
if not self.input_filter(sp): | ||
continue | ||
|
||
if self.input_field is None: | ||
utt_lst = sp.get_utterance_ids() | ||
text_entry = " ".join([corpus.get_utterance(x).text for x in utt_lst]) | ||
elif isinstance(self.input_field, str): | ||
text_entry = sp.meta(self.input_field) | ||
if text_entry is None: | ||
continue | ||
|
||
# do the catching and add to output_field | ||
catch = self.bert_opennes_score(text_entry) | ||
|
||
sp.add_meta(self.output_field, catch) | ||
|
||
else: | ||
raise KeyError("obj_type must be utterance, conversation, or speaker") | ||
|
||
if self.verbosity > 0: | ||
print(f"%03d/%03d {self.obj_type} processed" % (total, total)) | ||
return corpus |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
from convokit import Corpus, download, FightingWords | ||
from convokit.transformer import Transformer | ||
from inspect import signature | ||
from collections import defaultdict | ||
from itertools import permutations | ||
from nltk.tokenize import word_tokenize | ||
from convokit import Corpus, download | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import random | ||
from sentence_transformers import SentenceTransformer, util | ||
from haystack.document_stores import InMemoryDocumentStore | ||
from haystack.nodes import BM25Retriever | ||
from haystack.pipelines import DocumentSearchPipeline | ||
import language_tool_python | ||
|
||
|
||
class OpennessScoreSimilarity(Transformer): | ||
""" | ||
A transformer that uses BERT similarity to calculate openness score | ||
:param obj_type: type of Corpus object to calculate: 'conversation', 'speaker', or 'utterance', default to be 'utterance' | ||
:param input_field: Input fields from every utterance object. Will default to reading 'utt.text'. If a string is provided, than consider metadata with field name input_field. | ||
:param output_field: field for writing the computed output in metadata. Will default to write to utterance metadata with name 'capitalization'. | ||
:param input_filter: a boolean function of signature `input_filter(utterance, aux_input)`. attributes will only be computed for utterances where `input_filter` returns `True`. By default, will always return `True`, meaning that attributes will be computed for all utterances. | ||
:param verbosity: frequency at which to print status messages when computing attributes. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
obj_type="utterance", | ||
output_field="openness_score", | ||
input_field=None, | ||
input_filter=None, | ||
model_name="bert-base-cased", | ||
verbosity=1000, | ||
): | ||
if input_filter: | ||
if len(signature(input_filter).parameters) == 1: | ||
self.input_filter = lambda utt: input_filter(utt) | ||
else: | ||
self.input_filter = input_filter | ||
else: | ||
self.input_filter = lambda utt: True | ||
self.obj_type = obj_type | ||
self.input_field = input_field | ||
self.output_field = output_field | ||
self.verbosity = verbosity | ||
self.grammar_tool = language_tool_python.LanguageToolPublicAPI("en") | ||
self.answer_sample = ["Mhm", "Okay", "I see", "Yup"] | ||
self.document_store = InMemoryDocumentStore(use_bm25=True) | ||
self.model = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1", device="cpu") # SBERT model | ||
|
||
def fit(self, corpus: Corpus, y=None): | ||
"""Learn context information for the given corpus.""" | ||
self.corpus = corpus | ||
self._load_questions(corpus) | ||
|
||
def _print_output(self, i): | ||
return (self.verbosity > 0) and (i > 0) and (i % self.verbosity == 0) | ||
|
||
def generated_openness_score_similarity(self, text): | ||
if len(text) > 500 and len(word_tokenize(text)) > 100: | ||
text_token = word_tokenize(self.find_last_question(text)) | ||
text = "" | ||
for token in text: | ||
text = text + " " + token | ||
prediction = self.pipe.run(query=text, params={"Retriever": {"top_k": 10}}) | ||
answers = [prediction["documents"][i].meta["answer"] for i in range(10)] | ||
return self._avg_bert_sim(answers) | ||
|
||
def transform(self, corpus: Corpus) -> Corpus: | ||
""" | ||
Score the given utterance on their openness and store it to the corresponding object metadata fields. | ||
:param corpus: Corpus | ||
:return: the corpus | ||
""" | ||
if self.obj_type == "utterance": | ||
total = len(list(corpus.iter_utterances())) | ||
|
||
for idx, utterance in enumerate(corpus.iter_utterances()): | ||
if self._print_output(idx): | ||
print(f"%03d/%03d {self.obj_type} processed" % (idx, total)) | ||
|
||
if not self.input_filter(utterance): | ||
continue | ||
|
||
if self.input_field is None: | ||
text_entry = utterance.text | ||
elif isinstance(self.input_field, str): | ||
text_entry = utterance.meta(self.input_field) | ||
if text_entry is None: | ||
continue | ||
|
||
# do the catching and add to output_field | ||
catch = self.generated_openness_score_similarity(text_entry) | ||
|
||
utterance.add_meta(self.output_field, catch) | ||
|
||
elif self.obj_type == "conversation": | ||
total = len(list(corpus.iter_conversations())) | ||
for idx, convo in enumerate(corpus.iter_conversations()): | ||
if self._print_output(idx): | ||
print(f"%03d/%03d {self.obj_type} processed" % (idx, total)) | ||
|
||
if not self.input_filter(convo): | ||
continue | ||
|
||
if self.input_field is None: | ||
utt_lst = convo.get_utterance_ids() | ||
text_entry = " ".join([corpus.get_utterance(x).text for x in utt_lst]) | ||
elif isinstance(self.input_field, str): | ||
text_entry = convo.meta(self.input_field) | ||
if text_entry is None: | ||
continue | ||
|
||
# do the catching and add to output_field | ||
catch = self.generated_openness_score_similarity(text_entry) | ||
|
||
convo.add_meta(self.output_field, catch) | ||
|
||
elif self.obj_type == "speaker": | ||
total = len(list(corpus.iter_speakers())) | ||
for idx, sp in enumerate(corpus.iter_speakers()): | ||
if self._print_output(idx): | ||
print(f"%03d/%03d {self.obj_type} processed" % (idx, total)) | ||
|
||
if not self.input_filter(sp): | ||
continue | ||
|
||
if self.input_field is None: | ||
utt_lst = sp.get_utterance_ids() | ||
text_entry = " ".join([corpus.get_utterance(x).text for x in utt_lst]) | ||
elif isinstance(self.input_field, str): | ||
text_entry = sp.meta(self.input_field) | ||
if text_entry is None: | ||
continue | ||
|
||
# do the catching and add to output_field | ||
catch = self.generated_openness_score_similarity(text_entry) | ||
|
||
sp.add_meta(self.output_field, catch) | ||
|
||
else: | ||
raise KeyError("obj_type must be utterance, conversation, or speaker") | ||
|
||
if self.verbosity > 0: | ||
print(f"%03d/%03d {self.obj_type} processed" % (total, total)) | ||
return corpus | ||
|
||
# helper function | ||
def _load_questions(self, corpus): | ||
"""""" | ||
docs = [] | ||
convo_ids = corpus.get_conversation_ids() | ||
for idx in convo_ids: | ||
convo = corpus.get_conversation(idx) | ||
utts = convo.get_chronological_utterance_list() | ||
had_question = False | ||
before_text = "" | ||
for utt in utts: | ||
if had_question: | ||
dic_transf = { | ||
"content": before_text, | ||
"meta": {"convo_id": idx, "answer": utt.text}, | ||
} | ||
docs.append(dic_transf) | ||
had_question = False | ||
if utt.meta["questions"] > 0: | ||
had_question = True | ||
before_text = utt.text | ||
self.document_store.write_documents(docs) | ||
self.retriever = BM25Retriever(document_store=self.document_store) | ||
self.pipe = DocumentSearchPipeline(retriever=self.retriever) | ||
|
||
def _sbert_embedd_sim(self, embedding1, embedding2): | ||
return float(util.cos_sim(embedding1, embedding2)) | ||
|
||
def _avg_bert_sim(self, texts): | ||
embeddings = [] | ||
for text in texts: | ||
embeddings.append(self.model.encode(text)) | ||
|
||
scores = [] | ||
for i, embedding1 in enumerate(embeddings): | ||
for j, embedding2 in enumerate(embeddings): | ||
if i >= j: | ||
continue | ||
scores.append(self._sbert_embedd_sim(embedding1, embedding2)) | ||
return np.mean(scores) | ||
|
||
def _find_last_question(text): | ||
end_sent = set([".", "?", "!"]) | ||
last_q = text.rfind("?") | ||
for i in range(last_q - 1, -1, -1): | ||
if text[i] in end_sent: | ||
return text[i + 1 : last_q + 1].strip() | ||
return text[: last_q + 1].strip() |