Skip to content

Commit

Permalink
query builder wip
Browse files Browse the repository at this point in the history
  • Loading branch information
lbeurerkellner committed May 9, 2024
1 parent fdb773c commit d64ade8
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
10 changes: 6 additions & 4 deletions src/lmql/api/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,17 @@ def set_decoder(self, decoder='argmax', **kwargs):
self.decoder = (decoder, kwargs)
return self

def set_prompt(self, prompt="What is the capital of France? [ANSWER]"):
def set_prompt(self, prompt: str):
if self.prompt is not None:
raise ValueError("You cannot set multiple prompt values. Please set the entire prompt with a single set_prompt call.")
self.prompt = prompt
return self

def set_model(self, model="gpt2"):
def set_model(self, model: str):
self.model = model
return self

def set_where(self, where="len(TOKENS(ANSWER)) < 10"):
def set_where(self, where: str):
"""
Add a where clause to the query
If a where clause already exists, the new clause is appended with an 'and'
Expand All @@ -46,7 +48,7 @@ def set_where(self, where="len(TOKENS(ANSWER)) < 10"):
self.where = where if self.where is None else f"{self.where} and {where}"
return self

def set_distribution(self, variable="ANSWER", expr='["A", "B"]'):
def set_distribution(self, variable: str, expr: str):
self.distribution_expr = (variable, expr)
return self

Expand Down
8 changes: 4 additions & 4 deletions src/lmql/tests/test_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ def test_query_builder():
query = (lmql.QueryBuilder()
.set_decoder('argmax')
.set_prompt('What is the capital of France? [ANSWER]')
.set_model('gpt2')
.set_model('local:gpt2')
.set_where('len(TOKENS(ANSWER)) < 10')
.set_where('len(TOKENS(ANSWER)) > 2')
.build())

expected = 'argmax "What is the capital of France? [ANSWER]" from "gpt2" where len(TOKENS(ANSWER)) < 10 and len(TOKENS(ANSWER)) > 2'
expected = 'argmax "What is the capital of France? [ANSWER]" from "local:gpt2" where len(TOKENS(ANSWER)) < 10 and len(TOKENS(ANSWER)) > 2'

assert expected==query.query_string, f"Expected: {expected}, got: {query.query_string}"
out = query.run_sync()
Expand All @@ -25,11 +25,11 @@ def test_query_builder_with_dist():
query = (lmql.QueryBuilder()
.set_decoder('argmax')
.set_prompt('What is the capital of France? [ANSWER]')
.set_model('gpt2')
.set_model('local:gpt2')
.set_distribution('ANSWER', '["Paris", "London"]')
.build())

expected = 'argmax "What is the capital of France? [ANSWER]" from "gpt2" distribution ANSWER in ["Paris", "London"]'
expected = 'argmax "What is the capital of France? [ANSWER]" from "local:gpt2" distribution ANSWER in ["Paris", "London"]'

assert expected==query.query_string, f"Expected: {expected}, got: {query.query_string}"
out = query.run_sync()
Expand Down

0 comments on commit d64ade8

Please sign in to comment.