From d3722a643ec6f12848042e689d344c066802b2e7 Mon Sep 17 00:00:00 2001 From: BAAI-OpenPlatform <107522723+BAAI-OpenPlatform@users.noreply.github.com> Date: Wed, 2 Aug 2023 10:36:13 +0800 Subject: [PATCH] Update aquila.py --- flagai/model/predictor/aquila.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flagai/model/predictor/aquila.py b/flagai/model/predictor/aquila.py index d47bba66..df2966db 100755 --- a/flagai/model/predictor/aquila.py +++ b/flagai/model/predictor/aquila.py @@ -29,8 +29,8 @@ def aquila_generate( total_len = min(2048, max_gen_len + max_prompt_size) - # tokens = torch.full((bsz, total_len), 0).cuda().long() - tokens = torch.full((bsz, total_len), 0).to("cuda:5").long() + tokens = torch.full((bsz, total_len), 0).cuda().long() + #tokens = torch.full((bsz, total_len), 0).to("cuda:5").long() for k, t in enumerate(prompt_tokens): tokens[k, : len(t)] = t.clone().detach().long() input_text_mask = tokens != 0