-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add PromptGuard to safety_utils #608
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,6 +37,7 @@ def main( | |
enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5 | ||
use_fast_kernels: bool = False, # Enable using SDPA from PyTorch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels | ||
enable_llamaguard_content_safety: bool = False, | ||
enable_promptguard_safety: bool = False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. specially for this call, can we set it to true? |
||
**kwargs | ||
): | ||
if prompt_file is not None: | ||
|
@@ -81,6 +82,7 @@ def main( | |
enable_sensitive_topics, | ||
enable_saleforce_content_safety, | ||
enable_llamaguard_content_safety, | ||
enable_promptguard_safety | ||
) | ||
# Safety check of the user prompt | ||
safety_results = [check(dialogs[idx][0]["content"]) for check in safety_checker] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -201,14 +201,55 @@ def __call__(self, output_text, **kwargs): | |
report = result | ||
|
||
return "Llama Guard", is_safe, report | ||
|
||
|
||
# Function to load the PeftModel for performance optimization | ||
class PromptGuardSafetyChecker(object): | ||
|
||
def __init__(self): | ||
from transformers import AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig | ||
model_id = "meta-llama/Prompt-Guard-86M" | ||
|
||
self.tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
self.model = AutoModelForSequenceClassification.from_pretrained(model_id) | ||
|
||
def get_scores(self, text, temperature=1.0, device='cpu'): | ||
from torch.nn.functional import softmax | ||
inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) | ||
inputs = inputs.to(device) | ||
if len(inputs[0]) > 512: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As max_length is 512 this condition should never be true. Can we instead follow the PromptGuard recommendation and split the text into multiple segments which we apply in parallel (batched)? Especially because of the much bigger context length of Llama 3.1. |
||
warnings.warn( | ||
"Input length is > 512 token. PromptGuard check result could be incorrect." | ||
) | ||
with torch.no_grad(): | ||
logits = self.model(**inputs).logits | ||
scaled_logits = logits / temperature | ||
probabilities = softmax(scaled_logits, dim=-1) | ||
|
||
return { | ||
'jailbreak': probabilities[0, 2].item(), | ||
'indirect_injection': (probabilities[0, 1] + probabilities[0, 2]).item() | ||
} | ||
|
||
def __call__(self, text_for_check, **kwargs): | ||
agent_type = kwargs.get('agent_type', AgentType.USER) | ||
if agent_type == AgentType.AGENT: | ||
return "PromptGuard", True, "PromptGuard is not used for model output so checking not carried out" | ||
sentences = text_for_check.split(".") | ||
running_scores = {'jailbreak':0, 'indirect_injection' :0} | ||
for sentence in sentences: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Its probably more efficient to do this batched as commented above. Lets split the prompt in blocks of 512 (with some overlap) and then feed them batched into the model which will be way more efficient than feeding the sentences one by one. |
||
scores = self.get_scores(sentence) | ||
running_scores['jailbreak'] = max([running_scores['jailbreak'],scores['jailbreak']]) | ||
running_scores['indirect_injection'] = max([running_scores['indirect_injection'],scores['indirect_injection']]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we comment that this is not being used for the user dialog? |
||
is_safe = True if running_scores['jailbreak'] < 0.5 else False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we set the bar at 0.5? I think 0.8 or 0.9 would be better based on talks with the team. |
||
report = str(running_scores) | ||
return "PromptGuard", is_safe, report | ||
|
||
|
||
# Function to determine which safety checker to use based on the options selected | ||
def get_safety_checker(enable_azure_content_safety, | ||
enable_sensitive_topics, | ||
enable_salesforce_content_safety, | ||
enable_llamaguard_content_safety): | ||
enable_llamaguard_content_safety, | ||
enable_promptguard_safety): | ||
safety_checker = [] | ||
if enable_azure_content_safety: | ||
safety_checker.append(AzureSaftyChecker()) | ||
|
@@ -218,5 +259,7 @@ def get_safety_checker(enable_azure_content_safety, | |
safety_checker.append(SalesforceSafetyChecker()) | ||
if enable_llamaguard_content_safety: | ||
safety_checker.append(LlamaGuardSafetyChecker()) | ||
if enable_promptguard_safety: | ||
safety_checker.append(PromptGuardSafetyChecker()) | ||
return safety_checker | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the small size of the mode, can we leave it as default true?