From bc1af1fa84f6d0ecd7000f75f65d3f7127121fd5 Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Tue, 5 Nov 2024 20:36:34 +0000 Subject: [PATCH] Skeleton ShieldGemma class --- keras_hub/src/models/gemma/gemma_presets.py | 6 +- keras_hub/src/models/gemma/shieldgemma.py | 80 +++++++++++++++++++ .../src/models/gemma/shieldgemma_test.py | 0 3 files changed, 83 insertions(+), 3 deletions(-) create mode 100644 keras_hub/src/models/gemma/shieldgemma.py create mode 100644 keras_hub/src/models/gemma/shieldgemma_test.py diff --git a/keras_hub/src/models/gemma/gemma_presets.py b/keras_hub/src/models/gemma/gemma_presets.py index ecd22db5b0..48a6c34c24 100644 --- a/keras_hub/src/models/gemma/gemma_presets.py +++ b/keras_hub/src/models/gemma/gemma_presets.py @@ -206,7 +206,7 @@ "metadata": { "description": "2 billion parameter, 26-layer, ShieldGemma model.", "params": 2614341888, - "official_name": "Gemma", + "official_name": "ShieldGemma", "path": "gemma", "model_card": "https://www.kaggle.com/models/google/shieldgemma", }, @@ -216,7 +216,7 @@ "metadata": { "description": "9 billion parameter, 42-layer, ShieldGemma model.", "params": 9241705984, - "official_name": "Gemma", + "official_name": "ShieldGemma", "path": "gemma", "model_card": "https://www.kaggle.com/models/google/shieldgemma", }, @@ -226,7 +226,7 @@ "metadata": { "description": "27 billion parameter, 42-layer, ShieldGemma model.", "params": 27227128320, - "official_name": "Gemma", + "official_name": "ShieldGemma", "path": "gemma", "model_card": "https://www.kaggle.com/models/google/shieldgemma", }, diff --git a/keras_hub/src/models/gemma/shieldgemma.py b/keras_hub/src/models/gemma/shieldgemma.py new file mode 100644 index 0000000000..94a1165df3 --- /dev/null +++ b/keras_hub/src/models/gemma/shieldgemma.py @@ -0,0 +1,80 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.gemma import gemma_causal_lm +from keras_hub.src.models.task import Task + + +class ShieldGemmaViolationProbaility(keras.layers.Layer): + """Relative probabilities for the 'Yes' (violating) and 'No' tokens.""" + + def __init__(self, yes_token_idx, no_token_idx, **kw): + super().__init__(**kw) + self.yes_token_idx = yes_token_idx + self.no_token_idx = no_token_idx + + def call(self, logits, padding_mask): + last_prompt_index = keras.ops.cast( + keras.ops.sum(padding_mask, axis=1) - 1, "int32" + ) + last_logits = keras.ops.take(logits, last_prompt_index, axis=1)[:, 0] + yes_logits = last_logits[:, self.yes_token_idx] + no_logits = last_logits[:, self.no_token_idx] + yes_no_logits = keras.ops.stack((yes_logits, no_logits), axis=1) + return keras.ops.softmax(yes_no_logits, axis=1) + + +@keras_hub_export("keras_hub.models.ShieldGemma") +class ShieldGemma(Task): + """A ShieldGemma model for safety content moderation, built on Gemma 2. + + ShieldGemma is a Gemma 2 variant fine-tuned to detect and predict violations + of four harm types—Harrassment, Hate Speech, Dangerous Content, and + Sexual Content—in text content from a user or model. Architecturally, + the weights are the same as any other Gemma 2 class, but the prediction is + augmented with a final layer that returns the probability that the provided + content violates the harm type specified in the prompt. The probability is + computed as the relative probabilities of the `Yes` (violating) and `No` + (non-violating) tokens at the final prompt token, i.e., is the next most + likley token a yes or a no. + + Links: + + * https://arxiv.org/abs/2407.21772 + * https://ai.google.dev/gemma/docs/shieldgemma/model_card + * https://ai.google.dev/responsible/docs/safeguards/shieldgemma + * https://www.kaggle.com/models/google/shieldgemma + + Args: + gemma: A `keras_hub.models.GemmaCausalLM` initialized with ShieldGemma + weights. + + Examples: + + Coming soon. + """ + + backbone_cls = gemma_causal_lm.GemmaCausalLM.backbone_cls + preprocessor_cls = gemma_causal_lm.GemmaCausalLM.preprocessor_cls + + def __init__(self, gemma: gemma_causal_lm.GemmaCausalLM, **kwargs): + # === Layers === + self.gemma = gemma + self.backbone = self.gemma.backbone + self.preprocessor = self.gemma.preprocessor + self.yes_no_layer = ShieldGemmaViolationProbaility( + yes_token_idx=self.preprocessor.tokenizer.token_to_id("Yes"), + no_token_idx=self.preprocessor.tokenizer.token_to_id("No"), + ) + + # === Functional Model === + inputs = self.gemma.input + logits = self.gemma(inputs) + outputs = self.yes_no_layer(logits, inputs["padding_mask"]) + super().__init__(inputs=inputs, outputs=outputs, **kwargs) + + @classmethod + def from_preset(cls, **kwargs): + """Instantiate a `keras_hub.models.ShieldGemma` from a model preset.""" + gemma = gemma_causal_lm.GemmaCausalLM.from_preset(**kwargs) + return cls(gemma) diff --git a/keras_hub/src/models/gemma/shieldgemma_test.py b/keras_hub/src/models/gemma/shieldgemma_test.py new file mode 100644 index 0000000000..e69de29bb2