-
Notifications
You must be signed in to change notification settings - Fork 80
/
SUT.py
113 lines (86 loc) · 3.8 KB
/
SUT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
###############################################################################
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
###############################################################################
import array
import os
import mlperf_loadgen as lg
import pandas as pd
import threading
import struct
from huggingface_hub import InferenceClient
def except_hook(args):
print(f"Thread failed with error: {args.exc_value}")
os._exit(1)
threading.excepthook = except_hook
def load_dataset(dataset_path):
tok_input = pd.read_pickle(dataset_path)['tok_input'].tolist()
ret = []
for sample in tok_input:
ret.append((len(sample), ','.join(str(token) for token in sample)))
return ret
class Dataset():
def __init__(self, total_sample_count, dataset_path):
self.data = load_dataset(dataset_path)
self.count = min(len(self.data), total_sample_count)
def LoadSamplesToRam(self, sample_list):
pass
def UnloadSamplesFromRam(self, sample_list):
pass
def __del__(self):
pass
class SUT_base():
def __init__(self, args):
self.data_object = Dataset(dataset_path=args.dataset_path,
total_sample_count=args.total_sample_count)
self.qsl = lg.ConstructQSL(self.data_object.count, args.total_sample_count,
self.data_object.LoadSamplesToRam, self.data_object.UnloadSamplesFromRam)
self.tgi_semaphore = threading.Semaphore(args.max_num_threads)
self.client = InferenceClient(args.sut_server)
self.gen_tok_lens = []
def flush_queries(self):
pass
class Server(SUT_base):
def __init__(self, args):
SUT_base.__init__(self, args)
def tgi_request(self, sample):
_, str_input = self.data_object.data[sample.index]
res_stream = self.client.text_generation(
str_input, max_new_tokens=1024, stream=True, details=True)
out = []
for res_token_id, res_token in enumerate(res_stream):
res_token = res_token.token
if res_token_id == 0:
arr = array.array('B', struct.pack('L', res_token.id))
buf_info = arr.buffer_info()
lg.FirstTokenComplete([lg.QuerySampleResponse(
sample.id, buf_info[0], buf_info[1] * arr.itemsize, 1)])
out.append(res_token.id)
arr = array.array('B', struct.pack('L' * len(out), *out))
buf_info = arr.buffer_info()
lg.QuerySamplesComplete([lg.QuerySampleResponse(
sample.id, buf_info[0], buf_info[1] * arr.itemsize, len(out))])
self.gen_tok_lens.append(len(out))
self.tgi_semaphore.release()
def issue_queries(self, query_samples):
for sample in query_samples:
self.tgi_semaphore.acquire()
threading.Thread(target=self.tgi_request, args=[sample]).start()
class Offline(SUT_base):
def __init__(self, args):
SUT_base.__init__(self, args)
def tgi_request(self, sample):
_, str_input = self.data_object.data[sample.index]
res = self.client.text_generation(
str_input, max_new_tokens=1024, stream=False, details=True)
out = [token.id for token in res.details.tokens]
arr = array.array('B', struct.pack('L' * len(out), *out))
buf_info = arr.buffer_info()
lg.QuerySamplesComplete([lg.QuerySampleResponse(
sample.id, buf_info[0], buf_info[1] * arr.itemsize, len(out))])
self.gen_tok_lens.append(len(out))
self.tgi_semaphore.release()
def issue_queries(self, query_samples):
query_samples.sort(key=lambda s: self.data_object.data[s.index][0])
for sample in query_samples:
self.tgi_semaphore.acquire()
threading.Thread(target=self.tgi_request, args=[sample]).start()