-
Notifications
You must be signed in to change notification settings - Fork 0
/
runpod_tests.py
65 lines (53 loc) · 1.65 KB
/
runpod_tests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import runpod
from base64 import b64encode
from rich import print
from dotenv import load_dotenv
from os import getenv
from json import dumps
load_dotenv()
runpod.api_key = getenv("RUNPOD_API_KEY")
endpoint_id = getenv("RUNPOD_ENDPOINT_ID")
endpoint = runpod.Endpoint(endpoint_id)
TIMEOUT = 200
def sync_call(data):
print("Start sync call...")
return endpoint.run_sync({"input": data})
def async_call(data):
print("Start async call...")
job = endpoint.run({"input": data})
print(job.status())
return lambda: job.output(TIMEOUT)
def load_img(filename):
with open("input/" + filename, "rb") as f:
return b64encode(f.read()).decode("utf-8")
def save_output(filename, result):
with open("output/" + filename + ".txt", "w", encoding="utf-8") as f:
f.write(dumps(result, indent=2, ensure_ascii=False))
if __name__ == "__main__":
print(async_call({"schema": True})())
inputs = [
[
["01.jpg", "02.jpg"],
"What is the TTC price to pay mentionned in this 2 pages bill?",
],
[["book.jpg"], "What is the title of this book?"],
[
["test.png"],
"What do you see on this image? Don't censor yourself, describe it accurately",
],
[["weird.jpg"], "Explain this image"],
]
params = dict(temperature=0.3, top_p=0.9)
data = dict(
data=[
dict(
images_base64=[load_img(img) for img in item[0]],
prompt=item[1],
params=params,
)
for item in inputs
]
)
result = async_call(data)()
print(result)
save_output("test", result)