Skip to content

Commit

Permalink
Ran linter. Cleaned up print log statements.
Browse files Browse the repository at this point in the history
  • Loading branch information
JSv4 committed May 27, 2024
1 parent 2c40859 commit aa5ad70
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 45 deletions.
15 changes: 5 additions & 10 deletions config/graphql/mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,10 +1338,7 @@ class Arguments:
@staticmethod
@login_required
def mutate(root, info, model):
language_model = LanguageModel(
model=model,
creator=info.context.user
)
language_model = LanguageModel(model=model, creator=info.context.user)
language_model.save()
set_permissions_for_obj_to_user(
info.context.user, language_model, [PermissionTypes.CRUD]
Expand All @@ -1364,7 +1361,7 @@ def mutate(root, info, name, description):
owner=info.context.user,
name=name,
description=description,
creator=info.context.user
creator=info.context.user,
)
fieldset.save()
set_permissions_for_obj_to_user(
Expand Down Expand Up @@ -1401,9 +1398,7 @@ def mutate(
limit_to_label=None,
instructions=None,
):
fieldset = Fieldset.objects.get(
pk=from_global_id(fieldset_id)[1]
)
fieldset = Fieldset.objects.get(pk=from_global_id(fieldset_id)[1])
language_model = LanguageModel.objects.get(
pk=from_global_id(language_model_id)[1]
)
Expand All @@ -1416,7 +1411,7 @@ def mutate(
instructions=instructions,
language_model=language_model,
agentic=agentic,
creator=info.context.user
creator=info.context.user,
)
column.save()
set_permissions_for_obj_to_user(
Expand Down Expand Up @@ -1445,7 +1440,7 @@ def mutate(root, info, corpus_id, name, fieldset_id):
name=name,
fieldset=fieldset,
owner=info.context.user,
creator=info.context.user
creator=info.context.user,
)
extract.save()
set_permissions_for_obj_to_user(
Expand Down
28 changes: 13 additions & 15 deletions opencontractserver/tasks/extract_tasks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import json
import logging

from celery import shared_task
from django.db import transaction
Expand All @@ -8,8 +8,10 @@
from opencontractserver.annotations.models import Annotation
from opencontractserver.extracts.models import Extract, Row
from opencontractserver.types.enums import PermissionTypes
from opencontractserver.utils.permissioning import set_permissions_for_obj_to_user
from opencontractserver.utils.embeddings import calculate_embedding_for_text
from opencontractserver.utils.permissioning import set_permissions_for_obj_to_user

logger = logging.getLogger(__name__)


# Mock these functions for now
Expand All @@ -18,14 +20,13 @@ def agent_fetch_my_definitions(annot):


def extract_for_query(annots, query, output_type):
print(f"Ran extract_for_query")
return None

Check warning on line 23 in opencontractserver/tasks/extract_tasks.py

View check run for this annotation

Codecov / codecov/patch

opencontractserver/tasks/extract_tasks.py#L23

Added line #L23 was not covered by tests


@shared_task
def run_extract(extract_id, user_id):

print(f"Run extract for extract {extract_id}")
logger.info(f"Run extract for extract {extract_id}")

extract = Extract.objects.get(pk=extract_id)

Expand All @@ -37,29 +38,27 @@ def run_extract(extract_id, user_id):
fieldset = extract.fieldset

document_ids = corpus.documents.all().values_list("id", flat=True)
print(f"Document ids: {document_ids}")

for document_id in document_ids:
for column in fieldset.columns.all():

print(f"Processing column {column} for doc {document_id}")

with transaction.atomic():
row = Row.objects.create(
extract=extract,
column=column,
data_definition=column.output_type,
creator_id=user_id
creator_id=user_id,
)
set_permissions_for_obj_to_user(user_id, row, [PermissionTypes.CRUD])

try:
print(f"run_extract() - processing column {column} for {document_id}")
logger.debug(
f"run_extract() - processing column {column} for {document_id}"
)
row.started = timezone.now()
row.save()

output_type = eval(column.output_type)
print(f"output_type: {output_type}")

annotations = Annotation.objects.filter(
document_id=document_id, embedding__isnull=False
Expand All @@ -71,7 +70,6 @@ def run_extract(extract_id, user_id):
)

match_text = column.match_text or column.query
print(f"Match_text: {match_text}")

if match_text:

Expand All @@ -84,18 +82,18 @@ def run_extract(extract_id, user_id):
)[:5]

if column.agentic:
annotations = annotations.union(agent_fetch_my_definitions(annotations))
annotations = annotations.union(
agent_fetch_my_definitions(annotations)
)

print(f"Prepare to extract_for_query annotations {annotations} / column {column.query} / {output_type}")
val = extract_for_query(annotations, column.query, output_type)
print(f"Extracted value: {val}")

row.data = {"data": val}
row.completed = timezone.now()
row.save()

except Exception as e:
print(f"Ran into error: {e}")
logger.error(f"run_extract() - Ran into error: {e}")
row.stacktrace = f"Error processing: {e}"
row.failed = timezone.now()
row.save()

Check warning on line 99 in opencontractserver/tasks/extract_tasks.py

View check run for this annotation

Codecov / codecov/patch

opencontractserver/tasks/extract_tasks.py#L95-L99

Added lines #L95 - L99 were not covered by tests
11 changes: 6 additions & 5 deletions opencontractserver/tests/test_extract_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ def __init__(self, user):
class ExtractsMutationTestCase(TestCase):
def setUp(self):
self.user = User.objects.create_user(
username="testuser",
password="testpassword"
username="testuser", password="testpassword"
)
self.client = Client(schema, context_value=TestContext(self.user))

Expand Down Expand Up @@ -77,12 +76,14 @@ def test_create_fieldset_mutation(self):
)

