From 1b1a3857eb4e44174ff6445879592e7fa6f944d2 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 30 Jul 2024 20:25:33 -0700 Subject: [PATCH] New Forecaster framework (#217) Co-authored-by: seanzhangkx8 <106214464+seanzhangkx8@users.noreply.github.com> Checked through all implementation and documentation, run demo and got expected results! --- convokit/convokitConfig.py | 5 + convokit/forecaster/CRAFT/CRAFTNN.py | 241 --- convokit/forecaster/CRAFT/__init__.py | 5 +- .../CRAFT/{CRAFTUtil.py => data.py} | 81 +- convokit/forecaster/CRAFT/model.py | 228 +++ convokit/forecaster/CRAFT/runners.py | 434 +++++ convokit/forecaster/CRAFTModel.py | 667 +++----- convokit/forecaster/forecaster.py | 368 +++-- convokit/forecaster/forecasterModel.py | 44 +- convokit/util.py | 32 +- docs/source/config.rst | 9 +- docs/source/forecaster.rst | 31 +- .../forecaster/CRAFT Forecaster demo.ipynb | 1423 +++++++++++++++++ 13 files changed, 2648 insertions(+), 920 deletions(-) delete mode 100644 convokit/forecaster/CRAFT/CRAFTNN.py rename convokit/forecaster/CRAFT/{CRAFTUtil.py => data.py} (80%) create mode 100644 convokit/forecaster/CRAFT/model.py create mode 100644 convokit/forecaster/CRAFT/runners.py create mode 100644 examples/forecaster/CRAFT Forecaster demo.ipynb diff --git a/convokit/convokitConfig.py b/convokit/convokitConfig.py index 37f3ac37..aed12f06 100644 --- a/convokit/convokitConfig.py +++ b/convokit/convokitConfig.py @@ -7,6 +7,7 @@ "# Default Backend Parameters\n" "db_host: localhost:27017\n" "data_directory: ~/.convokit/saved-corpora\n" + "model_directory: ~/.convokit/saved-models\n" "default_backend: mem" ) @@ -51,6 +52,10 @@ def db_host(self): def data_directory(self): return self.config_contents.get("data_directory", "~/.convokit/saved-corpora") + @property + def model_directory(self): + return self.config_contents.get("model_directory", "~/.convokit/saved-models") + @property def default_backend(self): return self._get_config_from_env_or_file("default_backend", "mem") diff --git a/convokit/forecaster/CRAFT/CRAFTNN.py b/convokit/forecaster/CRAFT/CRAFTNN.py deleted file mode 100644 index 71fa7a3f..00000000 --- a/convokit/forecaster/CRAFT/CRAFTNN.py +++ /dev/null @@ -1,241 +0,0 @@ -try: - import torch -except (ModuleNotFoundError, ImportError) as e: - raise ModuleNotFoundError( - "torch is not currently installed. Run 'pip install convokit[craft]' if you would like to use the CRAFT model." - ) - -from torch import nn -import os -from urllib.request import urlretrieve -from .CRAFTUtil import CONSTANTS - - -class EncoderRNN(nn.Module): - """ - This module represents the utterance encoder component of CRAFT, - responsible for creating vector representations of utterances - """ - - def __init__(self, hidden_size, embedding, n_layers=1, dropout=0): - super(EncoderRNN, self).__init__() - self.n_layers = n_layers - self.hidden_size = hidden_size - self.embedding = embedding - - # Initialize GRU; the input_size and hidden_size params are both set to 'hidden_size' - # because our input size is a word embedding with number of features == hidden_size - self.gru = nn.GRU( - hidden_size, - hidden_size, - n_layers, - dropout=(0 if n_layers == 1 else dropout), - bidirectional=True, - ) - - def forward(self, input_seq, input_lengths, hidden=None): - # Convert word indexes to embeddings - embedded = self.embedding(input_seq) - # Pack padded batch of sequences for RNN module - packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths) - # Forward pass through GRU - outputs, hidden = self.gru(packed, hidden) - # Unpack padding - outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs) - # Sum bidirectional GRU outputs - outputs = outputs[:, :, : self.hidden_size] + outputs[:, :, self.hidden_size :] - # Return output and final hidden state - return outputs, hidden - - -class ContextEncoderRNN(nn.Module): - """This module represents the context encoder component of CRAFT, responsible for creating an order-sensitive vector representation of conversation context""" - - def __init__(self, hidden_size, n_layers=1, dropout=0): - super(ContextEncoderRNN, self).__init__() - self.n_layers = n_layers - self.hidden_size = hidden_size - - # only unidirectional GRU for context encoding - self.gru = nn.GRU( - hidden_size, - hidden_size, - n_layers, - dropout=(0 if n_layers == 1 else dropout), - bidirectional=False, - ) - - def forward(self, input_seq, input_lengths, hidden=None): - # Pack padded batch of sequences for RNN module - packed = torch.nn.utils.rnn.pack_padded_sequence(input_seq, input_lengths) - # Forward pass through GRU - outputs, hidden = self.gru(packed, hidden) - # Unpack padding - outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs) - # return output and final hidden state - return outputs, hidden - - -class SingleTargetClf(nn.Module): - """This module represents the CRAFT classifier head, which takes the context encoding and uses it to make a forecast""" - - def __init__(self, hidden_size, dropout=0.1): - super(SingleTargetClf, self).__init__() - - self.hidden_size = hidden_size - - # initialize classifier - self.layer1 = nn.Linear(hidden_size, hidden_size) - self.layer1_act = nn.LeakyReLU() - self.layer2 = nn.Linear(hidden_size, hidden_size // 2) - self.layer2_act = nn.LeakyReLU() - self.clf = nn.Linear(hidden_size // 2, 1) - self.dropout = nn.Dropout(p=dropout) - - def forward(self, encoder_outputs, encoder_input_lengths): - # from stackoverflow (https://stackoverflow.com/questions/50856936/taking-the-last-state-from-bilstm-bigru-in-pytorch) - # First we unsqueeze seqlengths two times so it has the same number of - # of dimensions as output_forward - # (batch_size) -> (1, batch_size, 1) - lengths = encoder_input_lengths.unsqueeze(0).unsqueeze(2) - # Then we expand it accordingly - # (1, batch_size, 1) -> (1, batch_size, hidden_size) - lengths = lengths.expand((1, -1, encoder_outputs.size(2))) - - # take only the last state of the encoder for each batch - last_outputs = torch.gather(encoder_outputs, 0, lengths - 1).squeeze(dim=0) - # forward pass through hidden layers - layer1_out = self.layer1_act(self.layer1(self.dropout(last_outputs))) - layer2_out = self.layer2_act(self.layer2(self.dropout(layer1_out))) - # compute and return logits - logits = self.clf(self.dropout(layer2_out)).squeeze(dim=1) - return logits - - -class Predictor(nn.Module): - """This helper module encapsulates the CRAFT pipeline, defining the logic of passing an input through each consecutive sub-module.""" - - def __init__(self, encoder, context_encoder, classifier): - super(Predictor, self).__init__() - self.encoder = encoder - self.context_encoder = context_encoder - self.classifier = classifier - - def forward( - self, - input_batch, - dialog_lengths, - dialog_lengths_list, - utt_lengths, - batch_indices, - dialog_indices, - batch_size, - max_length, - ): - # Forward input through encoder model - _, utt_encoder_hidden = self.encoder(input_batch, utt_lengths) - - # Convert utterance encoder final states to batched dialogs for use by context encoder - context_encoder_input = makeContextEncoderInput( - utt_encoder_hidden, dialog_lengths_list, batch_size, batch_indices, dialog_indices - ) - - # Forward pass through context encoder - context_encoder_outputs, context_encoder_hidden = self.context_encoder( - context_encoder_input, dialog_lengths - ) - - # Forward pass through classifier to get prediction logits - logits = self.classifier(context_encoder_outputs, dialog_lengths) - - # Apply sigmoid activation - predictions = torch.sigmoid(logits) - return predictions - - -def makeContextEncoderInput( - utt_encoder_hidden, dialog_lengths, batch_size, batch_indices, dialog_indices -): - """The utterance encoder takes in utterances in combined batches, with no knowledge of which ones go where in which conversation. - Its output is therefore also unordered. We correct this by using the information computed during tensor conversion to regroup - the utterance vectors into their proper conversational order.""" - # first, sum the forward and backward encoder states - utt_encoder_summed = utt_encoder_hidden[-2, :, :] + utt_encoder_hidden[-1, :, :] - # we now have hidden state of shape [utterance_batch_size, hidden_size] - # split it into a list of [hidden_size,] x utterance_batch_size - last_states = [t.squeeze() for t in utt_encoder_summed.split(1, dim=0)] - - # create a placeholder list of tensors to group the states by source dialog - states_dialog_batched = [[None for _ in range(dialog_lengths[i])] for i in range(batch_size)] - - # group the states by source dialog - for hidden_state, batch_idx, dialog_idx in zip(last_states, batch_indices, dialog_indices): - states_dialog_batched[batch_idx][dialog_idx] = hidden_state - - # stack each dialog into a tensor of shape [dialog_length, hidden_size] - states_dialog_batched = [torch.stack(d) for d in states_dialog_batched] - - # finally, condense all the dialog tensors into a single zero-padded tensor - # of shape [max_dialog_length, batch_size, hidden_size] - return torch.nn.utils.rnn.pad_sequence(states_dialog_batched) - - -def initialize_model( - custom_model_path, - voc, - device, - device_type: str, - hidden_size, - encoder_n_layers, - dropout, - context_encoder_n_layers, -): - print("Loading saved parameters...") - if custom_model_path is None: - if not os.path.isfile("model.tar"): - print("\tDownloading trained CRAFT...") - urlretrieve(CONSTANTS["MODEL_URL"], "model.tar") - print("\t...Done!") - custom_model_path = "model.tar" - # If running in a non-GPU environment, you need to tell PyTorch to convert the parameters to CPU tensor format. - # To do so, replace the previous line with the following: - if device_type == "cpu": - checkpoint = torch.load(custom_model_path, map_location=torch.device("cpu")) - elif device_type == "cuda": - checkpoint = torch.load(custom_model_path) - encoder_sd = checkpoint["en"] - context_sd = checkpoint["ctx"] - if "atk_clf" in checkpoint: - attack_clf_sd = checkpoint["atk_clf"] - - embedding_sd = checkpoint["embedding"] - voc.__dict__ = checkpoint["voc_dict"] - - print("Building encoders, decoder, and classifier...") - # Initialize word embeddings - embedding = nn.Embedding(voc.num_words, hidden_size) - embedding.load_state_dict(embedding_sd) - # Initialize utterance and context encoders - encoder: EncoderRNN = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout) - context_encoder: ContextEncoderRNN = ContextEncoderRNN( - hidden_size, context_encoder_n_layers, dropout - ) - encoder.load_state_dict(encoder_sd) - context_encoder.load_state_dict(context_sd) - # Initialize classifier - attack_clf: SingleTargetClf = SingleTargetClf(hidden_size, dropout) - if "atk_clf" in checkpoint: - attack_clf.load_state_dict(attack_clf_sd) - # Use appropriate device - encoder = encoder.to(device) - context_encoder = context_encoder.to(device) - attack_clf = attack_clf.to(device) - print("Models built and ready to go!") - - # Set dropout layers to eval mode - encoder.eval() - context_encoder.eval() - attack_clf.eval() - - # Initialize the pipeline - return Predictor(encoder, context_encoder, attack_clf) diff --git a/convokit/forecaster/CRAFT/__init__.py b/convokit/forecaster/CRAFT/__init__.py index ce3c1739..d79367ea 100644 --- a/convokit/forecaster/CRAFT/__init__.py +++ b/convokit/forecaster/CRAFT/__init__.py @@ -1,2 +1,3 @@ -from .CRAFTUtil import * -from .CRAFTNN import * +from .data import * +from .model import * +from .runners import * diff --git a/convokit/forecaster/CRAFT/CRAFTUtil.py b/convokit/forecaster/CRAFT/data.py similarity index 80% rename from convokit/forecaster/CRAFT/CRAFTUtil.py rename to convokit/forecaster/CRAFT/data.py index 28dd2d6c..15f3e435 100644 --- a/convokit/forecaster/CRAFT/CRAFTUtil.py +++ b/convokit/forecaster/CRAFT/data.py @@ -3,6 +3,7 @@ import nltk import random import itertools +import json try: import torch @@ -13,36 +14,19 @@ from typing import List, Tuple -CONSTANTS = { - "PAD_token": 0, - "SOS_token": 1, - "EOS_token": 2, - "UNK_token": 3, - "WORD2INDEX_URL": "http://zissou.infosci.cornell.edu/convokit/models/craft_wikiconv/word2index.json", - "INDEX2WORD_URL": "http://zissou.infosci.cornell.edu/convokit/models/craft_wikiconv/index2word.json", - "MODEL_URL": "http://zissou.infosci.cornell.edu/convokit/models/craft_wikiconv/craft_full.tar", -} - # Default word tokens PAD_token = 0 # Used for padding short sentences SOS_token = 1 # Start-of-sentence token EOS_token = 2 # End-of-sentence token UNK_token = 3 # Unknown word token -# model download paths -WORD2INDEX_URL = "http://zissou.infosci.cornell.edu/convokit/models/craft_wikiconv/word2index.json" -INDEX2WORD_URL = "http://zissou.infosci.cornell.edu/convokit/models/craft_wikiconv/index2word.json" -MODEL_URL = "http://zissou.infosci.cornell.edu/convokit/models/craft_wikiconv/craft_full.tar" - class Voc: - """A class for representing the vocabulary used by a CRAFT model""" - def __init__(self, name, word2index=None, index2word=None): self.name = name self.trimmed = ( False if not word2index else True - ) # if a precomputed vocab is specified assume the speaker wants to use it as-is + ) # if a precomputed vocab is specified assume the user wants to use it as-is self.word2index = word2index if word2index else {"UNK": UNK_token} self.word2count = {} self.index2word = ( @@ -93,20 +77,6 @@ def trim(self, min_count): self.addWord(word) -# Create a Voc object from precomputed data structures -def loadPrecomputedVoc(corpus_name, word2index_url, index2word_url): - # load the word-to-index lookup map - r = requests.get(word2index_url) - word2index = r.json() - # load the index-to-word lookup map - r = requests.get(index2word_url) - index2word = r.json() - return Voc(corpus_name, word2index, index2word) - - -# Helper functions for preprocessing and tokenizing text - - # Turn a Unicode string to plain ASCII, thanks to # https://stackoverflow.com/a/518232/2809427 def unicodeToAscii(s): @@ -114,7 +84,7 @@ def unicodeToAscii(s): # Tokenize the string using NLTK -def craft_tokenize(voc, text): +def tokenize(voc, text): tokenizer = nltk.tokenize.RegexpTokenizer(pattern=r"\w+|[^\w\s]") # simplify the problem space by considering only ASCII data cleaned_text = unicodeToAscii(text.lower()) @@ -124,13 +94,38 @@ def craft_tokenize(voc, text): return [] tokens = tokenizer.tokenize(cleaned_text) + + # replace out-of-vocabulary tokens for i in range(len(tokens)): if tokens[i] not in voc.word2index: tokens[i] = "UNK" + return tokens -# Helper functions for turning dialog and text sequences into tensors, and manipulating those tensors +# Create a Voc object from precomputed data structures +def loadPrecomputedVoc(corpus_name, word2index_path, index2word_path): + with open(word2index_path) as fp: + word2index = json.load(fp) + with open(index2word_path) as fp: + index2word = json.load(fp) + return Voc(corpus_name, word2index, index2word) + + +# Given a context utterance list from Forecaster, preprocess each utterance's text by tokenizing and truncating. +# Returns the processed dialog entry where text has been replaced with a list of +# tokens, each no longer than MAX_LENGTH - 1 (to leave space for the EOS token) +def processContext(voc, context, is_attack): + processed = [] + for utterance in context.context: + # since the iterative nature of Forecaster may lead us to see the same utterance + # multiple times, we'll cache the tokenized form of the utterance as metadata + # and look it up if it already exists + if "craft_tokens" not in utterance.meta: + utterance.meta["craft_tokens"] = tokenize(voc, utterance.text) + tokens = utterance.meta["craft_tokens"] + processed.append({"tokens": tokens, "is_attack": is_attack, "id": utterance.id}) + return processed def indexesFromSentence(voc, sentence): @@ -141,15 +136,15 @@ def zeroPadding(l, fillvalue=PAD_token): return list(itertools.zip_longest(*l, fillvalue=fillvalue)) -def binaryMatrix(l): +def binaryMatrix(l, value=PAD_token): m = [] for i, seq in enumerate(l): m.append([]) for token in seq: if token == PAD_token: - m[i].append(0) + m[i].append(False) else: - m[i].append(1) + m[i].append(True) return m @@ -189,13 +184,13 @@ def outputVar(l, voc): max_target_len = max([len(indexes) for indexes in indexes_batch]) padList = zeroPadding(indexes_batch) mask = binaryMatrix(padList) - mask = torch.ByteTensor(mask) + mask = torch.BoolTensor(mask) padVar = torch.LongTensor(padList) return padVar, mask, max_target_len # Returns all items for a given batch of pairs -def batch2TrainData(voc, pair_batch: List[Tuple], already_sorted=False): +def batch2TrainData(voc, pair_batch, already_sorted=False): if not already_sorted: pair_batch.sort(key=lambda x: len(x[0]), reverse=True) input_batch, output_batch, label_batch, id_batch = [], [], [], [] @@ -203,10 +198,7 @@ def batch2TrainData(voc, pair_batch: List[Tuple], already_sorted=False): input_batch.append(pair[0]) output_batch.append(pair[1]) label_batch.append(pair[2]) - if len(pair) > 3: - id_batch.append(pair[3]) - else: - id_batch.append(None) + id_batch.append(pair[3]) dialog_lengths = torch.tensor([len(x) for x in input_batch]) input_utterances, batch_indices, dialog_indices = dialogBatch2UtteranceBatch(input_batch) inp, utt_lengths = inputVar(input_utterances, voc) @@ -242,7 +234,8 @@ def batchIterator(voc, source_data, batch_size, shuffle=True): batch.sort(key=lambda x: len(x[0]), reverse=True) # for analysis purposes, get the source dialogs and labels associated with this batch batch_dialogs = [x[0] for x in batch] + batch_labels = [x[2] for x in batch] # convert batch to tensors batch_tensors = batch2TrainData(voc, batch, already_sorted=True) - yield (batch_tensors, batch_dialogs, true_batch_size) + yield (batch_tensors, batch_dialogs, batch_labels, true_batch_size) cur_idx += batch_size diff --git a/convokit/forecaster/CRAFT/model.py b/convokit/forecaster/CRAFT/model.py new file mode 100644 index 00000000..2a4fad06 --- /dev/null +++ b/convokit/forecaster/CRAFT/model.py @@ -0,0 +1,228 @@ +try: + import torch +except (ModuleNotFoundError, ImportError) as e: + raise ModuleNotFoundError( + "torch is not currently installed. Run 'pip install convokit[craft]' if you would like to use the CRAFT model." + ) + +from torch import nn +import torch.nn.functional as F + + +class EncoderRNN(nn.Module): + def __init__(self, hidden_size, embedding, n_layers=1, dropout=0): + super(EncoderRNN, self).__init__() + self.n_layers = n_layers + self.hidden_size = hidden_size + self.embedding = embedding + + # Initialize GRU; the input_size and hidden_size params are both set to 'hidden_size' + # because our input size is a word embedding with number of features == hidden_size + self.gru = nn.GRU( + hidden_size, + hidden_size, + n_layers, + dropout=(0 if n_layers == 1 else dropout), + bidirectional=True, + ) + + def forward(self, input_seq, input_lengths, hidden=None): + # Convert word indexes to embeddings + embedded = self.embedding(input_seq) + # Pack padded batch of sequences for RNN module + packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths.cpu()) + # Forward pass through GRU + outputs, hidden = self.gru(packed, hidden) + # Unpack padding + outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs) + # Sum bidirectional GRU outputs + outputs = outputs[:, :, : self.hidden_size] + outputs[:, :, self.hidden_size :] + # Return output and final hidden state + return outputs, hidden + + +class ContextEncoderRNN(nn.Module): + def __init__(self, hidden_size, n_layers=1, dropout=0): + super(ContextEncoderRNN, self).__init__() + self.n_layers = n_layers + self.hidden_size = hidden_size + + # only unidirectional GRU for context encoding + self.gru = nn.GRU( + hidden_size, + hidden_size, + n_layers, + dropout=(0 if n_layers == 1 else dropout), + bidirectional=False, + ) + + def forward(self, input_seq, input_lengths, hidden=None): + # Pack padded batch of sequences for RNN module + packed = torch.nn.utils.rnn.pack_padded_sequence(input_seq, input_lengths.cpu()) + # Forward pass through GRU + outputs, hidden = self.gru(packed, hidden) + # Unpack padding + outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs) + # return output and final hidden state + return outputs, hidden + + +# Luong attention layer +class Attn(torch.nn.Module): + def __init__(self, method, hidden_size): + super(Attn, self).__init__() + self.method = method + if self.method not in ["dot", "general", "concat"]: + raise ValueError(self.method, "is not an appropriate attention method.") + self.hidden_size = hidden_size + if self.method == "general": + self.attn = torch.nn.Linear(self.hidden_size, hidden_size) + elif self.method == "concat": + self.attn = torch.nn.Linear(self.hidden_size * 2, hidden_size) + self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size)) + + def dot_score(self, hidden, encoder_output): + return torch.sum(hidden * encoder_output, dim=2) + + def general_score(self, hidden, encoder_output): + energy = self.attn(encoder_output) + return torch.sum(hidden * energy, dim=2) + + def concat_score(self, hidden, encoder_output): + energy = self.attn( + torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), 2) + ).tanh() + return torch.sum(self.v * energy, dim=2) + + def forward(self, hidden, encoder_outputs): + # Calculate the attention weights (energies) based on the given method + if self.method == "general": + attn_energies = self.general_score(hidden, encoder_outputs) + elif self.method == "concat": + attn_energies = self.concat_score(hidden, encoder_outputs) + elif self.method == "dot": + attn_energies = self.dot_score(hidden, encoder_outputs) + + # Transpose max_length and batch_size dimensions + attn_energies = attn_energies.t() + + # Return the softmax normalized probability scores (with added dimension) + return F.softmax(attn_energies, dim=1).unsqueeze(1) + + +class LuongAttnDecoderRNN(nn.Module): + def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1): + super(LuongAttnDecoderRNN, self).__init__() + + # Keep for reference + self.attn_model = attn_model + self.hidden_size = hidden_size + self.output_size = output_size + self.n_layers = n_layers + self.dropout = dropout + + # Define layers + self.embedding = embedding + self.embedding_dropout = nn.Dropout(dropout) + self.gru = nn.GRU( + hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout) + ) + self.concat = nn.Linear(hidden_size * 2, hidden_size) + self.out = nn.Linear(hidden_size, output_size) + + self.attn = Attn(attn_model, hidden_size) + + def forward(self, input_step, last_hidden, encoder_outputs): + # Note: we run this one step (word) at a time + # Get embedding of current input word + embedded = self.embedding(input_step) + embedded = self.embedding_dropout(embedded) + # Forward through unidirectional GRU + rnn_output, hidden = self.gru(embedded, last_hidden) + # Calculate attention weights from the current GRU output + attn_weights = self.attn(rnn_output, encoder_outputs) + # Multiply attention weights to encoder outputs to get new "weighted sum" context vector + context = attn_weights.bmm(encoder_outputs.transpose(0, 1)) + # Concatenate weighted context vector and GRU output using Luong eq. 5 + rnn_output = rnn_output.squeeze(0) + context = context.squeeze(1) + concat_input = torch.cat((rnn_output, context), 1) + concat_output = torch.tanh(self.concat(concat_input)) + # Predict next word using Luong eq. 6 + output = self.out(concat_output) + output = F.softmax(output, dim=1) + # Return output and final hidden state + return output, hidden + + +class AttnSingleTargetClf(nn.Module): + def __init__(self, hidden_size, dropout=0.1): + super(AttnSingleTargetClf, self).__init__() + + self.hidden_size = hidden_size + + # initialize attention + self.attn = nn.Linear(hidden_size, 1) + + # initialize classifier + self.layer1 = nn.Linear(hidden_size, hidden_size) + self.layer1_act = nn.LeakyReLU() + self.layer2 = nn.Linear(hidden_size, hidden_size // 2) + self.layer2_act = nn.LeakyReLU() + self.clf = nn.Linear(hidden_size // 2, 1) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, encoder_outputs): + # compute attention weights + self.attn_weights = self.attn(encoder_outputs).squeeze().transpose(0, 1) + # softmax normalize weights + self.attn_weights = F.softmax(self.attn_weights, dim=1).unsqueeze(1) + # transpose context encoder outputs so we can apply batch matrix multiply + encoder_outputs_transp = encoder_outputs.transpose(0, 1) + # compute weighted context vector + context_vec = torch.bmm(self.attn_weights, encoder_outputs_transp).squeeze() + # forward pass through hidden layers + layer1_out = self.layer1_act(self.layer1(self.dropout(context_vec))) + layer2_out = self.layer2_act(self.layer2(self.dropout(layer1_out))) + # compute and return logits + logits = self.clf(self.dropout(layer2_out)).squeeze() + return logits + + +class SingleTargetClf(nn.Module): + """ + Single-target classifier head with no attention layer (predicts only from + the last state vector of the RNN) + """ + + def __init__(self, hidden_size, dropout=0.1): + super(SingleTargetClf, self).__init__() + + self.hidden_size = hidden_size + + # initialize classifier + self.layer1 = nn.Linear(hidden_size, hidden_size) + self.layer1_act = nn.LeakyReLU() + self.layer2 = nn.Linear(hidden_size, hidden_size // 2) + self.layer2_act = nn.LeakyReLU() + self.clf = nn.Linear(hidden_size // 2, 1) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, encoder_outputs, encoder_input_lengths): + # from stackoverflow (https://stackoverflow.com/questions/50856936/taking-the-last-state-from-bilstm-bigru-in-pytorch) + # First we unsqueeze seqlengths two times so it has the same number of + # of dimensions as output_forward + # (batch_size) -> (1, batch_size, 1) + lengths = encoder_input_lengths.unsqueeze(0).unsqueeze(2) + # Then we expand it accordingly + # (1, batch_size, 1) -> (1, batch_size, hidden_size) + lengths = lengths.expand((1, -1, encoder_outputs.size(2))) + + # take only the last state of the encoder for each batch + last_outputs = torch.gather(encoder_outputs, 0, lengths - 1).squeeze() + # forward pass through hidden layers + layer1_out = self.layer1_act(self.layer1(self.dropout(last_outputs))) + layer2_out = self.layer2_act(self.layer2(self.dropout(layer1_out))) + # compute and return logits + logits = self.clf(self.dropout(layer2_out)).squeeze() + return logits diff --git a/convokit/forecaster/CRAFT/runners.py b/convokit/forecaster/CRAFT/runners.py new file mode 100644 index 00000000..0d27157f --- /dev/null +++ b/convokit/forecaster/CRAFT/runners.py @@ -0,0 +1,434 @@ +try: + import torch +except (ModuleNotFoundError, ImportError) as e: + raise ModuleNotFoundError( + "torch is not currently installed. Run 'pip install convokit[craft]' if you would like to use the CRAFT model." + ) + +import numpy as np +import pandas as pd +import torch.nn.functional as F +from torch import nn +from copy import deepcopy + + +class Predictor(nn.Module): + """This helper module encapsulates the CRAFT pipeline, defining the logic of passing an input through each consecutive sub-module.""" + + def __init__(self, encoder, context_encoder, classifier): + super(Predictor, self).__init__() + self.encoder = encoder + self.context_encoder = context_encoder + self.classifier = classifier + + def forward( + self, + input_batch, + dialog_lengths, + dialog_lengths_list, + utt_lengths, + batch_indices, + dialog_indices, + batch_size, + max_length, + ): + # Forward input through encoder model + _, utt_encoder_hidden = self.encoder(input_batch, utt_lengths) + + # Convert utterance encoder final states to batched dialogs for use by context encoder + context_encoder_input = makeContextEncoderInput( + utt_encoder_hidden, dialog_lengths_list, batch_size, batch_indices, dialog_indices + ) + + # Forward pass through context encoder + context_encoder_outputs, context_encoder_hidden = self.context_encoder( + context_encoder_input, dialog_lengths + ) + + # Forward pass through classifier to get prediction logits + logits = self.classifier(context_encoder_outputs, dialog_lengths) + + # Apply sigmoid activation + predictions = F.sigmoid(logits) + return predictions + + +def makeContextEncoderInput( + utt_encoder_hidden, dialog_lengths, batch_size, batch_indices, dialog_indices +): + # first, sum the forward and backward encoder states + utt_encoder_summed = utt_encoder_hidden[-2, :, :] + utt_encoder_hidden[-1, :, :] + # we now have hidden state of shape [utterance_batch_size, hidden_size] + # split it into a list of [hidden_size,] x utterance_batch_size + last_states = [t.squeeze() for t in utt_encoder_summed.split(1, dim=0)] + + # create a placeholder list of tensors to group the states by source dialog + states_dialog_batched = [[None for _ in range(dialog_lengths[i])] for i in range(batch_size)] + + # group the states by source dialog + for hidden_state, batch_idx, dialog_idx in zip(last_states, batch_indices, dialog_indices): + states_dialog_batched[batch_idx][dialog_idx] = hidden_state + + # stack each dialog into a tensor of shape [dialog_length, hidden_size] + states_dialog_batched = [torch.stack(d) for d in states_dialog_batched] + + # finally, condense all the dialog tensors into a single zero-padded tensor + # of shape [max_dialog_length, batch_size, hidden_size] + return torch.nn.utils.rnn.pad_sequence(states_dialog_batched) + + +def train( + input_variable, + dialog_lengths, + dialog_lengths_list, + utt_lengths, + batch_indices, + dialog_indices, + labels, # input/output arguments + encoder, + context_encoder, + attack_clf, # network arguments + encoder_optimizer, + context_encoder_optimizer, + attack_clf_optimizer, # optimization arguments + batch_size, + clip, + device, +): # misc arguments + # Zero gradients + encoder_optimizer.zero_grad() + context_encoder_optimizer.zero_grad() + attack_clf_optimizer.zero_grad() + + # Set device options + input_variable = input_variable.to(device) + dialog_lengths = dialog_lengths.to(device) + utt_lengths = utt_lengths.to(device) + labels = labels.to(device) + + # Forward pass through utterance encoder + _, utt_encoder_hidden = encoder(input_variable, utt_lengths) + + # Convert utterance encoder final states to batched dialogs for use by context encoder + context_encoder_input = makeContextEncoderInput( + utt_encoder_hidden, dialog_lengths_list, batch_size, batch_indices, dialog_indices + ) + + # Forward pass through context encoder + context_encoder_outputs, _ = context_encoder(context_encoder_input, dialog_lengths) + + # Forward pass through classifier to get prediction logits + logits = attack_clf(context_encoder_outputs, dialog_lengths) + + # Calculate loss + loss = F.binary_cross_entropy_with_logits(logits, labels) + + # Perform backpropatation + loss.backward() + + # Clip gradients: gradients are modified in place + _ = torch.nn.utils.clip_grad_norm_(encoder.parameters(), clip) + _ = torch.nn.utils.clip_grad_norm_(context_encoder.parameters(), clip) + _ = torch.nn.utils.clip_grad_norm_(attack_clf.parameters(), clip) + + # Adjust model weights + encoder_optimizer.step() + context_encoder_optimizer.step() + attack_clf_optimizer.step() + + return loss.item() + + +def evaluateBatch( + encoder, + context_encoder, + predictor, + voc, + input_batch, + dialog_lengths, + dialog_lengths_list, + utt_lengths, + batch_indices, + dialog_indices, + batch_size, + device, + max_length, + threshold=0.5, +): + # Set device options + input_batch = input_batch.to(device) + dialog_lengths = dialog_lengths.to(device) + utt_lengths = utt_lengths.to(device) + # Predict future attack using predictor + scores = predictor( + input_batch, + dialog_lengths, + dialog_lengths_list, + utt_lengths, + batch_indices, + dialog_indices, + batch_size, + max_length, + ) + predictions = (scores > threshold).float() + return predictions, scores + + +def validate( + dataset, + encoder, + context_encoder, + predictor, + voc, + batch_size, + device, + max_length, + batch_iterator_func, +): + # create a batch iterator for the given data + batch_iterator = batch_iterator_func(voc, dataset, batch_size, shuffle=False) + # find out how many iterations we will need to cover the whole dataset + n_iters = len(dataset) // batch_size + int(len(dataset) % batch_size > 0) + # containers for full prediction results so we can compute accuracy at the end + all_preds = [] + all_labels = [] + for iteration in range(1, n_iters + 1): + batch, batch_dialogs, _, true_batch_size = next(batch_iterator) + # Extract fields from batch + ( + input_variable, + dialog_lengths, + utt_lengths, + batch_indices, + dialog_indices, + labels, + convo_ids, + target_variable, + mask, + max_target_len, + ) = batch + dialog_lengths_list = [len(x) for x in batch_dialogs] + # run the model + predictions, scores = evaluateBatch( + encoder, + context_encoder, + predictor, + voc, + input_variable, + dialog_lengths, + dialog_lengths_list, + utt_lengths, + batch_indices, + dialog_indices, + true_batch_size, + device, + max_length, + ) + # aggregate results for computing accuracy at the end + all_preds += [p.item() for p in predictions] + all_labels += [l.item() for l in labels] + print( + "Iteration: {}; Percent complete: {:.1f}%".format(iteration, iteration / n_iters * 100) + ) + + # compute and return the accuracy + return (np.asarray(all_preds) == np.asarray(all_labels)).mean() + + +def trainIters( + voc, + pairs, + val_pairs, + encoder, + context_encoder, + attack_clf, + encoder_optimizer, + context_encoder_optimizer, + attack_clf_optimizer, + embedding, + n_iteration, + batch_size, + print_every, + validate_every, + clip, + device, + max_length, + batch_iterator_func, +): + # create a batch iterator for training data + batch_iterator = batch_iterator_func(voc, pairs, batch_size) + + # Initializations + print("Initializing ...") + start_iteration = 1 + print_loss = 0 + + # Training loop + print("Training...") + # keep track of best validation accuracy - only save when we have a model that beats the current best + best_acc = 0 + best_model = None + for iteration in range(start_iteration, n_iteration + 1): + training_batch, training_dialogs, _, true_batch_size = next(batch_iterator) + # Extract fields from batch + ( + input_variable, + dialog_lengths, + utt_lengths, + batch_indices, + dialog_indices, + labels, + _, + target_variable, + mask, + max_target_len, + ) = training_batch + dialog_lengths_list = [len(x) for x in training_dialogs] + + # Run a training iteration with batch + loss = train( + input_variable, + dialog_lengths, + dialog_lengths_list, + utt_lengths, + batch_indices, + dialog_indices, + labels, # input/output arguments + encoder, + context_encoder, + attack_clf, # network arguments + encoder_optimizer, + context_encoder_optimizer, + attack_clf_optimizer, # optimization arguments + true_batch_size, + clip, + device, + ) # misc arguments + print_loss += loss + + # Print progress + if iteration % print_every == 0: + print_loss_avg = print_loss / print_every + print( + "Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}".format( + iteration, iteration / n_iteration * 100, print_loss_avg + ) + ) + print_loss = 0 + + # Evaluate on validation set + if iteration % validate_every == 0: + print("Validating!") + # put the network components into evaluation mode + encoder.eval() + context_encoder.eval() + attack_clf.eval() + + predictor = Predictor(encoder, context_encoder, attack_clf) + accuracy = validate( + val_pairs, + encoder, + context_encoder, + predictor, + voc, + batch_size, + device, + max_length, + batch_iterator_func, + ) + print("Validation set accuracy: {:.2f}%".format(accuracy * 100)) + + # keep track of our best model so far + if accuracy > best_acc: + print("Validation accuracy better than current best; saving model...") + best_acc = accuracy + best_model = deepcopy( + { + "iteration": iteration, + "en": encoder.state_dict(), + "ctx": context_encoder.state_dict(), + "atk_clf": attack_clf.state_dict(), + "en_opt": encoder_optimizer.state_dict(), + "ctx_opt": context_encoder_optimizer.state_dict(), + "atk_clf_opt": attack_clf_optimizer.state_dict(), + "loss": loss, + "voc_dict": voc.__dict__, + "embedding": embedding.state_dict(), + } + ) + + # put the network components back into training mode + encoder.train() + context_encoder.train() + attack_clf.train() + + return best_model + + +def evaluateDataset( + dataset, + encoder, + context_encoder, + predictor, + voc, + batch_size, + device, + max_length, + batch_iterator_func, + threshold, + pred_col_name, + score_col_name, +): + # create a batch iterator for the given data + batch_iterator = batch_iterator_func(voc, dataset, batch_size, shuffle=False) + # find out how many iterations we will need to cover the whole dataset + n_iters = len(dataset) // batch_size + int(len(dataset) % batch_size > 0) + output_df = {"id": [], pred_col_name: [], score_col_name: []} + for iteration in range(1, n_iters + 1): + batch, batch_dialogs, _, true_batch_size = next(batch_iterator) + # Extract fields from batch + ( + input_variable, + dialog_lengths, + utt_lengths, + batch_indices, + dialog_indices, + labels, + convo_ids, + target_variable, + mask, + max_target_len, + ) = batch + dialog_lengths_list = [len(x) for x in batch_dialogs] + # run the model + predictions, scores = evaluateBatch( + encoder, + context_encoder, + predictor, + voc, + input_variable, + dialog_lengths, + dialog_lengths_list, + utt_lengths, + batch_indices, + dialog_indices, + true_batch_size, + device, + max_length, + threshold, + ) + + # format the output as a dataframe (which we can later re-join with the corpus) + for i in range(true_batch_size): + convo_id = convo_ids[i] + pred = predictions[i].item() + score = scores[i].item() + output_df["id"].append(convo_id) + output_df[pred_col_name].append(pred) + output_df[score_col_name].append(score) + + print( + "Iteration: {}; Percent complete: {:.1f}%".format(iteration, iteration / n_iters * 100) + ) + + return pd.DataFrame(output_df).set_index("id") diff --git a/convokit/forecaster/CRAFTModel.py b/convokit/forecaster/CRAFTModel.py index 99214454..b0937e7f 100644 --- a/convokit/forecaster/CRAFTModel.py +++ b/convokit/forecaster/CRAFTModel.py @@ -6,497 +6,292 @@ ) import pandas as pd -from convokit.forecaster.CRAFT.CRAFTUtil import loadPrecomputedVoc, batchIterator, CONSTANTS -from .CRAFT.CRAFTNN import initialize_model, makeContextEncoderInput, Predictor +from convokit.forecaster.CRAFT.data import loadPrecomputedVoc, processContext, batchIterator +from convokit import download, warn +from convokit.convokitConfig import ConvoKitConfig +from .CRAFT.model import EncoderRNN, ContextEncoderRNN, SingleTargetClf +from .CRAFT.runners import Predictor, trainIters, evaluateDataset from .forecasterModel import ForecasterModel import numpy as np import torch.nn.functional as F -from torch import optim -from sklearn.model_selection import train_test_split -from typing import Dict +from torch import optim, nn +from typing import Dict, Union import os -default_options = { - "hidden_size": 500, - "encoder_n_layers": 2, - "context_encoder_n_layers": 2, - "decoder_n_layers": 2, +# parameters baked into the model design (because the provided models were saved with these parameters); +# these cannot be changed by the user +HIDDEN_SIZE = 500 +ENCODER_N_LAYERS = 2 +CONTEXT_ENCODER_N_LAYERS = 2 +DECODER_N_LAYERS = 2 +MAX_LENGTH = 80 + +# config dict contains parameters that could, in theory, be adjusted without causing things to crash +DEFAULT_CONFIG = { "dropout": 0.1, "batch_size": 64, "clip": 50.0, "learning_rate": 1e-5, "print_every": 10, - "train_epochs": 30, + "finetune_epochs": 30, "validation_size": 0.2, - "max_length": 80, - "trained_model_output_filepath": "finetuned_model.tar", +} + +MODEL_FILENAME_MAP = { + "craft-wiki-pretrained": "craft_pretrained.tar", + "craft-wiki-finetuned": "craft_full.tar", + "craft-cmv-pretrained": "craft_pretrained.tar", + "craft-cmv-finetuned": "craft_full.tar", +} + +DECISION_THRESHOLDS = { + "craft-wiki-pretrained": 0.570617, + "craft-wiki-finetuned": 0.570617, + "craft-cmv-pretrained": 0.548580, + "craft-cmv-finetuned": 0.548580, } # To understand the separation of concerns for the CRAFT files: -# CRAFT/craftNN.py contains the class implementations needed to initialize the CRAFT Neural Network model -# CRAFT/craftUtil.py contains utility methods for manipulating the data for it to be passed to the CRAFT model +# CRAFT/model.py contains the pytorch modules that comprise the CRAFT neural network +# CRAFT/data.py contains utility methods for manipulating the data for it to be passed to the CRAFT model +# CRAFT/runners.py adapts the scripts for training and inference of a CRAFT model class CRAFTModel(ForecasterModel): """ - CRAFTModel is one of the Forecaster models that can be used with the Forecaster Transformer. - - By default, CRAFTModel will be initialized with default options - - - hidden_size: 500 - - encoder_n_layers: 2 - - context_encoder_n_layers: 2 - - decoder_n_layers: 2 - - dropout: 0.1 - - batch_size (batch size for computation, i.e. how many (context, reply, id) tuples to use per batch of evaluation): 64 - - clip: 50.0 - - learning_rate: 1e-5 - - print_every: 10 - - train_epochs (number of epochs for training): 30 - - validation_size (percentage of training input data to use as validation): 0.2 - - max_length (maximum utterance length in the dataset): 80 - - :param device_type: 'cpu' or 'cuda', default: 'cpu' - :param model_path: filepath to CRAFT model if loading a custom CRAFT model - :param options: configuration options for the neural network: uses default options otherwise. - :param forecast_attribute_name: name of DataFrame column containing predictions, default: "prediction" - :param forecast_prob_attribute_name: name of DataFrame column containing prediction scores, default: "score" + A ConvoKit Forecaster-adherent reimplementation of the CRAFT conversational forecasting model from + the paper "Trouble on the Horizon: Forecasting the Derailment of Online Conversations as they Develop" + (Chang and Danescu-Niculescu-Mizil, 2019). + + Usage note: CRAFT is a neural network model; full end-to-end training of neural networks is considered + outside the scope of ConvoKit, so the ConvoKit CRAFTModel must be initialized with existing weights. + ConvoKit provides weights for the CGA-WIKI and CGA-CMV corpora. If you just want to run a fully-trained + CRAFT model on those corpora (i.e., only transform, no fit), you can use the finetuned weights + (craft-wiki-finetuned and craft-cmv-finetuned, respectively). If you want to take a pretrained model and + finetune it on your own data (i.e., both fit and transform), you can use the pretrained weights + (craft-wiki-pretrained and craft-cmv-pretrained, respectively), which provide trained versions of the + underlying utterance and conversation encoder layers but leave the classification layers at their + random initializations so that they can be fitted to your data. + + :param initial_weights: Specifies where to find the saved model to be loaded to initialize CRAFT. To use ConvoKit's provided models, use "craft-wiki-pretrained" for the model pretrained on Wikipedia data, or "craft-wiki-finetuned" for the model already fine-tuned on CGA-WIKI. Replace "wiki" with "cmv" for the Reddit CMV equivalents. Alternatively, if you have a custom model you want to use, you can pass in the full path to the saved PyTorch checkpoint file. + :param vocab_index2word: File containing the mapping from vocabulary index to raw string tokens. If you are using a provided model, you MUST leave this as the default value of "auto" (other values will be ignored and overridden to "auto"). Conversely, if using a custom model, you CANNOT leave this as "auto" and you must provide a full path to the vocabulary file that you made for your custom model. + :param vocab_word2index: File containing the mapping from raw string tokens to vocabulary index. If you are using a provided model, you MUST leave this as the default value of "auto" (other values will be ignored and overridden to "auto"). Conversely, if using a custom model, you CANNOT leave this as "auto" and you must provide a full path to the vocabulary file that you made for your custom model. + :param decision_threshold: Output probability beyond which a forecast should be considered "positive"/"True". Highly recommended to leave this at auto, which will use published values for the provided models, or 0.5 for custom models. + :param torch_device: "cpu" or "cuda" (for GPUs). If you have access to a GPU it is strongly recommended to set this to "cuda"; the default is "cpu" only for compatibility with non-GPU setups. + :param config: Allows overwriting of CRAFT hyperparameters. Strongly recommended to keep this at default unless you know what you're doing! """ def __init__( self, - device_type: str = "cpu", - model_path: str = None, - options: Dict = None, - forecast_attribute_name: str = "prediction", - forecast_feat_name=None, - forecast_prob_attribute_name: str = "pred_score", - forecast_prob_feat_name=None, + initial_weights: str, + vocab_index2word: str = "auto", + vocab_word2index: str = "auto", + decision_threshold: Union[float, str] = "auto", + torch_device: str = "cpu", + config: dict = DEFAULT_CONFIG, ): - super().__init__( - forecast_attribute_name=forecast_attribute_name, - forecast_feat_name=forecast_feat_name, - forecast_prob_attribute_name=forecast_prob_attribute_name, - forecast_prob_feat_name=forecast_prob_feat_name, - ) - assert device_type in ["cuda", "cpu"] - # device: controls GPU usage: 'cuda' to enable GPU, 'cpu' to run on CPU only. - self.device = torch.device(device_type) - self.device_type = device_type - # voc: the vocabulary object (convokit.forecaster.craftUtil.Voc) used by predictor. - # Used to convert text data into numerical input for CRAFT. - self.voc = loadPrecomputedVoc( - "wikiconv", CONSTANTS["WORD2INDEX_URL"], CONSTANTS["INDEX2WORD_URL"] - ) + super().__init__() + + # load the initial weights and store this as the current model + if initial_weights in MODEL_FILENAME_MAP: + # load ConvoKitConfig in order to look up the model storage path + convokitconfig = ConvoKitConfig() + download_dir = os.path.expanduser(convokitconfig.model_directory) + # download the model and its supporting vocabulary objects + base_path = download(initial_weights, data_dir=download_dir) + model_path = os.path.join(base_path, MODEL_FILENAME_MAP[initial_weights]) + # load the vocab, ensuring that we use the download ones + if vocab_index2word != "auto" or vocab_word2index != "auto": + warn( + f"CRAFTModel was initialized using a ConvoKit-provided model {initial_weights} but a custom vocabulary was specified. This is an unsupported configuration; the custom vocabulary will be ignored and the model-provided vocabulary will be loaded." + ) + self._voc = loadPrecomputedVoc( + initial_weights, + os.path.join(base_path, "word2index.json"), + os.path.join(base_path, "index2word.json"), + ) + else: + # assume that initial_weights is a true path to a local model + model_path = initial_weights + # we don't know the vocab for local models, so the user must manually supply one + if vocab_index2word == "auto" or vocab_word2index == "auto": + raise ValueError( + "CRAFTModel was initialized using a path to a custom model; a custom vocabulary also must be specified for this use case ('auto' is not supported)!" + ) + self._voc = loadPrecomputedVoc( + os.path.basename(initial_weights), vocab_word2index, vocab_index2word + ) + self._model = torch.load(model_path, map_location=torch.device(torch_device)) - if options is None: - self.options = default_options + # either take the decision threshold as given or use a predetermined one (default 0.5 if none can be found) + if type(decision_threshold) == float: + self._decision_threshold = decision_threshold else: - for k, v in default_options.items(): - if k not in options: - options[k] = v - self.options = options - print("Initializing CRAFT model with options:") - print(self.options) - - if model_path is not None: - if not os.path.isfile(model_path) or not model_path.endswith(".tar"): - print("Could not find CRAFT model tar file at: {}".format(model_path)) - model_path = None - self.predictor: Predictor = initialize_model( - model_path, - self.voc, - self.device, - self.device_type, - self.options["hidden_size"], - self.options["encoder_n_layers"], - self.options["dropout"], - self.options["context_encoder_n_layers"], - ) + if decision_threshold != "auto": + raise TypeError("CRAFTModel: decision_threshold must be either a float or 'auto'") + self._decision_threshold = DECISION_THRESHOLDS.get(initial_weights, 0.5) - def _evaluate_batch( - self, - predictor, - input_batch, - dialog_lengths, - dialog_lengths_list, - utt_lengths, - batch_indices, - dialog_indices, - true_batch_size, - ): + self._device = torch.device(torch_device) + self._config = config + + def _context_to_craft_data(self, contexts): """ - Helper for _evaluate_dataset. Runs CRAFT evaluation on a single batch; _evaluate_dataset calls this helper iteratively to get results for the entire dataset. - - :param predictor: the trained CRAFT model to use, provided as a PyTorch Model instance. - :param input_batch: the batch to run CRAFT on (produced by convokit.forecaster.craftUtil.batchIterator, formatted as a batch of utterances) - :param dialog_lengths: how many comments are in each conversation in this batch, as a PyTorch Tensor - :param dialog_lengths_list: same as dialog_lengths, but as a Python List - :param utt_lengths: for each conversation, records the number of tokens in each utterance of the conversation - :param batch_indices: used by CRAFT to reconstruct the original dialog batch from the given utterance batch. Records which dialog each utterance originally came from. - :param dialog_indices: used by CRAFT to reconstruct the original dialog batch from the given utterance batch. Records where in the dialog the utterance originally came from. - :param true_batch_size: number of dialogs in the original dialog batch this utterance batch was generated from. - - :return: per-utterance scores and binarized predictions. + Convert context utterances to a list of token-lists using the model's vocabulary object, + maintaining the original temporal ordering """ - # Set device options - input_batch = input_batch.to(self.device) - dialog_lengths = dialog_lengths.to(self.device) - utt_lengths = utt_lengths.to(self.device) - # Predict future attack using predictor - scores = predictor( - input_batch, - dialog_lengths, - dialog_lengths_list, - utt_lengths, - batch_indices, - dialog_indices, - true_batch_size, - self.options["max_length"], - ) - predictions = (scores > 0.5).float() - return predictions, scores - - def _evaluate_dataset(self, predictor, dataset): + pairs = [] + for context in contexts: + convo = context.current_utterance.get_conversation() + label = self.labeler(convo) + processed_context = processContext(self._voc, context, label) + utt = processed_context[-1]["tokens"][: (MAX_LENGTH - 1)] + context_utts = [u["tokens"][: (MAX_LENGTH - 1)] for u in processed_context] + pairs.append((context_utts, utt, label, context.current_utterance.id)) + return pairs + + def _init_craft(self): """ - Run a trained CRAFT model over an entire dataset in a batched fashion. - - :param predictor: the trained CRAFT model to use, provided as a PyTorch Model instance. - :param dataset: the dataset to evaluate on, formatted as a list of (context, reply, id_of_reply) tuples. - :return: a DataFrame, indexed by utterance ID, of CRAFT scores for each utterance, and the corresponding binary prediction. + Initialize the CRAFT layers using the currently saved checkpoints + (these will either be the initial_weights, or what got saved after fit()) """ - # create a batch iterator for the given data - batch_iterator = batchIterator(self.voc, dataset, self.options["batch_size"], shuffle=False) - # find out how many iterations we will need to cover the whole dataset - n_iters = len(dataset) // self.options["batch_size"] + int( - len(dataset) % self.options["batch_size"] > 0 + print("Loading saved parameters...") + encoder_sd = self._model["en"] + context_sd = self._model["ctx"] + try: + attack_clf_sd = self._model["atk_clf"] + except KeyError: + # this happens if we're loading from a non-finetuned initial weights; the classifier layer still needs training + attack_clf_sd = None + embedding_sd = self._model["embedding"] + self._voc.__dict__ = self._model["voc_dict"] + + print("Building encoders, decoder, and classifier...") + # Initialize word embeddings + embedding = nn.Embedding(self._voc.num_words, HIDDEN_SIZE) + embedding.load_state_dict(embedding_sd) + # Initialize utterance and context encoders + encoder = EncoderRNN(HIDDEN_SIZE, embedding, ENCODER_N_LAYERS, self._config["dropout"]) + context_encoder = ContextEncoderRNN( + HIDDEN_SIZE, CONTEXT_ENCODER_N_LAYERS, self._config["dropout"] ) - output_df = { - "id": [], - self.forecast_attribute_name: [], - self.forecast_prob_attribute_name: [], - } - for iteration in range(1, n_iters + 1): - batch, batch_dialogs, true_batch_size = next(batch_iterator) - # Extract fields from batch - ( - input_variable, - dialog_lengths, - utt_lengths, - batch_indices, - dialog_indices, - labels, - batch_ids, - target_variable, - mask, - max_target_len, - ) = batch - dialog_lengths_list = [len(x) for x in batch_dialogs] - # run the model - predictions, scores = self._evaluate_batch( - predictor, - input_variable, - dialog_lengths, - dialog_lengths_list, - utt_lengths, - batch_indices, - dialog_indices, - true_batch_size, - ) - - # format the output as a dataframe (which we can later re-join with the corpus) - for i in range(true_batch_size): - utt_id = batch_ids[i] - pred = predictions[i].item() - score = scores[i].item() - output_df["id"].append(utt_id) - output_df[self.forecast_attribute_name].append(pred) - output_df[self.forecast_prob_attribute_name].append(score) - - print( - "Iteration: {}; Percent complete: {:.1f}%".format( - iteration, iteration / n_iters * 100 - ) - ) - - return pd.DataFrame(output_df).set_index("id") - - def _train_NN( - self, - input_variable, - dialog_lengths, - dialog_lengths_list, - utt_lengths, - batch_indices, - dialog_indices, - labels, # input/output arguments - encoder, - context_encoder, - attack_clf, # network arguments - encoder_optimizer, - context_encoder_optimizer, - attack_clf_optimizer, # optimization arguments - batch_size, - clip, - ): # misc arguments - # Zero gradients - encoder_optimizer.zero_grad() - context_encoder_optimizer.zero_grad() - attack_clf_optimizer.zero_grad() - - # Set device options - input_variable = input_variable.to(self.device) - dialog_lengths = dialog_lengths.to(self.device) - utt_lengths = utt_lengths.to(self.device) - labels = labels.to(self.device) - - # Forward pass through utterance encoder - _, utt_encoder_hidden = encoder(input_variable, utt_lengths) - - # Convert utterance encoder final states to batched dialogs for use by context encoder - context_encoder_input = makeContextEncoderInput( - utt_encoder_hidden, dialog_lengths_list, batch_size, batch_indices, dialog_indices - ) - - # Forward pass through context encoder - context_encoder_outputs, _ = context_encoder(context_encoder_input, dialog_lengths) - - # Forward pass through classifier to get prediction logits - logits = attack_clf(context_encoder_outputs, dialog_lengths) - - # Calculate loss - loss = F.binary_cross_entropy_with_logits(logits, labels) - - # Perform backpropatation - loss.backward() - - # Clip gradients: gradients are modified in place - _ = torch.nn.utils.clip_grad_norm_(encoder.parameters(), clip) - _ = torch.nn.utils.clip_grad_norm_(context_encoder.parameters(), clip) - _ = torch.nn.utils.clip_grad_norm_(attack_clf.parameters(), clip) - - # Adjust model weights - encoder_optimizer.step() - context_encoder_optimizer.step() - attack_clf_optimizer.step() - - return loss.item() - - def _validate(self, predictor, dataset): - # create a batch iterator for the given data - batch_iterator = batchIterator(self.voc, dataset, self.options["batch_size"], shuffle=False) - # find out how many iterations we will need to cover the whole dataset - n_iters = len(dataset) // self.options["batch_size"] + int( - len(dataset) % self.options["batch_size"] > 0 - ) - # containers for full prediction results so we can compute accuracy at the end - all_preds = [] - all_labels = [] - for iteration in range(1, n_iters + 1): - batch, batch_dialogs, true_batch_size = next(batch_iterator) - # Extract fields from batch - ( - input_variable, - dialog_lengths, - utt_lengths, - batch_indices, - dialog_indices, - batch_labels, - batch_ids, - target_variable, - mask, - max_target_len, - ) = batch - dialog_lengths_list = [len(x) for x in batch_dialogs] - # run the model - predictions, scores = self._evaluate_batch( - predictor, - input_variable, - dialog_lengths, - dialog_lengths_list, - utt_lengths, - batch_indices, - dialog_indices, - true_batch_size, - ) - # aggregate results for computing accuracy at the end - all_preds += [p.item() for p in predictions] - all_labels += [l.item() for l in batch_labels] - print( - "Iteration: {}; Percent complete: {:.1f}%".format( - iteration, iteration / n_iters * 100 - ) - ) + encoder.load_state_dict(encoder_sd) + context_encoder.load_state_dict(context_sd) + # Initialize classifier + attack_clf = SingleTargetClf(HIDDEN_SIZE, self._config["dropout"]) + if attack_clf_sd is not None: + attack_clf.load_state_dict(attack_clf_sd) + # Use appropriate device + encoder = encoder.to(self._device) + context_encoder = context_encoder.to(self._device) + attack_clf = attack_clf.to(self._device) + print("Models built and ready to go!") + + return embedding, encoder, context_encoder, attack_clf + + def fit(self, contexts, val_contexts=None): + """ + Fine-tune the CRAFT model, and save the best model according to validation performance. - # compute and return the accuracy - return (np.asarray(all_preds) == np.asarray(all_labels)).mean() + :param contexts: an iterator over context tuples, provided by the Forecaster framework + :param val_contexts: an iterator over context tuples to be used only for validation. IMPORTANT: this is marked Optional only for compatibility with the generic Forecaster API; CRAFT actually REQUIRES a validation set so leaving this parameter at None will raise an error! + """ + # convert the input contexts into CRAFT's data format + train_pairs = self._context_to_craft_data(contexts) + print("Processed", len(train_pairs), "context tuples for model training") + # val_contexts is made Optional to conform to the Forecaster spec, but in reality CRAFT requires a validation set + if val_contexts is None: + raise ValueError("CRAFTModel requires a validation set!") + val_pairs = self._context_to_craft_data(val_contexts) + print("Processed", len(val_pairs), "context tuples for model validation") + + # initialize the CRAFT model with whatever weights we currently have saved + embedding, encoder, context_encoder, attack_clf = self._init_craft() - def _train_iters( - self, - train_pairs, - val_pairs, - encoder, - context_encoder, - attack_clf, - encoder_optimizer, - context_encoder_optimizer, - attack_clf_optimizer, - embedding, - n_iteration, - validate_every, - ): - # create a batch iterator for training data - batch_iterator = batchIterator(self.voc, train_pairs, self.options["batch_size"]) - - # Initializations - print("Initializing ...") - start_iteration = 1 - print_loss = 0 - - # Training loop - print("Training...") - # keep track of best validation accuracy - only save when we have a model that beats the current best - best_acc = 0 - for iteration in range(start_iteration, n_iteration + 1): - training_batch, training_dialogs, true_batch_size = next(batch_iterator) - # Extract fields from batch - ( - input_variable, - dialog_lengths, - utt_lengths, - batch_indices, - dialog_indices, - labels, - batch_ids, - target_variable, - mask, - max_target_len, - ) = training_batch - dialog_lengths_list = [len(x) for x in training_dialogs] - - # Run a training iteration with batch - loss = self._train_NN( - input_variable, - dialog_lengths, - dialog_lengths_list, - utt_lengths, - batch_indices, - dialog_indices, - labels, # input/output arguments - encoder, - context_encoder, - attack_clf, # network arguments - encoder_optimizer, - context_encoder_optimizer, - attack_clf_optimizer, # optimization arguments - true_batch_size, - self.options["clip"], - ) # misc arguments - print_loss += loss - - # Print progress - if iteration % self.options["print_every"] == 0: - print_loss_avg = print_loss / self.options["print_every"] - print( - "Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}".format( - iteration, iteration / n_iteration * 100, print_loss_avg - ) - ) - print_loss = 0 - - # Evaluate on validation set - if iteration % validate_every == 0: - print("Validating!") - # put the network components into evaluation mode - encoder.eval() - context_encoder.eval() - attack_clf.eval() - - predictor = Predictor(encoder, context_encoder, attack_clf) - accuracy = self._validate(predictor, val_pairs) - print("Validation set accuracy: {:.2f}%".format(accuracy * 100)) - - # keep track of our best model so far - if accuracy > best_acc: - print("Validation accuracy better than current best; saving model...") - best_acc = accuracy - torch.save( - { - "iteration": iteration, - "en": encoder.state_dict(), - "ctx": context_encoder.state_dict(), - "atk_clf": attack_clf.state_dict(), - "en_opt": encoder_optimizer.state_dict(), - "ctx_opt": context_encoder_optimizer.state_dict(), - "atk_clf_opt": attack_clf_optimizer.state_dict(), - "loss": loss, - "voc_dict": self.voc.__dict__, - "embedding": embedding.state_dict(), - }, - self.options["trained_model_output_filepath"], - ) - - # put the network components back into training mode - encoder.train() - context_encoder.train() - attack_clf.train() - - def train(self, id_to_context_reply_label): - ids = list(id_to_context_reply_label) - train_pair_ids, val_pair_ids = train_test_split( - ids, test_size=self.options["validation_size"] - ) - train_pairs = [id_to_context_reply_label[pair_id] for pair_id in train_pair_ids] - val_pairs = [id_to_context_reply_label[pair_id] for pair_id in val_pair_ids] # Compute the number of training iterations we will need in order to achieve the number of epochs specified in the settings at the start of the notebook - n_iter_per_epoch = len(train_pairs) // self.options["batch_size"] + int( - len(train_pairs) % self.options["batch_size"] == 1 + n_iter_per_epoch = len(train_pairs) // self._config["batch_size"] + int( + len(train_pairs) % self._config["batch_size"] == 1 ) - n_iteration = n_iter_per_epoch * self.options["train_epochs"] + n_iteration = n_iter_per_epoch * self._config["finetune_epochs"] # Put dropout layers in train mode - self.predictor.encoder.train() - self.predictor.context_encoder.train() - self.predictor.classifier.train() + encoder.train() + context_encoder.train() + attack_clf.train() # Initialize optimizers print("Building optimizers...") - encoder_optimizer = optim.Adam( - self.predictor.encoder.parameters(), lr=self.options["learning_rate"] - ) + encoder_optimizer = optim.Adam(encoder.parameters(), lr=self._config["learning_rate"]) context_encoder_optimizer = optim.Adam( - self.predictor.context_encoder.parameters(), lr=self.options["learning_rate"] - ) - attack_clf_optimizer = optim.Adam( - self.predictor.classifier.parameters(), lr=self.options["learning_rate"] + context_encoder.parameters(), lr=self._config["learning_rate"] ) + attack_clf_optimizer = optim.Adam(attack_clf.parameters(), lr=self._config["learning_rate"]) # Run training iterations, validating after every epoch print("Starting Training!") print("Will train for {} iterations".format(n_iteration)) - self._train_iters( + best_model = trainIters( + self._voc, train_pairs, val_pairs, - self.predictor.encoder, - self.predictor.context_encoder, - self.predictor.classifier, + encoder, + context_encoder, + attack_clf, encoder_optimizer, context_encoder_optimizer, attack_clf_optimizer, - self.predictor.encoder.embedding, + embedding, n_iteration, + self._config["batch_size"], + self._config["print_every"], n_iter_per_epoch, + self._config["clip"], + self._device, + MAX_LENGTH, + batchIterator, ) - def forecast(self, id_to_context_reply_label): + # save the resulting checkpoints so we can load them later during transform + self._model = best_model + + def transform(self, contexts, forecast_attribute_name, forecast_prob_attribute_name): """ - Compute forecasts and forecast scores for the given dictionary of utterance id to (context, reply) pairs. Return the values in a DataFrame. + Run a fine-tuned CRAFT model on the provided data - :param id_to_context_reply_label: dict mapping utterance id to (context, reply, label) - :return: a pandas DataFrame + :param contexts: context tuples from the Forecaster framework + :param forecast_attribute_name: Forecaster will use this to look up the table column containing your model's discretized predictions (see output specification below) + :param forecast_prob_attribute_name: Forecaster will use this to look up the table column containing your model's raw forecast probabilities (see output specification below) + + :return: a Pandas DataFrame, with one row for each context, indexed by the ID of that context's current utterance. Contains two columns, one with raw probabilities named according to forecast_prob_attribute_name, and one with discretized (binary) forecasts named according to forecast_attribute_name """ - dataset = [ - (context, reply, label, id_) - for id_, (context, reply, label) in id_to_context_reply_label.items() - ] - return self._evaluate_dataset(self.predictor, dataset) + # convert the input contexts into CRAFT's data format + test_pairs = self._context_to_craft_data(contexts) + print("Processed", len(test_pairs), "context tuples for model evaluation") + + # initialize the CRAFT model with whatever weights we currently have saved + embedding, encoder, context_encoder, attack_clf = self._init_craft() + + # Set dropout layers to eval mode + encoder.eval() + context_encoder.eval() + attack_clf.eval() + + # Initialize the pipeline + predictor = Predictor(encoder, context_encoder, attack_clf) + + # Run the pipeline! + forecasts_df = evaluateDataset( + test_pairs, + encoder, + context_encoder, + predictor, + self._voc, + self._config["batch_size"], + self._device, + MAX_LENGTH, + batchIterator, + self._decision_threshold, + forecast_attribute_name, + forecast_prob_attribute_name, + ) + + return forecasts_df diff --git a/convokit/forecaster/forecaster.py b/convokit/forecaster/forecaster.py index 7d649d49..346c0110 100644 --- a/convokit/forecaster/forecaster.py +++ b/convokit/forecaster/forecaster.py @@ -1,160 +1,144 @@ -from convokit.model import Corpus, Conversation, Utterance -from typing import Callable, Optional -from convokit import Transformer -from .cumulativeBoW import CumulativeBoW +from convokit import Corpus, Conversation, Utterance, Transformer +from typing import Callable, Optional, Union, Any, List, Iterator +from collections import namedtuple from .forecasterModel import ForecasterModel import pandas as pd +import numpy as np +from matplotlib import pyplot as plt + +# Define a namedtuple template to represent conversational context tuples +ContextTuple = namedtuple( + "ContextTuple", ["context", "current_utterance", "future_context", "conversation_id"] +) class Forecaster(Transformer): """ - Implements basic Forecaster behavior. - - :param forecaster_model: ForecasterModel to use, e.g. cumulativeBoW or CRAFT - :param forecast_mode: 'future' or 'past'. 'future' (the default behavior) annotates each utterance with a forecast score using all context up to and including that utterance (i.e., a prediction of the future state of the conversation after this utterance). 'past' annotates each utterance with a forecast score using all context prior to that utterance (i.e., what the model believed this utterance would look like prior to actually seeing it) - :param convo_structure: conversations in expected corpus are 'branched' or 'linear', default: "branched" - :param text_func: optional function for extracting the text of the utterance, default: uses utterance's text attribute - :param label_func: callable function for getting the utterance's forecast label (True or False); only used in training - :param use_last_only: if forecast_mode is 'past' and use_last_only is True, for each dialog, use only the context-reply pair where the reply is the last utterance in the dialog - :param skip_broken_convos: if True and convo_structure is 'branched', exclude all conversations that have broken reply-to structures, default: True + A wrapper class that provides a consistent, Transformer-style interface to any conversational forecasting model. + From a user perspective, this makes it easy to apply forecasting models to ConvoKit corpora and evaluate them without + having to know a lot about the inner workings of conversational forecasting, and to swap between different kinds of + models without having to change a lot of code. From a developer perspective, this provides a prebuilt foundation upon which + new conversational forecasting models can be easily developed, as the Forecaster class handles to complicated work of + iterating over conversational contexts in temporal fashion, allowing the developer to focus only on writing the code to handle each conversational context. + + :param forecaster_model: An instance of a ForecasterModel subclass that implements the conversational forecasting model you want to use. ConvoKit provides CRAFT and BERT implementations. + :param labeler: A function that specifies where/how to find the label for any given conversation. Alternatively, a string can be provided, in which case it will be interpreted as the name of a Conversation metadata field containing the label. + :param context_preprocessor: An optional function that allows simple preprocessing of conversational contexts. Note that this should NOT be used to perform any restructuring or feature engineering on the data (that work is considered the exclusive purview of the underlying ForecasterModel); instead, it is intended to perform simple Corpus-specific data cleaning steps (i.e., removing utterances that lack key metadata required by the model) :param forecast_attribute_name: metadata feature name to use in annotation for forecast result, default: "forecast" :param forecast_prob_attribute_name: metadata feature name to use in annotation for forecast result probability, default: "forecast_prob" """ def __init__( self, - forecaster_model: ForecasterModel = None, - forecast_mode: str = "future", - convo_structure: str = "branched", - text_func=lambda utt: utt.text, - label_func: Callable[[Utterance], bool] = lambda utt: True, - use_last_only: bool = False, - skip_broken_convos: bool = True, + forecaster_model: ForecasterModel, + labeler: Union[Callable[[Conversation], int], str], + context_preprocessor: Optional[Callable[[List[Utterance]], List[Utterance]]] = None, forecast_attribute_name: str = "forecast", forecast_prob_attribute_name: str = "forecast_prob", ): - assert convo_structure in ["branched", "linear"] - self.convo_structure = convo_structure - - if forecaster_model is None: - print( - "No model passed to Forecaster. Initializing default forecaster model: Cumulative Bag-of-words..." - ) - self.forecaster_model = CumulativeBoW( - forecast_attribute_name=forecast_attribute_name, - forecast_prob_attribute_name=forecast_prob_attribute_name, - ) + self.forecaster_model = forecaster_model + if type(labeler) == str: + # assume the string is the name of a conversation metadata field containing the label + self.labeler = lambda c: int(c.meta[labeler]) else: - self.forecaster_model = forecaster_model - self.forecast_mode = forecast_mode - self.label_func = label_func - self.text_func = text_func - self.use_last_only = use_last_only - self.skip_broken_convos = skip_broken_convos + self.labeler = labeler + self.context_preprocessor = context_preprocessor self.forecast_attribute_name = forecast_attribute_name self.forecast_prob_attribute_name = forecast_prob_attribute_name - def _get_context_reply_label_dict( - self, corpus: Corpus, convo_selector, utt_excluder, include_label=True - ): - """ - Returns a dict mapping reply id to (context, reply, label). + # also give the underlying ForecasterModel access to the labeler function + self.forecaster_model.labeler = self.labeler - If self.forecast_mode == 'future': return a dict mapping the leaf utt id to the path from root utt to leaf utt + def _create_context_iterator( + self, + corpus: Corpus, + context_selector: Callable[[ContextTuple], bool], + include_future_context: bool = False, + ) -> Iterator[ContextTuple]: """ - dialogs = [] - if self.convo_structure == "branched": - for convo in corpus.iter_conversations(convo_selector): - try: - for path in convo.get_root_to_leaf_paths(): - path = [utt for utt in path if not utt_excluder(utt)] - if len(path) == 1: - continue - dialogs.append(path) - except ValueError as e: - if not self.skip_broken_convos: - raise e - - elif self.convo_structure == "linear": - for convo in corpus.iter_conversations(convo_selector): - utts = convo.get_chronological_utterance_list( - selector=lambda x: not utt_excluder(x) - ) - if len(utts) == 1: + Helper function that generates an iterator over conversational contexts that satisfy the provided context selector, + across the entire corpus + """ + for convo in corpus.iter_conversations(): + # contexts are iterated in chronological order, representing the idea that conversational forecasting models + # must make an updated forecast every time a new utterance is posted + chronological_utts = convo.get_chronological_utterance_list() + for i in range(len(chronological_utts)): + current_utt = chronological_utts[i] + # context is all utterances up to and including the most recent utterance + context = chronological_utts[: (i + 1)] + # if a preprocessor is given, run it first to get the "clean" version of the context + if self.context_preprocessor is not None: + context = self.context_preprocessor(context) + if include_future_context: + if i == len(chronological_utts) - 1: + # not to be confused with future_context=None, which indicates that include_future_context was false; + # this special value indicates that include_future_context is true but there is no future context + # (because we are at the end of the conversation) + future_context = [] + else: + future_context = [chronological_utts[(i + 1) :]] + else: + future_context = None + # pack the full context tuple + context_tuple = ContextTuple(context, current_utt, future_context, convo.id) + # the current context tuple should be skipped if it does not satisfy the given selector, + # or the context is empty (which may happen as a result of preprocessing) + if len(context_tuple.context) == 0 or not context_selector(context_tuple): continue - dialogs.append(utts) - - id_to_context_reply_label = dict() - - # this flag determines whether the dictionary entry for each utterance ID should include that - # utterance in the context (True corresponds to "future" behavior). This needs to be always - # False when include_label = True, since include_label assumes that the label comes from the - # utterance after the last utterance in the context. This override logic won't affect - # forecast_mode however, since that argument only applies to transform() while include_label - # is only True when called from fit() - include_current = (self.forecast_mode == "future") and (not include_label) - - for dialog in dialogs: - if self.use_last_only: - reply = self.text_func(dialog[-1]) - context = [ - self.text_func(utt) for utt in (dialog if include_current else dialog[:-1]) - ] - label = self.label_func(dialog[-1]) if include_label else None - id_to_context_reply_label[dialog[-1].id] = (context, reply, label) - else: - for idx in range(0 if include_current else 1, len(dialog)): - reply = self.text_func(dialog[idx]) - label = self.label_func(dialog[idx]) if include_label else None - reply_id = dialog[idx].id - context = [ - self.text_func(utt) - for utt in (dialog[: (idx + 1)] if include_current else dialog[:idx]) - ] - id_to_context_reply_label[reply_id] = ( - (context, reply, label) if include_label else (context, reply, None) - ) - - return id_to_context_reply_label + # if the current context was not skipped, it is next in the iterator + yield context_tuple def fit( self, corpus: Corpus, - y=None, - selector: Callable[[Conversation], bool] = lambda convo: True, - ignore_utterances: Callable[[Utterance], bool] = lambda utt: False, + context_selector: Callable[[ContextTuple], bool] = lambda context: True, + val_context_selector: Optional[Callable[[ContextTuple], bool]] = None, ): """ - Train the ForecasterModel on the given corpus. + Wrapper method for training the underlying conversational forecasting model. Forecaster itself does not implement any actual training logic. + Instead, it handles the job of selecting and iterating over context tuples. The resulting iterator is presented as a parameter to the fit + method of the underlying model, which can process the tuples however it sees fit. Within each tuple, context is unstructured - it contains all + utterances temporally preceding the most recent utterance, plus that most recent utterance itself, but does not impose any particular structure + beyond that, allowing each conversational forecasting model to decide how it wants to define “context”. + + :param corpus: The Corpus containing the data to train on + :param context_selector: A function that takes in a context tuple and returns a boolean indicator of whether it should be included in training data. This can be used to both select data based on splits (i.e. keep only those in the “train” split) and to specify special behavior of what contexts are looked at in training (i.e. in CRAFT where only the last context, directly preceding the toxic comment, is used in training). + :param val_context_selector: An optional function that mirrors context_selector but is used to create a separate held-out validation set - :param corpus: target Corpus - :param selector: a (lambda) function that takes a Conversation and returns a bool: True if the Conversation is to be included in the fitting step. By default, includes all Conversations. - :param ignore_utterances: a (lambda) function that takes an Utterance and returns a bool: True if the Utterance should be excluded from the Conversation in the fitting step. By default, all Utterances are included. :return: fitted Forecaster Transformer """ - id_to_context_reply_label = self._get_context_reply_label_dict( - corpus, selector, ignore_utterances, include_label=True + contexts = self._create_context_iterator( + corpus, context_selector, include_future_context=True ) - self.forecaster_model.train(id_to_context_reply_label) + val_contexts = None + if val_context_selector is not None: + val_contexts = self._create_context_iterator( + corpus, val_context_selector, include_future_context=True + ) + self.forecaster_model.fit(contexts, val_contexts) + + return self def transform( self, corpus: Corpus, - selector: Callable[[Conversation], bool] = lambda convo: True, - ignore_utterances: Callable[[Utterance], bool] = lambda utt: False, + context_selector: Callable[[ContextTuple], bool] = lambda context: True, ) -> Corpus: """ - Annotate the corpus utterances with forecast and forecast score information + Wrapper method for applying the underlying conversational forecasting model to make forecasts over the Conversations in a given Corpus. + Like the fit method, this simply acts to create an iterator over context tuples to be transformed, and forwards the iterator to the + underlying conversational forecasting model to do the actual forecasting. + + :param corpus: the Corpus containing the data to run on + :param context_selector: A function that takes in a context tuple and returns a boolean indicator of whether it should be included. Excluded contexts will simply not have a forecast. - :param corpus: target Corpus - :param selector: a (lambda) function that takes a Conversation and returns a bool: True if the Conversation is to be included in the transformation step. By default, includes all Conversations. - :param ignore_utterances: a (lambda) function that takes an Utterance and returns a bool: True if the Utterance should be excluded from the Conversation in the transformation step. By default, all Utterances are included. :return: annotated Corpus """ - id_to_context_reply_label = self._get_context_reply_label_dict( - corpus, selector, ignore_utterances, include_label=False + contexts = self._create_context_iterator(corpus, context_selector) + forecast_df = self.forecaster_model.transform( + contexts, self.forecast_attribute_name, self.forecast_prob_attribute_name ) - forecast_df = self.forecaster_model.forecast(id_to_context_reply_label) for utt in corpus.iter_utterances(): if utt.id in forecast_df.index: @@ -175,60 +159,122 @@ def transform( def fit_transform( self, corpus: Corpus, - y=None, - selector: Callable[[Conversation], bool] = lambda convo: True, - ignore_utterances: Callable[[Utterance], bool] = lambda utt: False, + context_selector: Callable[[ContextTuple], bool] = lambda context: True, ) -> Corpus: - self.fit(corpus, selector=selector, ignore_utterances=ignore_utterances) - return self.transform(corpus, selector=selector, ignore_utterances=ignore_utterances) + """ + Convenience method for running fit and transform on the same data + + :param corpus: the Corpus containing the data to run on + :param context_selector: A function that takes in a context tuple and returns a boolean indicator of whether it should be included. Excluded contexts will simply not have a forecast. + + :return: annotated Corpus + """ + self.fit(corpus, context_selector) + return self.transform(corpus, context_selector) + + def _draw_horizon_plot( + self, corpus: Corpus, selector: Callable[[Conversation], bool] = lambda convo: True + ): + """ + Draw the "forecast horizon" plot showing how far before the end of the conversation the first forecast is made + (for true positives). Note this is not always an especially meaningful plot, if the Corpus being used includes + to-be-forecasted events earlier in the conversation and not at the end, but it works for datasets like + CGA-CMV where the event is defined to be after the end of the included utterances. + """ + comments_until_end = {} + for convo in corpus.iter_conversations(): + if selector(convo) and self.labeler(convo) == 1: + for i, utt in enumerate(convo.get_chronological_utterance_list()): + prediction = utt.meta.get(self.forecast_attribute_name) + if prediction is not None and prediction > 0: + comments_until_end[convo.id] = ( + len(convo.get_chronological_utterance_list()) - i + ) + break + comments_until_end_vals = list(comments_until_end.values()) + plt.hist( + comments_until_end_vals, bins=range(1, np.max(comments_until_end_vals)), density=True + ) + plt.xlabel( + "Number of comments between index of first positive forecast and end of conversation" + ) + plt.ylabel("Percent of convesations") + plt.show() + return comments_until_end def summarize( - self, - corpus: Corpus, - selector: Callable[[Conversation], bool] = lambda convo: True, - ignore_utterances: Callable[[Utterance], bool] = lambda utt: False, - exclude_na=True, + self, corpus: Corpus, selector: Callable[[Conversation], bool] = lambda convo: True ): """ - Returns a DataFrame of utterances and their forecasts (and forecast probabilities) + Compute and display conversation-level performance metrics over a Corpus that has already been annotated by transform - :param corpus: target Corpus - :param exclude_na: whether to drop NaN results - :param selector: a (lambda) function that takes a Conversation and returns a bool: True if the Conversation is to be included in the summary step. By default, includes all Conversations. - :param ignore_utterances: a (lambda) function that takes an Utterance and returns a bool: True if the Utterance should be excluded from the Conversation in the summary step. By default, all Utterances are included. - :return: a pandas DataFrame + :param corpus: the Corpus containing the forecasts to evaluate + :param selector: A filtering function to limit the conversations the metrics are computed over. Note that unlike the context_selectors used in fit and transform, this selector operates on conversations (since evaluation is conversation-level). """ - utt_forecast_prob = [] - for convo in corpus.iter_conversations(selector): - for utt in convo.iter_utterances(lambda x: not ignore_utterances(x)): - utt_forecast_prob.append( - ( - utt.id, - utt.meta[self.forecast_attribute_name], - utt.meta[self.forecast_prob_attribute_name], - ) + conversational_forecasts_df = { + "conversation_id": [], + "label": [], + "score": [], + "forecast": [], + } + for convo in corpus.iter_conversations(): + if selector(convo): + conversational_forecasts_df["conversation_id"].append(convo.id) + conversational_forecasts_df["label"].append(self.labeler(convo)) + forecasts = np.asarray( + [ + utt.meta[self.forecast_attribute_name] + for utt in convo.iter_utterances() + if utt.meta.get(self.forecast_attribute_name, None) is not None + ] ) - forecast_df = ( - pd.DataFrame( - utt_forecast_prob, - columns=["utt_id", self.forecast_attribute_name, self.forecast_prob_attribute_name], - ) - .set_index("utt_id") - .sort_values(self.forecast_prob_attribute_name, ascending=False) + forecast_scores = np.asarray( + [ + utt.meta[self.forecast_prob_attribute_name] + for utt in convo.iter_utterances() + if utt.meta.get(self.forecast_prob_attribute_name, None) is not None + ] + ) + conversational_forecasts_df["score"].append(np.max(forecast_scores)) + conversational_forecasts_df["forecast"].append(np.max(forecasts)) + conversational_forecasts_df = pd.DataFrame(conversational_forecasts_df).set_index( + "conversation_id" ) - if exclude_na: - forecast_df = forecast_df.dropna() - return forecast_df - def get_model(self): - """ - Get the forecaster model object - """ - return self.forecaster_model + acc = ( + conversational_forecasts_df["label"] == conversational_forecasts_df["forecast"] + ).mean() + tp = ( + (conversational_forecasts_df["label"] == 1) + & (conversational_forecasts_df["forecast"] == 1) + ).sum() + fp = ( + (conversational_forecasts_df["label"] == 0) + & (conversational_forecasts_df["forecast"] == 1) + ).sum() + tn = ( + (conversational_forecasts_df["label"] == 0) + & (conversational_forecasts_df["forecast"] == 0) + ).sum() + fn = ( + (conversational_forecasts_df["label"] == 1) + & (conversational_forecasts_df["forecast"] == 0) + ).sum() + p = tp / (tp + fp) + r = tp / (tp + fn) + fpr = fp / (fp + tn) + f1 = 2 / (((tp + fp) / tp) + ((tp + fn) / tp)) + metrics = {"Accuracy": acc, "Precision": p, "Recall": r, "FPR": fpr, "F1": f1} - def set_model(self, forecaster_model): - """ - Set the forecaster model - :return: - """ - self.forecaster_model = forecaster_model + print(pd.Series(metrics)) + + comments_until_end = self._draw_horizon_plot(corpus, selector) + comments_until_end_vals = list(comments_until_end.values()) + print( + "Horizon statistics (# of comments between first positive forecast and conversation end):" + ) + print( + f"Mean = {np.mean(comments_until_end_vals)}, Median = {np.median(comments_until_end_vals)}" + ) + + return conversational_forecasts_df, metrics diff --git a/convokit/forecaster/forecasterModel.py b/convokit/forecaster/forecasterModel.py index b5392fdf..0051ff32 100644 --- a/convokit/forecaster/forecasterModel.py +++ b/convokit/forecaster/forecasterModel.py @@ -1,31 +1,43 @@ from abc import ABC, abstractmethod +from typing import Callable class ForecasterModel(ABC): - def __init__( - self, - forecast_attribute_name: str = "prediction", - forecast_prob_attribute_name: str = "score", - ): - """ + """ + An abstract class defining an interface that Forecaster can call into to invoke a conversational forecasting algorithm. + The “contract” between Forecaster and ForecasterModel means that ForecasterModel can expect to receive conversational data + in a consistent format, defined above. + """ - :param forecast_attribute_name: name for DataFrame column containing predictions, default: "prediction" - :param forecast_prob_attribute_name: name for column containing prediction scores, default: "score" - """ - self.forecast_attribute_name = forecast_attribute_name - self.forecast_prob_attribute_name = forecast_prob_attribute_name + def __init__(self): + self._labeler = None + + @property + def labeler(self): + return self._labeler + + @labeler.setter + def labeler(self, value: Callable): + self._labeler = value @abstractmethod - def train(self, id_to_context_reply_label): + def fit(self, contexts, val_contexts=None): """ - Train the Forecaster Model with the context-reply-label tuples + Train this conversational forecasting model on the given data + + :param contexts: an iterator over context tuples + :param val_contexts: an optional second iterator over context tuples to be used as a separate held-out validation set. Concrete ForecasterModel implementations may choose to ignore this, or conversely even enforce its presence. """ pass @abstractmethod - def forecast(self, id_to_context_reply_label): + def transform(self, contexts, forecast_attribute_name, forecast_prob_attribute_name): """ - Use the Forecaster Model to compute forecasts and scores - for given context-reply pairs and return a results dataframe + Apply this trained conversational forecasting model to the given data, and return its forecasts + in the form of a DataFrame indexed by (current) utterance ID + + :param contexts: an iterator over context tuples + + :return: a Pandas DataFrame, with one row for each context, indexed by the ID of that context's current utterance. Contains two columns, one with raw probabilities named according to forecast_prob_attribute_name, and one with discretized (binary) forecasts named according to forecast_attribute_name. Subclass implementations of ForecasterModel MUST adhere to this return value specification! """ pass diff --git a/convokit/util.py b/convokit/util.py index f7452c5d..1026f7c9 100644 --- a/convokit/util.py +++ b/convokit/util.py @@ -87,6 +87,7 @@ def download( cur_version = dataset_config["cur_version"] DatasetURLs = dataset_config["DatasetURLs"] + ModelURLs = dataset_config["ModelURLS"] if name.startswith("subreddit"): subreddit_name = name.split("-", maxsplit=1)[1] @@ -158,7 +159,14 @@ def download( # name not in downloaded or \ # (use_newest_version and name in cur_version and # downloaded[name] < cur_version[name]): - if name.endswith("-motifs"): + if name in ModelURLs: + for url in ModelURLs[name]: + full_name = name + url[url.rfind("/") :] + model_file_path = dataset_path + url[url.rfind("/") :] + if not os.path.exists(os.path.dirname(model_file_path)): + os.makedirs(os.path.dirname(model_file_path)) + _download_helper(model_file_path, url, verbose, full_name, downloadeds_path) + elif name.endswith("-motifs"): for url in DatasetURLs[name]: full_name = name + url[url.rfind("/") :] if full_name not in downloaded: @@ -236,12 +244,14 @@ def download_local(name: str, data_dir: str): def _download_helper( dataset_path: str, url: str, verbose: bool, name: str, downloadeds_path: str ) -> None: + is_corpus = False if ( url.lower().endswith(".corpus") or url.lower().endswith(".corpus.zip") or url.lower().endswith(".zip") ): dataset_path += ".zip" + is_corpus = True with urllib.request.urlopen(url) as response, open(dataset_path, "wb") as out_file: if verbose: @@ -269,16 +279,18 @@ def _download_helper( if verbose: print("Done") - with open(downloadeds_path, "a") as f: - fn = os.path.join( - os.path.dirname(dataset_path), name - ) # os.path.join(os.path.dirname(data), name) - f.write( - "{}$#${}$#${}\n".format( - name, os.path.realpath(os.path.dirname(dataset_path) + "/"), corpus_version(fn) + # for Corpus objects only: check the Corpus version + if is_corpus: + with open(downloadeds_path, "a") as f: + fn = os.path.join( + os.path.dirname(dataset_path), name + ) # os.path.join(os.path.dirname(data), name) + f.write( + "{}$#${}$#${}\n".format( + name, os.path.realpath(os.path.dirname(dataset_path) + "/"), corpus_version(fn) + ) ) - ) - # f.write(name + "\n") + # f.write(name + "\n") def corpus_version(filename: str) -> int: diff --git a/docs/source/config.rst b/docs/source/config.rst index 9b076f35..6534a5ca 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -2,7 +2,8 @@ Configurations =================== After you import convokit for the first time, a default configuration file will be generated in ~/.convokit/config.yml. -There are currently three variables: -db_host: -data_directory: -default_backend: +There are currently four variables: +- db_host: database localhost port, default to be "localhost:27017". +- data_directory: local directory for downloaded corpuses, default to be "~/.convokit/saved-corpora". +- model_directory: local directory for downloaded models, default to be "~/.convokit/saved-models". +- default_backend: default ConvoKit backend choice, can be "mem" or "db", default to be "mem". For more information, check `Storage Options `_ diff --git a/docs/source/forecaster.rst b/docs/source/forecaster.rst index f963e5d9..376deca6 100644 --- a/docs/source/forecaster.rst +++ b/docs/source/forecaster.rst @@ -1,13 +1,32 @@ Forecaster ========== -Conversational Forecasting Transformers. +The Forecaster class provides a generic interface to *conversational forecasting models*, a class of models designed to computationally capture the trajectory +of conversations in order to predict future events. Though individual conversational forecasting models can get quite complex, the Forecaster API abstracts +away the implementation details into a standard fit-transform interface. + +For end users of Forecaster: see the demo notebook which `uses Forecaster to fine-tune the CRAFT forecasting model on the CGA-CMV corpus `_ + +For developers of conversational forecasting models: Forecaster also represents a common framework for conversational forecasting +that you can use, in conjunction with other ML/NLP ecosystems like PyTorch and Huggingface, to streamline the development of your models! +You can create your conversational forecasting model as a subclass of ForecasterModel, which can then be directly "plugged in" to the +Forecaster wrapper which will provide a standard fit-transform interface to your model. At runtime, Forecaster will feed a temporally-ordered +stream of conversational data to your ForecasterModel in the form of "context tuples". Context tuples are generated in chronological order, +simulating the notion that the model is following the conversation as it develops in real time and generating a new prediction every time a +new utterance appears (e.g., in a social media setting, every time a new comment is posted). Each context tuple, in turn, is defined as a +NamedTuple with the following fields: + +* ``context``: a chronological list of Utterances up to and including the most recent Utterance at the time this context was generated. Beyond the chronological ordering, no structure of any kind is imposed on the Utterances, so developers of conversational forecasting models are free to perform any structuring of their own that they desire (so yes, if you want, you can build conversational graphs on top of the provided context!) +* ``current_utterance``: the most recent utterance at the time this context tuple was generated. In the vast majority of cases, this will be identical to the last utterance in the context, except in cases where that utterance might have gotten filtered out of the context by the preprocessor (in those cases, current_utterance still reflects the "missing" most recent utterance, in order to provide a reference point for where we currently are in the conversation) +* ``future_context``: during **training only** (i.e., in the fit function), the context tuple also includes this additional field that lists all future Utterances; that is, all Utterances chronologically after the current utterance (or an empty list if this Utterance is the last one). This is meant only to help with data preprocessing and selection during training; for example, CRAFT trains only on the last context in each conversation, so we need to look at future_context to know whether we are at the end of the conversation. It **should not be used as input to the model**, as that would be "cheating" - in fact, to enforce this, future_context is **not available during evaluation** (i.e. in the transform function) so that any model that improperly made use of future_context would crash during evaluation! +* ``conversation_id``: the Conversation that this context-reply pair came from. ForecasterModel also has access to Forecaster's labeler function and can use that together with the conversation_id to look up the label + +Illustrative example, a conversation containing utterances ``[a, b, c, d]`` (in temporal order) will produce the following four context tuples, in this exact order: +#. ``(context=[a], current_utterance=a, future_context=[b,c,d])`` +#. ``(context=[a,b], current_utterance=b, future_context=[c,d])`` +#. ``(context=[a,b,c], current_utterance=c, future_context=[d])`` +#. ``(context=[a,b,c,d], current_utterance=d, future_context=[])`` -Refer to :doc:`CRAFT Model ` and :doc:`Cumulative Bag-of-Words ` for the models that Forecaster can be loaded with. - -Example usage: `CRAFT forecasting of conversational derailment `_. - -Example usage: `Forecasting of conversational derailment using a cumulative bag-of-words model `_. .. automodule:: convokit.forecaster.forecaster :members: diff --git a/examples/forecaster/CRAFT Forecaster demo.ipynb b/examples/forecaster/CRAFT Forecaster demo.ipynb new file mode 100644 index 00000000..f5575582 --- /dev/null +++ b/examples/forecaster/CRAFT Forecaster demo.ipynb @@ -0,0 +1,1423 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a76caad0-a440-43cf-bfcd-af0ad6d68042", + "metadata": {}, + "source": [ + "# ConvoKit Forecaster framework: CRAFT demo\n", + "\n", + "The `Forecaster` class provides a generic interface to *conversational forecasting models*, a class of models designed to computationally capture the trajectory of conversations in order to predict future events. Though individual conversational forecasting models can get quite complex, the `Forecaster` API abstracts away the implementation details into a standard fit-transform interface. To demonstrate the power of this framework, this notebook walks through an example of fine-tuning the CRAFT conversational forecasting model (Chang and Danescu-Niculescu-Mizil, 2019) on the CGA-CMV corpus. You will see how the `Forecaster` API allows us to load the data, select training, validation, and testing samples, train the CRAFT model, and perform evaluation - replicating the original paper's full pipeline (minus pre-training, which is considered outside the scope of ConvoKit) all in only a few lines of code!\n", + "\n", + "Let's start by importing the necessary ConvoKit classes and functions, and loading the CGA-CMV corpus." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "326d6337-43c1-48c4-90de-d907b5fa32b6", + "metadata": {}, + "outputs": [], + "source": [ + "from convokit import download, Corpus, Forecaster, CRAFTModel\n", + "from functools import partial" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "3c530c49-80d9-455d-a062-cb31a7514d85", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading conversations-gone-awry-cmv-corpus to /reef/kz88/convokit/download_corpus/conversations-gone-awry-cmv-corpus\n", + "Downloading conversations-gone-awry-cmv-corpus from http://zissou.infosci.cornell.edu/convokit/datasets/conversations-gone-awry-cmv-corpus/conversations-gone-awry-cmv-corpus.zip (51.5MB)... Done\n" + ] + } + ], + "source": [ + "corpus = Corpus(filename=download(\"conversations-gone-awry-cmv-corpus\", data_dir=\"YOUR_DATA_DIRECTORY\"))" + ] + }, + { + "cell_type": "markdown", + "id": "a4d27b1a-3d1f-4039-b10f-21b6dd230c10", + "metadata": {}, + "source": [ + "## Define selectors for the Forecaster\n", + "\n", + "Core to the flexibility of the `Forecaster` framework is the concept of *selectors*. \n", + "\n", + "To capture the temporal dimension of the conversational forecasting task, `Forecaster` iterates through conversations in chronological utterance order, at each step presenting to the backend forecasting model a \"context tuple\" containing both the comment itself and the full \"context\" preceding that comment. As a general framework, `Forecaster` on its own does not try to make any further assumptions about what \"context\" should contain or look like; it simply presents context as a chronologically ordered list of all utterances up to and including the current one. \n", + "\n", + "But in practice, we often want to be pickier about what we mean by \"context\". At a basic level, we might want to select only specific contexts during training versus during evaluation. The simplest version of this is the desire to split the conversations by training and testing splits, but more specifically, we might also want to select only certain contexts within a conversation. This is necessary for CRAFT training, which works by taking only the chronologically last context (i.e., all utterances up to and not including the toxic comment, or up to the end of the conversation) as a labeled training instance. This is where selectors come in! A selector is a user-provided function that takes in a context and returns a boolean representing whether or not that context should be used. You can provide separate selectors for `fit` and `transform`, and `fit` also takes in a second selector that you can use to define validation data.\n", + "\n", + "Here we show how to implement the necessary selectors for CRAFT." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "6bfdb2b9-757a-4a44-89dc-a6e3b6249ac8", + "metadata": {}, + "outputs": [], + "source": [ + "def generic_fit_selector(context_tuple, split):\n", + " \"\"\"\n", + " We use this generic function for both training and validation data.\n", + " In both cases, its job is to select only those contexts for which the\n", + " FUTURE context is empty. This is in accordance with how CRAFT was\n", + " originally trained on CGA-CMV, taking the last context from each\n", + " conversation (\"last\" defined as being up to and including the chronologically\n", + " last utterance as recorded in the corpus)\n", + " \"\"\"\n", + " matches_split = (context_tuple.current_utterance.get_conversation().meta[\"split\"] == split)\n", + " is_end = (len(context_tuple.future_context) == 0)\n", + " return (matches_split and is_end)\n", + "\n", + "def transform_selector(context_tuple):\n", + " \"\"\"\n", + " For transform we only need to check that the conversation is in the test split\n", + " \"\"\"\n", + " return (context_tuple.current_utterance.get_conversation().meta[\"split\"] == \"test\")" + ] + }, + { + "cell_type": "markdown", + "id": "9614aff5-843e-4b3b-b03f-6f57f8e76b8a", + "metadata": {}, + "source": [ + "## Initialize the Forecaster and CRAFTModel backend\n", + "\n", + "Now the rest of the process is pretty straightforward! We simply need to:\n", + "1. Initialize a backend `ForecasterModel` for the `Forecaster` to use, in this case we choose ConvoKit's implementation of CRAFT.\n", + "2. Initialize a `Forecaster` instance to wrap that `ForecasterModel` in a generic fit-transform API" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9ee62d7f", + "metadata": {}, + "outputs": [], + "source": [ + "# We define the constant DEVICE to specify whether we want to run in GPU mode or CPU mode. As CRAFT is a neural model, GPU mode\n", + "# (activated with the value \"cuda\") is preferred. But if your machine lacks a GPU, you can change the value to \"cpu\" to enable\n", + "# CPU mode (noting that it will be slower)\n", + "DEVICE = \"cuda\"" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "653febf7-ef59-4e8b-8ce2-39b4a33c3747", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading craft-cmv-pretrained to /reef/kz88/convokit/download_model/craft-cmv-pretrained\n", + "Downloading craft-cmv-pretrained/craft_pretrained.tar from https://zissou.infosci.cornell.edu/convokit/models/craft_cmv/craft_pretrained.tar (974.6MB)... Done\n", + "Downloading craft-cmv-pretrained/index2word.json from https://zissou.infosci.cornell.edu/convokit/models/craft_cmv/index2word.json (1.0MB)... Done\n", + "Downloading craft-cmv-pretrained/word2index.json from https://zissou.infosci.cornell.edu/convokit/models/craft_cmv/word2index.json (928.0KB)... Done\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/reef/kz88/convokit/testing/lib/python3.11/site-packages/convokit/forecaster/CRAFTModel.py:124: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + " self._model = torch.load(model_path, map_location=torch.device(torch_device))\n" + ] + } + ], + "source": [ + "craft = CRAFTModel(\"craft-cmv-pretrained\", torch_device=DEVICE)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "c971ec31-964b-4df0-a7bb-d064ae62c5a3", + "metadata": {}, + "outputs": [], + "source": [ + "craft_forecaster = Forecaster(craft, \"has_removed_comment\")" + ] + }, + { + "cell_type": "markdown", + "id": "b840f526-dafd-4022-b5a1-90adecbd1591", + "metadata": {}, + "source": [ + "## Fine-tune the model using Forecaster.fit\n", + "\n", + "And now, just like any other ConvoKit Transformer, model training is done simply by calling `fit` (note how we pass in the selectors we previously defined!)..." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "8590e7b3-f633-4844-baea-18bee961eebd", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processed 4106 context tuples for model training\n", + "Processed 1368 context tuples for model validation\n", + "Loading saved parameters...\n", + "Building encoders, decoder, and classifier...\n", + "Models built and ready to go!\n", + "Building optimizers...\n", + "Starting Training!\n", + "Will train for 1920 iterations\n", + "Initializing ...\n", + "Training...\n", + "Iteration: 10; Percent complete: 0.5%; Average loss: 0.6930\n", + "Iteration: 20; Percent complete: 1.0%; Average loss: 0.6927\n", + "Iteration: 30; Percent complete: 1.6%; Average loss: 0.6931\n", + "Iteration: 40; Percent complete: 2.1%; Average loss: 0.6928\n", + "Iteration: 50; Percent complete: 2.6%; Average loss: 0.6932\n", + "Iteration: 60; Percent complete: 3.1%; Average loss: 0.6927\n", + "Validating!\n", + "Iteration: 1; Percent complete: 4.5%\n", + "Iteration: 2; Percent complete: 9.1%\n", + "Iteration: 3; Percent complete: 13.6%\n", + "Iteration: 4; Percent complete: 18.2%\n", + "Iteration: 5; Percent complete: 22.7%\n", + "Iteration: 6; Percent complete: 27.3%\n", + "Iteration: 7; Percent complete: 31.8%\n", + "Iteration: 8; Percent complete: 36.4%\n", + "Iteration: 9; Percent complete: 40.9%\n", + "Iteration: 10; Percent complete: 45.5%\n", + "Iteration: 11; Percent complete: 50.0%\n", + "Iteration: 12; Percent complete: 54.5%\n", + "Iteration: 13; Percent complete: 59.1%\n", + "Iteration: 14; Percent complete: 63.6%\n", + "Iteration: 15; Percent complete: 68.2%\n", + "Iteration: 16; Percent complete: 72.7%\n", + "Iteration: 17; Percent complete: 77.3%\n", + "Iteration: 18; Percent complete: 81.8%\n", + "Iteration: 19; Percent complete: 86.4%\n", + "Iteration: 20; Percent complete: 90.9%\n", + "Iteration: 21; Percent complete: 95.5%\n", + "Iteration: 22; Percent complete: 100.0%\n", + "Validation set accuracy: 54.68%\n", + "Validation accuracy better than current best; saving model...\n", + "Iteration: 70; Percent complete: 3.6%; Average loss: 0.6925\n", + "Iteration: 80; Percent complete: 4.2%; Average loss: 0.6922\n", + "Iteration: 90; Percent complete: 4.7%; Average loss: 0.6914\n", + "Iteration: 100; Percent complete: 5.2%; Average loss: 0.6918\n", + "Iteration: 110; Percent complete: 5.7%; Average loss: 0.6911\n", + "Iteration: 120; Percent complete: 6.2%; Average loss: 0.6915\n", + "Validating!\n", + "Iteration: 1; Percent complete: 4.5%\n", + "Iteration: 2; Percent complete: 9.1%\n", + "Iteration: 3; Percent complete: 13.6%\n", + "Iteration: 4; Percent complete: 18.2%\n", + "Iteration: 5; Percent complete: 22.7%\n", + "Iteration: 6; Percent complete: 27.3%\n", + "Iteration: 7; Percent complete: 31.8%\n", + "Iteration: 8; Percent complete: 36.4%\n", + "Iteration: 9; Percent complete: 40.9%\n", + "Iteration: 10; Percent complete: 45.5%\n", + "Iteration: 11; Percent complete: 50.0%\n", + "Iteration: 12; Percent complete: 54.5%\n", + "Iteration: 13; Percent complete: 59.1%\n", + "Iteration: 14; Percent complete: 63.6%\n", + "Iteration: 15; Percent complete: 68.2%\n", + "Iteration: 16; Percent complete: 72.7%\n", + "Iteration: 17; Percent complete: 77.3%\n", + "Iteration: 18; Percent complete: 81.8%\n", + "Iteration: 19; Percent complete: 86.4%\n", + "Iteration: 20; Percent complete: 90.9%\n", + "Iteration: 21; Percent complete: 95.5%\n", + "Iteration: 22; Percent complete: 100.0%\n", + "Validation set accuracy: 57.46%\n", + "Validation accuracy better than current best; saving model...\n", + "Iteration: 130; Percent complete: 6.8%; Average loss: 0.6924\n", + "Iteration: 140; Percent complete: 7.3%; Average loss: 0.6909\n", + "Iteration: 150; Percent complete: 7.8%; Average loss: 0.6903\n", + "Iteration: 160; Percent complete: 8.3%; Average loss: 0.6905\n", + "Iteration: 170; Percent complete: 8.9%; Average loss: 0.6901\n", + "Iteration: 180; Percent complete: 9.4%; Average loss: 0.6909\n", + "Iteration: 190; Percent complete: 9.9%; Average loss: 0.6896\n", + "Validating!\n", + "Iteration: 1; Percent complete: 4.5%\n", + "Iteration: 2; Percent complete: 9.1%\n", + "Iteration: 3; Percent complete: 13.6%\n", + "Iteration: 4; Percent complete: 18.2%\n", + "Iteration: 5; Percent complete: 22.7%\n", + "Iteration: 6; Percent complete: 27.3%\n", + "Iteration: 7; Percent complete: 31.8%\n", + "Iteration: 8; Percent complete: 36.4%\n", + "Iteration: 9; Percent complete: 40.9%\n", + "Iteration: 10; Percent complete: 45.5%\n", + "Iteration: 11; Percent complete: 50.0%\n", + "Iteration: 12; Percent complete: 54.5%\n", + "Iteration: 13; Percent complete: 59.1%\n", + "Iteration: 14; Percent complete: 63.6%\n", + "Iteration: 15; Percent complete: 68.2%\n", + "Iteration: 16; Percent complete: 72.7%\n", + "Iteration: 17; Percent complete: 77.3%\n", + "Iteration: 18; Percent complete: 81.8%\n", + "Iteration: 19; Percent complete: 86.4%\n", + "Iteration: 20; Percent complete: 90.9%\n", + "Iteration: 21; Percent complete: 95.5%\n", + "Iteration: 22; Percent complete: 100.0%\n", + "Validation set accuracy: 57.68%\n", + "Validation accuracy better than current best; saving model...\n", + "Iteration: 200; Percent complete: 10.4%; Average loss: 0.6904\n", + "Iteration: 210; Percent complete: 10.9%; Average loss: 0.6888\n", + "Iteration: 220; Percent complete: 11.5%; Average loss: 0.6887\n", + "Iteration: 230; Percent complete: 12.0%; Average loss: 0.6893\n", + "Iteration: 240; Percent complete: 12.5%; Average loss: 0.6874\n", + "Iteration: 250; Percent complete: 13.0%; Average loss: 0.6864\n", + "Validating!\n", + "Iteration: 1; Percent complete: 4.5%\n", + "Iteration: 2; Percent complete: 9.1%\n", + "Iteration: 3; Percent complete: 13.6%\n", + "Iteration: 4; Percent complete: 18.2%\n", + "Iteration: 5; Percent complete: 22.7%\n", + "Iteration: 6; Percent complete: 27.3%\n", + "Iteration: 7; Percent complete: 31.8%\n", + "Iteration: 8; Percent complete: 36.4%\n", + "Iteration: 9; Percent complete: 40.9%\n", + "Iteration: 10; Percent complete: 45.5%\n", + "Iteration: 11; Percent complete: 50.0%\n", + "Iteration: 12; Percent complete: 54.5%\n", + "Iteration: 13; Percent complete: 59.1%\n", + "Iteration: 14; Percent complete: 63.6%\n", + "Iteration: 15; Percent complete: 68.2%\n", + "Iteration: 16; Percent complete: 72.7%\n", + "Iteration: 17; Percent complete: 77.3%\n", + "Iteration: 18; Percent complete: 81.8%\n", + "Iteration: 19; Percent complete: 86.4%\n", + "Iteration: 20; Percent complete: 90.9%\n", + "Iteration: 21; Percent complete: 95.5%\n", + "Iteration: 22; Percent complete: 100.0%\n", + "Validation set accuracy: 57.68%\n", + "Iteration: 260; Percent complete: 13.5%; Average loss: 0.6862\n", + "Iteration: 270; Percent complete: 14.1%; Average loss: 0.6865\n", + "Iteration: 280; Percent complete: 14.6%; Average loss: 0.6851\n", + "Iteration: 290; Percent complete: 15.1%; Average loss: 0.6827\n", + "Iteration: 300; Percent complete: 15.6%; Average loss: 0.6821\n", + "Iteration: 310; Percent complete: 16.1%; Average loss: 0.6838\n", + "Iteration: 320; Percent complete: 16.7%; Average loss: 0.6796\n", + "Validating!\n", + "Iteration: 1; Percent complete: 4.5%\n", + "Iteration: 2; Percent complete: 9.1%\n", + "Iteration: 3; Percent complete: 13.6%\n", + "Iteration: 4; Percent complete: 18.2%\n", + "Iteration: 5; Percent complete: 22.7%\n", + "Iteration: 6; Percent complete: 27.3%\n", + "Iteration: 7; Percent complete: 31.8%\n", + "Iteration: 8; Percent complete: 36.4%\n", + "Iteration: 9; Percent complete: 40.9%\n", + "Iteration: 10; Percent complete: 45.5%\n", + "Iteration: 11; Percent complete: 50.0%\n", + "Iteration: 12; Percent complete: 54.5%\n", + "Iteration: 13; Percent complete: 59.1%\n", + "Iteration: 14; Percent complete: 63.6%\n", + "Iteration: 15; Percent complete: 68.2%\n", + "Iteration: 16; Percent complete: 72.7%\n", + "Iteration: 17; Percent complete: 77.3%\n", + "Iteration: 18; Percent complete: 81.8%\n", + "Iteration: 19; Percent complete: 86.4%\n", + "Iteration: 20; Percent complete: 90.9%\n", + "Iteration: 21; Percent complete: 95.5%\n", + "Iteration: 22; Percent complete: 100.0%\n", + "Validation set accuracy: 59.36%\n", + "Validation accuracy better than current best; saving model...\n", + "Iteration: 330; Percent complete: 17.2%; Average loss: 0.6776\n", + "Iteration: 340; Percent complete: 17.7%; Average loss: 0.6787\n", + "Iteration: 350; Percent complete: 18.2%; Average loss: 0.6744\n", + "Iteration: 360; Percent complete: 18.8%; Average loss: 0.6713\n", + "Iteration: 370; Percent complete: 19.3%; Average loss: 0.6653\n", + "Iteration: 380; Percent complete: 19.8%; Average loss: 0.6668\n", + "Validating!\n", + "Iteration: 1; Percent complete: 4.5%\n", + "Iteration: 2; Percent complete: 9.1%\n", + "Iteration: 3; Percent complete: 13.6%\n", + "Iteration: 4; Percent complete: 18.2%\n", + "Iteration: 5; Percent complete: 22.7%\n", + "Iteration: 6; Percent complete: 27.3%\n", + "Iteration: 7; Percent complete: 31.8%\n", + "Iteration: 8; Percent complete: 36.4%\n", + "Iteration: 9; Percent complete: 40.9%\n", + "Iteration: 10; Percent complete: 45.5%\n", + "Iteration: 11; Percent complete: 50.0%\n", + "Iteration: 12; Percent complete: 54.5%\n", + "Iteration: 13; Percent complete: 59.1%\n", + "Iteration: 14; Percent complete: 63.6%\n", + "Iteration: 15; Percent complete: 68.2%\n", + "Iteration: 16; Percent complete: 72.7%\n", + "Iteration: 17; Percent complete: 77.3%\n", + "Iteration: 18; Percent complete: 81.8%\n", + "Iteration: 19; Percent complete: 86.4%\n", + "Iteration: 20; Percent complete: 90.9%\n", + "Iteration: 21; Percent complete: 95.5%\n", + "Iteration: 22; Percent complete: 100.0%\n", + "Validation set accuracy: 60.16%\n", + "Validation accuracy better than current best; saving model...\n", + "Iteration: 390; Percent complete: 20.3%; Average loss: 0.6713\n", + "Iteration: 400; Percent complete: 20.8%; Average loss: 0.6594\n", + "Iteration: 410; Percent complete: 21.4%; Average loss: 0.6544\n", + "Iteration: 420; Percent complete: 21.9%; Average loss: 0.6555\n", + "Iteration: 430; Percent complete: 22.4%; Average loss: 0.6486\n", + "Iteration: 440; Percent complete: 22.9%; Average loss: 0.6357\n", + "Validating!\n", + "Iteration: 1; Percent complete: 4.5%\n", + "Iteration: 2; Percent complete: 9.1%\n", + "Iteration: 3; Percent complete: 13.6%\n", + "Iteration: 4; Percent complete: 18.2%\n", + "Iteration: 5; Percent complete: 22.7%\n", + "Iteration: 6; Percent complete: 27.3%\n", + "Iteration: 7; Percent complete: 31.8%\n", + "Iteration: 8; Percent complete: 36.4%\n", + "Iteration: 9; Percent complete: 40.9%\n", + "Iteration: 10; Percent complete: 45.5%\n", + "Iteration: 11; Percent complete: 50.0%\n", + "Iteration: 12; Percent complete: 54.5%\n", + "Iteration: 13; Percent complete: 59.1%\n", + "Iteration: 14; Percent complete: 63.6%\n", + "Iteration: 15; Percent complete: 68.2%\n", + "Iteration: 16; Percent complete: 72.7%\n", + "Iteration: 17; Percent complete: 77.3%\n", + "Iteration: 18; Percent complete: 81.8%\n", + "Iteration: 19; Percent complete: 86.4%\n", + "Iteration: 20; Percent complete: 90.9%\n", + "Iteration: 21; Percent complete: 95.5%\n", + "Iteration: 22; Percent complete: 100.0%\n", + "Validation set accuracy: 63.30%\n", + "Validation accuracy better than current best; saving model...\n", + "Iteration: 450; Percent complete: 23.4%; Average loss: 0.6403\n", + "Iteration: 460; Percent complete: 24.0%; Average loss: 0.6375\n", + "Iteration: 470; Percent complete: 24.5%; Average loss: 0.6197\n", + "Iteration: 480; Percent complete: 25.0%; Average loss: 0.6118\n", + "Iteration: 490; Percent complete: 25.5%; Average loss: 0.6064\n", + "Iteration: 500; Percent complete: 26.0%; Average loss: 0.6035\n", + "Iteration: 510; Percent complete: 26.6%; Average loss: 0.6061\n", + "Validating!\n", + "Iteration: 1; Percent complete: 4.5%\n", + "Iteration: 2; Percent complete: 9.1%\n", + "Iteration: 3; Percent complete: 13.6%\n", + "Iteration: 4; Percent complete: 18.2%\n", + "Iteration: 5; Percent complete: 22.7%\n", + "Iteration: 6; Percent complete: 27.3%\n", + "Iteration: 7; Percent complete: 31.8%\n", + "Iteration: 8; Percent complete: 36.4%\n", + "Iteration: 9; Percent complete: 40.9%\n", + "Iteration: 10; Percent complete: 45.5%\n", + "Iteration: 11; Percent complete: 50.0%\n", + "Iteration: 12; Percent complete: 54.5%\n", + "Iteration: 13; Percent complete: 59.1%\n", + "Iteration: 14; Percent complete: 63.6%\n", + "Iteration: 15; Percent complete: 68.2%\n", + "Iteration: 16; Percent complete: 72.7%\n", + "Iteration: 17; Percent complete: 77.3%\n", + "Iteration: 18; Percent complete: 81.8%\n", + "Iteration: 19; Percent complete: 86.4%\n", + "Iteration: 20; Percent complete: 90.9%\n", + "Iteration: 21; Percent complete: 95.5%\n", + "Iteration: 22; Percent complete: 100.0%\n", + "Validation set accuracy: 64.91%\n", + "Validation accuracy better than current best; saving model...\n", + "Iteration: 520; Percent complete: 27.1%; Average loss: 0.5963\n", + "Iteration: 530; Percent complete: 27.6%; Average loss: 0.5713\n", + "Iteration: 540; Percent complete: 28.1%; Average loss: 0.5619\n", + "Iteration: 550; Percent complete: 28.6%; Average loss: 0.5665\n", + "Iteration: 560; Percent complete: 29.2%; Average loss: 0.5564\n", + "Iteration: 570; Percent complete: 29.7%; Average loss: 0.5412\n", + "Validating!\n", + "Iteration: 1; Percent complete: 4.5%\n", + "Iteration: 2; Percent complete: 9.1%\n", + "Iteration: 3; Percent complete: 13.6%\n", + "Iteration: 4; Percent complete: 18.2%\n", + "Iteration: 5; Percent complete: 22.7%\n", + "Iteration: 6; Percent complete: 27.3%\n", + "Iteration: 7; Percent complete: 31.8%\n", + "Iteration: 8; Percent complete: 36.4%\n", + "Iteration: 9; Percent complete: 40.9%\n", + "Iteration: 10; Percent complete: 45.5%\n", + "Iteration: 11; Percent complete: 50.0%\n", + "Iteration: 12; Percent complete: 54.5%\n", + "Iteration: 13; Percent complete: 59.1%\n", + "Iteration: 14; Percent complete: 63.6%\n", + "Iteration: 15; Percent complete: 68.2%\n", + "Iteration: 16; Percent complete: 72.7%\n", + "Iteration: 17; Percent complete: 77.3%\n", + "Iteration: 18; Percent complete: 81.8%\n", + "Iteration: 19; Percent complete: 86.4%\n", + "Iteration: 20; Percent complete: 90.9%\n", + "Iteration: 21; Percent complete: 95.5%\n", + "Iteration: 22; Percent complete: 100.0%\n", + "Validation set accuracy: 65.42%\n", + "Validation accuracy better than current best; saving model...\n", + "Iteration: 580; Percent complete: 30.2%; Average loss: 0.5506\n", + "Iteration: 590; Percent complete: 30.7%; Average loss: 0.5674\n", + "Iteration: 600; Percent complete: 31.2%; Average loss: 0.5170\n", + "Iteration: 610; Percent complete: 31.8%; Average loss: 0.5094\n", + "Iteration: 620; Percent complete: 32.3%; Average loss: 0.5063\n", + "Iteration: 630; Percent complete: 32.8%; Average loss: 0.5252\n", + "Iteration: 640; Percent complete: 33.3%; Average loss: 0.5031\n", + "Validating!\n", + "Iteration: 1; Percent complete: 4.5%\n", + "Iteration: 2; Percent complete: 9.1%\n", + "Iteration: 3; Percent complete: 13.6%\n", + "Iteration: 4; Percent complete: 18.2%\n", + "Iteration: 5; Percent complete: 22.7%\n", + "Iteration: 6; Percent complete: 27.3%\n", + "Iteration: 7; Percent complete: 31.8%\n", + "Iteration: 8; Percent complete: 36.4%\n", + "Iteration: 9; Percent complete: 40.9%\n", + "Iteration: 10; Percent complete: 45.5%\n", + "Iteration: 11; Percent complete: 50.0%\n", + "Iteration: 12; Percent complete: 54.5%\n", + "Iteration: 13; Percent complete: 59.1%\n", + "Iteration: 14; Percent complete: 63.6%\n", + "Iteration: 15; Percent complete: 68.2%\n", + "Iteration: 16; Percent complete: 72.7%\n", + "Iteration: 17; Percent complete: 77.3%\n", + "Iteration: 18; Percent complete: 81.8%\n", + "Iteration: 19; Percent complete: 86.4%\n", + "Iteration: 20; Percent complete: 90.9%\n", + "Iteration: 21; Percent complete: 95.5%\n", + "Iteration: 22; Percent complete: 100.0%\n", + "Validation set accuracy: 66.01%\n", + "Validation accuracy better than current best; saving model...\n", + "Iteration: 650; Percent complete: 33.9%; Average loss: 0.5073\n", + "Iteration: 660; Percent complete: 34.4%; Average loss: 0.4794\n", + "Iteration: 670; Percent complete: 34.9%; Average loss: 0.4498\n", + "Iteration: 680; Percent complete: 35.4%; Average loss: 0.4728\n", + "Iteration: 690; Percent complete: 35.9%; Average loss: 0.4763\n", + "Iteration: 700; Percent complete: 36.5%; Average loss: 0.4636\n", + "Validating!\n", + "Iteration: 1; Percent complete: 4.5%\n", + "Iteration: 2; Percent complete: 9.1%\n", + "Iteration: 3; Percent complete: 13.6%\n", + "Iteration: 4; Percent complete: 18.2%\n", + "Iteration: 5; Percent complete: 22.7%\n", + "Iteration: 6; Percent complete: 27.3%\n", + "Iteration: 7; Percent complete: 31.8%\n", + "Iteration: 8; Percent complete: 36.4%\n", + "Iteration: 9; Percent complete: 40.9%\n", + "Iteration: 10; Percent complete: 45.5%\n", + "Iteration: 11; Percent complete: 50.0%\n", + "Iteration: 12; Percent complete: 54.5%\n", + "Iteration: 13; Percent complete: 59.1%\n", + "Iteration: 14; Percent complete: 63.6%\n", + "Iteration: 15; Percent complete: 68.2%\n", + "Iteration: 16; Percent complete: 72.7%\n", + "Iteration: 17; Percent complete: 77.3%\n", + "Iteration: 18; Percent complete: 81.8%\n", + "Iteration: 19; Percent complete: 86.4%\n", + "Iteration: 20; Percent complete: 90.9%\n", + "Iteration: 21; Percent complete: 95.5%\n", + "Iteration: 22; Percent complete: 100.0%\n", + "Validation set accuracy: 65.57%\n", + "Iteration: 710; Percent complete: 37.0%; Average loss: 0.4624\n", + "Iteration: 720; Percent complete: 37.5%; Average loss: 0.4821\n", + "Iteration: 730; Percent complete: 38.0%; Average loss: 0.4278\n", + "Iteration: 740; Percent complete: 38.5%; Average loss: 0.4257\n", + "Iteration: 750; Percent complete: 39.1%; Average loss: 0.4064\n", + "Iteration: 760; Percent complete: 39.6%; Average loss: 0.4220\n", + "Validating!\n", + "Iteration: 1; Percent complete: 4.5%\n", + "Iteration: 2; Percent complete: 9.1%\n", + "Iteration: 3; Percent complete: 13.6%\n", + "Iteration: 4; Percent complete: 18.2%\n", + "Iteration: 5; Percent complete: 22.7%\n", + "Iteration: 6; Percent complete: 27.3%\n", + "Iteration: 7; Percent complete: 31.8%\n", + "Iteration: 8; Percent complete: 36.4%\n", + "Iteration: 9; Percent complete: 40.9%\n", + "Iteration: 10; Percent complete: 45.5%\n", + "Iteration: 11; Percent complete: 50.0%\n", + "Iteration: 12; Percent complete: 54.5%\n", + "Iteration: 13; Percent complete: 59.1%\n", + "Iteration: 14; Percent complete: 63.6%\n", + "Iteration: 15; Percent complete: 68.2%\n", + "Iteration: 16; Percent complete: 72.7%\n", + "Iteration: 17; Percent complete: 77.3%\n", + "Iteration: 18; Percent complete: 81.8%\n", + "Iteration: 19; Percent complete: 86.4%\n", + "Iteration: 20; Percent complete: 90.9%\n", + "Iteration: 21; Percent complete: 95.5%\n", + "Iteration: 22; Percent complete: 100.0%\n", + "Validation set accuracy: 66.23%\n", + "Validation accuracy better than current best; saving model...\n", + "Iteration: 770; Percent complete: 40.1%; Average loss: 0.4294\n", + "Iteration: 780; Percent complete: 40.6%; Average loss: 0.4430\n", + "Iteration: 790; Percent complete: 41.1%; Average loss: 0.4081\n", + "Iteration: 800; Percent complete: 41.7%; Average loss: 0.3874\n", + "Iteration: 810; Percent complete: 42.2%; Average loss: 0.4166\n", + "Iteration: 820; Percent complete: 42.7%; Average loss: 0.3547\n", + "Iteration: 830; Percent complete: 43.2%; Average loss: 0.3526\n", + "Validating!\n", + "Iteration: 1; Percent complete: 4.5%\n", + "Iteration: 2; Percent complete: 9.1%\n", + "Iteration: 3; Percent complete: 13.6%\n", + "Iteration: 4; Percent complete: 18.2%\n", + "Iteration: 5; Percent complete: 22.7%\n", + "Iteration: 6; Percent complete: 27.3%\n", + "Iteration: 7; Percent complete: 31.8%\n", + "Iteration: 8; Percent complete: 36.4%\n", + "Iteration: 9; Percent complete: 40.9%\n", + "Iteration: 10; Percent complete: 45.5%\n", + "Iteration: 11; Percent complete: 50.0%\n", + "Iteration: 12; Percent complete: 54.5%\n", + "Iteration: 13; Percent complete: 59.1%\n", + "Iteration: 14; Percent complete: 63.6%\n", + "Iteration: 15; Percent complete: 68.2%\n", + "Iteration: 16; Percent complete: 72.7%\n", + "Iteration: 17; Percent complete: 77.3%\n", + "Iteration: 18; Percent complete: 81.8%\n", + "Iteration: 19; Percent complete: 86.4%\n", + "Iteration: 20; Percent complete: 90.9%\n", + "Iteration: 21; Percent complete: 95.5%\n", + "Iteration: 22; Percent complete: 100.0%\n", + "Validation set accuracy: 65.86%\n", + "Iteration: 840; Percent complete: 43.8%; Average loss: 0.3744\n", + "Iteration: 850; Percent complete: 44.3%; Average loss: 0.3777\n", + "Iteration: 860; Percent complete: 44.8%; Average loss: 0.3474\n", + "Iteration: 870; Percent complete: 45.3%; Average loss: 0.3519\n", + "Iteration: 880; Percent complete: 45.8%; Average loss: 0.3252\n", + "Iteration: 890; Percent complete: 46.4%; Average loss: 0.3426\n", + "Validating!\n", + "Iteration: 1; Percent complete: 4.5%\n", + "Iteration: 2; Percent complete: 9.1%\n", + "Iteration: 3; Percent complete: 13.6%\n", + "Iteration: 4; Percent complete: 18.2%\n", + "Iteration: 5; Percent complete: 22.7%\n", + "Iteration: 6; Percent complete: 27.3%\n", + "Iteration: 7; Percent complete: 31.8%\n", + "Iteration: 8; Percent complete: 36.4%\n", + "Iteration: 9; Percent complete: 40.9%\n", + "Iteration: 10; Percent complete: 45.5%\n", + "Iteration: 11; Percent complete: 50.0%\n", + "Iteration: 12; Percent complete: 54.5%\n", + "Iteration: 13; Percent complete: 59.1%\n", + "Iteration: 14; Percent complete: 63.6%\n", + "Iteration: 15; Percent complete: 68.2%\n", + "Iteration: 16; Percent complete: 72.7%\n", + "Iteration: 17; Percent complete: 77.3%\n", + "Iteration: 18; Percent complete: 81.8%\n", + "Iteration: 19; Percent complete: 86.4%\n", + "Iteration: 20; Percent complete: 90.9%\n", + "Iteration: 21; Percent complete: 95.5%\n", + "Iteration: 22; Percent complete: 100.0%\n", + "Validation set accuracy: 65.79%\n", + "Iteration: 900; Percent complete: 46.9%; Average loss: 0.3051\n", + "Iteration: 910; Percent complete: 47.4%; Average loss: 0.3458\n", + "Iteration: 920; Percent complete: 47.9%; Average loss: 0.3052\n", + "Iteration: 930; Percent complete: 48.4%; Average loss: 0.2959\n", + "Iteration: 940; Percent complete: 49.0%; Average loss: 0.2871\n", + "Iteration: 950; Percent complete: 49.5%; Average loss: 0.2716\n", + "Iteration: 960; Percent complete: 50.0%; Average loss: 0.2974\n", + "Validating!\n", + "Iteration: 1; Percent complete: 4.5%\n", + "Iteration: 2; Percent complete: 9.1%\n", + "Iteration: 3; Percent complete: 13.6%\n", + "Iteration: 4; Percent complete: 18.2%\n", + "Iteration: 5; Percent complete: 22.7%\n", + "Iteration: 6; Percent complete: 27.3%\n", + "Iteration: 7; Percent complete: 31.8%\n", + "Iteration: 8; Percent complete: 36.4%\n", + "Iteration: 9; Percent complete: 40.9%\n", + "Iteration: 10; Percent complete: 45.5%\n", + "Iteration: 11; Percent complete: 50.0%\n", + "Iteration: 12; Percent complete: 54.5%\n", + "Iteration: 13; Percent complete: 59.1%\n", + "Iteration: 14; Percent complete: 63.6%\n", + "Iteration: 15; Percent complete: 68.2%\n", + "Iteration: 16; Percent complete: 72.7%\n", + "Iteration: 17; Percent complete: 77.3%\n", + "Iteration: 18; Percent complete: 81.8%\n", + "Iteration: 19; Percent complete: 86.4%\n", + "Iteration: 20; Percent complete: 90.9%\n", + "Iteration: 21; Percent complete: 95.5%\n", + "Iteration: 22; Percent complete: 100.0%\n", + "Validation set accuracy: 65.94%\n", + "Iteration: 970; Percent complete: 50.5%; Average loss: 0.2821\n", + "Iteration: 980; Percent complete: 51.0%; Average loss: 0.2534\n", + "Iteration: 990; Percent complete: 51.6%; Average loss: 0.2481\n", + "Iteration: 1000; Percent complete: 52.1%; Average loss: 0.2411\n", + "Iteration: 1010; Percent complete: 52.6%; Average loss: 0.2254\n", + "Iteration: 1020; Percent complete: 53.1%; Average loss: 0.2376\n", + "Validating!\n", + "Iteration: 1; Percent complete: 4.5%\n", + "Iteration: 2; Percent complete: 9.1%\n", + "Iteration: 3; Percent complete: 13.6%\n", + "Iteration: 4; Percent complete: 18.2%\n", + "Iteration: 5; Percent complete: 22.7%\n", + "Iteration: 6; Percent complete: 27.3%\n", + "Iteration: 7; Percent complete: 31.8%\n", + "Iteration: 8; Percent complete: 36.4%\n", + "Iteration: 9; Percent complete: 40.9%\n", + "Iteration: 10; Percent complete: 45.5%\n", + "Iteration: 11; Percent complete: 50.0%\n", + "Iteration: 12; Percent complete: 54.5%\n", + "Iteration: 13; Percent complete: 59.1%\n", + "Iteration: 14; Percent complete: 63.6%\n", + "Iteration: 15; Percent complete: 68.2%\n", + "Iteration: 16; Percent complete: 72.7%\n", + "Iteration: 17; Percent complete: 77.3%\n", + "Iteration: 18; Percent complete: 81.8%\n", + "Iteration: 19; Percent complete: 86.4%\n", + "Iteration: 20; Percent complete: 90.9%\n", + "Iteration: 21; Percent complete: 95.5%\n", + "Iteration: 22; Percent complete: 100.0%\n", + "Validation set accuracy: 65.42%\n", + "Iteration: 1030; Percent complete: 53.6%; Average loss: 0.2693\n", + "Iteration: 1040; Percent complete: 54.2%; Average loss: 0.2334\n", + "Iteration: 1050; Percent complete: 54.7%; Average loss: 0.2027\n", + "Iteration: 1060; Percent complete: 55.2%; Average loss: 0.2216\n", + "Iteration: 1070; Percent complete: 55.7%; Average loss: 0.2104\n", + "Iteration: 1080; Percent complete: 56.2%; Average loss: 0.2178\n", + "Validating!\n", + "Iteration: 1; Percent complete: 4.5%\n", + "Iteration: 2; Percent complete: 9.1%\n", + "Iteration: 3; Percent complete: 13.6%\n", + "Iteration: 4; Percent complete: 18.2%\n", + "Iteration: 5; Percent complete: 22.7%\n", + "Iteration: 6; Percent complete: 27.3%\n", + "Iteration: 7; Percent complete: 31.8%\n", + "Iteration: 8; Percent complete: 36.4%\n", + "Iteration: 9; Percent complete: 40.9%\n", + "Iteration: 10; Percent complete: 45.5%\n", + "Iteration: 11; Percent complete: 50.0%\n", + "Iteration: 12; Percent complete: 54.5%\n", + "Iteration: 13; Percent complete: 59.1%\n", + "Iteration: 14; Percent complete: 63.6%\n", + "Iteration: 15; Percent complete: 68.2%\n", + "Iteration: 16; Percent complete: 72.7%\n", + "Iteration: 17; Percent complete: 77.3%\n", + "Iteration: 18; Percent complete: 81.8%\n", + "Iteration: 19; Percent complete: 86.4%\n", + "Iteration: 20; Percent complete: 90.9%\n", + "Iteration: 21; Percent complete: 95.5%\n", + "Iteration: 22; Percent complete: 100.0%\n", + "Validation set accuracy: 65.79%\n", + "Iteration: 1090; Percent complete: 56.8%; Average loss: 0.2004\n", + "Iteration: 1100; Percent complete: 57.3%; Average loss: 0.1887\n", + "Iteration: 1110; Percent complete: 57.8%; Average loss: 0.1888\n", + "Iteration: 1120; Percent complete: 58.3%; Average loss: 0.1541\n", + "Iteration: 1130; Percent complete: 58.9%; Average loss: 0.1757\n", + "Iteration: 1140; Percent complete: 59.4%; Average loss: 0.1515\n", + "Iteration: 1150; Percent complete: 59.9%; Average loss: 0.1530\n", + "Validating!\n", + "Iteration: 1; Percent complete: 4.5%\n", + "Iteration: 2; Percent complete: 9.1%\n", + "Iteration: 3; Percent complete: 13.6%\n", + "Iteration: 4; Percent complete: 18.2%\n", + "Iteration: 5; Percent complete: 22.7%\n", + "Iteration: 6; Percent complete: 27.3%\n", + "Iteration: 7; Percent complete: 31.8%\n", + "Iteration: 8; Percent complete: 36.4%\n", + "Iteration: 9; Percent complete: 40.9%\n", + "Iteration: 10; Percent complete: 45.5%\n", + "Iteration: 11; Percent complete: 50.0%\n", + "Iteration: 12; Percent complete: 54.5%\n", + "Iteration: 13; Percent complete: 59.1%\n", + "Iteration: 14; Percent complete: 63.6%\n", + "Iteration: 15; Percent complete: 68.2%\n", + "Iteration: 16; Percent complete: 72.7%\n", + "Iteration: 17; Percent complete: 77.3%\n", + "Iteration: 18; Percent complete: 81.8%\n", + "Iteration: 19; Percent complete: 86.4%\n", + "Iteration: 20; Percent complete: 90.9%\n", + "Iteration: 21; Percent complete: 95.5%\n", + "Iteration: 22; Percent complete: 100.0%\n", + "Validation set accuracy: 65.42%\n", + "Iteration: 1160; Percent complete: 60.4%; Average loss: 0.1731\n", + "Iteration: 1170; Percent complete: 60.9%; Average loss: 0.1979\n", + "Iteration: 1180; Percent complete: 61.5%; Average loss: 0.1446\n", + "Iteration: 1190; Percent complete: 62.0%; Average loss: 0.1310\n", + "Iteration: 1200; Percent complete: 62.5%; Average loss: 0.1198\n", + "Iteration: 1210; Percent complete: 63.0%; Average loss: 0.1191\n", + "Validating!\n", + "Iteration: 1; Percent complete: 4.5%\n", + "Iteration: 2; Percent complete: 9.1%\n", + "Iteration: 3; Percent complete: 13.6%\n", + "Iteration: 4; Percent complete: 18.2%\n", + "Iteration: 5; Percent complete: 22.7%\n", + "Iteration: 6; Percent complete: 27.3%\n", + "Iteration: 7; Percent complete: 31.8%\n", + "Iteration: 8; Percent complete: 36.4%\n", + "Iteration: 9; Percent complete: 40.9%\n", + "Iteration: 10; Percent complete: 45.5%\n", + "Iteration: 11; Percent complete: 50.0%\n", + "Iteration: 12; Percent complete: 54.5%\n", + "Iteration: 13; Percent complete: 59.1%\n", + "Iteration: 14; Percent complete: 63.6%\n", + "Iteration: 15; Percent complete: 68.2%\n", + "Iteration: 16; Percent complete: 72.7%\n", + "Iteration: 17; Percent complete: 77.3%\n", + "Iteration: 18; Percent complete: 81.8%\n", + "Iteration: 19; Percent complete: 86.4%\n", + "Iteration: 20; Percent complete: 90.9%\n", + "Iteration: 21; Percent complete: 95.5%\n", + "Iteration: 22; Percent complete: 100.0%\n", + "Validation set accuracy: 66.01%\n", + "Iteration: 1220; Percent complete: 63.5%; Average loss: 0.1401\n", + "Iteration: 1230; Percent complete: 64.1%; Average loss: 0.1560\n", + "Iteration: 1240; Percent complete: 64.6%; Average loss: 0.1022\n", + "Iteration: 1250; Percent complete: 65.1%; Average loss: 0.1002\n", + "Iteration: 1260; Percent complete: 65.6%; Average loss: 0.1108\n", + "Iteration: 1270; Percent complete: 66.1%; Average loss: 0.1040\n", + "Iteration: 1280; Percent complete: 66.7%; Average loss: 0.1078\n", + "Validating!\n", + "Iteration: 1; Percent complete: 4.5%\n", + "Iteration: 2; Percent complete: 9.1%\n", + "Iteration: 3; Percent complete: 13.6%\n", + "Iteration: 4; Percent complete: 18.2%\n", + "Iteration: 5; Percent complete: 22.7%\n", + "Iteration: 6; Percent complete: 27.3%\n", + "Iteration: 7; Percent complete: 31.8%\n", + "Iteration: 8; Percent complete: 36.4%\n", + "Iteration: 9; Percent complete: 40.9%\n", + "Iteration: 10; Percent complete: 45.5%\n", + "Iteration: 11; Percent complete: 50.0%\n", + "Iteration: 12; Percent complete: 54.5%\n", + "Iteration: 13; Percent complete: 59.1%\n", + "Iteration: 14; Percent complete: 63.6%\n", + "Iteration: 15; Percent complete: 68.2%\n", + "Iteration: 16; Percent complete: 72.7%\n", + "Iteration: 17; Percent complete: 77.3%\n", + "Iteration: 18; Percent complete: 81.8%\n", + "Iteration: 19; Percent complete: 86.4%\n", + "Iteration: 20; Percent complete: 90.9%\n", + "Iteration: 21; Percent complete: 95.5%\n", + "Iteration: 22; Percent complete: 100.0%\n", + "Validation set accuracy: 65.42%\n", + "Iteration: 1290; Percent complete: 67.2%; Average loss: 0.1085\n", + "Iteration: 1300; Percent complete: 67.7%; Average loss: 0.1144\n", + "Iteration: 1310; Percent complete: 68.2%; Average loss: 0.0813\n", + "Iteration: 1320; Percent complete: 68.8%; Average loss: 0.0656\n", + "Iteration: 1330; Percent complete: 69.3%; Average loss: 0.0924\n", + "Iteration: 1340; Percent complete: 69.8%; Average loss: 0.0810\n", + "Validating!\n", + "Iteration: 1; Percent complete: 4.5%\n", + "Iteration: 2; Percent complete: 9.1%\n", + "Iteration: 3; Percent complete: 13.6%\n", + "Iteration: 4; Percent complete: 18.2%\n", + "Iteration: 5; Percent complete: 22.7%\n", + "Iteration: 6; Percent complete: 27.3%\n", + "Iteration: 7; Percent complete: 31.8%\n", + "Iteration: 8; Percent complete: 36.4%\n", + "Iteration: 9; Percent complete: 40.9%\n", + "Iteration: 10; Percent complete: 45.5%\n", + "Iteration: 11; Percent complete: 50.0%\n", + "Iteration: 12; Percent complete: 54.5%\n", + "Iteration: 13; Percent complete: 59.1%\n", + "Iteration: 14; Percent complete: 63.6%\n", + "Iteration: 15; Percent complete: 68.2%\n", + "Iteration: 16; Percent complete: 72.7%\n", + "Iteration: 17; Percent complete: 77.3%\n", + "Iteration: 18; Percent complete: 81.8%\n", + "Iteration: 19; Percent complete: 86.4%\n", + "Iteration: 20; Percent complete: 90.9%\n", + "Iteration: 21; Percent complete: 95.5%\n", + "Iteration: 22; Percent complete: 100.0%\n", + "Validation set accuracy: 65.57%\n", + "Iteration: 1350; Percent complete: 70.3%; Average loss: 0.0795\n", + "Iteration: 1360; Percent complete: 70.8%; Average loss: 0.1059\n", + "Iteration: 1370; Percent complete: 71.4%; Average loss: 0.0635\n", + "Iteration: 1380; Percent complete: 71.9%; Average loss: 0.0664\n", + "Iteration: 1390; Percent complete: 72.4%; Average loss: 0.0652\n", + "Iteration: 1400; Percent complete: 72.9%; Average loss: 0.0820\n", + "Validating!\n", + "Iteration: 1; Percent complete: 4.5%\n", + "Iteration: 2; Percent complete: 9.1%\n", + "Iteration: 3; Percent complete: 13.6%\n", + "Iteration: 4; Percent complete: 18.2%\n", + "Iteration: 5; Percent complete: 22.7%\n", + "Iteration: 6; Percent complete: 27.3%\n", + "Iteration: 7; Percent complete: 31.8%\n", + "Iteration: 8; Percent complete: 36.4%\n", + "Iteration: 9; Percent complete: 40.9%\n", + "Iteration: 10; Percent complete: 45.5%\n", + "Iteration: 11; Percent complete: 50.0%\n", + "Iteration: 12; Percent complete: 54.5%\n", + "Iteration: 13; Percent complete: 59.1%\n", + "Iteration: 14; Percent complete: 63.6%\n", + "Iteration: 15; Percent complete: 68.2%\n", + "Iteration: 16; Percent complete: 72.7%\n", + "Iteration: 17; Percent complete: 77.3%\n", + "Iteration: 18; Percent complete: 81.8%\n", + "Iteration: 19; Percent complete: 86.4%\n", + "Iteration: 20; Percent complete: 90.9%\n", + "Iteration: 21; Percent complete: 95.5%\n", + "Iteration: 22; Percent complete: 100.0%\n", + "Validation set accuracy: 65.20%\n", + "Iteration: 1410; Percent complete: 73.4%; Average loss: 0.0509\n", + "Iteration: 1420; Percent complete: 74.0%; Average loss: 0.0781\n", + "Iteration: 1430; Percent complete: 74.5%; Average loss: 0.0520\n", + "Iteration: 1440; Percent complete: 75.0%; Average loss: 0.0534\n", + "Iteration: 1450; Percent complete: 75.5%; Average loss: 0.0359\n", + "Iteration: 1460; Percent complete: 76.0%; Average loss: 0.0341\n", + "Iteration: 1470; Percent complete: 76.6%; Average loss: 0.0411\n", + "Validating!\n", + "Iteration: 1; Percent complete: 4.5%\n", + "Iteration: 2; Percent complete: 9.1%\n", + "Iteration: 3; Percent complete: 13.6%\n", + "Iteration: 4; Percent complete: 18.2%\n", + "Iteration: 5; Percent complete: 22.7%\n", + "Iteration: 6; Percent complete: 27.3%\n", + "Iteration: 7; Percent complete: 31.8%\n", + "Iteration: 8; Percent complete: 36.4%\n", + "Iteration: 9; Percent complete: 40.9%\n", + "Iteration: 10; Percent complete: 45.5%\n", + "Iteration: 11; Percent complete: 50.0%\n", + "Iteration: 12; Percent complete: 54.5%\n", + "Iteration: 13; Percent complete: 59.1%\n", + "Iteration: 14; Percent complete: 63.6%\n", + "Iteration: 15; Percent complete: 68.2%\n", + "Iteration: 16; Percent complete: 72.7%\n", + "Iteration: 17; Percent complete: 77.3%\n", + "Iteration: 18; Percent complete: 81.8%\n", + "Iteration: 19; Percent complete: 86.4%\n", + "Iteration: 20; Percent complete: 90.9%\n", + "Iteration: 21; Percent complete: 95.5%\n", + "Iteration: 22; Percent complete: 100.0%\n", + "Validation set accuracy: 65.64%\n", + "Iteration: 1480; Percent complete: 77.1%; Average loss: 0.0554\n", + "Iteration: 1490; Percent complete: 77.6%; Average loss: 0.0728\n", + "Iteration: 1500; Percent complete: 78.1%; Average loss: 0.0547\n", + "Iteration: 1510; Percent complete: 78.6%; Average loss: 0.0460\n", + "Iteration: 1520; Percent complete: 79.2%; Average loss: 0.0487\n", + "Iteration: 1530; Percent complete: 79.7%; Average loss: 0.0293\n", + "Validating!\n", + "Iteration: 1; Percent complete: 4.5%\n", + "Iteration: 2; Percent complete: 9.1%\n", + "Iteration: 3; Percent complete: 13.6%\n", + "Iteration: 4; Percent complete: 18.2%\n", + "Iteration: 5; Percent complete: 22.7%\n", + "Iteration: 6; Percent complete: 27.3%\n", + "Iteration: 7; Percent complete: 31.8%\n", + "Iteration: 8; Percent complete: 36.4%\n", + "Iteration: 9; Percent complete: 40.9%\n", + "Iteration: 10; Percent complete: 45.5%\n", + "Iteration: 11; Percent complete: 50.0%\n", + "Iteration: 12; Percent complete: 54.5%\n", + "Iteration: 13; Percent complete: 59.1%\n", + "Iteration: 14; Percent complete: 63.6%\n", + "Iteration: 15; Percent complete: 68.2%\n", + "Iteration: 16; Percent complete: 72.7%\n", + "Iteration: 17; Percent complete: 77.3%\n", + "Iteration: 18; Percent complete: 81.8%\n", + "Iteration: 19; Percent complete: 86.4%\n", + "Iteration: 20; Percent complete: 90.9%\n", + "Iteration: 21; Percent complete: 95.5%\n", + "Iteration: 22; Percent complete: 100.0%\n", + "Validation set accuracy: 65.42%\n", + "Iteration: 1540; Percent complete: 80.2%; Average loss: 0.0403\n", + "Iteration: 1550; Percent complete: 80.7%; Average loss: 0.0526\n", + "Iteration: 1560; Percent complete: 81.2%; Average loss: 0.0306\n", + "Iteration: 1570; Percent complete: 81.8%; Average loss: 0.0497\n", + "Iteration: 1580; Percent complete: 82.3%; Average loss: 0.0310\n", + "Iteration: 1590; Percent complete: 82.8%; Average loss: 0.0381\n", + "Iteration: 1600; Percent complete: 83.3%; Average loss: 0.0325\n", + "Validating!\n", + "Iteration: 1; Percent complete: 4.5%\n", + "Iteration: 2; Percent complete: 9.1%\n", + "Iteration: 3; Percent complete: 13.6%\n", + "Iteration: 4; Percent complete: 18.2%\n", + "Iteration: 5; Percent complete: 22.7%\n", + "Iteration: 6; Percent complete: 27.3%\n", + "Iteration: 7; Percent complete: 31.8%\n", + "Iteration: 8; Percent complete: 36.4%\n", + "Iteration: 9; Percent complete: 40.9%\n", + "Iteration: 10; Percent complete: 45.5%\n", + "Iteration: 11; Percent complete: 50.0%\n", + "Iteration: 12; Percent complete: 54.5%\n", + "Iteration: 13; Percent complete: 59.1%\n", + "Iteration: 14; Percent complete: 63.6%\n", + "Iteration: 15; Percent complete: 68.2%\n", + "Iteration: 16; Percent complete: 72.7%\n", + "Iteration: 17; Percent complete: 77.3%\n", + "Iteration: 18; Percent complete: 81.8%\n", + "Iteration: 19; Percent complete: 86.4%\n", + "Iteration: 20; Percent complete: 90.9%\n", + "Iteration: 21; Percent complete: 95.5%\n", + "Iteration: 22; Percent complete: 100.0%\n", + "Validation set accuracy: 65.35%\n", + "Iteration: 1610; Percent complete: 83.9%; Average loss: 0.0358\n", + "Iteration: 1620; Percent complete: 84.4%; Average loss: 0.0210\n", + "Iteration: 1630; Percent complete: 84.9%; Average loss: 0.0243\n", + "Iteration: 1640; Percent complete: 85.4%; Average loss: 0.0193\n", + "Iteration: 1650; Percent complete: 85.9%; Average loss: 0.0218\n", + "Iteration: 1660; Percent complete: 86.5%; Average loss: 0.0271\n", + "Validating!\n", + "Iteration: 1; Percent complete: 4.5%\n", + "Iteration: 2; Percent complete: 9.1%\n", + "Iteration: 3; Percent complete: 13.6%\n", + "Iteration: 4; Percent complete: 18.2%\n", + "Iteration: 5; Percent complete: 22.7%\n", + "Iteration: 6; Percent complete: 27.3%\n", + "Iteration: 7; Percent complete: 31.8%\n", + "Iteration: 8; Percent complete: 36.4%\n", + "Iteration: 9; Percent complete: 40.9%\n", + "Iteration: 10; Percent complete: 45.5%\n", + "Iteration: 11; Percent complete: 50.0%\n", + "Iteration: 12; Percent complete: 54.5%\n", + "Iteration: 13; Percent complete: 59.1%\n", + "Iteration: 14; Percent complete: 63.6%\n", + "Iteration: 15; Percent complete: 68.2%\n", + "Iteration: 16; Percent complete: 72.7%\n", + "Iteration: 17; Percent complete: 77.3%\n", + "Iteration: 18; Percent complete: 81.8%\n", + "Iteration: 19; Percent complete: 86.4%\n", + "Iteration: 20; Percent complete: 90.9%\n", + "Iteration: 21; Percent complete: 95.5%\n", + "Iteration: 22; Percent complete: 100.0%\n", + "Validation set accuracy: 65.28%\n", + "Iteration: 1670; Percent complete: 87.0%; Average loss: 0.0281\n", + "Iteration: 1680; Percent complete: 87.5%; Average loss: 0.0336\n", + "Iteration: 1690; Percent complete: 88.0%; Average loss: 0.0186\n", + "Iteration: 1700; Percent complete: 88.5%; Average loss: 0.0217\n", + "Iteration: 1710; Percent complete: 89.1%; Average loss: 0.0187\n", + "Iteration: 1720; Percent complete: 89.6%; Average loss: 0.0231\n", + "Validating!\n", + "Iteration: 1; Percent complete: 4.5%\n", + "Iteration: 2; Percent complete: 9.1%\n", + "Iteration: 3; Percent complete: 13.6%\n", + "Iteration: 4; Percent complete: 18.2%\n", + "Iteration: 5; Percent complete: 22.7%\n", + "Iteration: 6; Percent complete: 27.3%\n", + "Iteration: 7; Percent complete: 31.8%\n", + "Iteration: 8; Percent complete: 36.4%\n", + "Iteration: 9; Percent complete: 40.9%\n", + "Iteration: 10; Percent complete: 45.5%\n", + "Iteration: 11; Percent complete: 50.0%\n", + "Iteration: 12; Percent complete: 54.5%\n", + "Iteration: 13; Percent complete: 59.1%\n", + "Iteration: 14; Percent complete: 63.6%\n", + "Iteration: 15; Percent complete: 68.2%\n", + "Iteration: 16; Percent complete: 72.7%\n", + "Iteration: 17; Percent complete: 77.3%\n", + "Iteration: 18; Percent complete: 81.8%\n", + "Iteration: 19; Percent complete: 86.4%\n", + "Iteration: 20; Percent complete: 90.9%\n", + "Iteration: 21; Percent complete: 95.5%\n", + "Iteration: 22; Percent complete: 100.0%\n", + "Validation set accuracy: 65.28%\n", + "Iteration: 1730; Percent complete: 90.1%; Average loss: 0.0157\n", + "Iteration: 1740; Percent complete: 90.6%; Average loss: 0.0183\n", + "Iteration: 1750; Percent complete: 91.1%; Average loss: 0.0261\n", + "Iteration: 1760; Percent complete: 91.7%; Average loss: 0.0149\n", + "Iteration: 1770; Percent complete: 92.2%; Average loss: 0.0140\n", + "Iteration: 1780; Percent complete: 92.7%; Average loss: 0.0156\n", + "Iteration: 1790; Percent complete: 93.2%; Average loss: 0.0126\n", + "Validating!\n", + "Iteration: 1; Percent complete: 4.5%\n", + "Iteration: 2; Percent complete: 9.1%\n", + "Iteration: 3; Percent complete: 13.6%\n", + "Iteration: 4; Percent complete: 18.2%\n", + "Iteration: 5; Percent complete: 22.7%\n", + "Iteration: 6; Percent complete: 27.3%\n", + "Iteration: 7; Percent complete: 31.8%\n", + "Iteration: 8; Percent complete: 36.4%\n", + "Iteration: 9; Percent complete: 40.9%\n", + "Iteration: 10; Percent complete: 45.5%\n", + "Iteration: 11; Percent complete: 50.0%\n", + "Iteration: 12; Percent complete: 54.5%\n", + "Iteration: 13; Percent complete: 59.1%\n", + "Iteration: 14; Percent complete: 63.6%\n", + "Iteration: 15; Percent complete: 68.2%\n", + "Iteration: 16; Percent complete: 72.7%\n", + "Iteration: 17; Percent complete: 77.3%\n", + "Iteration: 18; Percent complete: 81.8%\n", + "Iteration: 19; Percent complete: 86.4%\n", + "Iteration: 20; Percent complete: 90.9%\n", + "Iteration: 21; Percent complete: 95.5%\n", + "Iteration: 22; Percent complete: 100.0%\n", + "Validation set accuracy: 64.55%\n", + "Iteration: 1800; Percent complete: 93.8%; Average loss: 0.0225\n", + "Iteration: 1810; Percent complete: 94.3%; Average loss: 0.0203\n", + "Iteration: 1820; Percent complete: 94.8%; Average loss: 0.0171\n", + "Iteration: 1830; Percent complete: 95.3%; Average loss: 0.0155\n", + "Iteration: 1840; Percent complete: 95.8%; Average loss: 0.0148\n", + "Iteration: 1850; Percent complete: 96.4%; Average loss: 0.0108\n", + "Validating!\n", + "Iteration: 1; Percent complete: 4.5%\n", + "Iteration: 2; Percent complete: 9.1%\n", + "Iteration: 3; Percent complete: 13.6%\n", + "Iteration: 4; Percent complete: 18.2%\n", + "Iteration: 5; Percent complete: 22.7%\n", + "Iteration: 6; Percent complete: 27.3%\n", + "Iteration: 7; Percent complete: 31.8%\n", + "Iteration: 8; Percent complete: 36.4%\n", + "Iteration: 9; Percent complete: 40.9%\n", + "Iteration: 10; Percent complete: 45.5%\n", + "Iteration: 11; Percent complete: 50.0%\n", + "Iteration: 12; Percent complete: 54.5%\n", + "Iteration: 13; Percent complete: 59.1%\n", + "Iteration: 14; Percent complete: 63.6%\n", + "Iteration: 15; Percent complete: 68.2%\n", + "Iteration: 16; Percent complete: 72.7%\n", + "Iteration: 17; Percent complete: 77.3%\n", + "Iteration: 18; Percent complete: 81.8%\n", + "Iteration: 19; Percent complete: 86.4%\n", + "Iteration: 20; Percent complete: 90.9%\n", + "Iteration: 21; Percent complete: 95.5%\n", + "Iteration: 22; Percent complete: 100.0%\n", + "Validation set accuracy: 65.13%\n", + "Iteration: 1860; Percent complete: 96.9%; Average loss: 0.0103\n", + "Iteration: 1870; Percent complete: 97.4%; Average loss: 0.0136\n", + "Iteration: 1880; Percent complete: 97.9%; Average loss: 0.0098\n", + "Iteration: 1890; Percent complete: 98.4%; Average loss: 0.0078\n", + "Iteration: 1900; Percent complete: 99.0%; Average loss: 0.0089\n", + "Iteration: 1910; Percent complete: 99.5%; Average loss: 0.0134\n", + "Iteration: 1920; Percent complete: 100.0%; Average loss: 0.0097\n", + "Validating!\n", + "Iteration: 1; Percent complete: 4.5%\n", + "Iteration: 2; Percent complete: 9.1%\n", + "Iteration: 3; Percent complete: 13.6%\n", + "Iteration: 4; Percent complete: 18.2%\n", + "Iteration: 5; Percent complete: 22.7%\n", + "Iteration: 6; Percent complete: 27.3%\n", + "Iteration: 7; Percent complete: 31.8%\n", + "Iteration: 8; Percent complete: 36.4%\n", + "Iteration: 9; Percent complete: 40.9%\n", + "Iteration: 10; Percent complete: 45.5%\n", + "Iteration: 11; Percent complete: 50.0%\n", + "Iteration: 12; Percent complete: 54.5%\n", + "Iteration: 13; Percent complete: 59.1%\n", + "Iteration: 14; Percent complete: 63.6%\n", + "Iteration: 15; Percent complete: 68.2%\n", + "Iteration: 16; Percent complete: 72.7%\n", + "Iteration: 17; Percent complete: 77.3%\n", + "Iteration: 18; Percent complete: 81.8%\n", + "Iteration: 19; Percent complete: 86.4%\n", + "Iteration: 20; Percent complete: 90.9%\n", + "Iteration: 21; Percent complete: 95.5%\n", + "Iteration: 22; Percent complete: 100.0%\n", + "Validation set accuracy: 65.13%\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "craft_forecaster.fit(corpus, \n", + " partial(generic_fit_selector, split=\"train\"), \n", + " val_context_selector=partial(generic_fit_selector, split=\"val\"))" + ] + }, + { + "cell_type": "markdown", + "id": "3238e632-4e1b-4ba1-938a-b9aeef07f0cf", + "metadata": {}, + "source": [ + "## Run the fitted model on the test set and perform evaluation\n", + "\n", + "...and inference is done simply by calling `transform`! (again, note the selector)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "4d3e4892-f6ef-4b6f-a9e4-d318f8cb96b9", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processed 8466 context tuples for model evaluation\n", + "Loading saved parameters...\n", + "Building encoders, decoder, and classifier...\n", + "Models built and ready to go!\n", + "Iteration: 1; Percent complete: 0.8%\n", + "Iteration: 2; Percent complete: 1.5%\n", + "Iteration: 3; Percent complete: 2.3%\n", + "Iteration: 4; Percent complete: 3.0%\n", + "Iteration: 5; Percent complete: 3.8%\n", + "Iteration: 6; Percent complete: 4.5%\n", + "Iteration: 7; Percent complete: 5.3%\n", + "Iteration: 8; Percent complete: 6.0%\n", + "Iteration: 9; Percent complete: 6.8%\n", + "Iteration: 10; Percent complete: 7.5%\n", + "Iteration: 11; Percent complete: 8.3%\n", + "Iteration: 12; Percent complete: 9.0%\n", + "Iteration: 13; Percent complete: 9.8%\n", + "Iteration: 14; Percent complete: 10.5%\n", + "Iteration: 15; Percent complete: 11.3%\n", + "Iteration: 16; Percent complete: 12.0%\n", + "Iteration: 17; Percent complete: 12.8%\n", + "Iteration: 18; Percent complete: 13.5%\n", + "Iteration: 19; Percent complete: 14.3%\n", + "Iteration: 20; Percent complete: 15.0%\n", + "Iteration: 21; Percent complete: 15.8%\n", + "Iteration: 22; Percent complete: 16.5%\n", + "Iteration: 23; Percent complete: 17.3%\n", + "Iteration: 24; Percent complete: 18.0%\n", + "Iteration: 25; Percent complete: 18.8%\n", + "Iteration: 26; Percent complete: 19.5%\n", + "Iteration: 27; Percent complete: 20.3%\n", + "Iteration: 28; Percent complete: 21.1%\n", + "Iteration: 29; Percent complete: 21.8%\n", + "Iteration: 30; Percent complete: 22.6%\n", + "Iteration: 31; Percent complete: 23.3%\n", + "Iteration: 32; Percent complete: 24.1%\n", + "Iteration: 33; Percent complete: 24.8%\n", + "Iteration: 34; Percent complete: 25.6%\n", + "Iteration: 35; Percent complete: 26.3%\n", + "Iteration: 36; Percent complete: 27.1%\n", + "Iteration: 37; Percent complete: 27.8%\n", + "Iteration: 38; Percent complete: 28.6%\n", + "Iteration: 39; Percent complete: 29.3%\n", + "Iteration: 40; Percent complete: 30.1%\n", + "Iteration: 41; Percent complete: 30.8%\n", + "Iteration: 42; Percent complete: 31.6%\n", + "Iteration: 43; Percent complete: 32.3%\n", + "Iteration: 44; Percent complete: 33.1%\n", + "Iteration: 45; Percent complete: 33.8%\n", + "Iteration: 46; Percent complete: 34.6%\n", + "Iteration: 47; Percent complete: 35.3%\n", + "Iteration: 48; Percent complete: 36.1%\n", + "Iteration: 49; Percent complete: 36.8%\n", + "Iteration: 50; Percent complete: 37.6%\n", + "Iteration: 51; Percent complete: 38.3%\n", + "Iteration: 52; Percent complete: 39.1%\n", + "Iteration: 53; Percent complete: 39.8%\n", + "Iteration: 54; Percent complete: 40.6%\n", + "Iteration: 55; Percent complete: 41.4%\n", + "Iteration: 56; Percent complete: 42.1%\n", + "Iteration: 57; Percent complete: 42.9%\n", + "Iteration: 58; Percent complete: 43.6%\n", + "Iteration: 59; Percent complete: 44.4%\n", + "Iteration: 60; Percent complete: 45.1%\n", + "Iteration: 61; Percent complete: 45.9%\n", + "Iteration: 62; Percent complete: 46.6%\n", + "Iteration: 63; Percent complete: 47.4%\n", + "Iteration: 64; Percent complete: 48.1%\n", + "Iteration: 65; Percent complete: 48.9%\n", + "Iteration: 66; Percent complete: 49.6%\n", + "Iteration: 67; Percent complete: 50.4%\n", + "Iteration: 68; Percent complete: 51.1%\n", + "Iteration: 69; Percent complete: 51.9%\n", + "Iteration: 70; Percent complete: 52.6%\n", + "Iteration: 71; Percent complete: 53.4%\n", + "Iteration: 72; Percent complete: 54.1%\n", + "Iteration: 73; Percent complete: 54.9%\n", + "Iteration: 74; Percent complete: 55.6%\n", + "Iteration: 75; Percent complete: 56.4%\n", + "Iteration: 76; Percent complete: 57.1%\n", + "Iteration: 77; Percent complete: 57.9%\n", + "Iteration: 78; Percent complete: 58.6%\n", + "Iteration: 79; Percent complete: 59.4%\n", + "Iteration: 80; Percent complete: 60.2%\n", + "Iteration: 81; Percent complete: 60.9%\n", + "Iteration: 82; Percent complete: 61.7%\n", + "Iteration: 83; Percent complete: 62.4%\n", + "Iteration: 84; Percent complete: 63.2%\n", + "Iteration: 85; Percent complete: 63.9%\n", + "Iteration: 86; Percent complete: 64.7%\n", + "Iteration: 87; Percent complete: 65.4%\n", + "Iteration: 88; Percent complete: 66.2%\n", + "Iteration: 89; Percent complete: 66.9%\n", + "Iteration: 90; Percent complete: 67.7%\n", + "Iteration: 91; Percent complete: 68.4%\n", + "Iteration: 92; Percent complete: 69.2%\n", + "Iteration: 93; Percent complete: 69.9%\n", + "Iteration: 94; Percent complete: 70.7%\n", + "Iteration: 95; Percent complete: 71.4%\n", + "Iteration: 96; Percent complete: 72.2%\n", + "Iteration: 97; Percent complete: 72.9%\n", + "Iteration: 98; Percent complete: 73.7%\n", + "Iteration: 99; Percent complete: 74.4%\n", + "Iteration: 100; Percent complete: 75.2%\n", + "Iteration: 101; Percent complete: 75.9%\n", + "Iteration: 102; Percent complete: 76.7%\n", + "Iteration: 103; Percent complete: 77.4%\n", + "Iteration: 104; Percent complete: 78.2%\n", + "Iteration: 105; Percent complete: 78.9%\n", + "Iteration: 106; Percent complete: 79.7%\n", + "Iteration: 107; Percent complete: 80.5%\n", + "Iteration: 108; Percent complete: 81.2%\n", + "Iteration: 109; Percent complete: 82.0%\n", + "Iteration: 110; Percent complete: 82.7%\n", + "Iteration: 111; Percent complete: 83.5%\n", + "Iteration: 112; Percent complete: 84.2%\n", + "Iteration: 113; Percent complete: 85.0%\n", + "Iteration: 114; Percent complete: 85.7%\n", + "Iteration: 115; Percent complete: 86.5%\n", + "Iteration: 116; Percent complete: 87.2%\n", + "Iteration: 117; Percent complete: 88.0%\n", + "Iteration: 118; Percent complete: 88.7%\n", + "Iteration: 119; Percent complete: 89.5%\n", + "Iteration: 120; Percent complete: 90.2%\n", + "Iteration: 121; Percent complete: 91.0%\n", + "Iteration: 122; Percent complete: 91.7%\n", + "Iteration: 123; Percent complete: 92.5%\n", + "Iteration: 124; Percent complete: 93.2%\n", + "Iteration: 125; Percent complete: 94.0%\n", + "Iteration: 126; Percent complete: 94.7%\n", + "Iteration: 127; Percent complete: 95.5%\n", + "Iteration: 128; Percent complete: 96.2%\n", + "Iteration: 129; Percent complete: 97.0%\n", + "Iteration: 130; Percent complete: 97.7%\n", + "Iteration: 131; Percent complete: 98.5%\n", + "Iteration: 132; Percent complete: 99.2%\n", + "Iteration: 133; Percent complete: 100.0%\n" + ] + } + ], + "source": [ + "corpus = craft_forecaster.transform(corpus, transform_selector)" + ] + }, + { + "cell_type": "markdown", + "id": "9f517cfb-3a1d-4279-bf09-96695938d30c", + "metadata": {}, + "source": [ + "Finally, to get a human-readable interpretation of model performance, we can use `summarize` to generate a table of standard performance metrics. It also returns a table of conversation-level predictions in case you want to do more complex analysis!" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "c7fa3d49-39c1-4dd4-9f8e-ca31f421ac5a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy 0.626462\n", + "Precision 0.610473\n", + "Recall 0.698830\n", + "FPR 0.445906\n", + "F1 0.651670\n", + "dtype: float64\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAkgAAAGwCAYAAABSN5pGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/TGe4hAAAACXBIWXMAAA9hAAAPYQGoP6dpAABVM0lEQVR4nO3deVxU9f4/8NewzICsArIpgQsJKoKJItpVSwrTb4pZuZDgkt66igtlapmolYCpkVeuZre0flfSzCXTwojcUtxAXBGVQFxYXAIEE4X5/P7w4ck5rGMDB/X1fDzmkfM5n/mc9znMnHl15iwqIYQAEREREUmMlC6AiIiIqKlhQCIiIiKSYUAiIiIikmFAIiIiIpJhQCIiIiKSYUAiIiIikmFAIiIiIpIxUbqAh5VWq8Xly5dhZWUFlUqldDlERERUD0II3LhxA66urjAyqnk/EQPSA7p8+TLc3NyULoOIiIgewIULF9CqVasapzMgPSArKysAd1ewtbW1wtUQERFRfZSUlMDNzU36Hq8JA9IDuvezmrW1NQMSERHRQ6auw2N4kDYRERGRDAMSERERkQwDEhEREZEMAxIRERGRDAMSERERkQwDEhEREZEMAxIRERGRDAMSERERkQwDEhEREZEMAxIRERGRDAMSERERkQwDEhEREZEMAxIRERGRDAMSERERkQwDEhEREZGMidIFUNPmMXOb0iVIcmIGKl0CERE9JrgHiYiIiEiGAYmIiIhIhgGJiIiISIYBiYiIiEiGAYmIiIhIhgGJiIiISIYBiYiIiEiGAYmIiIhIhgGJiIiISIYBiYiIiEiGAYmIiIhIhgGJiIiISIYBiYiIiEiGAYmIiIhIhgGJiIiISEbxgBQfHw8PDw+YmZkhICAABw8erLHvyZMnMXToUHh4eEClUiEuLq5Kn3vT5I+JEydKffr27Vtl+htvvNEQi0dEREQPIUUD0rp16xAZGYmoqCikpaXB19cXwcHBKCwsrLb/zZs30aZNG8TExMDZ2bnaPocOHUJeXp70SEpKAgC88sorOv3Gjx+v02/hwoWGXTgiIiJ6aCkakJYsWYLx48djzJgx6NChA1asWIFmzZrhyy+/rLZ/t27d8PHHH2P48OHQaDTV9mnRogWcnZ2lx9atW9G2bVv06dNHp1+zZs10+llbWxt8+YiIiOjhpFhAun37NlJTUxEUFPRXMUZGCAoKQkpKisHm8b///Q9jx46FSqXSmbZmzRo4ODigU6dOmDVrFm7evFnrWOXl5SgpKdF5EBER0aPJRKkZX716FZWVlXByctJpd3JywunTpw0yj82bN6OoqAijR4/WaR85ciTc3d3h6uqKY8eOYcaMGcjMzMTGjRtrHCs6Ohrz5s0zSF1ERETUtCkWkBrDF198gRdeeAGurq467RMmTJD+7ePjAxcXF/Tr1w9ZWVlo27ZttWPNmjULkZGR0vOSkhK4ubk1TOFERESkKMUCkoODA4yNjVFQUKDTXlBQUOMB2Po4f/48fvnll1r3Ct0TEBAAADh37lyNAUmj0dR43BMRERE9WhQ7BkmtVqNr165ITk6W2rRaLZKTkxEYGPi3x1+1ahUcHR0xcODAOvump6cDAFxcXP72fImIiOjhp+hPbJGRkQgPD4e/vz+6d++OuLg4lJWVYcyYMQCAsLAwtGzZEtHR0QDuHnR96tQp6d+XLl1Ceno6LC0t0a5dO2lcrVaLVatWITw8HCYmuouYlZWFhIQEDBgwAPb29jh27BimTZuG3r17o3Pnzo205ERERNSUKRqQhg0bhitXrmDOnDnIz8+Hn58fEhMTpQO3c3NzYWT0106uy5cvo0uXLtLzRYsWYdGiRejTpw927twptf/yyy/Izc3F2LFjq8xTrVbjl19+kcKYm5sbhg4ditmzZzfcghIREdFDRSWEEEoX8TAqKSmBjY0NiouLH+lrKHnM3KZ0CZKcmLp/LiUiIqpNfb+/Fb/VCBEREVFTw4BEREREJMOARERERCTDgEREREQkw4BEREREJPNI32rkYdWUzhwjIiJ6HHEPEhEREZEMAxIRERGRDAMSERERkQwDEhEREZEMAxIRERGRDAMSERERkQwDEhEREZEMAxIRERGRDAMSERERkQwDEhEREZEMAxIRERGRDAMSERERkQwDEhEREZEMAxIRERGRDAMSERERkQwDEhEREZEMAxIRERGRDAMSERERkQwDEhEREZEMAxIRERGRDAMSERERkQwDEhEREZEMAxIRERGRDAMSERERkQwDEhEREZEMAxIRERGRDAMSERERkQwDEhEREZEMAxIRERGRDAMSERERkQwDEhEREZGM4gEpPj4eHh4eMDMzQ0BAAA4ePFhj35MnT2Lo0KHw8PCASqVCXFxclT5z586FSqXSeXh5een0uXXrFiZOnAh7e3tYWlpi6NChKCgoMPSiERER0UNK0YC0bt06REZGIioqCmlpafD19UVwcDAKCwur7X/z5k20adMGMTExcHZ2rnHcjh07Ii8vT3r89ttvOtOnTZuGH374AevXr8euXbtw+fJlvPTSSwZdNiIiInp4mSg58yVLlmD8+PEYM2YMAGDFihXYtm0bvvzyS8ycObNK/27duqFbt24AUO30e0xMTGoMUMXFxfjiiy+QkJCAZ599FgCwatUqeHt7Y//+/ejRo0e1rysvL0d5ebn0vKSkpH4LSURERA8dxfYg3b59G6mpqQgKCvqrGCMjBAUFISUl5W+NffbsWbi6uqJNmzYIDQ1Fbm6uNC01NRV37tzRma+XlxeeeOKJWucbHR0NGxsb6eHm5va3aiQiIqKmS7GAdPXqVVRWVsLJyUmn3cnJCfn5+Q88bkBAAFavXo3ExEQsX74c2dnZ+Mc//oEbN24AAPLz86FWq2Fra6vXfGfNmoXi4mLpceHChQeukYiIiJo2RX9iawgvvPCC9O/OnTsjICAA7u7u+PbbbzFu3LgHHlej0UCj0RiiRCIiImriFNuD5ODgAGNj4ypnjxUUFNR6ALa+bG1t8eSTT+LcuXMAAGdnZ9y+fRtFRUUNOl8iIiJ6eCkWkNRqNbp27Yrk5GSpTavVIjk5GYGBgQabT2lpKbKysuDi4gIA6Nq1K0xNTXXmm5mZidzcXIPOl4iIiB5eiv7EFhkZifDwcPj7+6N79+6Ii4tDWVmZdFZbWFgYWrZsiejoaAB3D+w+deqU9O9Lly4hPT0dlpaWaNeuHQDg7bffxosvvgh3d3dcvnwZUVFRMDY2xogRIwAANjY2GDduHCIjI2FnZwdra2tEREQgMDCwxjPYiIiI6PGiaEAaNmwYrly5gjlz5iA/Px9+fn5ITEyUDtzOzc2FkdFfO7kuX76MLl26SM8XLVqERYsWoU+fPti5cycA4OLFixgxYgSuXbuGFi1a4Omnn8b+/fvRokUL6XWffPIJjIyMMHToUJSXlyM4OBj/+c9/GmehiYiIqMlTCSGE0kU8jEpKSmBjY4Pi4mJYW1sbdGyPmdsMOt6jIidmoNIlEBHRQ66+39+K32qEiIiIqKlhQCIiIiKSYUAiIiIikmFAIiIiIpJhQCIiIiKSYUAiIiIikmFAIiIiIpJhQCIiIiKSYUAiIiIikmFAIiIiIpJhQCIiIiKSYUAiIiIiktE7IH311VfYtu2vm6m+8847sLW1Rc+ePXH+/HmDFkdERESkBL0D0oIFC2Bubg4ASElJQXx8PBYuXAgHBwdMmzbN4AUSERERNTYTfV9w4cIFtGvXDgCwefNmDB06FBMmTECvXr3Qt29fQ9dHRERE1Oj03oNkaWmJa9euAQB+/vlnPPfccwAAMzMz/Pnnn4atjoiIiEgBeu9Beu655/D666+jS5cuOHPmDAYMGAAAOHnyJDw8PAxdHxEREVGj03sPUnx8PAIDA3HlyhVs2LAB9vb2AIDU1FSMGDHC4AUSERERNTa99yDZ2tpi2bJlVdrnzZtnkIKIiIiIlKZ3QAKAoqIiHDx4EIWFhdBqtVK7SqXCqFGjDFYcERERkRL0Dkg//PADQkNDUVpaCmtra6hUKmkaAxIRERE9CvQOSG+99RbGjh2LBQsWoFmzZg1RE1G1PGZuq7tTI8iJGah0CURE1MD0Pkj70qVLmDx5MsMRERERPbL0DkjBwcE4fPhwQ9RCRERE1CTo/RPbwIEDMX36dJw6dQo+Pj4wNTXVmT5o0CCDFUdERESkBL0D0vjx4wEA8+fPrzJNpVKhsrLy71dFREREpCC9A9L9p/UTERERPYr0PgaJiIiI6FH3QAFp165dePHFF9GuXTu0a9cOgwYNwp49ewxdGxEREZEi9A5I//vf/xAUFIRmzZph8uTJmDx5MszNzdGvXz8kJCQ0RI1EREREjUolhBD6vMDb2xsTJkzAtGnTdNqXLFmCzz//HBkZGQYtsKkqKSmBjY0NiouLYW1tbdCxm8oFEal6vFAkEdHDq77f33rvQfr999/x4osvVmkfNGgQsrOz9R2OiIiIqMnROyC5ubkhOTm5Svsvv/wCNzc3gxRFREREpKQHuhfb5MmTkZ6ejp49ewIA9u7di9WrV+PTTz81eIFEREREjU3vgPTmm2/C2dkZixcvxrfffgvg7nFJ69atw+DBgw1eIBEREVFj0zsgAcCQIUMwZMgQQ9dCRERE1CQofqHI+Ph4eHh4wMzMDAEBATh48GCNfU+ePImhQ4fCw8MDKpUKcXFxVfpER0ejW7dusLKygqOjI0JCQpCZmanTp2/fvlCpVDqPN954w9CLRkRERA+peu1BsrOzw5kzZ+Dg4IDmzZtDpVLV2Pf69ev1nvm6desQGRmJFStWICAgAHFxcQgODkZmZiYcHR2r9L958ybatGmDV155pcplBu7ZtWsXJk6ciG7duqGiogLvvvsunn/+eZw6dQoWFhZSv/Hjx+vcT65Zs2b1rpseb03pMgy85AARUcOoV0D65JNPYGVlJf27toCkjyVLlmD8+PEYM2YMAGDFihXYtm0bvvzyS8ycObNK/27duqFbt24AUO10AEhMTNR5vnr1ajg6OiI1NRW9e/eW2ps1awZnZ2eDLAcRERE9WuoVkMLDw6V/jx492iAzvn37NlJTUzFr1iypzcjICEFBQUhJSTHIPACguLgYwN29YPdbs2YN/ve//8HZ2Rkvvvgi3n///Vr3IpWXl6O8vFx6XlJSYrAaiYiIqGnR+yBtY2Nj5OXlVfkJ7Nq1a3B0dERlZWW9xrl69SoqKyvh5OSk0+7k5ITTp0/rW1a1tFotpk6dil69eqFTp05S+8iRI+Hu7g5XV1ccO3YMM2bMQGZmJjZu3FjjWNHR0Zg3b55B6iIiIqKmTe+AVNOdScrLy6FWq/92QYY0ceJEnDhxAr/99ptO+4QJE6R/+/j4wMXFBf369UNWVhbatm1b7VizZs1CZGSk9LykpIQXxiQiInpE1TsgLV26FACgUqnw3//+F5aWltK0yspK7N69G15eXvWesYODA4yNjVFQUKDTXlBQYJBjgyZNmoStW7di9+7daNWqVa19AwICAADnzp2rMSBpNBpoNJq/XRcRERE1ffUOSJ988gmAu3uQVqxYAWNjY2maWq2Gh4cHVqxYUe8Zq9VqdO3aFcnJyQgJCQFw9yex5ORkTJo0qd7jyAkhEBERgU2bNmHnzp1o3bp1na9JT08HALi4uDzwfImIiOjRUe+AdO9GtM888ww2btyI5s2b/+2ZR0ZGIjw8HP7+/ujevTvi4uJQVlYmndUWFhaGli1bIjo6GsDdA7tPnTol/fvSpUtIT0+HpaUl2rVrB+Duz2oJCQn4/vvvYWVlhfz8fACAjY0NzM3NkZWVhYSEBAwYMAD29vY4duwYpk2bht69e6Nz585/e5mIiIjo4af3MUg7duww2MyHDRuGK1euYM6cOcjPz4efnx8SExOlA7dzc3NhZPTXtSwvX76MLl26SM8XLVqERYsWoU+fPti5cycAYPny5QDuXgzyfqtWrcLo0aOhVqvxyy+/SGHMzc0NQ4cOxezZsw22XERERPRwU4majrquxcWLF7Flyxbk5ubi9u3bOtOWLFlisOKaspKSEtjY2KC4uBjW1tYGHbspXYiQmjZeKJKISD/1/f7Wew9ScnIyBg0ahDZt2uD06dPo1KkTcnJyIITAU0899beKJiIiImoK9L4X26xZs/D222/j+PHjMDMzw4YNG3DhwgX06dMHr7zySkPUSERERNSo9A5IGRkZCAsLAwCYmJjgzz//hKWlJebPn4/Y2FiDF0hERETU2PQOSBYWFtJxRy4uLsjKypKmXb161XCVERERESlE72OQevTogd9++w3e3t4YMGAA3nrrLRw/fhwbN25Ejx49GqJGIiIiokald0BasmQJSktLAQDz5s1DaWkp1q1bB09Pz8fmDDYiIiJ6tOkdkNq0aSP928LCQq+rZxMRERE9DPQ+BunChQu4ePGi9PzgwYOYOnUqVq5cadDCiIiIiJSid0AaOXKkdDXt/Px8BAUF4eDBg3jvvfcwf/58gxdIRERE1Nj0DkgnTpxA9+7dAQDffvstfHx8sG/fPqxZswarV682dH1EREREjU7vgHTnzh1oNBoAwC+//IJBgwYBALy8vJCXl2fY6oiIiIgUoHdA6tixI1asWIE9e/YgKSkJ/fv3B3D3RrL29vYGL5CIiIiosekdkGJjY/HZZ5+hb9++GDFiBHx9fQEAW7ZskX56IyIiInqY6X2af9++fXH16lWUlJSgefPmUvuECRPQrFkzgxZHREREpAS99yABgBACqamp+Oyzz3Djxg0AgFqtZkAiIiKiR4Lee5DOnz+P/v37Izc3F+Xl5XjuuedgZWWF2NhYlJeX88KRRERE9NDTew/SlClT4O/vjz/++APm5uZS+5AhQ5CcnGzQ4oiIiIiUoPcepD179mDfvn1Qq9U67R4eHrh06ZLBCiMiIiJSit57kLRaLSorK6u0X7x4EVZWVgYpioiIiEhJegek559/HnFxcdJzlUqF0tJSREVFYcCAAYasjYiIiEgRev/EtnjxYgQHB6NDhw64desWRo4cibNnz8LBwQHffPNNQ9RIRERE1Kj0DkitWrXC0aNHsXbtWhw7dgylpaUYN24cQkNDdQ7aJiIiInpY6R2Qbt26BTMzM7z22msNUQ8RERGR4vQ+BsnR0RHh4eFISkqCVqttiJqIiIiIFKV3QPrqq69w8+ZNDB48GC1btsTUqVNx+PDhhqiNiIiISBF6B6QhQ4Zg/fr1KCgowIIFC3Dq1Cn06NEDTz75JObPn98QNRIRERE1qge6FxsAWFlZYcyYMfj5559x7NgxWFhYYN68eYasjYiIiEgRDxyQbt26hW+//RYhISF46qmncP36dUyfPt2QtREREREpQu+z2LZv346EhARs3rwZJiYmePnll/Hzzz+jd+/eDVEfERERUaPTOyANGTIE//d//4evv/4aAwYMgKmpaUPURURERKQYvQNSQUEB77lGREREjzS9A5KVlRW0Wi3OnTuHwsLCKtdC4k9tRERE9LDTOyDt378fI0eOxPnz5yGE0JmmUqlQWVlpsOKIiIiIlKB3QHrjjTfg7++Pbdu2wcXFBSqVqiHqIiIiIlKM3gHp7Nmz+O6779CuXbuGqIeIiIhIcXpfBykgIADnzp1riFqIiIiImgS99yBFRETgrbfeQn5+Pnx8fKqc5t+5c2eDFUdERESkBL33IA0dOhQZGRkYO3YsunXrBj8/P3Tp0kX6r77i4+Ph4eEBMzMzBAQE4ODBgzX2PXnyJIYOHQoPDw+oVCrExcU90Ji3bt3CxIkTYW9vD0tLSwwdOhQFBQV6105ERESPJr0DUnZ2dpXH77//Lv1XH+vWrUNkZCSioqKQlpYGX19fBAcHo7CwsNr+N2/eRJs2bRATEwNnZ+cHHnPatGn44YcfsH79euzatQuXL1/GSy+9pFftRERE9OhSCfm5+o0oICAA3bp1w7JlywAAWq0Wbm5uiIiIwMyZM2t9rYeHB6ZOnYqpU6fqNWZxcTFatGiBhIQEvPzyywCA06dPw9vbGykpKejRo0e18ysvL0d5ebn0vKSkBG5ubiguLoa1tfWDroLql23mNoOOR4+unJiBSpdARPRQKSkpgY2NTZ3f3w90s9qsrCxEREQgKCgIQUFBmDx5MrKysvQa4/bt20hNTUVQUNBfxRgZISgoCCkpKQ9SVr3GTE1NxZ07d3T6eHl54Yknnqh1vtHR0bCxsZEebm5uD1QjERERNX16B6Tt27ejQ4cOOHjwIDp37ozOnTvjwIED6NixI5KSkuo9ztWrV1FZWQknJyeddicnJ+Tn5+tbVr3HzM/Ph1qthq2trV7znTVrFoqLi6XHhQsXHqhGIiIiavr0Pott5syZmDZtGmJiYqq0z5gxA88995zBimtKNBoNNBqN0mUQERFRI9B7D1JGRgbGjRtXpX3s2LE4depUvcdxcHCAsbFxlbPHCgoKajwA2xBjOjs74/bt2ygqKjLYfImIiOjRondAatGiBdLT06u0p6enw9HRsd7jqNVqdO3aFcnJyVKbVqtFcnIyAgMD9S2r3mN27doVpqamOn0yMzORm5v7wPMlIiKiR4veP7GNHz8eEyZMwO+//46ePXsCAPbu3YvY2FhERkbqNVZkZCTCw8Ph7++P7t27Iy4uDmVlZRgzZgwAICwsDC1btkR0dDSAuwdh39tLdfv2bVy6dAnp6emwtLSUbn1S15g2NjYYN24cIiMjYWdnB2tra0RERCAwMLDGM9iIiIjo8aJ3QHr//fdhZWWFxYsXY9asWQAAV1dXzJ07F5MnT9ZrrGHDhuHKlSuYM2cO8vPz4efnh8TEROkg69zcXBgZ/bWT6/LlyzoXo1y0aBEWLVqEPn36YOfOnfUaEwA++eQTGBkZYejQoSgvL0dwcDD+85//6LsqiIiI6BH1t66DdOPGDQCAlZWVwQp6WNT3OgoPgtdBovridZCIiPRT3+9vvfcgZWdno6KiAp6enjrB6OzZszA1NYWHh8cDFUxERETUVOh9kPbo0aOxb9++Ku0HDhzA6NGjDVETERERkaL0DkhHjhxBr169qrT36NGj2rPbiIiIiB42egcklUolHXt0v+LiYlRWVhqkKCIiIiIl6R2QevfujejoaJ0wVFlZiejoaDz99NMGLY6IiIhICXofpB0bG4vevXujffv2+Mc//gEA2LNnD0pKSvDrr78avEAiIiKixqb3HqQOHTrg2LFjePXVV1FYWIgbN24gLCwMp0+fRqdOnRqiRiIiIqJGpfceJODuhSEXLFhg6FqIiIiImgS99yARERERPeoYkIiIiIhkGJCIiIiIZOoVkLZs2YI7d+40dC1ERERETUK9AtKQIUNQVFQEADA2NkZhYWFD1kRERESkqHoFpBYtWmD//v0AACEEVCpVgxZFREREpKR6neb/xhtvYPDgwVCpVFCpVHB2dq6xL283QtR4PGZuU7qEJicnZqDSJRDRI6BeAWnu3LkYPnw4zp07h0GDBmHVqlWwtbVt4NKIiIiIlFHvC0V6eXnBy8sLUVFReOWVV9CsWbOGrIuIiIhIMXpfSTsqKgoAcOXKFWRmZgIA2rdvjxYtWhi2MiIiIiKF6H0dpJs3b2Ls2LFwdXVF79690bt3b7i6umLcuHG4efNmQ9RIRERE1Kj0DkjTpk3Drl27sGXLFhQVFaGoqAjff/89du3ahbfeeqshaiQiIiJqVHr/xLZhwwZ899136Nu3r9Q2YMAAmJub49VXX8Xy5csNWR8RERFRo3ugn9icnJyqtDs6OvInNiIiInok6B2QAgMDERUVhVu3bkltf/75J+bNm4fAwECDFkdERESkBL1/Yvv0008RHByMVq1awdfXFwBw9OhRmJmZYfv27QYvkIiIiKix6R2QOnXqhLNnz2LNmjU4ffo0AGDEiBEIDQ2Fubm5wQskIiIiamx6ByQAaNasGcaPH2/oWoiIiIiaBL2PQSIiIiJ61DEgEREREckwIBERERHJMCARERERyegdkNq0aYNr165VaS8qKkKbNm0MUhQRERGRkvQOSDk5OaisrKzSXl5ejkuXLhmkKCIiIiIl1fs0/y1btkj/3r59O2xsbKTnlZWVSE5OhoeHh0GLIyIiIlJCvQNSSEgIAEClUiE8PFxnmqmpKTw8PLB48WKDFkdERESkhHoHJK1WCwBo3bo1Dh06BAcHhwYrioiIiEhJel9JOzs7uyHqICIiImoyHug0/+TkZLz77rt4/fXXMXbsWJ3Hg4iPj4eHhwfMzMwQEBCAgwcP1tp//fr18PLygpmZGXx8fPDjjz/qTFepVNU+Pv74Y6mPh4dHlekxMTEPVD8RERE9WvQOSPPmzcPzzz+P5ORkXL16FX/88YfOQ1/r1q1DZGQkoqKikJaWBl9fXwQHB6OwsLDa/vv27cOIESMwbtw4HDlyBCEhIQgJCcGJEyekPnl5eTqPL7/8EiqVCkOHDtUZa/78+Tr9IiIi9K6fiIiIHj0qIYTQ5wUuLi5YuHAhRo0aZZACAgIC0K1bNyxbtgzA3WOd3NzcEBERgZkzZ1bpP2zYMJSVlWHr1q1SW48ePeDn54cVK1ZUO4+QkBDcuHEDycnJUpuHhwemTp2KqVOnPlDdJSUlsLGxQXFxMaytrR9ojJp4zNxm0PGIHic5MQOVLoGImrD6fn/rvQfp9u3b6Nmz598q7v6xUlNTERQU9FdBRkYICgpCSkpKta9JSUnR6Q8AwcHBNfYvKCjAtm3bMG7cuCrTYmJiYG9vjy5duuDjjz9GRUVFjbWWl5ejpKRE50FERESPJr0D0uuvv46EhASDzPzq1auorKyEk5OTTruTkxPy8/OrfU1+fr5e/b/66itYWVnhpZde0mmfPHky1q5dix07duCf//wnFixYgHfeeafGWqOjo2FjYyM93Nzc6rOIRERE9BDS+yy2W7duYeXKlfjll1/QuXNnmJqa6kxfsmSJwYozhC+//BKhoaEwMzPTaY+MjJT+3blzZ6jVavzzn/9EdHQ0NBpNlXFmzZql85qSkhKGJCIiokeU3gHp2LFj8PPzAwCdA6OBu2eP6cPBwQHGxsYoKCjQaS8oKICzs3O1r3F2dq53/z179iAzMxPr1q2rs5aAgABUVFQgJycH7du3rzJdo9FUG5yIiIjo0aN3QNqxY4fBZq5Wq9G1a1ckJydLV+rWarVITk7GpEmTqn1NYGAgkpOTdQ6uTkpKQmBgYJW+X3zxBbp27QpfX986a0lPT4eRkREcHR0faFmIiIjo0aF3QLrn3LlzyMrKQu/evWFubg4hhN57kIC7P3WFh4fD398f3bt3R1xcHMrKyjBmzBgAQFhYGFq2bIno6GgAwJQpU9CnTx8sXrwYAwcOxNq1a3H48GGsXLlSZ9ySkhKsX7++2tufpKSk4MCBA3jmmWdgZWWFlJQUTJs2Da+99hqaN2/+AGuDiIiIHiV6B6Rr167h1VdfxY4dO6BSqXD27Fm0adMG48aNQ/PmzfW+H9uwYcNw5coVzJkzB/n5+fDz80NiYqJ0IHZubi6MjP46lrxnz55ISEjA7Nmz8e6778LT0xObN29Gp06ddMZdu3YthBAYMWJElXlqNBqsXbsWc+fORXl5OVq3bo1p06bpHGNEREREjy+9r4MUFhaGwsJC/Pe//4W3tzeOHj2KNm3aYPv27YiMjMTJkycbqtYmhddBImqaeB0kIqpNfb+/9d6D9PPPP2P79u1o1aqVTrunpyfOnz+vf6VERERETYze10EqKytDs2bNqrRfv36dZ3kRERHRI0HvgPSPf/wDX3/9tfRcpVJBq9Vi4cKFeOaZZwxaHBEREZES9P6JbeHChejXrx8OHz6M27dv45133sHJkydx/fp17N27tyFqJCIiImpUeu9B6tSpE86cOYOnn34agwcPRllZGV566SUcOXIEbdu2bYgaiYiIiBrVA10HycbGBu+9956hayEiIiJqEvTeg7Rq1SqsX7++Svv69evx1VdfGaQoIiIiIiXpHZCio6Ph4OBQpd3R0RELFiwwSFFEREREStI7IOXm5qJ169ZV2t3d3ZGbm2uQooiIiIiUpHdAcnR0xLFjx6q0Hz16FPb29gYpioiIiEhJegekESNGYPLkydixYwcqKytRWVmJX3/9FVOmTMHw4cMbokYiIiKiRqX3WWwffPABcnJy0K9fP5iY3H25VqtFWFgYj0EiIiKiR4JeAUkIgfz8fKxevRoffvgh0tPTYW5uDh8fH7i7uzdUjURERESNSu+A1K5dO5w8eRKenp7w9PRsqLqIiIiIFKPXMUhGRkbw9PTEtWvXGqoeIiIiIsXpfZB2TEwMpk+fjhMnTjREPURERESK0/sg7bCwMNy8eRO+vr5Qq9UwNzfXmX79+nWDFUdERESkBL0DUlxcXAOUQURERNR06B2QwsPDG6IOIiIioiZD72OQACArKwuzZ8/GiBEjUFhYCAD46aefcPLkSYMWR0RERKQEvQPSrl274OPjgwMHDmDjxo0oLS0FcPdWI1FRUQYvkIiIiKix6R2QZs6ciQ8//BBJSUlQq9VS+7PPPov9+/cbtDgiIiIiJegdkI4fP44hQ4ZUaXd0dMTVq1cNUhQRERGRkvQOSLa2tsjLy6vSfuTIEbRs2dIgRREREREpSe+ANHz4cMyYMQP5+flQqVTQarXYu3cv3n77bYSFhTVEjURERESNSu+AtGDBAnh5ecHNzQ2lpaXo0KEDevfujZ49e2L27NkNUSMRERFRo9L7OkhqtRqff/455syZg+PHj6O0tBRdunThjWuJiIjokVHvgKTVavHxxx9jy5YtuH37Nvr164eoqKgqtxohIlKSx8xtSpcgyYkZqHQJRPSA6v0T20cffYR3330XlpaWaNmyJT799FNMnDixIWsjIiIiUkS9A9LXX3+N//znP9i+fTs2b96MH374AWvWrIFWq23I+oiIiIgaXb0DUm5uLgYMGCA9DwoKgkqlwuXLlxukMCIiIiKl1DsgVVRUwMzMTKfN1NQUd+7cMXhRREREREqq90HaQgiMHj0aGo1Gart16xbeeOMNWFhYSG0bN240bIVEREREjazeASk8PLxK22uvvWbQYoiIiIiagnoHpFWrVjVkHURERERNht5X0iYiIiJ61DWJgBQfHw8PDw+YmZkhICAABw8erLX/+vXr4eXlBTMzM/j4+ODHH3/UmT569GioVCqdR//+/XX6XL9+HaGhobC2toatrS3GjRuH0tJSgy8bERERPXwUD0jr1q1DZGQkoqKikJaWBl9fXwQHB6OwsLDa/vv27cOIESMwbtw4HDlyBCEhIQgJCcGJEyd0+vXv3x95eXnS45tvvtGZHhoaipMnTyIpKQlbt27F7t27MWHChAZbTiIiInp4qIQQQskCAgIC0K1bNyxbtgzA3VuauLm5ISIiAjNnzqzSf9iwYSgrK8PWrVulth49esDPzw8rVqwAcHcPUlFRETZv3lztPDMyMtChQwccOnQI/v7+AIDExEQMGDAAFy9ehKura511l5SUwMbGBsXFxbC2ttZ3sWvVlG6VQEQPjrcaIWp66vv9regepNu3byM1NRVBQUFSm5GREYKCgpCSklLta1JSUnT6A0BwcHCV/jt37oSjoyPat2+PN998E9euXdMZw9bWVgpHwN0LXxoZGeHAgQPVzre8vBwlJSU6DyIiIno0KRqQrl69isrKSjg5Oem0Ozk5IT8/v9rX5Ofn19m/f//++Prrr5GcnIzY2Fjs2rULL7zwAiorK6UxHB0ddcYwMTGBnZ1djfONjo6GjY2N9HBzc9N7eYmIiOjhUO/T/B8mw4cPl/7t4+ODzp07o23btti5cyf69ev3QGPOmjULkZGR0vOSkhKGJCIiokeUonuQHBwcYGxsjIKCAp32goICODs7V/saZ2dnvfoDQJs2beDg4IBz585JY8gPAq+oqMD169drHEej0cDa2lrnQURERI8mRfcgqdVqdO3aFcnJyQgJCQFw9yDt5ORkTJo0qdrXBAYGIjk5GVOnTpXakpKSEBgYWON8Ll68iGvXrsHFxUUao6ioCKmpqejatSsA4Ndff4VWq0VAQIBhFo6IHntN6YQLHjBOpB/FT/OPjIzE559/jq+++goZGRl48803UVZWhjFjxgAAwsLCMGvWLKn/lClTkJiYiMWLF+P06dOYO3cuDh8+LAWq0tJSTJ8+Hfv370dOTg6Sk5MxePBgtGvXDsHBwQAAb29v9O/fH+PHj8fBgwexd+9eTJo0CcOHD6/XGWxERET0aFP8GKRhw4bhypUrmDNnDvLz8+Hn54fExETpQOzc3FwYGf2V43r27ImEhATMnj0b7777Ljw9PbF582Z06tQJAGBsbIxjx47hq6++QlFREVxdXfH888/jgw8+0LnR7po1azBp0iT069cPRkZGGDp0KJYuXdq4C09ERERNkuLXQXpY8TpIRPQw4U9sRHc9FNdBIiIiImqKGJCIiIiIZBiQiIiIiGQYkIiIiIhkGJCIiIiIZBiQiIiIiGQYkIiIiIhkGJCIiIiIZBiQiIiIiGQYkIiIiIhkGJCIiIiIZBiQiIiIiGQYkIiIiIhkGJCIiIiIZBiQiIiIiGQYkIiIiIhkGJCIiIiIZBiQiIiIiGQYkIiIiIhkGJCIiIiIZBiQiIiIiGQYkIiIiIhkGJCIiIiIZBiQiIiIiGRMlC6AiIgeHx4ztyldgiQnZqDSJVATxj1IRERERDIMSEREREQyDEhEREREMjwGiYjoMdCUjv0hehhwDxIRERGRDAMSERERkQwDEhEREZEMAxIRERGRDAMSERERkQwDEhEREZEMAxIRERGRDAMSERERkUyTCEjx8fHw8PCAmZkZAgICcPDgwVr7r1+/Hl5eXjAzM4OPjw9+/PFHadqdO3cwY8YM+Pj4wMLCAq6urggLC8Ply5d1xvDw8IBKpdJ5xMTENMjyERER0cNF8YC0bt06REZGIioqCmlpafD19UVwcDAKCwur7b9v3z6MGDEC48aNw5EjRxASEoKQkBCcOHECAHDz5k2kpaXh/fffR1paGjZu3IjMzEwMGjSoyljz589HXl6e9IiIiGjQZSUiIqKHg0oIIZQsICAgAN26dcOyZcsAAFqtFm5uboiIiMDMmTOr9B82bBjKysqwdetWqa1Hjx7w8/PDihUrqp3HoUOH0L17d5w/fx5PPPEEgLt7kKZOnYqpU6fWq87y8nKUl5dLz0tKSuDm5obi4mJYW1vXd3HrhbcEICJqeDkxA5UugRRQUlICGxubOr+/Fd2DdPv2baSmpiIoKEhqMzIyQlBQEFJSUqp9TUpKik5/AAgODq6xPwAUFxdDpVLB1tZWpz0mJgb29vbo0qULPv74Y1RUVNQ4RnR0NGxsbKSHm5tbPZaQiIiIHkaK3qz26tWrqKyshJOTk067k5MTTp8+Xe1r8vPzq+2fn59fbf9bt25hxowZGDFihE5SnDx5Mp566inY2dlh3759mDVrFvLy8rBkyZJqx5k1axYiIyOl5/f2IBEREdGjR9GA1NDu3LmDV199FUIILF++XGfa/WGnc+fOUKvV+Oc//4no6GhoNJoqY2k0mmrbiYiI6NGj6E9sDg4OMDY2RkFBgU57QUEBnJ2dq32Ns7NzvfrfC0fnz59HUlJSnccJBQQEoKKiAjk5OfovCBERET1SFA1IarUaXbt2RXJystSm1WqRnJyMwMDAal8TGBio0x8AkpKSdPrfC0dnz57FL7/8Ant7+zprSU9Ph5GRERwdHR9waYiIiOhRofhPbJGRkQgPD4e/vz+6d++OuLg4lJWVYcyYMQCAsLAwtGzZEtHR0QCAKVOmoE+fPli8eDEGDhyItWvX4vDhw1i5ciWAu+Ho5ZdfRlpaGrZu3YrKykrp+CQ7Ozuo1WqkpKTgwIEDeOaZZ2BlZYWUlBRMmzYNr732Gpo3b67MiiAiIqImQ/GANGzYMFy5cgVz5sxBfn4+/Pz8kJiYKB2InZubCyOjv3Z09ezZEwkJCZg9ezbeffddeHp6YvPmzejUqRMA4NKlS9iyZQsAwM/PT2deO3bsQN++faHRaLB27VrMnTsX5eXlaN26NaZNm6ZzXBIRERE9vhS/DtLDqr7XUXgQvA4SEVHD43WQHk8PxXWQiIiIiJoiBiQiIiIiGQYkIiIiIhkGJCIiIiIZBiQiIiIiGQYkIiIiIhkGJCIiIiIZBiQiIiIiGQYkIiIiIhkGJCIiIiIZBiQiIiIiGQYkIiIiIhkGJCIiIiIZBiQiIiIiGQYkIiIiIhkGJCIiIiIZE6ULICIiUoLHzG1Kl9Ak5cQMVLqEJoF7kIiIiIhkGJCIiIiIZBiQiIiIiGR4DBIRERFJmsqxWUofC8U9SEREREQyDEhEREREMgxIRERERDIMSEREREQyDEhEREREMgxIRERERDIMSEREREQyDEhEREREMgxIRERERDIMSEREREQyDEhEREREMgxIRERERDIMSEREREQyDEhEREREMgxIRERERDJNIiDFx8fDw8MDZmZmCAgIwMGDB2vtv379enh5ecHMzAw+Pj748ccfdaYLITBnzhy4uLjA3NwcQUFBOHv2rE6f69evIzQ0FNbW1rC1tcW4ceNQWlpq8GUjIiKih4/iAWndunWIjIxEVFQU0tLS4Ovri+DgYBQWFlbbf9++fRgxYgTGjRuHI0eOICQkBCEhIThx4oTUZ+HChVi6dClWrFiBAwcOwMLCAsHBwbh165bUJzQ0FCdPnkRSUhK2bt2K3bt3Y8KECQ2+vERERNT0qYQQQskCAgIC0K1bNyxbtgwAoNVq4ebmhoiICMycObNK/2HDhqGsrAxbt26V2nr06AE/Pz+sWLECQgi4urrirbfewttvvw0AKC4uhpOTE1avXo3hw4cjIyMDHTp0wKFDh+Dv7w8ASExMxIABA3Dx4kW4urrWWXdJSQlsbGxQXFwMa2trQ6wKicfMbQYdj4iI6GGTEzOwQcat7/e3SYPMvZ5u376N1NRUzJo1S2ozMjJCUFAQUlJSqn1NSkoKIiMjddqCg4OxefNmAEB2djby8/MRFBQkTbexsUFAQABSUlIwfPhwpKSkwNbWVgpHABAUFAQjIyMcOHAAQ4YMqTLf8vJylJeXS8+Li4sB3F3RhqYtv2nwMYmIiB4mDfH9ev+4de0fUjQgXb16FZWVlXByctJpd3JywunTp6t9TX5+frX98/Pzpen32mrr4+joqDPdxMQEdnZ2Uh+56OhozJs3r0q7m5tbTYtHRERED8gmrmHHv3HjBmxsbGqcrmhAepjMmjVLZ8+VVqvF9evXYW9vD5VKpWBlDaekpARubm64cOGCwX9GfFhxnVSP66UqrpOquE6qx/VSVUOuEyEEbty4UefhNIoGJAcHBxgbG6OgoECnvaCgAM7OztW+xtnZudb+9/5bUFAAFxcXnT5+fn5SH/lB4BUVFbh+/XqN89VoNNBoNDpttra2tS/gI8La2pofWhmuk+pxvVTFdVIV10n1uF6qaqh1Utueo3sUPYtNrVaja9euSE5Oltq0Wi2Sk5MRGBhY7WsCAwN1+gNAUlKS1L9169ZwdnbW6VNSUoIDBw5IfQIDA1FUVITU1FSpz6+//gqtVouAgACDLR8RERE9nBT/iS0yMhLh4eHw9/dH9+7dERcXh7KyMowZMwYAEBYWhpYtWyI6OhoAMGXKFPTp0weLFy/GwIEDsXbtWhw+fBgrV64EAKhUKkydOhUffvghPD090bp1a7z//vtwdXVFSEgIAMDb2xv9+/fH+PHjsWLFCty5cweTJk3C8OHD63UGGxERET3aFA9Iw4YNw5UrVzBnzhzk5+fDz88PiYmJ0kHWubm5MDL6a0dXz549kZCQgNmzZ+Pdd9+Fp6cnNm/ejE6dOkl93nnnHZSVlWHChAkoKirC008/jcTERJiZmUl91qxZg0mTJqFfv34wMjLC0KFDsXTp0sZb8IeARqNBVFRUlZ8WH2dcJ9XjeqmK66QqrpPqcb1U1RTWieLXQSIiIiJqahS/kjYRERFRU8OARERERCTDgEREREQkw4BEREREJMOARFVER0ejW7dusLKygqOjI0JCQpCZmal0WU1KTEyMdEmJx9mlS5fw2muvwd7eHubm5vDx8cHhw4eVLktRlZWVeP/999G6dWuYm5ujbdu2+OCDD+q879OjZPfu3XjxxRfh6uoKlUol3SvzHiEE5syZAxcXF5ibmyMoKAhnz55VpthGVNt6uXPnDmbMmAEfHx9YWFjA1dUVYWFhuHz5snIFN4K63iv3e+ONN6BSqRAXF9cotTEgURW7du3CxIkTsX//fiQlJeHOnTt4/vnnUVZWpnRpTcKhQ4fw2WefoXPnzkqXoqg//vgDvXr1gqmpKX766SecOnUKixcvRvPmzZUuTVGxsbFYvnw5li1bhoyMDMTGxmLhwoX497//rXRpjaasrAy+vr6Ij4+vdvrChQuxdOlSrFixAgcOHICFhQWCg4Nx69atRq60cdW2Xm7evIm0tDS8//77SEtLw8aNG5GZmYlBgwYpUGnjqeu9cs+mTZuwf//+xr1WoSCqQ2FhoQAgdu3apXQpirtx44bw9PQUSUlJok+fPmLKlClKl6SYGTNmiKefflrpMpqcgQMHirFjx+q0vfTSSyI0NFShipQFQGzatEl6rtVqhbOzs/j444+ltqKiIqHRaMQ333yjQIXKkK+X6hw8eFAAEOfPn2+cohRW0zq5ePGiaNmypThx4oRwd3cXn3zySaPUwz1IVKfi4mIAgJ2dncKVKG/ixIkYOHAggoKClC5FcVu2bIG/vz9eeeUVODo6okuXLvj888+VLktxPXv2RHJyMs6cOQMAOHr0KH777Te88MILClfWNGRnZyM/P1/nM2RjY4OAgACkpKQoWFnTU1xcDJVK9djc97M6Wq0Wo0aNwvTp09GxY8dGnbfiV9Kmpk2r1WLq1Kno1auXztXKH0dr165FWloaDh06pHQpTcLvv/+O5cuXIzIyEu+++y4OHTqEyZMnQ61WIzw8XOnyFDNz5kyUlJTAy8sLxsbGqKysxEcffYTQ0FClS2sS8vPzAUC6W8I9Tk5O0jQCbt26hRkzZmDEiBGP9Q1sY2NjYWJigsmTJzf6vBmQqFYTJ07EiRMn8NtvvyldiqIuXLiAKVOmICkpSeeWNY8zrVYLf39/LFiwAADQpUsXnDhxAitWrHisA9K3336LNWvWICEhAR07dkR6ejqmTp0KV1fXx3q9UP3duXMHr776KoQQWL58udLlKCY1NRWffvop0tLSoFKpGn3+/ImNajRp0iRs3boVO3bsQKtWrZQuR1GpqakoLCzEU089BRMTE5iYmGDXrl1YunQpTExMUFlZqXSJjc7FxQUdOnTQafP29kZubq5CFTUN06dPx8yZMzF8+HD4+Phg1KhRmDZtmnTD7ceds7MzAKCgoECnvaCgQJr2OLsXjs6fP4+kpKTHeu/Rnj17UFhYiCeeeELa7p4/fx5vvfUWPDw8Gnz+3INEVQghEBERgU2bNmHnzp1o3bq10iUprl+/fjh+/LhO25gxY+Dl5YUZM2bA2NhYocqU06tXryqXfzhz5gzc3d0VqqhpuHnzps4NtgHA2NgYWq1WoYqaltatW8PZ2RnJycnw8/MDAJSUlODAgQN48803lS1OYffC0dmzZ7Fjxw7Y29srXZKiRo0aVeV4z+DgYIwaNQpjxoxp8PkzIFEVEydOREJCAr7//ntYWVlJxwXY2NjA3Nxc4eqUYWVlVeUYLAsLC9jb2z+2x2ZNmzYNPXv2xIIFC/Dqq6/i4MGDWLlyJVauXKl0aYp68cUX8dFHH+GJJ55Ax44dceTIESxZsgRjx45VurRGU1painPnzknPs7OzkZ6eDjs7OzzxxBOYOnUqPvzwQ3h6eqJ169Z4//334erqipCQEOWKbgS1rRcXFxe8/PLLSEtLw9atW1FZWSlte+3s7KBWq5Uqu0HV9V6Rh0RTU1M4Ozujffv2DV9co5wrRw8VANU+Vq1apXRpTcrjfpq/EEL88MMPolOnTkKj0QgvLy+xcuVKpUtSXElJiZgyZYp44oknhJmZmWjTpo147733RHl5udKlNZodO3ZUuw0JDw8XQtw91f/9998XTk5OQqPRiH79+onMzExli24Eta2X7OzsGre9O3bsULr0BlPXe0WuMU/zVwnxGF3elYiIiKgeeJA2ERERkQwDEhEREZEMAxIRERGRDAMSERERkQwDEhEREZEMAxIRERGRDAMSERERkQwDEhEREZHMYxeQcnJyoFKpkJ6ernQpktOnT6NHjx4wMzOT7k1Ef9/q1atha2urdBmKUqlU2Lx5898ao6msx/p8ToQQmDBhAuzs7KTPed++fTF16tRGrbWhjB49us7bcezcuRMqlQpFRUUNWsvevXvh4+MDU1PTR/4WIQ1Jye+k+ryf6nLz5k0MHToU1tbWjfK+a0yNHpBGjx4NlUqFmJgYnfbNmzdDpVI1djlNQlRUFCwsLJCZmYnk5GSly2kSmsqXMmCYkKGUvLw8vPDCC0qXYRD1+ZwkJiZi9erV2Lp1K/Ly8tCpUyds3LgRH3zwwd+ad1N5D3z66adYvXq19Ly68NezZ0/k5eXBxsamQWuJjIyEn58fsrOzdWp62DSlbc3D6KuvvsKePXuwb9++RnnfNSZF9iCZmZkhNjYWf/zxhxKzbxC3b99+4NdmZWXh6aefhru7+2N/92YyLGdnZ2g0GqXLMIj6fE6ysrLg4uKCnj17wtnZGSYmJrCzs4OVlVWN4/6dz25js7GxqfPLXK1Ww9nZucH/hzMrKwvPPvssWrVq9cAB42Fa91S9rKwseHt7o1OnTo3yvmtUjXLHt/uEh4eL//u//xNeXl5i+vTpUvumTZvE/eVERUUJX19fndd+8sknwt3dXWeswYMHi48++kg4OjoKGxsbMW/ePHHnzh3x9ttvi+bNm4uWLVuKL7/8UnrNvRsCfvPNNyIwMFBoNBrRsWNHsXPnTp15HT9+XPTv319YWFgIR0dH8dprr4krV65I0/v06SMmTpwopkyZIuzt7UXfvn2rXd7Kykoxb9480bJlS6FWq4Wvr6/46aefpOmQ3aAvKiqqxnFiY2NF27ZthVqtFm5ubuLDDz+Uph87dkw888wzwszMTNjZ2Ynx48eLGzduGGRdrVu3Tjz99NPCzMxM+Pv7i8zMTHHw4EHRtWtXYWFhIfr37y8KCwt16v3888+Fl5eX0Gg0on379iI+Pr7KuBs2bBB9+/YV5ubmonPnzmLfvn1CiOpvXnhvvcTHx4t27doJjUYjHB0dxdChQ6tdX0IIsWrVKmFjYyM2bdokveb5558Xubm5Ov02b94sunTpIjQajWjdurWYO3euuHPnjhDi7o0R76/D3d1dFBUVCSMjI3Ho0CHpb9O8eXMREBAgjfn//t//E61atZKe5+bmildeeUXY2NiI5s2bi0GDBons7GyDrbOaABCbNm3Sa4xVq1YJNzc3YW5uLkJCQsSiRYuEjY1NvdfZvHnzhIuLi7h69arUf8CAAaJv376isrKy2joN8TkJDw+v8rcSoupNhd3d3cX8+fPFqFGjhJWVlQgPDxfl5eVi4sSJwtnZWWg0GvHEE0+IBQsWSP2rG1euvtuWnTt3im7dugm1Wi2cnZ3FjBkzpHUnhBDr168XnTp1kj7L/fr1E6WlpdIyDh48uNrlBSCys7Olz88ff/whiouLhZmZmfjxxx91ati4caOwtLQUZWVlQoj6vT/ly3n/496NrOtatpq2m3Vtb+va/r3zzjvC09NTmJubi9atW4vZs2eL27dvS9PT09NF3759haWlpbCyshJPPfWUOHToUK3bGrlz586JQYMGCUdHR2FhYSH8/f1FUlKSTh93d3fx0UcfiTFjxghLS0vh5uYmPvvsM50+Bw4cEH5+fkKj0YiuXbuKjRs3CgDiyJEj1c5XCCFu3bol3nrrLeHq6iqaNWsmunfvrnMD23vbusTEROHl5SUsLCxEcHCwuHz5stSnoqJCTJs2TdjY2Ag7Ozsxffp0ERYWJr2favLdd9+JDh06CLVaLdzd3cWiRYukaX369NFZd3369KlxnC1btgh/f3+h0WiEvb29CAkJkaZdv35djBo1Stja2gpzc3PRv39/cebMmXov3/bt24VGoxF//PGHzjwnT54snnnmGen5nj17pO+yVq1aiYiICOmzJUTVbYMiAWnw4MFi48aNwszMTFy4cEEI8eABycrKSkycOFGcPn1afPHFFwKACA4OFh999JE4c+aM+OCDD4Spqak0n3sf7latWonvvvtOnDp1Srz++uvCyspK2qD/8ccfokWLFmLWrFkiIyNDpKWlieeee05nRffp00dYWlqK6dOni9OnT4vTp09Xu7xLliwR1tbW4ptvvhGnT58W77zzjjA1NZX++Hl5eaJjx47irbfeEnl5eTqh5n7vvPOOaN68uVi9erU4d+6c2LNnj/j888+FEEKUlpYKFxcX8dJLL4njx4+L5ORk0bp1a527If+ddeXl5SUSExPFqVOnRI8ePUTXrl1F3759xW+//SbS0tJEu3btxBtvvCHN63//+59wcXERGzZsEL///rvYsGGDsLOzE6tXr64y7tatW0VmZqZ4+eWXhbu7u7hz544oLy8XcXFxwtraWuTl5Unr5dChQ8LY2FgkJCSInJwckZaWJj799NPq32ji7ofK1NRU+Pv7i3379onDhw+L7t27i549e0p9du/eLaytrcXq1atFVlaW+Pnnn4WHh4eYO3euEEKIwsJC6QsgLy9PCoJPPfWU+Pjjj4UQdze+dnZ2Qq1WS3+/119/XYSGhgohhLh9+7bw9vYWY8eOFceOHROnTp0SI0eOFO3bt5fu8P5311lNqgtItY2xf/9+YWRkJGJjY0VmZqb49NNPha2trU5AqmudVVRUiMDAQGkDuGzZMmFrayvOnz9fY52G+JwUFRWJ+fPni1atWun8raoLSNbW1mLRokXi3Llz4ty5c+Ljjz8Wbm5uYvfu3SInJ0fs2bNHJCQk1PoekKvPtuXixYuiWbNm4l//+pfIyMgQmzZtEg4ODtKX8uXLl4WJiYlYsmSJyM7OFseOHRPx8fHS8t4fkIqKikRgYKAYP3689DmpqKjQCUhCCPHyyy+L1157TafWoUOHSm31eX/er6KiQuTl5Qlra2sRFxcn8vLyxM2bN+tctnt/C/l2sz7b29q2f0II8cEHH4i9e/eK7OxssWXLFuHk5CRiY2Ol6R07dhSvvfaayMjIEGfOnBHffvutSE9Pr3FbU5309HSxYsUKcfz4cXHmzBkxe/ZsYWZmpvO+dnd3F3Z2diI+Pl6cPXtWREdHCyMjI+n74caNG6JFixZi5MiR4sSJE+KHH34Qbdq0qTMgvf7666Jnz55i9+7d0vtVo9FIn49727qgoCBx6NAhkZqaKry9vcXIkSOlMWJjY0Xz5s3Fhg0bxKlTp8S4ceOElZVVrQHp8OHDwsjISMyfP19kZmaKVatWCXNzcykQX7t2TYwfP14EBgaKvLw8ce3atWrH2bp1qzA2NhZz5swRp06dEunp6dL/gAghxKBBg4S3t7fYvXu3SE9PF8HBwaJdu3ZSyK1r+SoqKoSTk5P473//K40pbzt37pywsLAQn3zyiThz5ozYu3ev6NKlixg9erTO3+/+bYNiAUkIIXr06CHGjh0rhHjwgOTu7q7zf6Xt27cX//jHP6TnFRUVwsLCQnzzzTdCiL82YjExMVKfO3fuiFatWkkfqA8++EA8//zzOvO+cOGCACAyMzOFEHc/6F26dKlzeV1dXcVHH32k09atWzfxr3/9S3ru6+tb4/+1CCFESUmJ0Gg0OhuE+61cuVI0b95cJwlv27ZNGBkZifz8fCHE31tX97/pvvnmGwFAJCcnS23R0dGiffv20vO2bdtKXy73fPDBByIwMLDGcU+ePCkAiIyMDCHEX//HcL8NGzYIa2trUVJSUuO6ut+qVasEALF//36pLSMjQwAQBw4cEEII0a9fP50PqhB39/64uLhIz+8PGfdERkaKgQMHCiGEiIuLE8OGDdPZ69GuXTuxcuVKabz27dsLrVYrvb68vFyYm5uL7du3CyEMs86qU11Aqm2MESNGiAEDBuiMMWzYMJ2/RX3WWVZWlrCyshIzZswQ5ubmYs2aNTXWKIRhPidCVN1GCFF9QLr//16FECIiIkI8++yzOn+j+1X3HpCrz7bl3XffrfJeiI+PF5aWlqKyslKkpqYKACInJ6faedy//axu2YQQVQLSpk2bdPYW3durdO+9Wp/3Z3VsbGykL8r6LNu9euXbzbq2t3Vt/6rz8ccfi65du0rPrayspP/ZkKtuW1NfHTt2FP/+97+l5+7u7jphVKvVCkdHR7F8+XIhhBCfffaZsLe3F3/++afUZ/ny5bUGpPPnzwtjY2Nx6dIlnfZ+/fqJWbNmScsAQJw7d06aHh8fL5ycnKTnLi4uYuHChdLze+/N2gLSyJEjxXPPPafTNn36dNGhQwfp+ZQpU2rdcySEEIGBgdL/MMqdOXNGABB79+6V2q5evSrMzc3Ft99+W+/lmzJlinj22Wel5/K9SuPGjRMTJkzQmfeePXuEkZGR9PeQbxsUPYstNjYWX331FTIyMh54jI4dO8LI6K/FcHJygo+Pj/Tc2NgY9vb2KCws1HldYGCg9G8TExP4+/tLdRw9ehQ7duyApaWl9PDy8gJw9/fWe7p27VprbSUlJbh8+TJ69eql096rVy+9ljkjIwPl5eXo169fjdN9fX1hYWGhMw+tVovMzEyp7UHXVefOnXVeA0DndU5OTtJrysrKkJWVhXHjxumsvw8//FBn3cnHdXFxAYAq877fc889B3d3d7Rp0wajRo3CmjVrcPPmzRr7A3f/tt26dZOee3l5wdbWVudvPX/+fJ1ax48fj7y8vFrH7tOnD3777TdUVlZi165d6Nu3L/r27YudO3fi8uXLOHfuHPr27SvN49y5c7CyspLmYWdnh1u3biErK6tB11l1ahsjIyMDAQEBOv3v/6zUd521adMGixYtQmxsLAYNGoSRI0fWWI+hPif68Pf313k+evRopKeno3379pg8eTJ+/vnnBx67tm1LRkYGAgMDdY7T6NWrF0pLS3Hx4kX4+vqiX79+8PHxwSuvvILPP//8bx+rOWDAAJiammLLli0AgA0bNsDa2hpBQUEA6n5/1lddy3aPfLtZ1/a2ru0fAKxbtw69evWCs7MzLC0tMXv2bOTm5krTIyMj8frrryMoKAgxMTF6Ldc9paWlePvtt+Ht7Q1bW1tYWloiIyNDZz6A7udLpVLB2dlZ5/PVuXNnmJmZSX3kny+548ePo7KyEk8++aTOOtq1a5fOcjRr1gxt27aVnru4uEjzLS4uRl5ens5n+957szYZGRnVfi7Pnj2LysrKWl97v/T09Fq/v0xMTHRqs7e3R/v27XU+/7UtHwCEhoZK218AWLNmDQYOHCgdH3f06FGsXr1aZx0GBwdDq9UiOztbGuf+dWJS7yVsAL1790ZwcDBmzZqF0aNH60wzMjKCEEKn7c6dO1XGMDU11XmuUqmqbdNqtfWuq7S0FC+++CJiY2OrTLv3hQJAJ5A0JHNzc4OM86Dr6v4+9zZ+8rZ7ryktLQUAfP7551W+aI2Njesct7a/k5WVFdLS0rBz5078/PPPmDNnDubOnYtDhw498EGipaWlmDdvHl566aUq0+7fiMn17t0bN27cQFpaGnbv3o0FCxbA2dkZMTEx8PX1haurKzw9PaV5dO3aFWvWrKkyTosWLRp0nVXn745R33W2e/duGBsbIycnBxUVFTAxUXRzo0P+2X3qqaeQnZ2Nn376Cb/88gteffVVBAUF4bvvvmvUuoyNjZGUlIR9+/bh559/xr///W+89957OHDgAFq3bv1AY6rVarz88stISEjA8OHDkZCQgGHDhkl/j7ren4YmX/d1bW9///33WsdLSUlBaGgo5s2bh+DgYNjY2GDt2rVYvHix1Gfu3LkYOXIktm3bhp9++glRUVFYu3YthgwZUu+63377bSQlJWHRokVo164dzM3N8fLLL1c50Pzvfv/IlZaWwtjYGKmpqVW2B5aWlrXOV/4dqhRDfIfVtXzdunVD27ZtsXbtWrz55pvYtGmTztmVpaWl+Oc//4nJkydXGfuJJ56Q/n3/+1Px6yDFxMTghx9+QEpKik57ixYtkJ+fr7MCDHmdiP3790v/rqioQGpqKry9vQHc3ViePHkSHh4eaNeunc5Dn1BkbW0NV1dX7N27V6d979696NChQ73H8fT0hLm5eY2nNnt7e+Po0aMoKyvTmYeRkRHat29f7/kYgpOTE1xdXfH7779XWXf6bODVanW1/4diYmKCoKAgLFy4EMeOHUNOTg5+/fXXGsepqKjA4cOHpeeZmZkoKirS+VtnZmZWqbVdu3bS3jZTU9Mqtdja2qJz585YtmwZTE1N4eXlhd69e+PIkSPYunUr+vTpI/V96qmncPbsWTg6OlaZh42NjcHWmSF4e3vjwIEDOm33f1aA+q2zdevWYePGjdi5cydyc3NrPc3eUJ+Tv8va2hrDhg3D559/jnXr1mHDhg24fv06gOrfAzWpbdvi7e2NlJQUne3a3r17YWVlhVatWgG4u+Hv1asX5s2bhyNHjkCtVmPTpk3Vzqumz4lcaGgoEhMTcfLkSfz6668IDQ2VptX1/qyv+ixbdera3ta1/du3bx/c3d3x3nvvwd/fH56enjh//nyVfk8++SSmTZuGn3/+GS+99BJWrVoFoP7rcO/evRg9ejSGDBkCHx8fODs7Iycnp87X3c/b2xvHjh3DrVu3pDb550uuS5cuqKysRGFhYZX14+zsXK/52tjYwMXFReezfe+9WVe91X0un3zyySphrTadO3eu9furoqJCp7Zr164hMzNT789/aGgo1qxZgx9++AFGRkYYOHCgNO2pp57CqVOnqt1uqdXqasdTPCD5+PggNDQUS5cu1Wnv27cvrly5goULFyIrKwvx8fH46aefDDbf+Ph4bNq0CadPn8bEiRPxxx9/YOzYsQCAiRMn4vr16xgxYgQOHTqErKwsbN++HWPGjNFrtyIATJ8+HbGxsVi3bh0yMzMxc+ZMpKenY8qUKfUew8zMDDNmzMA777yDr7/+GllZWdi/fz+++OILAHffFGZmZggPD8eJEyewY8cOREREYNSoUdJPYo1p3rx5iI6OxtKlS3HmzBkcP34cq1atwpIlS+o9hoeHB0pLS5GcnIyrV6/i5s2b2Lp1K5YuXYr09HScP38eX3/9NbRaba0h0NTUFBEREThw4ABSU1MxevRo9OjRA927dwcAzJkzB19//TXmzZuHkydPIiMjA2vXrsXs2bN1aklOTkZ+fr7Ozx19+/bFmjVrpDBkZ2cHb29vrFu3TicghYaGwsHBAYMHD8aePXuQnZ2NnTt3YvLkydJPD4ZYZ4YwefJkJCYmYtGiRTh79iyWLVuGxMREnT51rbOLFy/izTffRGxsLJ5++mmsWrUKCxYsqPWLwBCfk79jyZIl+Oabb3D69GmcOXMG69evh7Ozs7Rnsqb3QHVq27b861//woULFxAREYHTp0/j+++/R1RUFCIjI2FkZIQDBw5gwYIFOHz4MHJzc7Fx40ZcuXJFClhyHh4eOHDgAHJycnD16tUa91T07t0bzs7OCA0NRevWrXX2VNbn/VkfdS1bTera3ta1/fP09ERubi7Wrl2LrKwsLF26VCdQ/vnnn5g0aRJ27tyJ8+fPY+/evTh06JC0Tqvb1lTH09MTGzduRHp6Oo4ePYqRI0fqvWdo5MiRUKlUGD9+PE6dOoUff/wRixYtqvU1Tz75JEJDQxEWFoaNGzciOzsbBw8eRHR0NLZt21bveU+ZMgUxMTHYvHkzTp8+jX/96191XtTxrbfeQnJyMj744AOcOXMGX331FZYtW4a333673vMF7l7D7JtvvkFUVBQyMjJw/PhxaY+hp6cnBg8ejPHjx+O3337D0aNH8dprr6Fly5YYPHiwXvMJDQ1FWloaPvroI7z88ss6lziZMWMG9u3bh0mTJiE9PR1nz57F999/j0mTJtU8YK1HVjUA+UGGQtw9uFGtVgt5OcuXLxdubm7CwsJChIWFiY8++qja0/zvV91Bi+7u7uKTTz6R5gVAJCQkiO7duwu1Wi06dOggfv31V53XnDlzRgwZMkQ67dDLy0tMnTpVOgCxuvlUp7KyUsydO1e0bNlSmJqaVjl9WYj6HXxaWVkpPvzwQ+Hu7i5MTU11TkMWov6n+T/Iurr/4EH5AaBCVH+Q45o1a4Sfn59Qq9WiefPmonfv3mLjxo01jvvHH38IADqnrr7xxhvC3t5eOvV2z549ok+fPqJ58+bSKerr1q2rcZ3dq2vDhg2iTZs2QqPRiKCgoCpnUyUmJoqePXsKc3NzYW1tLbp37y4dYC3E3dNT27VrJ0xMTHTef/dOLLh3AKYQdw8UBFDlrMa8vDwRFhYmHBwchEajEW3atBHjx48XxcXFBl1ncqjmIO26xvjiiy9Eq1athLm5uXjxxRerPc2/pnWm1WpFv379RHBwsM7BuhEREaJt27Y1niFkqM9JfQ/Svvcev2flypXCz89PWFhYCGtra9GvXz+RlpYmTa/pPXC/+m5bajsV/tSpUyI4OFi0aNFCaDQa8eSTT+ocBCz/HGdmZooePXoIc3Pzak/zv98777wjAIg5c+ZUqb0+7085+UHadS2bEDVvN+va3ta1/Zs+fbqwt7cXlpaWYtiwYeKTTz6R3rPl5eVi+PDhws3NTajVauHq6iomTZqkc6C0fFtTnezsbPHMM88Ic3Nz4ebmJpYtW1av95b8fZuSkiJ8fX2FWq0Wfn5+YsOGDXWexXb79m0xZ84c4eHhIUxNTYWLi4sYMmSIOHbsmBCi+m2w/MSnO3fuiClTpghra2tha2srIiMj9TrN/956v3f27j31OUhbiLsn2dzbvjk4OIiXXnpJmnbvNH8bGxthbm4ugoODqz3Nv7blu6d79+4CQJXPnRBCHDx4UDz33HPC0tJSWFhYiM6dO+ucHCL/+6mEaCI/UhIRPcRycnLQunVrHDlyhLcMInoEKP4TGxEREVFTw4BEREREJMOf2IiIiIhkuAeJiIiISIYBiYiIiEiGAYmIiIhIhgGJiIiISIYBiYiIiEiGAYmIiIhIhgGJiIiISIYBiYiIiEjm/wOP3P0O3lkUSQAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Horizon statistics (# of comments between first positive forecast and conversation end):\n", + "Mean = 4.353556485355648, Median = 4.0\n" + ] + }, + { + "data": { + "text/plain": [ + "( label score forecast\n", + " conversation_id \n", + " cus26gy 1 0.830510 1.0\n", + " cus37h0 1 0.657098 1.0\n", + " cus142u 0 0.481921 0.0\n", + " cus19ml 0 0.442551 0.0\n", + " cusxft0 1 0.450714 0.0\n", + " ... ... ... ...\n", + " e8qli0i 0 0.447431 0.0\n", + " e8qm4aj 0 0.336835 0.0\n", + " e8ql8ii 0 0.645406 1.0\n", + " e8qzjei 1 0.962892 1.0\n", + " e8r00ko 0 0.694327 1.0\n", + " \n", + " [1368 rows x 3 columns],\n", + " {'Accuracy': 0.6264619883040936,\n", + " 'Precision': 0.6104725415070242,\n", + " 'Recall': 0.6988304093567251,\n", + " 'FPR': 0.44590643274853803,\n", + " 'F1': 0.6516700749829584})" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "craft_forecaster.summarize(corpus, lambda c: c.meta['split'] == \"test\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27037a7c-4bde-4d37-80af-5501ae50cd81", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fe991f40-db35-45f6-8fab-09754e0665c0", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "44798811-1344-4dac-8c64-272654bba318", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d8ed3672-86e2-426e-827f-b59971e99cb5", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}