diff --git a/convokit/convokitConfig.py b/convokit/convokitConfig.py index 37f3ac37..aed12f06 100644 --- a/convokit/convokitConfig.py +++ b/convokit/convokitConfig.py @@ -7,6 +7,7 @@ "# Default Backend Parameters\n" "db_host: localhost:27017\n" "data_directory: ~/.convokit/saved-corpora\n" + "model_directory: ~/.convokit/saved-models\n" "default_backend: mem" ) @@ -51,6 +52,10 @@ def db_host(self): def data_directory(self): return self.config_contents.get("data_directory", "~/.convokit/saved-corpora") + @property + def model_directory(self): + return self.config_contents.get("model_directory", "~/.convokit/saved-models") + @property def default_backend(self): return self._get_config_from_env_or_file("default_backend", "mem") diff --git a/convokit/forecaster/CRAFT/CRAFTNN.py b/convokit/forecaster/CRAFT/CRAFTNN.py deleted file mode 100644 index 71fa7a3f..00000000 --- a/convokit/forecaster/CRAFT/CRAFTNN.py +++ /dev/null @@ -1,241 +0,0 @@ -try: - import torch -except (ModuleNotFoundError, ImportError) as e: - raise ModuleNotFoundError( - "torch is not currently installed. Run 'pip install convokit[craft]' if you would like to use the CRAFT model." - ) - -from torch import nn -import os -from urllib.request import urlretrieve -from .CRAFTUtil import CONSTANTS - - -class EncoderRNN(nn.Module): - """ - This module represents the utterance encoder component of CRAFT, - responsible for creating vector representations of utterances - """ - - def __init__(self, hidden_size, embedding, n_layers=1, dropout=0): - super(EncoderRNN, self).__init__() - self.n_layers = n_layers - self.hidden_size = hidden_size - self.embedding = embedding - - # Initialize GRU; the input_size and hidden_size params are both set to 'hidden_size' - # because our input size is a word embedding with number of features == hidden_size - self.gru = nn.GRU( - hidden_size, - hidden_size, - n_layers, - dropout=(0 if n_layers == 1 else dropout), - bidirectional=True, - ) - - def forward(self, input_seq, input_lengths, hidden=None): - # Convert word indexes to embeddings - embedded = self.embedding(input_seq) - # Pack padded batch of sequences for RNN module - packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths) - # Forward pass through GRU - outputs, hidden = self.gru(packed, hidden) - # Unpack padding - outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs) - # Sum bidirectional GRU outputs - outputs = outputs[:, :, : self.hidden_size] + outputs[:, :, self.hidden_size :] - # Return output and final hidden state - return outputs, hidden - - -class ContextEncoderRNN(nn.Module): - """This module represents the context encoder component of CRAFT, responsible for creating an order-sensitive vector representation of conversation context""" - - def __init__(self, hidden_size, n_layers=1, dropout=0): - super(ContextEncoderRNN, self).__init__() - self.n_layers = n_layers - self.hidden_size = hidden_size - - # only unidirectional GRU for context encoding - self.gru = nn.GRU( - hidden_size, - hidden_size, - n_layers, - dropout=(0 if n_layers == 1 else dropout), - bidirectional=False, - ) - - def forward(self, input_seq, input_lengths, hidden=None): - # Pack padded batch of sequences for RNN module - packed = torch.nn.utils.rnn.pack_padded_sequence(input_seq, input_lengths) - # Forward pass through GRU - outputs, hidden = self.gru(packed, hidden) - # Unpack padding - outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs) - # return output and final hidden state - return outputs, hidden - - -class SingleTargetClf(nn.Module): - """This module represents the CRAFT classifier head, which takes the context encoding and uses it to make a forecast""" - - def __init__(self, hidden_size, dropout=0.1): - super(SingleTargetClf, self).__init__() - - self.hidden_size = hidden_size - - # initialize classifier - self.layer1 = nn.Linear(hidden_size, hidden_size) - self.layer1_act = nn.LeakyReLU() - self.layer2 = nn.Linear(hidden_size, hidden_size // 2) - self.layer2_act = nn.LeakyReLU() - self.clf = nn.Linear(hidden_size // 2, 1) - self.dropout = nn.Dropout(p=dropout) - - def forward(self, encoder_outputs, encoder_input_lengths): - # from stackoverflow (https://stackoverflow.com/questions/50856936/taking-the-last-state-from-bilstm-bigru-in-pytorch) - # First we unsqueeze seqlengths two times so it has the same number of - # of dimensions as output_forward - # (batch_size) -> (1, batch_size, 1) - lengths = encoder_input_lengths.unsqueeze(0).unsqueeze(2) - # Then we expand it accordingly - # (1, batch_size, 1) -> (1, batch_size, hidden_size) - lengths = lengths.expand((1, -1, encoder_outputs.size(2))) - - # take only the last state of the encoder for each batch - last_outputs = torch.gather(encoder_outputs, 0, lengths - 1).squeeze(dim=0) - # forward pass through hidden layers - layer1_out = self.layer1_act(self.layer1(self.dropout(last_outputs))) - layer2_out = self.layer2_act(self.layer2(self.dropout(layer1_out))) - # compute and return logits - logits = self.clf(self.dropout(layer2_out)).squeeze(dim=1) - return logits - - -class Predictor(nn.Module): - """This helper module encapsulates the CRAFT pipeline, defining the logic of passing an input through each consecutive sub-module.""" - - def __init__(self, encoder, context_encoder, classifier): - super(Predictor, self).__init__() - self.encoder = encoder - self.context_encoder = context_encoder - self.classifier = classifier - - def forward( - self, - input_batch, - dialog_lengths, - dialog_lengths_list, - utt_lengths, - batch_indices, - dialog_indices, - batch_size, - max_length, - ): - # Forward input through encoder model - _, utt_encoder_hidden = self.encoder(input_batch, utt_lengths) - - # Convert utterance encoder final states to batched dialogs for use by context encoder - context_encoder_input = makeContextEncoderInput( - utt_encoder_hidden, dialog_lengths_list, batch_size, batch_indices, dialog_indices - ) - - # Forward pass through context encoder - context_encoder_outputs, context_encoder_hidden = self.context_encoder( - context_encoder_input, dialog_lengths - ) - - # Forward pass through classifier to get prediction logits - logits = self.classifier(context_encoder_outputs, dialog_lengths) - - # Apply sigmoid activation - predictions = torch.sigmoid(logits) - return predictions - - -def makeContextEncoderInput( - utt_encoder_hidden, dialog_lengths, batch_size, batch_indices, dialog_indices -): - """The utterance encoder takes in utterances in combined batches, with no knowledge of which ones go where in which conversation. - Its output is therefore also unordered. We correct this by using the information computed during tensor conversion to regroup - the utterance vectors into their proper conversational order.""" - # first, sum the forward and backward encoder states - utt_encoder_summed = utt_encoder_hidden[-2, :, :] + utt_encoder_hidden[-1, :, :] - # we now have hidden state of shape [utterance_batch_size, hidden_size] - # split it into a list of [hidden_size,] x utterance_batch_size - last_states = [t.squeeze() for t in utt_encoder_summed.split(1, dim=0)] - - # create a placeholder list of tensors to group the states by source dialog - states_dialog_batched = [[None for _ in range(dialog_lengths[i])] for i in range(batch_size)] - - # group the states by source dialog - for hidden_state, batch_idx, dialog_idx in zip(last_states, batch_indices, dialog_indices): - states_dialog_batched[batch_idx][dialog_idx] = hidden_state - - # stack each dialog into a tensor of shape [dialog_length, hidden_size] - states_dialog_batched = [torch.stack(d) for d in states_dialog_batched] - - # finally, condense all the dialog tensors into a single zero-padded tensor - # of shape [max_dialog_length, batch_size, hidden_size] - return torch.nn.utils.rnn.pad_sequence(states_dialog_batched) - - -def initialize_model( - custom_model_path, - voc, - device, - device_type: str, - hidden_size, - encoder_n_layers, - dropout, - context_encoder_n_layers, -): - print("Loading saved parameters...") - if custom_model_path is None: - if not os.path.isfile("model.tar"): - print("\tDownloading trained CRAFT...") - urlretrieve(CONSTANTS["MODEL_URL"], "model.tar") - print("\t...Done!") - custom_model_path = "model.tar" - # If running in a non-GPU environment, you need to tell PyTorch to convert the parameters to CPU tensor format. - # To do so, replace the previous line with the following: - if device_type == "cpu": - checkpoint = torch.load(custom_model_path, map_location=torch.device("cpu")) - elif device_type == "cuda": - checkpoint = torch.load(custom_model_path) - encoder_sd = checkpoint["en"] - context_sd = checkpoint["ctx"] - if "atk_clf" in checkpoint: - attack_clf_sd = checkpoint["atk_clf"] - - embedding_sd = checkpoint["embedding"] - voc.__dict__ = checkpoint["voc_dict"] - - print("Building encoders, decoder, and classifier...") - # Initialize word embeddings - embedding = nn.Embedding(voc.num_words, hidden_size) - embedding.load_state_dict(embedding_sd) - # Initialize utterance and context encoders - encoder: EncoderRNN = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout) - context_encoder: ContextEncoderRNN = ContextEncoderRNN( - hidden_size, context_encoder_n_layers, dropout - ) - encoder.load_state_dict(encoder_sd) - context_encoder.load_state_dict(context_sd) - # Initialize classifier - attack_clf: SingleTargetClf = SingleTargetClf(hidden_size, dropout) - if "atk_clf" in checkpoint: - attack_clf.load_state_dict(attack_clf_sd) - # Use appropriate device - encoder = encoder.to(device) - context_encoder = context_encoder.to(device) - attack_clf = attack_clf.to(device) - print("Models built and ready to go!") - - # Set dropout layers to eval mode - encoder.eval() - context_encoder.eval() - attack_clf.eval() - - # Initialize the pipeline - return Predictor(encoder, context_encoder, attack_clf) diff --git a/convokit/forecaster/CRAFT/__init__.py b/convokit/forecaster/CRAFT/__init__.py index ce3c1739..d79367ea 100644 --- a/convokit/forecaster/CRAFT/__init__.py +++ b/convokit/forecaster/CRAFT/__init__.py @@ -1,2 +1,3 @@ -from .CRAFTUtil import * -from .CRAFTNN import * +from .data import * +from .model import * +from .runners import * diff --git a/convokit/forecaster/CRAFT/CRAFTUtil.py b/convokit/forecaster/CRAFT/data.py similarity index 67% rename from convokit/forecaster/CRAFT/CRAFTUtil.py rename to convokit/forecaster/CRAFT/data.py index 28dd2d6c..7e0e77cc 100644 --- a/convokit/forecaster/CRAFT/CRAFTUtil.py +++ b/convokit/forecaster/CRAFT/data.py @@ -3,6 +3,7 @@ import nltk import random import itertools +import json try: import torch @@ -13,47 +14,23 @@ from typing import List, Tuple -CONSTANTS = { - "PAD_token": 0, - "SOS_token": 1, - "EOS_token": 2, - "UNK_token": 3, - "WORD2INDEX_URL": "http://zissou.infosci.cornell.edu/convokit/models/craft_wikiconv/word2index.json", - "INDEX2WORD_URL": "http://zissou.infosci.cornell.edu/convokit/models/craft_wikiconv/index2word.json", - "MODEL_URL": "http://zissou.infosci.cornell.edu/convokit/models/craft_wikiconv/craft_full.tar", -} - # Default word tokens PAD_token = 0 # Used for padding short sentences SOS_token = 1 # Start-of-sentence token EOS_token = 2 # End-of-sentence token UNK_token = 3 # Unknown word token -# model download paths -WORD2INDEX_URL = "http://zissou.infosci.cornell.edu/convokit/models/craft_wikiconv/word2index.json" -INDEX2WORD_URL = "http://zissou.infosci.cornell.edu/convokit/models/craft_wikiconv/index2word.json" -MODEL_URL = "http://zissou.infosci.cornell.edu/convokit/models/craft_wikiconv/craft_full.tar" - - class Voc: - """A class for representing the vocabulary used by a CRAFT model""" - def __init__(self, name, word2index=None, index2word=None): self.name = name - self.trimmed = ( - False if not word2index else True - ) # if a precomputed vocab is specified assume the speaker wants to use it as-is + self.trimmed = False if not word2index else True # if a precomputed vocab is specified assume the user wants to use it as-is self.word2index = word2index if word2index else {"UNK": UNK_token} self.word2count = {} - self.index2word = ( - index2word - if index2word - else {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS", UNK_token: "UNK"} - ) + self.index2word = index2word if index2word else {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS", UNK_token: "UNK"} self.num_words = 4 if not index2word else len(index2word) # Count SOS, EOS, PAD, UNK def addSentence(self, sentence): - for word in sentence.split(" "): + for word in sentence.split(' '): self.addWord(word) def addWord(self, word): @@ -77,89 +54,93 @@ def trim(self, min_count): if v >= min_count: keep_words.append(k) - print( - "keep_words {} / {} = {:.4f}".format( - len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index) - ) - ) + print('keep_words {} / {} = {:.4f}'.format( + len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index) + )) # Reinitialize dictionaries self.word2index = {"UNK": UNK_token} self.word2count = {} self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS", UNK_token: "UNK"} - self.num_words = 4 # Count default tokens + self.num_words = 4 # Count default tokens for word in keep_words: self.addWord(word) - -# Create a Voc object from precomputed data structures -def loadPrecomputedVoc(corpus_name, word2index_url, index2word_url): - # load the word-to-index lookup map - r = requests.get(word2index_url) - word2index = r.json() - # load the index-to-word lookup map - r = requests.get(index2word_url) - index2word = r.json() - return Voc(corpus_name, word2index, index2word) - - -# Helper functions for preprocessing and tokenizing text - - # Turn a Unicode string to plain ASCII, thanks to # https://stackoverflow.com/a/518232/2809427 def unicodeToAscii(s): - return "".join(c for c in unicodedata.normalize("NFD", s) if unicodedata.category(c) != "Mn") - + return ''.join( + c for c in unicodedata.normalize('NFD', s) + if unicodedata.category(c) != 'Mn' + ) # Tokenize the string using NLTK -def craft_tokenize(voc, text): - tokenizer = nltk.tokenize.RegexpTokenizer(pattern=r"\w+|[^\w\s]") +def tokenize(voc, text): + tokenizer = nltk.tokenize.RegexpTokenizer(pattern=r'\w+|[^\w\s]') # simplify the problem space by considering only ASCII data cleaned_text = unicodeToAscii(text.lower()) # if the resulting string is empty, nothing else to do if not cleaned_text.strip(): return [] - + tokens = tokenizer.tokenize(cleaned_text) + + # replace out-of-vocabulary tokens for i in range(len(tokens)): if tokens[i] not in voc.word2index: tokens[i] = "UNK" - return tokens + return tokens -# Helper functions for turning dialog and text sequences into tensors, and manipulating those tensors +# Create a Voc object from precomputed data structures +def loadPrecomputedVoc(corpus_name, word2index_path, index2word_path): + with open(word2index_path) as fp: + word2index = json.load(fp) + with open(index2word_path) as fp: + index2word = json.load(fp) + return Voc(corpus_name, word2index, index2word) +# Given a context utterance list from Forecaster, preprocess each utterance's text by tokenizing and truncating. +# Returns the processed dialog entry where text has been replaced with a list of +# tokens, each no longer than MAX_LENGTH - 1 (to leave space for the EOS token) +def processContext(voc, context, is_attack): + processed = [] + for utterance in context: + # since the iterative nature of Forecaster may lead us to see the same utterance + # multiple times, we'll cache the tokenized form of the utterance as metadata + # and look it up if it already exists + if "craft_tokens" not in utterance.meta: + utterance.meta["craft_tokens"] = tokenize(voc, utterance.text) + tokens = utterance.meta["craft_tokens"] + if len(tokens) >= MAX_LENGTH: + tokens = tokens[:(MAX_LENGTH-1)] + processed.append({"tokens": tokens, "is_attack": is_attack, "id": utterance.id}) + return processed def indexesFromSentence(voc, sentence): return [voc.word2index[word] for word in sentence] + [EOS_token] - def zeroPadding(l, fillvalue=PAD_token): return list(itertools.zip_longest(*l, fillvalue=fillvalue)) - -def binaryMatrix(l): +def binaryMatrix(l, value=PAD_token): m = [] for i, seq in enumerate(l): m.append([]) for token in seq: if token == PAD_token: - m[i].append(0) + m[i].append(False) else: - m[i].append(1) + m[i].append(True) return m - # Takes a batch of dialogs (lists of lists of tokens) and converts it into a # batch of utterances (lists of tokens) sorted by length, while keeping track of # the information needed to reconstruct the original batch of dialogs def dialogBatch2UtteranceBatch(dialog_batch): - utt_tuples = ( - [] - ) # will store tuples of (utterance, original position in batch, original position in dialog) + utt_tuples = [] # will store tuples of (utterance, original position in batch, original position in dialog) for batch_idx in range(len(dialog_batch)): dialog = dialog_batch[batch_idx] for dialog_idx in range(len(dialog)): @@ -173,7 +154,6 @@ def dialogBatch2UtteranceBatch(dialog_batch): dialog_indices = [u[2] for u in utt_tuples] return utt_batch, batch_indices, dialog_indices - # Returns padded input sequence tensor and lengths def inputVar(l, voc): indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l] @@ -182,20 +162,18 @@ def inputVar(l, voc): padVar = torch.LongTensor(padList) return padVar, lengths - # Returns padded target sequence tensor, padding mask, and max target length def outputVar(l, voc): indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l] max_target_len = max([len(indexes) for indexes in indexes_batch]) padList = zeroPadding(indexes_batch) mask = binaryMatrix(padList) - mask = torch.ByteTensor(mask) + mask = torch.BoolTensor(mask) padVar = torch.LongTensor(padList) return padVar, mask, max_target_len - # Returns all items for a given batch of pairs -def batch2TrainData(voc, pair_batch: List[Tuple], already_sorted=False): +def batch2TrainData(voc, pair_batch, already_sorted=False): if not already_sorted: pair_batch.sort(key=lambda x: len(x[0]), reverse=True) input_batch, output_batch, label_batch, id_batch = [], [], [], [] @@ -203,28 +181,13 @@ def batch2TrainData(voc, pair_batch: List[Tuple], already_sorted=False): input_batch.append(pair[0]) output_batch.append(pair[1]) label_batch.append(pair[2]) - if len(pair) > 3: - id_batch.append(pair[3]) - else: - id_batch.append(None) + id_batch.append(pair[3]) dialog_lengths = torch.tensor([len(x) for x in input_batch]) input_utterances, batch_indices, dialog_indices = dialogBatch2UtteranceBatch(input_batch) inp, utt_lengths = inputVar(input_utterances, voc) output, mask, max_target_len = outputVar(output_batch, voc) label_batch = torch.FloatTensor(label_batch) if label_batch[0] is not None else None - return ( - inp, - dialog_lengths, - utt_lengths, - batch_indices, - dialog_indices, - label_batch, - id_batch, - output, - mask, - max_target_len, - ) - + return inp, dialog_lengths, utt_lengths, batch_indices, dialog_indices, label_batch, id_batch, output, mask, max_target_len def batchIterator(voc, source_data, batch_size, shuffle=True): cur_idx = 0 @@ -235,14 +198,15 @@ def batchIterator(voc, source_data, batch_size, shuffle=True): cur_idx = 0 if shuffle: random.shuffle(source_data) - batch = source_data[cur_idx : (cur_idx + batch_size)] + batch = source_data[cur_idx:(cur_idx+batch_size)] # the true batch size may be smaller than the given batch size if there is not enough data left true_batch_size = len(batch) # ensure that the dialogs in this batch are sorted by length, as expected by the padding module batch.sort(key=lambda x: len(x[0]), reverse=True) # for analysis purposes, get the source dialogs and labels associated with this batch batch_dialogs = [x[0] for x in batch] + batch_labels = [x[2] for x in batch] # convert batch to tensors batch_tensors = batch2TrainData(voc, batch, already_sorted=True) - yield (batch_tensors, batch_dialogs, true_batch_size) - cur_idx += batch_size + yield (batch_tensors, batch_dialogs, batch_labels, true_batch_size) + cur_idx += batch_size \ No newline at end of file diff --git a/convokit/forecaster/CRAFT/model.py b/convokit/forecaster/CRAFT/model.py new file mode 100644 index 00000000..b9975642 --- /dev/null +++ b/convokit/forecaster/CRAFT/model.py @@ -0,0 +1,208 @@ +try: + import torch +except (ModuleNotFoundError, ImportError) as e: + raise ModuleNotFoundError( + "torch is not currently installed. Run 'pip install convokit[craft]' if you would like to use the CRAFT model." + ) + +from torch import nn +import torch.nn.functional as F + +class EncoderRNN(nn.Module): + def __init__(self, hidden_size, embedding, n_layers=1, dropout=0): + super(EncoderRNN, self).__init__() + self.n_layers = n_layers + self.hidden_size = hidden_size + self.embedding = embedding + + # Initialize GRU; the input_size and hidden_size params are both set to 'hidden_size' + # because our input size is a word embedding with number of features == hidden_size + self.gru = nn.GRU(hidden_size, hidden_size, n_layers, + dropout=(0 if n_layers == 1 else dropout), bidirectional=True) + + def forward(self, input_seq, input_lengths, hidden=None): + # Convert word indexes to embeddings + embedded = self.embedding(input_seq) + # Pack padded batch of sequences for RNN module + packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths.cpu()) + # Forward pass through GRU + outputs, hidden = self.gru(packed, hidden) + # Unpack padding + outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs) + # Sum bidirectional GRU outputs + outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:] + # Return output and final hidden state + return outputs, hidden + +class ContextEncoderRNN(nn.Module): + def __init__(self, hidden_size, n_layers=1, dropout=0): + super(ContextEncoderRNN, self).__init__() + self.n_layers = n_layers + self.hidden_size = hidden_size + + # only unidirectional GRU for context encoding + self.gru = nn.GRU(hidden_size, hidden_size, n_layers, + dropout=(0 if n_layers == 1 else dropout), bidirectional=False) + + def forward(self, input_seq, input_lengths, hidden=None): + # Pack padded batch of sequences for RNN module + packed = torch.nn.utils.rnn.pack_padded_sequence(input_seq, input_lengths.cpu()) + # Forward pass through GRU + outputs, hidden = self.gru(packed, hidden) + # Unpack padding + outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs) + # return output and final hidden state + return outputs, hidden + +# Luong attention layer +class Attn(torch.nn.Module): + def __init__(self, method, hidden_size): + super(Attn, self).__init__() + self.method = method + if self.method not in ['dot', 'general', 'concat']: + raise ValueError(self.method, "is not an appropriate attention method.") + self.hidden_size = hidden_size + if self.method == 'general': + self.attn = torch.nn.Linear(self.hidden_size, hidden_size) + elif self.method == 'concat': + self.attn = torch.nn.Linear(self.hidden_size * 2, hidden_size) + self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size)) + + def dot_score(self, hidden, encoder_output): + return torch.sum(hidden * encoder_output, dim=2) + + def general_score(self, hidden, encoder_output): + energy = self.attn(encoder_output) + return torch.sum(hidden * energy, dim=2) + + def concat_score(self, hidden, encoder_output): + energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), 2)).tanh() + return torch.sum(self.v * energy, dim=2) + + def forward(self, hidden, encoder_outputs): + # Calculate the attention weights (energies) based on the given method + if self.method == 'general': + attn_energies = self.general_score(hidden, encoder_outputs) + elif self.method == 'concat': + attn_energies = self.concat_score(hidden, encoder_outputs) + elif self.method == 'dot': + attn_energies = self.dot_score(hidden, encoder_outputs) + + # Transpose max_length and batch_size dimensions + attn_energies = attn_energies.t() + + # Return the softmax normalized probability scores (with added dimension) + return F.softmax(attn_energies, dim=1).unsqueeze(1) + +class LuongAttnDecoderRNN(nn.Module): + def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1): + super(LuongAttnDecoderRNN, self).__init__() + + # Keep for reference + self.attn_model = attn_model + self.hidden_size = hidden_size + self.output_size = output_size + self.n_layers = n_layers + self.dropout = dropout + + # Define layers + self.embedding = embedding + self.embedding_dropout = nn.Dropout(dropout) + self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout)) + self.concat = nn.Linear(hidden_size * 2, hidden_size) + self.out = nn.Linear(hidden_size, output_size) + + self.attn = Attn(attn_model, hidden_size) + + def forward(self, input_step, last_hidden, encoder_outputs): + # Note: we run this one step (word) at a time + # Get embedding of current input word + embedded = self.embedding(input_step) + embedded = self.embedding_dropout(embedded) + # Forward through unidirectional GRU + rnn_output, hidden = self.gru(embedded, last_hidden) + # Calculate attention weights from the current GRU output + attn_weights = self.attn(rnn_output, encoder_outputs) + # Multiply attention weights to encoder outputs to get new "weighted sum" context vector + context = attn_weights.bmm(encoder_outputs.transpose(0, 1)) + # Concatenate weighted context vector and GRU output using Luong eq. 5 + rnn_output = rnn_output.squeeze(0) + context = context.squeeze(1) + concat_input = torch.cat((rnn_output, context), 1) + concat_output = torch.tanh(self.concat(concat_input)) + # Predict next word using Luong eq. 6 + output = self.out(concat_output) + output = F.softmax(output, dim=1) + # Return output and final hidden state + return output, hidden + +class AttnSingleTargetClf(nn.Module): + def __init__(self, hidden_size, dropout=0.1): + super(AttnSingleTargetClf, self).__init__() + + self.hidden_size = hidden_size + + # initialize attention + self.attn = nn.Linear(hidden_size, 1) + + # initialize classifier + self.layer1 = nn.Linear(hidden_size, hidden_size) + self.layer1_act = nn.LeakyReLU() + self.layer2 = nn.Linear(hidden_size, hidden_size // 2) + self.layer2_act = nn.LeakyReLU() + self.clf = nn.Linear(hidden_size // 2, 1) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, encoder_outputs): + # compute attention weights + self.attn_weights = self.attn(encoder_outputs).squeeze().transpose(0,1) + # softmax normalize weights + self.attn_weights = F.softmax(self.attn_weights, dim=1).unsqueeze(1) + # transpose context encoder outputs so we can apply batch matrix multiply + encoder_outputs_transp = encoder_outputs.transpose(0,1) + # compute weighted context vector + context_vec = torch.bmm(self.attn_weights, encoder_outputs_transp).squeeze() + # forward pass through hidden layers + layer1_out = self.layer1_act(self.layer1(self.dropout(context_vec))) + layer2_out = self.layer2_act(self.layer2(self.dropout(layer1_out))) + # compute and return logits + logits = self.clf(self.dropout(layer2_out)).squeeze() + return logits + +class SingleTargetClf(nn.Module): + """ + Single-target classifier head with no attention layer (predicts only from + the last state vector of the RNN) + """ + def __init__(self, hidden_size, dropout=0.1): + super(SingleTargetClf, self).__init__() + + self.hidden_size = hidden_size + + # initialize classifier + self.layer1 = nn.Linear(hidden_size, hidden_size) + self.layer1_act = nn.LeakyReLU() + self.layer2 = nn.Linear(hidden_size, hidden_size // 2) + self.layer2_act = nn.LeakyReLU() + self.clf = nn.Linear(hidden_size // 2, 1) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, encoder_outputs, encoder_input_lengths): + # from stackoverflow (https://stackoverflow.com/questions/50856936/taking-the-last-state-from-bilstm-bigru-in-pytorch) + # First we unsqueeze seqlengths two times so it has the same number of + # of dimensions as output_forward + # (batch_size) -> (1, batch_size, 1) + lengths = encoder_input_lengths.unsqueeze(0).unsqueeze(2) + # Then we expand it accordingly + # (1, batch_size, 1) -> (1, batch_size, hidden_size) + lengths = lengths.expand((1, -1, encoder_outputs.size(2))) + + # take only the last state of the encoder for each batch + last_outputs = torch.gather(encoder_outputs, 0, lengths-1).squeeze() + # forward pass through hidden layers + layer1_out = self.layer1_act(self.layer1(self.dropout(last_outputs))) + layer2_out = self.layer2_act(self.layer2(self.dropout(layer1_out))) + # compute and return logits + logits = self.clf(self.dropout(layer2_out)).squeeze() + return logits + diff --git a/convokit/forecaster/CRAFT/runners.py b/convokit/forecaster/CRAFT/runners.py new file mode 100644 index 00000000..d2f6dbc1 --- /dev/null +++ b/convokit/forecaster/CRAFT/runners.py @@ -0,0 +1,244 @@ +try: + import torch +except (ModuleNotFoundError, ImportError) as e: + raise ModuleNotFoundError( + "torch is not currently installed. Run 'pip install convokit[craft]' if you would like to use the CRAFT model." + ) + +import numpy as np +import pandas as pd +import torch.nn.functional as F +from torch import nn + +class Predictor(nn.Module): + """This helper module encapsulates the CRAFT pipeline, defining the logic of passing an input through each consecutive sub-module.""" + def __init__(self, encoder, context_encoder, classifier): + super(Predictor, self).__init__() + self.encoder = encoder + self.context_encoder = context_encoder + self.classifier = classifier + + def forward(self, input_batch, dialog_lengths, dialog_lengths_list, utt_lengths, batch_indices, dialog_indices, batch_size, max_length): + # Forward input through encoder model + _, utt_encoder_hidden = self.encoder(input_batch, utt_lengths) + + # Convert utterance encoder final states to batched dialogs for use by context encoder + context_encoder_input = makeContextEncoderInput(utt_encoder_hidden, dialog_lengths_list, batch_size, batch_indices, dialog_indices) + + # Forward pass through context encoder + context_encoder_outputs, context_encoder_hidden = self.context_encoder(context_encoder_input, dialog_lengths) + + # Forward pass through classifier to get prediction logits + logits = self.classifier(context_encoder_outputs, dialog_lengths) + + # Apply sigmoid activation + predictions = F.sigmoid(logits) + return predictions + +def makeContextEncoderInput(utt_encoder_hidden, dialog_lengths, batch_size, batch_indices, dialog_indices): + # first, sum the forward and backward encoder states + utt_encoder_summed = utt_encoder_hidden[-2,:,:] + utt_encoder_hidden[-1,:,:] + # we now have hidden state of shape [utterance_batch_size, hidden_size] + # split it into a list of [hidden_size,] x utterance_batch_size + last_states = [t.squeeze() for t in utt_encoder_summed.split(1, dim=0)] + + # create a placeholder list of tensors to group the states by source dialog + states_dialog_batched = [[None for _ in range(dialog_lengths[i])] for i in range(batch_size)] + + # group the states by source dialog + for hidden_state, batch_idx, dialog_idx in zip(last_states, batch_indices, dialog_indices): + states_dialog_batched[batch_idx][dialog_idx] = hidden_state + + # stack each dialog into a tensor of shape [dialog_length, hidden_size] + states_dialog_batched = [torch.stack(d) for d in states_dialog_batched] + + # finally, condense all the dialog tensors into a single zero-padded tensor + # of shape [max_dialog_length, batch_size, hidden_size] + return torch.nn.utils.rnn.pad_sequence(states_dialog_batched) + +def train(input_variable, dialog_lengths, dialog_lengths_list, utt_lengths, batch_indices, dialog_indices, labels, # input/output arguments + encoder, context_encoder, attack_clf, # network arguments + encoder_optimizer, context_encoder_optimizer, attack_clf_optimizer, # optimization arguments + batch_size, clip, max_length=MAX_LENGTH): # misc arguments + + # Zero gradients + encoder_optimizer.zero_grad() + context_encoder_optimizer.zero_grad() + attack_clf_optimizer.zero_grad() + + # Set device options + input_variable = input_variable.to(device) + dialog_lengths = dialog_lengths.to(device) + utt_lengths = utt_lengths.to(device) + labels = labels.to(device) + + # Forward pass through utterance encoder + _, utt_encoder_hidden = encoder(input_variable, utt_lengths) + + # Convert utterance encoder final states to batched dialogs for use by context encoder + context_encoder_input = makeContextEncoderInput(utt_encoder_hidden, dialog_lengths_list, batch_size, batch_indices, dialog_indices) + + # Forward pass through context encoder + context_encoder_outputs, _ = context_encoder(context_encoder_input, dialog_lengths) + + # Forward pass through classifier to get prediction logits + logits = attack_clf(context_encoder_outputs, dialog_lengths) + + # Calculate loss + loss = F.binary_cross_entropy_with_logits(logits, labels) + + # Perform backpropatation + loss.backward() + + # Clip gradients: gradients are modified in place + _ = torch.nn.utils.clip_grad_norm_(encoder.parameters(), clip) + _ = torch.nn.utils.clip_grad_norm_(context_encoder.parameters(), clip) + _ = torch.nn.utils.clip_grad_norm_(attack_clf.parameters(), clip) + + # Adjust model weights + encoder_optimizer.step() + context_encoder_optimizer.step() + attack_clf_optimizer.step() + + return loss.item() + +def evaluateBatch(encoder, context_encoder, predictor, voc, input_batch, dialog_lengths, + dialog_lengths_list, utt_lengths, batch_indices, dialog_indices, batch_size, device, max_length=MAX_LENGTH): + # Set device options + input_batch = input_batch.to(device) + dialog_lengths = dialog_lengths.to(device) + utt_lengths = utt_lengths.to(device) + # Predict future attack using predictor + scores = predictor(input_batch, dialog_lengths, dialog_lengths_list, utt_lengths, batch_indices, dialog_indices, batch_size, max_length) + predictions = (scores > 0.5).float() + return predictions, scores + +def validate(dataset, encoder, context_encoder, predictor, voc, batch_size, device): + # create a batch iterator for the given data + batch_iterator = batchIterator(voc, dataset, batch_size, shuffle=False) + # find out how many iterations we will need to cover the whole dataset + n_iters = len(dataset) // batch_size + int(len(dataset) % batch_size > 0) + # containers for full prediction results so we can compute accuracy at the end + all_preds = [] + all_labels = [] + for iteration in range(1, n_iters+1): + batch, batch_dialogs, _, true_batch_size = next(batch_iterator) + # Extract fields from batch + input_variable, dialog_lengths, utt_lengths, batch_indices, dialog_indices, labels, convo_ids, target_variable, mask, max_target_len = batch + dialog_lengths_list = [len(x) for x in batch_dialogs] + # run the model + predictions, scores = evaluateBatch(encoder, context_encoder, predictor, voc, input_variable, + dialog_lengths, dialog_lengths_list, utt_lengths, batch_indices, dialog_indices, + true_batch_size, device) + # aggregate results for computing accuracy at the end + all_preds += [p.item() for p in predictions] + all_labels += [l.item() for l in labels] + print("Iteration: {}; Percent complete: {:.1f}%".format(iteration, iteration / n_iters * 100)) + + # compute and return the accuracy + return (np.asarray(all_preds) == np.asarray(all_labels)).mean() + +def trainIters(voc, pairs, val_pairs, encoder, context_encoder, attack_clf, + encoder_optimizer, context_encoder_optimizer, attack_clf_optimizer, embedding, + n_iteration, batch_size, print_every, validate_every, clip): + + # create a batch iterator for training data + batch_iterator = batchIterator(voc, pairs, batch_size) + + # Initializations + print('Initializing ...') + start_iteration = 1 + print_loss = 0 + + # Training loop + print("Training...") + # keep track of best validation accuracy - only save when we have a model that beats the current best + best_acc = 0 + best_model = None + for iteration in range(start_iteration, n_iteration + 1): + training_batch, training_dialogs, _, true_batch_size = next(batch_iterator) + # Extract fields from batch + input_variable, dialog_lengths, utt_lengths, batch_indices, dialog_indices, labels, _, target_variable, mask, max_target_len = training_batch + dialog_lengths_list = [len(x) for x in training_dialogs] + + # Run a training iteration with batch + loss = train(input_variable, dialog_lengths, dialog_lengths_list, utt_lengths, batch_indices, dialog_indices, labels, # input/output arguments + encoder, context_encoder, attack_clf, # network arguments + encoder_optimizer, context_encoder_optimizer, attack_clf_optimizer, # optimization arguments + true_batch_size, clip) # misc arguments + print_loss += loss + + # Print progress + if iteration % print_every == 0: + print_loss_avg = print_loss / print_every + print("Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}".format(iteration, iteration / n_iteration * 100, print_loss_avg)) + print_loss = 0 + + # Evaluate on validation set + if (iteration % validate_every == 0): + print("Validating!") + # put the network components into evaluation mode + encoder.eval() + context_encoder.eval() + attack_clf.eval() + + predictor = Predictor(encoder, context_encoder, attack_clf) + accuracy = validate(val_pairs, encoder, context_encoder, predictor, voc, batch_size, device) + print("Validation set accuracy: {:.2f}%".format(accuracy * 100)) + + # keep track of our best model so far + if accuracy > best_acc: + print("Validation accuracy better than current best; saving model...") + best_acc = accuracy + best_model = { + 'iteration': iteration, + 'en': encoder.state_dict(), + 'ctx': context_encoder.state_dict(), + 'atk_clf': attack_clf.state_dict(), + 'en_opt': encoder_optimizer.state_dict(), + 'ctx_opt': context_encoder_optimizer.state_dict(), + 'atk_clf_opt': attack_clf_optimizer.state_dict(), + 'loss': loss, + 'voc_dict': voc.__dict__, + 'embedding': embedding.state_dict() + } + + # put the network components back into training mode + encoder.train() + context_encoder.train() + attack_clf.train() + + return best_model + +def evaluateDataset(dataset, encoder, context_encoder, predictor, voc, batch_size, device): + # create a batch iterator for the given data + batch_iterator = batchIterator(voc, dataset, batch_size, shuffle=False) + # find out how many iterations we will need to cover the whole dataset + n_iters = len(dataset) // batch_size + int(len(dataset) % batch_size > 0) + output_df = { + "id": [], + "prediction": [], + "score": [] + } + for iteration in range(1, n_iters+1): + batch, batch_dialogs, _, true_batch_size = next(batch_iterator) + # Extract fields from batch + input_variable, dialog_lengths, utt_lengths, batch_indices, dialog_indices, labels, convo_ids, target_variable, mask, max_target_len = batch + dialog_lengths_list = [len(x) for x in batch_dialogs] + # run the model + predictions, scores = evaluateBatch(encoder, context_encoder, predictor, voc, input_variable, + dialog_lengths, dialog_lengths_list, utt_lengths, batch_indices, dialog_indices, + true_batch_size, device) + + # format the output as a dataframe (which we can later re-join with the corpus) + for i in range(true_batch_size): + convo_id = convo_ids[i] + pred = predictions[i].item() + score = scores[i].item() + output_df["id"].append(convo_id) + output_df["prediction"].append(pred) + output_df["score"].append(score) + + print("Iteration: {}; Percent complete: {:.1f}%".format(iteration, iteration / n_iters * 100)) + + return pd.DataFrame(output_df).set_index("id") diff --git a/convokit/forecaster/CRAFTModel.py b/convokit/forecaster/CRAFTModel.py index 99214454..d958f149 100644 --- a/convokit/forecaster/CRAFTModel.py +++ b/convokit/forecaster/CRAFTModel.py @@ -6,17 +6,16 @@ ) import pandas as pd -from convokit.forecaster.CRAFT.CRAFTUtil import loadPrecomputedVoc, batchIterator, CONSTANTS -from .CRAFT.CRAFTNN import initialize_model, makeContextEncoderInput, Predictor +from convokit.forecaster.CRAFT.data import loadPrecomputedVoc, batchIterator +from .CRAFT.model import initialize_model, makeContextEncoderInput, Predictor from .forecasterModel import ForecasterModel import numpy as np import torch.nn.functional as F from torch import optim -from sklearn.model_selection import train_test_split -from typing import Dict +from typing import Dict, Union import os -default_options = { +DEFAULT_CONFIG = { "hidden_size": 500, "encoder_n_layers": 2, "context_encoder_n_layers": 2, @@ -28,475 +27,43 @@ "print_every": 10, "train_epochs": 30, "validation_size": 0.2, - "max_length": 80, - "trained_model_output_filepath": "finetuned_model.tar", + "max_length": 80 +} + +DECISION_THRESHOLDS = { + "craft-wiki-pretrained": 0.570617, + "craft-wiki-finetuned": 0.570617, + "craft-cmv-pretrained": 0.548580, + "craft-cmv-finetuned": 0.548580 } # To understand the separation of concerns for the CRAFT files: -# CRAFT/craftNN.py contains the class implementations needed to initialize the CRAFT Neural Network model -# CRAFT/craftUtil.py contains utility methods for manipulating the data for it to be passed to the CRAFT model +# CRAFT/model.py contains the pytorch modules that comprise the CRAFT neural network +# CRAFT/data.py contains utility methods for manipulating the data for it to be passed to the CRAFT model +# CRAFT/runners.py adapts the scripts for training and inference of a CRAFT model class CRAFTModel(ForecasterModel): """ - CRAFTModel is one of the Forecaster models that can be used with the Forecaster Transformer. - - By default, CRAFTModel will be initialized with default options - - - hidden_size: 500 - - encoder_n_layers: 2 - - context_encoder_n_layers: 2 - - decoder_n_layers: 2 - - dropout: 0.1 - - batch_size (batch size for computation, i.e. how many (context, reply, id) tuples to use per batch of evaluation): 64 - - clip: 50.0 - - learning_rate: 1e-5 - - print_every: 10 - - train_epochs (number of epochs for training): 30 - - validation_size (percentage of training input data to use as validation): 0.2 - - max_length (maximum utterance length in the dataset): 80 - - :param device_type: 'cpu' or 'cuda', default: 'cpu' - :param model_path: filepath to CRAFT model if loading a custom CRAFT model - :param options: configuration options for the neural network: uses default options otherwise. - :param forecast_attribute_name: name of DataFrame column containing predictions, default: "prediction" - :param forecast_prob_attribute_name: name of DataFrame column containing prediction scores, default: "score" + A ConvoKit Forecaster-adherent reimplementation of the CRAFT conversational forecasting model from + the paper "Trouble on the Horizon: Forecasting the Derailment of Online Conversations as they Develop" + (Chang and Danescu-Niculescu-Mizil, 2019). + + Usage note: CRAFT is a neural network model; full end-to-end training of neural networks is considered + outside the scope of ConvoKit, so the ConvoKit CRAFTModel must be initialized with existing weights. + ConvoKit provides weights for the CGA-WIKI and CGA-CMV corpora. If you just want to run a fully-trained + CRAFT model on those corpora (i.e., only transform, no fit), you can use the finetuned weights + (craft-wiki-finetuned and craft-cmv-finetuned, respectively). If you want to take a pretrained model and + finetune it on your own data (i.e., both fit and transform), you can use the pretrained weights + (craft-wiki-pretrained and craft-cmv-pretrained, respectively), which provide trained versions of the + underlying utterance and conversation encoder layers but leave the classification layers at their + random initializations so that they can be fitted to your data. """ def __init__( self, - device_type: str = "cpu", - model_path: str = None, - options: Dict = None, - forecast_attribute_name: str = "prediction", - forecast_feat_name=None, - forecast_prob_attribute_name: str = "pred_score", - forecast_prob_feat_name=None, - ): - super().__init__( - forecast_attribute_name=forecast_attribute_name, - forecast_feat_name=forecast_feat_name, - forecast_prob_attribute_name=forecast_prob_attribute_name, - forecast_prob_feat_name=forecast_prob_feat_name, - ) - assert device_type in ["cuda", "cpu"] - # device: controls GPU usage: 'cuda' to enable GPU, 'cpu' to run on CPU only. - self.device = torch.device(device_type) - self.device_type = device_type - # voc: the vocabulary object (convokit.forecaster.craftUtil.Voc) used by predictor. - # Used to convert text data into numerical input for CRAFT. - self.voc = loadPrecomputedVoc( - "wikiconv", CONSTANTS["WORD2INDEX_URL"], CONSTANTS["INDEX2WORD_URL"] - ) - - if options is None: - self.options = default_options - else: - for k, v in default_options.items(): - if k not in options: - options[k] = v - self.options = options - print("Initializing CRAFT model with options:") - print(self.options) - - if model_path is not None: - if not os.path.isfile(model_path) or not model_path.endswith(".tar"): - print("Could not find CRAFT model tar file at: {}".format(model_path)) - model_path = None - self.predictor: Predictor = initialize_model( - model_path, - self.voc, - self.device, - self.device_type, - self.options["hidden_size"], - self.options["encoder_n_layers"], - self.options["dropout"], - self.options["context_encoder_n_layers"], - ) - - def _evaluate_batch( - self, - predictor, - input_batch, - dialog_lengths, - dialog_lengths_list, - utt_lengths, - batch_indices, - dialog_indices, - true_batch_size, - ): - """ - Helper for _evaluate_dataset. Runs CRAFT evaluation on a single batch; _evaluate_dataset calls this helper iteratively to get results for the entire dataset. - - :param predictor: the trained CRAFT model to use, provided as a PyTorch Model instance. - :param input_batch: the batch to run CRAFT on (produced by convokit.forecaster.craftUtil.batchIterator, formatted as a batch of utterances) - :param dialog_lengths: how many comments are in each conversation in this batch, as a PyTorch Tensor - :param dialog_lengths_list: same as dialog_lengths, but as a Python List - :param utt_lengths: for each conversation, records the number of tokens in each utterance of the conversation - :param batch_indices: used by CRAFT to reconstruct the original dialog batch from the given utterance batch. Records which dialog each utterance originally came from. - :param dialog_indices: used by CRAFT to reconstruct the original dialog batch from the given utterance batch. Records where in the dialog the utterance originally came from. - :param true_batch_size: number of dialogs in the original dialog batch this utterance batch was generated from. - - :return: per-utterance scores and binarized predictions. - """ - # Set device options - input_batch = input_batch.to(self.device) - dialog_lengths = dialog_lengths.to(self.device) - utt_lengths = utt_lengths.to(self.device) - # Predict future attack using predictor - scores = predictor( - input_batch, - dialog_lengths, - dialog_lengths_list, - utt_lengths, - batch_indices, - dialog_indices, - true_batch_size, - self.options["max_length"], - ) - predictions = (scores > 0.5).float() - return predictions, scores - - def _evaluate_dataset(self, predictor, dataset): - """ - Run a trained CRAFT model over an entire dataset in a batched fashion. - - :param predictor: the trained CRAFT model to use, provided as a PyTorch Model instance. - :param dataset: the dataset to evaluate on, formatted as a list of (context, reply, id_of_reply) tuples. - :return: a DataFrame, indexed by utterance ID, of CRAFT scores for each utterance, and the corresponding binary prediction. - """ - # create a batch iterator for the given data - batch_iterator = batchIterator(self.voc, dataset, self.options["batch_size"], shuffle=False) - # find out how many iterations we will need to cover the whole dataset - n_iters = len(dataset) // self.options["batch_size"] + int( - len(dataset) % self.options["batch_size"] > 0 - ) - output_df = { - "id": [], - self.forecast_attribute_name: [], - self.forecast_prob_attribute_name: [], - } - for iteration in range(1, n_iters + 1): - batch, batch_dialogs, true_batch_size = next(batch_iterator) - # Extract fields from batch - ( - input_variable, - dialog_lengths, - utt_lengths, - batch_indices, - dialog_indices, - labels, - batch_ids, - target_variable, - mask, - max_target_len, - ) = batch - dialog_lengths_list = [len(x) for x in batch_dialogs] - # run the model - predictions, scores = self._evaluate_batch( - predictor, - input_variable, - dialog_lengths, - dialog_lengths_list, - utt_lengths, - batch_indices, - dialog_indices, - true_batch_size, - ) - - # format the output as a dataframe (which we can later re-join with the corpus) - for i in range(true_batch_size): - utt_id = batch_ids[i] - pred = predictions[i].item() - score = scores[i].item() - output_df["id"].append(utt_id) - output_df[self.forecast_attribute_name].append(pred) - output_df[self.forecast_prob_attribute_name].append(score) - - print( - "Iteration: {}; Percent complete: {:.1f}%".format( - iteration, iteration / n_iters * 100 - ) - ) - - return pd.DataFrame(output_df).set_index("id") - - def _train_NN( - self, - input_variable, - dialog_lengths, - dialog_lengths_list, - utt_lengths, - batch_indices, - dialog_indices, - labels, # input/output arguments - encoder, - context_encoder, - attack_clf, # network arguments - encoder_optimizer, - context_encoder_optimizer, - attack_clf_optimizer, # optimization arguments - batch_size, - clip, - ): # misc arguments - # Zero gradients - encoder_optimizer.zero_grad() - context_encoder_optimizer.zero_grad() - attack_clf_optimizer.zero_grad() - - # Set device options - input_variable = input_variable.to(self.device) - dialog_lengths = dialog_lengths.to(self.device) - utt_lengths = utt_lengths.to(self.device) - labels = labels.to(self.device) - - # Forward pass through utterance encoder - _, utt_encoder_hidden = encoder(input_variable, utt_lengths) - - # Convert utterance encoder final states to batched dialogs for use by context encoder - context_encoder_input = makeContextEncoderInput( - utt_encoder_hidden, dialog_lengths_list, batch_size, batch_indices, dialog_indices - ) - - # Forward pass through context encoder - context_encoder_outputs, _ = context_encoder(context_encoder_input, dialog_lengths) - - # Forward pass through classifier to get prediction logits - logits = attack_clf(context_encoder_outputs, dialog_lengths) - - # Calculate loss - loss = F.binary_cross_entropy_with_logits(logits, labels) - - # Perform backpropatation - loss.backward() - - # Clip gradients: gradients are modified in place - _ = torch.nn.utils.clip_grad_norm_(encoder.parameters(), clip) - _ = torch.nn.utils.clip_grad_norm_(context_encoder.parameters(), clip) - _ = torch.nn.utils.clip_grad_norm_(attack_clf.parameters(), clip) - - # Adjust model weights - encoder_optimizer.step() - context_encoder_optimizer.step() - attack_clf_optimizer.step() - - return loss.item() - - def _validate(self, predictor, dataset): - # create a batch iterator for the given data - batch_iterator = batchIterator(self.voc, dataset, self.options["batch_size"], shuffle=False) - # find out how many iterations we will need to cover the whole dataset - n_iters = len(dataset) // self.options["batch_size"] + int( - len(dataset) % self.options["batch_size"] > 0 - ) - # containers for full prediction results so we can compute accuracy at the end - all_preds = [] - all_labels = [] - for iteration in range(1, n_iters + 1): - batch, batch_dialogs, true_batch_size = next(batch_iterator) - # Extract fields from batch - ( - input_variable, - dialog_lengths, - utt_lengths, - batch_indices, - dialog_indices, - batch_labels, - batch_ids, - target_variable, - mask, - max_target_len, - ) = batch - dialog_lengths_list = [len(x) for x in batch_dialogs] - # run the model - predictions, scores = self._evaluate_batch( - predictor, - input_variable, - dialog_lengths, - dialog_lengths_list, - utt_lengths, - batch_indices, - dialog_indices, - true_batch_size, - ) - # aggregate results for computing accuracy at the end - all_preds += [p.item() for p in predictions] - all_labels += [l.item() for l in batch_labels] - print( - "Iteration: {}; Percent complete: {:.1f}%".format( - iteration, iteration / n_iters * 100 - ) - ) - - # compute and return the accuracy - return (np.asarray(all_preds) == np.asarray(all_labels)).mean() - - def _train_iters( - self, - train_pairs, - val_pairs, - encoder, - context_encoder, - attack_clf, - encoder_optimizer, - context_encoder_optimizer, - attack_clf_optimizer, - embedding, - n_iteration, - validate_every, + initial_weights: str, + decision_threshold: Union[float, str] = "auto", + config: dict = DEFAULT_CONFIG ): - # create a batch iterator for training data - batch_iterator = batchIterator(self.voc, train_pairs, self.options["batch_size"]) - - # Initializations - print("Initializing ...") - start_iteration = 1 - print_loss = 0 - - # Training loop - print("Training...") - # keep track of best validation accuracy - only save when we have a model that beats the current best - best_acc = 0 - for iteration in range(start_iteration, n_iteration + 1): - training_batch, training_dialogs, true_batch_size = next(batch_iterator) - # Extract fields from batch - ( - input_variable, - dialog_lengths, - utt_lengths, - batch_indices, - dialog_indices, - labels, - batch_ids, - target_variable, - mask, - max_target_len, - ) = training_batch - dialog_lengths_list = [len(x) for x in training_dialogs] - - # Run a training iteration with batch - loss = self._train_NN( - input_variable, - dialog_lengths, - dialog_lengths_list, - utt_lengths, - batch_indices, - dialog_indices, - labels, # input/output arguments - encoder, - context_encoder, - attack_clf, # network arguments - encoder_optimizer, - context_encoder_optimizer, - attack_clf_optimizer, # optimization arguments - true_batch_size, - self.options["clip"], - ) # misc arguments - print_loss += loss - - # Print progress - if iteration % self.options["print_every"] == 0: - print_loss_avg = print_loss / self.options["print_every"] - print( - "Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}".format( - iteration, iteration / n_iteration * 100, print_loss_avg - ) - ) - print_loss = 0 - - # Evaluate on validation set - if iteration % validate_every == 0: - print("Validating!") - # put the network components into evaluation mode - encoder.eval() - context_encoder.eval() - attack_clf.eval() - - predictor = Predictor(encoder, context_encoder, attack_clf) - accuracy = self._validate(predictor, val_pairs) - print("Validation set accuracy: {:.2f}%".format(accuracy * 100)) - - # keep track of our best model so far - if accuracy > best_acc: - print("Validation accuracy better than current best; saving model...") - best_acc = accuracy - torch.save( - { - "iteration": iteration, - "en": encoder.state_dict(), - "ctx": context_encoder.state_dict(), - "atk_clf": attack_clf.state_dict(), - "en_opt": encoder_optimizer.state_dict(), - "ctx_opt": context_encoder_optimizer.state_dict(), - "atk_clf_opt": attack_clf_optimizer.state_dict(), - "loss": loss, - "voc_dict": self.voc.__dict__, - "embedding": embedding.state_dict(), - }, - self.options["trained_model_output_filepath"], - ) - - # put the network components back into training mode - encoder.train() - context_encoder.train() - attack_clf.train() - - def train(self, id_to_context_reply_label): - ids = list(id_to_context_reply_label) - train_pair_ids, val_pair_ids = train_test_split( - ids, test_size=self.options["validation_size"] - ) - train_pairs = [id_to_context_reply_label[pair_id] for pair_id in train_pair_ids] - val_pairs = [id_to_context_reply_label[pair_id] for pair_id in val_pair_ids] - # Compute the number of training iterations we will need in order to achieve the number of epochs specified in the settings at the start of the notebook - n_iter_per_epoch = len(train_pairs) // self.options["batch_size"] + int( - len(train_pairs) % self.options["batch_size"] == 1 - ) - n_iteration = n_iter_per_epoch * self.options["train_epochs"] - - # Put dropout layers in train mode - self.predictor.encoder.train() - self.predictor.context_encoder.train() - self.predictor.classifier.train() - - # Initialize optimizers - print("Building optimizers...") - encoder_optimizer = optim.Adam( - self.predictor.encoder.parameters(), lr=self.options["learning_rate"] - ) - context_encoder_optimizer = optim.Adam( - self.predictor.context_encoder.parameters(), lr=self.options["learning_rate"] - ) - attack_clf_optimizer = optim.Adam( - self.predictor.classifier.parameters(), lr=self.options["learning_rate"] - ) - - # Run training iterations, validating after every epoch - print("Starting Training!") - print("Will train for {} iterations".format(n_iteration)) - self._train_iters( - train_pairs, - val_pairs, - self.predictor.encoder, - self.predictor.context_encoder, - self.predictor.classifier, - encoder_optimizer, - context_encoder_optimizer, - attack_clf_optimizer, - self.predictor.encoder.embedding, - n_iteration, - n_iter_per_epoch, - ) - - def forecast(self, id_to_context_reply_label): - """ - Compute forecasts and forecast scores for the given dictionary of utterance id to (context, reply) pairs. Return the values in a DataFrame. - - :param id_to_context_reply_label: dict mapping utterance id to (context, reply, label) - :return: a pandas DataFrame - """ - dataset = [ - (context, reply, label, id_) - for id_, (context, reply, label) in id_to_context_reply_label.items() - ] - return self._evaluate_dataset(self.predictor, dataset) + super().__init__() \ No newline at end of file