Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanLee97 committed Oct 19, 2024
1 parent c371424 commit ca9e956
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions angle_emb/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,8 @@ def get_pooling(outputs: torch.Tensor,
:param outputs: torch.Tensor. Model outputs (without pooling)
:param inputs: Dict. Model inputs
:param pooling_strategy: str. Pooling strategy ['cls', 'cls_avg', 'cls_max', 'last', 'avg', 'mean', 'max', 'all', index]
:param pooling_strategy: str.
Pooling strategy [`cls`, `cls_avg`, `cls_max`, `last`, `avg`, `mean`, `max`, `all`, int]
:param padding_side: str. Padding strategy of tokenizers (`left` or `right`).
It can be obtained by `tokenizer.padding_side`.
"""
Expand All @@ -272,7 +273,8 @@ def get_pooling(outputs: torch.Tensor,
sequence_lengths = -1 if padding_side == 'left' else inputs["attention_mask"].sum(dim=1) - 1
outputs = outputs[torch.arange(batch_size, device=outputs.device), sequence_lengths]
elif pooling_strategy in ['avg', 'mean']:
outputs = torch.sum(outputs * inputs["attention_mask"][:, :, None], dim=1) / inputs["attention_mask"].sum(dim=1).unsqueeze(1)
outputs = torch.sum(
outputs * inputs["attention_mask"][:, :, None], dim=1) / inputs["attention_mask"].sum(dim=1).unsqueeze(1)
elif pooling_strategy == 'max':
outputs, _ = torch.max(outputs * inputs["attention_mask"][:, :, None], dim=1)
elif pooling_strategy == 'all':
Expand All @@ -283,7 +285,8 @@ def get_pooling(outputs: torch.Tensor,
outputs = outputs[:, int(pooling_strategy)]
else:
raise NotImplementedError(
'please specify pooling_strategy from [`cls`, `cls_avg`, `cls_max`, `last`, `avg`, `mean`, `max`, `all`, int]')
'please specify pooling_strategy from '
'[`cls`, `cls_avg`, `cls_max`, `last`, `avg`, `mean`, `max`, `all`, int]')
return outputs


Expand Down Expand Up @@ -689,7 +692,8 @@ class Pooler:
Using Pooler to obtain sentence embeddings.
:param model: PreTrainedModel
:param pooling_strategy: Optional[str]. Currently support [`cls`, `cls_avg`, `cls_max`, `last`, `avg`, `mean`, `max`, `all`, int]. Default None.
:param pooling_strategy: Optional[str].
Currently support [`cls`, `cls_avg`, `cls_max`, `last`, `avg`, `mean`, `max`, `all`, int]. Default None.
:param padding_side: Optional[str]. `left` or `right`. Default None.
:param is_llm: bool. Default False
"""
Expand Down

0 comments on commit ca9e956

Please sign in to comment.