Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add inference to ci #33

Merged
merged 15 commits into from
Feb 15, 2024
4 changes: 4 additions & 0 deletions .github/workflows/ci-cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,7 @@ jobs:
- name: Check training results
run: |
python ci/check_loss.py

- name: Check inference results
run: |
python ci/check_inference.py
21 changes: 21 additions & 0 deletions ci/check_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import subprocess


if __name__ == "__main__":

with open(".last_run_name", "r") as f:
run_name = f.read().strip()

prompt = """[INST] Using the schema context below, generate a SQL query that answers the question.
CREATE TABLE head (age INTEGER)
How many heads of the departments are older than 56 ? [/INST] """

p = subprocess.Popen(["modal", "run", "src.inference", "--run-folder", f"/runs/{run_name}", "--prompt", prompt], stdout=subprocess.PIPE)
output = ""

for line in iter(p.stdout.readline, b''):
output += line.decode()
print(line.decode())

print("Asserting that the output contains the expected SQL query")
assert "[SQL] SELECT" in output and "[/SQL]" in output
5 changes: 4 additions & 1 deletion ci/check_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,8 @@
train_loss = float(results["TrainingLoss"].iloc[-1])
val_loss = float(results["ValidationLoss"].iloc[-1])

# Arbitrary threshold
max_loss = 10 if b"Mixtral" in contents else 0.25

print(f"Loss: {train_loss:.2f} (training), {val_loss:.2f} (validation)")
sys.exit(val_loss > 0.25) # Arbitrary threshold
sys.exit(val_loss > max_loss)
9 changes: 7 additions & 2 deletions ci/prep_for_ci.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,15 @@
@click.option("--data")
def main(config: str, data: str):
"""Set the config to train for only one epoch and truncate the dataset."""
train_set_size = 1000
val_set_size = 64
with open(config) as f:
cfg = yaml.safe_load(f.read())

if cfg["sample_packing"]:
train_set_size = 2048
else:
train_set_size = 1024
val_set_size = 64

cfg["val_set_size"] = val_set_size
cfg["num_epochs"] = 1
cfg.pop("eval_steps", None) # Evaluate once at the end of the epoch
Expand Down
2 changes: 1 addition & 1 deletion config/mixtral.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
base_model: mistralai/Mixtral-8x7B-v0.1
base_model: mistralai/Mixtral-8x7B-Instruct-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code: true
Expand Down
18 changes: 11 additions & 7 deletions src/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,14 @@ async def completion(self, input: str):


@stub.local_entrypoint()
def inference_main(run_folder: str):
text = input(
"Enter a prompt (including the prompt template, e.g. [INST] ... [/INST]):\n"
)
print("Loading model ...")
for chunk in Inference(f"{run_folder}/lora-out/merged").completion.remote_gen(text):
print(chunk, end="")
def inference_main(run_folder: str, prompt: str = ""):
if prompt:
for chunk in Inference(f"{run_folder}/lora-out/merged").completion.remote_gen(prompt):
print(chunk, end="")
else:
prompt = input(
"Enter a prompt (including the prompt template, e.g. [INST] ... [/INST]):\n"
)
print("Loading model ...")
for chunk in Inference(f"{run_folder}/lora-out/merged").completion.remote_gen(prompt):
print(chunk, end="")
Loading