Skip to content

Commit

Permalink
Implement tokenization
Browse files Browse the repository at this point in the history
  • Loading branch information
Fabian Degen committed Jan 6, 2025
1 parent 8b2fdba commit 4492392
Showing 1 changed file with 41 additions and 5 deletions.
46 changes: 41 additions & 5 deletions transformer_lens/HookedEncoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,12 @@ def __init__(self, cfg, tokenizer=None, move_to_device=True, **kwargs):
@overload
def forward(
self,
input: Int[torch.Tensor, "batch pos"],
input: Union[
str,
List[str],
Int[torch.Tensor, "batch pos"],
Float[torch.Tensor, "batch pos d_model"],
],
return_type: Literal["logits"],
task: str = "MLM",
token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
Expand All @@ -103,7 +108,12 @@ def forward(
@overload
def forward(
self,
input: Int[torch.Tensor, "batch pos"],
input: Union[
str,
List[str],
Int[torch.Tensor, "batch pos"],
Float[torch.Tensor, "batch pos d_model"],
],
return_type: Literal[None],
task: str = "MLM",
token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
Expand All @@ -113,9 +123,14 @@ def forward(

def forward(
self,
input: Int[torch.Tensor, "batch pos"],
input: Union[
str,
List[str],
Int[torch.Tensor, "batch pos"],
Float[torch.Tensor, "batch pos d_model"],
],
return_type: Optional[str] = "logits",
task: str = "MLM",
task: str = None,
token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
) -> Optional[Float[torch.Tensor, "batch pos d_vocab"]]:
Expand All @@ -131,7 +146,28 @@ def forward(
if return_type == None:
return None

tokens = input
if task is None:
logging.warning("Task not provided, defaulting to masked language modelling (MLM)")
task = "MLM"

if isinstance(input, str) or isinstance(input, list):
assert self.tokenizer is not None, "Must provide a tokenizer if input is a string"
if task == "NSP" and (isinstance(input, str) or len(input) != 2):
raise ValueError(
"Next sentence prediction task requires exactly two sentences, please provide a list of strings with each sentence as an element."
)
encodings = self.tokenizer(
input, return_tensors="pt", padding=True, truncation=True, max_length=self.cfg.n_ctx
)
tokens = encodings.input_ids
token_type_ids = encodings.token_type_ids
one_zero_attention_mask = encodings.attention_mask
else:
if task == "NSP" and token_type_ids is None:
raise ValueError(
"You are using the NSP task without specifying token_type_ids. This means that the model will treat the input as a single sequence which will lead to incorrect results."
)
tokens = input

if tokens.device.type != self.cfg.device:
tokens = tokens.to(self.cfg.device)
Expand Down

0 comments on commit 4492392

Please sign in to comment.