From 3079f3c4c44faaa2395e5b6aa65a07ce6bb0a3aa Mon Sep 17 00:00:00 2001 From: Jonas Dippel Date: Thu, 3 Aug 2023 08:30:11 +0200 Subject: [PATCH] added only extraction of class token as an option --- thingsvision/core/extraction/torch.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/thingsvision/core/extraction/torch.py b/thingsvision/core/extraction/torch.py index bd9209d..b30548d 100644 --- a/thingsvision/core/extraction/torch.py +++ b/thingsvision/core/extraction/torch.py @@ -31,6 +31,11 @@ def __init__( self.activations = {} self.hook_handle = None + if 'extract_cls_token' in self.model_parameters: + self.extract_cls_token = self.model_parameters['extract_cls_token'] + else: + self.extract_cls_token = False + if not self.model: self.load_model() self.prepare_inference() @@ -68,6 +73,11 @@ def hook(model, input, output) -> None: act = output[0] else: act = output + + if self.extract_cls_token: + # we extract only the representations of the first token (cls token) + act = act[:, 0] + try: self.activations[name] = act.clone().detach() except AttributeError: