Skip to content

Commit

Permalink
added only extraction of class token as an option
Browse files Browse the repository at this point in the history
  • Loading branch information
jonasd4 committed Aug 3, 2023
1 parent 7245361 commit 3079f3c
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions thingsvision/core/extraction/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def __init__(
self.activations = {}
self.hook_handle = None

if 'extract_cls_token' in self.model_parameters:
self.extract_cls_token = self.model_parameters['extract_cls_token']
else:
self.extract_cls_token = False

if not self.model:
self.load_model()
self.prepare_inference()
Expand Down Expand Up @@ -68,6 +73,11 @@ def hook(model, input, output) -> None:
act = output[0]
else:
act = output

if self.extract_cls_token:
# we extract only the representations of the first token (cls token)
act = act[:, 0]

try:
self.activations[name] = act.clone().detach()
except AttributeError:
Expand Down

0 comments on commit 3079f3c

Please sign in to comment.