def test_create_column_mutation(self):
language_model = LanguageModel.objects.create(model="TestModel", creator=self.user)
language_model = LanguageModel.objects.create(
model="TestModel", creator=self.user
)
fieldset = Fieldset.objects.create(
owner=self.user,
name="TestFieldset",
description="Test description",
creator=self.user
creator=self.user,
)

mutation = """
Expand Down Expand Up @@ -121,7 +122,7 @@ def test_start_extract_mutation(self):
owner=self.user,
name="TestFieldset",
description="Test description",
creator=self.user
creator=self.user,
)

mutation = """
Expand Down
9 changes: 4 additions & 5 deletions opencontractserver/tests/test_extract_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,13 @@ def setUp(self):
self.client = Client(schema, context_value=TestContext(self.user))

self.language_model = LanguageModel.objects.create(
model="TestModel",
creator=self.user
model="TestModel", creator=self.user
)
self.fieldset = Fieldset.objects.create(
owner=self.user,
name="TestFieldset",
description="Test description",
creator=self.user
creator=self.user,
)
self.column = Column.objects.create(
creator=self.user,
Expand All @@ -52,14 +51,14 @@ def setUp(self):
name="TestExtract",
fieldset=self.fieldset,
owner=self.user,
creator=self.user
creator=self.user,
)
self.row = Row.objects.create(
extract=self.extract,
column=self.column,
data={"data": "TestData"},
data_definition="str",
creator=self.user
creator=self.user,
)

def test_language_model_query(self):
Expand Down
15 changes: 5 additions & 10 deletions opencontractserver/tests/test_extract_tasks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
from unittest.mock import patch

from django.contrib.auth import get_user_model
Expand Down Expand Up @@ -33,30 +32,29 @@ def setUp(self):
)

self.language_model = LanguageModel.objects.create(
model="TestModel",
creator=self.user
model="TestModel", creator=self.user
)
self.fieldset = Fieldset.objects.create(
owner=self.user,
name="TestFieldset",
description="Test description",
creator=self.user
creator=self.user,
)
self.column = Column.objects.create(
fieldset=self.fieldset,
query="TestQuery",
output_type="str",
language_model=self.language_model,
agentic=True,
creator=self.user
creator=self.user,
)
self.corpus = Corpus.objects.create(title="TestCorpus", creator=self.user)
self.extract = Extract.objects.create(
corpus=self.corpus,
name="TestExtract",
fieldset=self.fieldset,
owner=self.user,
creator=self.user
creator=self.user,
)

pdf_file = ContentFile(
Expand Down Expand Up @@ -89,10 +87,7 @@ def test_run_extract_task(
self.extract.refresh_from_db()
self.assertIsNotNone(self.extract.started)

row = Row.objects.filter(
extract=self.extract,
column=self.column
).first()
row = Row.objects.filter(extract=self.extract, column=self.column).first()
self.assertIsNotNone(row)
self.assertEqual(row.data, {"data": "Mocked extracted data"})
self.assertEqual(row.data_definition, "str")
Expand Down

0 comments on commit aa5ad70

Please sign in to comment.