-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathllama_guard_2.py
72 lines (56 loc) · 2.16 KB
/
llama_guard_2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import argparse
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from guardbench import benchmark
def moderate(
conversations: list[list[dict[str, str]]],
tokenizer: AutoTokenizer,
model: AutoModelForCausalLM,
safe_token_id: int,
unsafe_token_id: int,
) -> list[float]:
# Llama Guard does not support conversation starting with the assistant
for i, x in enumerate(conversations):
if x[0]["role"] == "assistant":
conversations[i] = x[1:]
# Apply chat template
input_ids = [tokenizer.apply_chat_template(x) for x in conversations]
# Convert to tensor
input_ids = torch.tensor(input_ids, device=model.device)
# Generate output
output = model.generate(
input_ids=input_ids,
max_new_tokens=5,
output_scores=True,
return_dict_in_generate=True,
pad_token_id=0,
)
# Take logits for the first generated token of each prompt
logits = output.scores[0][:, [safe_token_id, unsafe_token_id]]
# Compute "unsafe" probabilities
return torch.softmax(logits, dim=-1)[:, 1].tolist()
def main(device: str, datasets: list[str], batch_size: int) -> None:
model_id = "meta-llama/Meta-Llama-Guard-2-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
model = model.to(device)
model = model.eval()
benchmark(
moderate=moderate,
model_name="Llama Guard 2",
batch_size=batch_size,
datasets=datasets,
# Moderate kwargs
tokenizer=tokenizer,
model=model,
safe_token_id=tokenizer.encode("safe")[0],
unsafe_token_id=tokenizer.encode("unsafe")[0],
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--device", default="cuda", type=str, help="Device")
parser.add_argument("--datasets", nargs="+", default="all", help="Datasets")
parser.add_argument("--batch_size", default=1, type=int, help="Batch size")
args = parser.parse_args()
with torch.no_grad():
main(args.device, args.datasets, args.batch_size)