-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathrun-gemini.py
103 lines (97 loc) · 3.64 KB
/
run-gemini.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import argparse
import typing
import time
import vertexai
from vertexai.generative_models import GenerativeModel, SafetySetting
import pfgen
def callback(
tasks: typing.List[typing.Dict[str, str]], params: typing.Dict[str, typing.Any]
) -> typing.Iterator[typing.Optional[str]]:
model = params["model"].split("/")[-1]
assert params["mode"] == "qa"
project = os.environ.get("VERTEXAI_PROJECT", "")
location = os.environ.get("VERTEXAI_LOCATION", "us-central1")
assert project and location
vertexai.init(project=project, location=location)
model = GenerativeModel(model)
for task in tasks:
for trial in range(10):
try:
responses = model.generate_content(
[task["prompt"]],
generation_config={
"max_output_tokens": params.get("max_tokens", 500),
"temperature": params.get("temperature", 1.0),
"top_p": params.get("top_p", 1.0),
},
safety_settings=[
SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold=SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH,
),
SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH,
),
SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold=SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH,
),
SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold=SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH,
),
],
)
if params.get("multi_choice", False):
yield responses.candidates[0].content.parts[-1].text
else:
yield responses.text
except Exception as e:
print(f"API Error: {e}")
if trial < 5 and f"{e}".startswith("429"):
print("Rate limited, retrying after 20 seconds...")
time.sleep(20)
continue
yield None
break
if __name__ == "__main__":
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--mode",
type=str,
default="qa",
choices=["chat", "qa"],
help="Which chat template to use.",
)
parser.add_argument(
"--model",
type=str,
default="gemini-1.5-flash-001",
help="Gemini model name.",
)
parser.add_argument(
"--multi-choice",
action="store_true",
help="Use multi-choice generation.",
)
parser.add_argument(
"--temperature", type=float, default=1.0, help="Temperature for sampling."
)
parser.add_argument(
"--num-trials", type=int, default=10, help="Number of trials to run."
)
args = parser.parse_args()
pfgen.run_tasks(
args.mode,
callback,
engine="gemini",
model="google/" + args.model.split("/")[-1],
multi_choice=args.multi_choice,
temperature=args.temperature,
num_trials=args.num_trials,
max_tokens=3000,
)