From 2dfad85773689938a381a7ecea531ace2b6c814d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Sun, 26 Jan 2025 12:45:47 +0700 Subject: [PATCH] Allow string pairs as input to text classification serving --- lib/bumblebee/shared.ex | 8 ++++++++ lib/bumblebee/text.ex | 2 +- lib/bumblebee/text/text_classification.ex | 2 +- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/lib/bumblebee/shared.ex b/lib/bumblebee/shared.ex index d5b81f19..36dc135f 100644 --- a/lib/bumblebee/shared.ex +++ b/lib/bumblebee/shared.ex @@ -210,6 +210,14 @@ defmodule Bumblebee.Shared do end end + def validate_string_or_pairs(input) do + case input do + input when is_binary(input) -> {:ok, input} + {left, right} when is_binary(left) and is_binary(right) -> {:ok, input} + _other -> {:error, "expected a string or a pair of strings, got: #{inspect(input)}"} + end + end + @doc """ Validates that the input is a single value and not a batch. """ diff --git a/lib/bumblebee/text.ex b/lib/bumblebee/text.ex index b11f47a7..770a2192 100644 --- a/lib/bumblebee/text.ex +++ b/lib/bumblebee/text.ex @@ -288,7 +288,7 @@ defmodule Bumblebee.Text do defdelegate translation(model_info, tokenizer, generation_config, opts \\ []), to: Bumblebee.Text.Translation - @type text_classification_input :: String.t() + @type text_classification_input :: String.t() | {String.t(), String.t()} @type text_classification_output :: %{predictions: list(text_classification_prediction())} @type text_classification_prediction :: %{score: number(), label: String.t()} diff --git a/lib/bumblebee/text/text_classification.ex b/lib/bumblebee/text/text_classification.ex index 22f0541c..be054ae8 100644 --- a/lib/bumblebee/text/text_classification.ex +++ b/lib/bumblebee/text/text_classification.ex @@ -74,7 +74,7 @@ defmodule Bumblebee.Text.TextClassification do |> Nx.Serving.batch_size(batch_size) |> Nx.Serving.process_options(batch_keys: batch_keys) |> Nx.Serving.client_preprocessing(fn input -> - {texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string/1) + {texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string_or_pairs/1) inputs = Nx.with_default_backend(Nx.BinaryBackend, fn ->