Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
turboderp committed Sep 1, 2023
2 parents 6020a1b + f8e9d7e commit b8541ad
Showing 1 changed file with 21 additions and 5 deletions.
26 changes: 21 additions & 5 deletions example_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,14 @@ def stream():
built_response += stream_text
return stream_text, False, full_prompt, utilized_prompt, built_response

def leftTrimTokens(text: str, desiredLen: int):

encodedText = tokenizer.encode(text)
if encodedText.shape[-1] <= desiredLen:
return text
else:
return tokenizer.decode(encodedText[:, -desiredLen:])[0]

def oneshot_generation(prompt: str, stop_conditions: list, max_new_tokens: int, gen_settings: ExLlamaGenerator.Settings):

begin_stream(prompt, stop_conditions, max_new_tokens, gen_settings)
Expand Down Expand Up @@ -206,7 +214,7 @@ async def streamInfer(request, ws):
begin_stream(prompt, stopToken, maxNew, gs)
while True:
chunk, eos, x, y, builtResp = stream()
await ws.send(json.dumps({'action':'streamInfer',
await ws.send(json.dumps({'action':request["action"],
'request_id':request['request_id'],
'utilContext':utilized_prompt + builtResp,
'response':builtResp}))
Expand All @@ -225,17 +233,25 @@ async def main(websocket, path):
response = await estimateToken(request, websocket)
await websocket.send(json.dumps({'action':action, 'request_id':reqID, 'response':response}))

if action == "echo":
elif action == "echo":
await websocket.send(json.dumps({'action':action, 'request_id':reqID}))

elif action == "oneShotInfer":
fctx, utlctx, res = await oneShotInfer(request, websocket)
await websocket.send(json.dumps({'action':action, 'request_id':reqID,'utilContext':utlctx, 'response':res}))

elif action == "streamInfer":

elif action == "leftTrim":
prompt = request["text"]
desiredLen = int(request["desiredLen"])
processedPrompt = leftTrimTokens(prompt, desiredLen)
await websocket.send(json.dumps({'action':action, 'request_id':reqID, 'response':processedPrompt}))

else:
utlctx, builtResp= await streamInfer(request, websocket)
await websocket.send(json.dumps({'action':action, 'request_id':reqID,'utilContext':utlctx, 'response':builtResp+'</s>'}))



#except Exception as e:
#print({"error": str(e)})

Expand All @@ -247,7 +263,7 @@ async def main(websocket, path):
model_path = glob.glob(st_pattern)[0]
esTokenizer = SentencePieceProcessor(model_file = tokenizer_path)
config = ExLlamaConfig(model_config_path) # create config from config.json
config.set_auto_map('18.8897,18.8897')
config.set_auto_map('17.615,18.8897')
config.model_path = model_path # supply path to model weights file

model = ExLlama(config) # create ExLlama instance and load the weights
Expand Down

0 comments on commit b8541ad

Please sign in to comment.