Skip to content

Commit

Permalink
recreate to_dict and add relations
Browse files Browse the repository at this point in the history
  • Loading branch information
Benedikt Fuchs committed Jun 19, 2023
1 parent 00a317c commit 142a5f0
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 15 deletions.
41 changes: 27 additions & 14 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,14 @@ def __len__(self) -> int:
def embedding(self):
return self.get_embedding()

def to_dict(self, tag_type: Optional[str] = None):
return {
"text": self.text,
"start_pos": self.start_position,
"end_pos": self.end_position,
"labels": [label.to_dict() for label in self.get_labels(tag_type)],
}


class Relation(_PartOfSentence):
def __new__(self, first: Span, second: Span):
Expand Down Expand Up @@ -664,6 +672,15 @@ def end_position(self) -> int:
def embedding(self):
pass

def to_dict(self, tag_type: Optional[str] = None):
return {
"from_text": self.first.text,
"to_text": self.second.text,
"from_idx": self.first.tokens[0].idx - 1,
"to_idx": self.second.tokens[0].idx - 1,
"labels": [label.to_dict() for label in self.get_labels(tag_type)],
}


class Sentence(DataPoint):
"""A Sentence is a list of tokens and is used to represent a sentence or text fragment."""
Expand Down Expand Up @@ -760,17 +777,17 @@ def __init__(
def unlabeled_identifier(self):
return f'Sentence[{len(self)}]: "{self.text}"'

def get_relations(self, type: str) -> List[Relation]:
def get_relations(self, label_type: Optional[str] = None) -> List[Relation]:
relations: List[Relation] = []
for label in self.get_labels(type):
for label in self.get_labels(label_type):
if isinstance(label.data_point, Relation):
relations.append(label.data_point)
return relations

def get_spans(self, type: str) -> List[Span]:
def get_spans(self, label_type: Optional[str] = None) -> List[Span]:
spans: List[Span] = []
for potential_span in self._known_spans.values():
if isinstance(potential_span, Span) and potential_span.has_label(type):
if isinstance(potential_span, Span) and (label_type is None or potential_span.has_label(label_type)):
spans.append(potential_span)
return sorted(spans)

Expand Down Expand Up @@ -937,16 +954,12 @@ def to_original_text(self) -> str:
).strip()

def to_dict(self, tag_type: Optional[str] = None):
labels = []

if tag_type:
labels = [label.to_dict() for label in self.get_labels(tag_type)]
return {"text": self.to_original_text(), tag_type: labels}

if self.labels:
labels = [label.to_dict() for label in self.labels]

return {"text": self.to_original_text(), "all labels": labels}
return {
"text": self.to_original_text(),
"labels": [label.to_dict() for label in self.get_labels(tag_type) if label.data_point is self],
"entities": [span.to_dict() for span in self.get_spans(tag_type)],
"relations": [relation.to_dict() for relation in self.get_relations(tag_type)],
}

def get_span(self, start: int, stop: int):
span_slice = slice(start, stop)
Expand Down
2 changes: 1 addition & 1 deletion flair/models/relation_classifier_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def _valid_entities(self, sentence: Sentence) -> Iterator[_Entity]:
:return: Valid entities as `_Entity`
"""
for label_type, valid_labels in self.entity_label_types.items():
for entity_span in sentence.get_spans(type=label_type):
for entity_span in sentence.get_spans(label_type=label_type):
entity_label: Label = entity_span.get_label(label_type=label_type)

# Only use entities labelled with the specified labels for each label type
Expand Down

0 comments on commit 142a5f0

Please sign in to comment.