Skip to content

Commit

Permalink
Merge pull request #100 from roboflow/fix/dependencies
Browse files Browse the repository at this point in the history
fix Florence2 dependencies
  • Loading branch information
SkalskiP authored Jan 6, 2025
2 parents 78d0ba0 + 6e6108b commit 778223b
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 574 deletions.
44 changes: 0 additions & 44 deletions .github/workflows/maestro-tests.yml

This file was deleted.

12 changes: 6 additions & 6 deletions maestro/cli/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ def find_training_recipes(app: typer.Typer) -> None:
except Exception:
_warn_about_recipe_import_error(model_name="Florence 2")

try:
from maestro.trainer.models.paligemma.entrypoint import paligemma_app

app.add_typer(paligemma_app, name="paligemma")
except Exception:
_warn_about_recipe_import_error(model_name="PaliGemma")
# try:
# from maestro.trainer.models.paligemma.entrypoint import paligemma_app
#
# app.add_typer(paligemma_app, name="paligemma")
# except Exception:
# _warn_about_recipe_import_error(model_name="PaliGemma")


def _warn_about_recipe_import_error(model_name: str) -> None:
Expand Down
2 changes: 2 additions & 0 deletions maestro/trainer/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from maestro.trainer.common.utils.metrics import (
BLEUMetric,
CharacterErrorRateMetric,
MeanAveragePrecisionMetric,
TranslationErrorRateMetric,
WordErrorRateMetric,
)
105 changes: 97 additions & 8 deletions maestro/trainer/common/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import matplotlib.pyplot as plt
import supervision as sv
from evaluate import load
from jiwer import cer, wer
from PIL import Image
from supervision.metrics.mean_average_precision import MeanAveragePrecision
Expand Down Expand Up @@ -98,12 +99,13 @@ def compute(self, targets: list[str], predictions: list[str]) -> dict[str, float
"""Computes the WER metric based on the targets and predictions.
Args:
targets (List[str]): The ground truth texts.
predictions (List[str]): The predicted texts.
targets (List[str]): The ground truth texts (references), where each element
represents the reference text for the corresponding prediction.
predictions (List[str]): The predicted texts (hypotheses) to be evaluated.
Returns:
Dict[str, float]: A dictionary of computed WER metrics with metric names as
keys and their values.
Dict[str, float]: A dictionary containing the computed WER score, with the
metric name ("wer") as the key and its value as the score.
"""
if len(targets) != len(predictions):
raise ValueError("The number of targets and predictions must be the same.")
Expand Down Expand Up @@ -139,12 +141,13 @@ def compute(self, targets: list[str], predictions: list[str]) -> dict[str, float
"""Computes the CER metric based on the targets and predictions.
Args:
targets (List[str]): The ground truth texts.
predictions (List[str]): The predicted texts.
targets (List[str]): The ground truth texts (references), where each element
represents the reference text for the corresponding prediction.
predictions (List[str]): The predicted texts (hypotheses) to be evaluated.
Returns:
Dict[str, float]: A dictionary of computed CER metrics with metric names as
keys and their values.
Dict[str, float]: A dictionary containing the computed CER score, with the
metric name ("cer") as the key and its value as the score.
"""
if len(targets) != len(predictions):
raise ValueError("The number of targets and predictions must be the same.")
Expand All @@ -159,6 +162,92 @@ def compute(self, targets: list[str], predictions: list[str]) -> dict[str, float
return {"cer": average_cer}


class TranslationErrorRateMetric(BaseMetric):
"""A class used to compute the Translation Error Rate (TER) metric.
TER measures the minimum number of edits (insertions, deletions, substitutions, and shifts)
needed to transform a predicted text into its reference text, making it useful for
evaluating machine translation and other text generation tasks.
"""

name = "translation_error_rate"

def __init__(self, case_sensitive: bool = True):
"""Initialize the TER metric.
Args:
case_sensitive (bool, optional): Whether to perform case-sensitive comparison.
Defaults to True.
"""
self.ter = load("ter")
self.case_sensitive = case_sensitive

def describe(self) -> list[str]:
"""Returns a list of metric names that this class will compute.
Returns:
List[str]: A list of metric names.
"""
return ["ter"]

def compute(self, targets: list[str], predictions: list[str]) -> dict[str, float]:
"""Computes the TER metric based on the targets and predictions.
Args:
targets (List[str]): The ground truth texts (references), where each element
represents the reference text for the corresponding prediction.
predictions (List[str]): The predicted texts (hypotheses) to be evaluated.
Returns:
Dict[str, float]: A dictionary containing the computed TER score, with the
metric name ("ter") as the key and its value as the score.
"""
if len(targets) != len(predictions):
raise ValueError("The number of targets and predictions must be the same.")

results = self.ter.compute(predictions=predictions, references=targets, case_sensitive=self.case_sensitive)
return {"ter": results["score"]}


class BLEUMetric(BaseMetric):
"""A class used to compute the BLEU (Bilingual Evaluation Understudy) metric.
BLEU is a popular metric for evaluating the quality of text predictions in natural
language processing tasks, particularly machine translation. It measures the
similarity between the predicted text and one or more reference texts based on
n-gram precision, brevity penalty, and other factors.
"""

bleu = load("bleu")
name = "bleu"

def describe(self) -> list[str]:
"""Returns a list of metric names that this class will compute.
Returns:
List[str]: A list of metric names.
"""
return ["bleu"]

def compute(self, targets: list[str], predictions: list[str]) -> dict[str, float]:
"""Computes the BLEU metric based on the targets and predictions.
Args:
targets (List[str]): The ground truth texts (references), where each element
represents the reference text for the corresponding prediction.
predictions (List[str]): The predicted texts (hypotheses) to be evaluated.
Returns:
Dict[str, float]: A dictionary containing the computed BLEU score, with the
metric name ("bleu") as the key and its value as the score.
"""
if len(targets) != len(predictions):
raise ValueError("The number of targets and predictions must be the same.")

results = self.bleu.compute(predictions=predictions, references=targets)
return {"bleu": results["bleu"]}


class MetricsTracker:
@classmethod
def init(cls, metrics: list[str]) -> MetricsTracker:
Expand Down
4 changes: 4 additions & 0 deletions maestro/trainer/models/florence_2/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@

from maestro.trainer.common.utils.metrics import (
BaseMetric,
BLEUMetric,
CharacterErrorRateMetric,
MeanAveragePrecisionMetric,
TranslationErrorRateMetric,
WordErrorRateMetric,
)
from maestro.trainer.models.florence_2.checkpoints import (
Expand All @@ -27,6 +29,8 @@
MeanAveragePrecisionMetric.name: MeanAveragePrecisionMetric,
WordErrorRateMetric.name: WordErrorRateMetric,
CharacterErrorRateMetric.name: CharacterErrorRateMetric,
BLEUMetric.name: BLEUMetric,
TranslationErrorRateMetric.name: TranslationErrorRateMetric,
}


Expand Down
2 changes: 1 addition & 1 deletion maestro/trainer/models/florence_2/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def run_predictions(
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
max_new_tokens=2048,
do_sample=False,
num_beams=3,
)
Expand Down
39 changes: 24 additions & 15 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@ build-backend = "setuptools.build_meta"

[project]
name = "maestro"
version = "0.2.0rc5"
version = "0.2.0rc6"
description = "Visual Prompting for Large Multimodal Models (LMMs)"
readme = "README.md"
authors = [
{name = "Roboflow", email = "[email protected]"}
{name = "Piotr Skalski", email = "[email protected]"}
]
maintainers = [
{name = "Piotr Skalski", email = "[email protected]"}
]
readme = "README.md"
license = {file = "LICENSE"}
keywords = ["roboflow","maestro","transformers", "torch", "accelerate", "multimodal", "lmm", "vision", "nlp", "prompting", "vlm"]
requires-python = ">=3.9,<3.13"
Expand All @@ -34,18 +37,19 @@ classifiers = [
]

dependencies = [
"supervision>=0.24,<0.27",
# "accelerate>=0.33",
# "sentencepiece>=0.2.0",
# "peft>=0.12,<0.14",
# "einops>=0.8.0",
# "timm>=1.0.9",

# core
"supervision>=0.24.0rc1",
"requests>=2.31.0,<=2.32.3",
"transformers>=4.44.2,<4.48.0",
"torch~=2.4.0",
"accelerate>=0.33,<1.1",
"sentencepiece~=0.2.0",
"peft>=0.12,<0.14",
"flash-attn~=2.6.3; sys_platform != 'darwin'",
"einops~=0.8.0",
"timm~=1.0.9",
"typer~=0.12.5",
"jiwer~=3.0.4",
"typer>=0.12.5",
"jiwer>=3.0.4",
"evaluate>=0.4.3",
"sacrebleu>=2.3.0"
]

[project.urls]
Expand All @@ -68,7 +72,12 @@ dev = [
"pre-commit>=3.8,<4.1",
"mypy>=1.11.2,<1.14.0",
"ruff>=0.6.5,<0.8.0",
"tox>=4.18.1,<4.24.0"
"tox>=4.18.1,<4.22.0"
]
florence_2 = [
"transformers>=4.43.0",
"torch>=2.4.0",
"flash-attn>=2.7.0.post2; sys_platform != 'darwin'"
]

[project.scripts]
Expand Down
Loading

0 comments on commit 778223b

Please sign in to comment.