Skip to content

Commit

Permalink
Allow string pairs as input to text classification serving
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Jan 26, 2025
1 parent 65bc636 commit 2dfad85
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 2 deletions.
8 changes: 8 additions & 0 deletions lib/bumblebee/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,14 @@ defmodule Bumblebee.Shared do
end
end

def validate_string_or_pairs(input) do
case input do
input when is_binary(input) -> {:ok, input}
{left, right} when is_binary(left) and is_binary(right) -> {:ok, input}
_other -> {:error, "expected a string or a pair of strings, got: #{inspect(input)}"}
end
end

@doc """
Validates that the input is a single value and not a batch.
"""
Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee/text.ex
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ defmodule Bumblebee.Text do
defdelegate translation(model_info, tokenizer, generation_config, opts \\ []),
to: Bumblebee.Text.Translation

@type text_classification_input :: String.t()
@type text_classification_input :: String.t() | {String.t(), String.t()}
@type text_classification_output :: %{predictions: list(text_classification_prediction())}
@type text_classification_prediction :: %{score: number(), label: String.t()}

Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee/text/text_classification.ex
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ defmodule Bumblebee.Text.TextClassification do
|> Nx.Serving.batch_size(batch_size)
|> Nx.Serving.process_options(batch_keys: batch_keys)
|> Nx.Serving.client_preprocessing(fn input ->
{texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string/1)
{texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string_or_pairs/1)

inputs =
Nx.with_default_backend(Nx.BinaryBackend, fn ->
Expand Down

0 comments on commit 2dfad85

Please sign in to comment.