diff --git a/config/graphql/mutations.py b/config/graphql/mutations.py index e03e4846..c7c892d7 100644 --- a/config/graphql/mutations.py +++ b/config/graphql/mutations.py @@ -1338,10 +1338,7 @@ class Arguments: @staticmethod @login_required def mutate(root, info, model): - language_model = LanguageModel( - model=model, - creator=info.context.user - ) + language_model = LanguageModel(model=model, creator=info.context.user) language_model.save() set_permissions_for_obj_to_user( info.context.user, language_model, [PermissionTypes.CRUD] @@ -1364,7 +1361,7 @@ def mutate(root, info, name, description): owner=info.context.user, name=name, description=description, - creator=info.context.user + creator=info.context.user, ) fieldset.save() set_permissions_for_obj_to_user( @@ -1401,9 +1398,7 @@ def mutate( limit_to_label=None, instructions=None, ): - fieldset = Fieldset.objects.get( - pk=from_global_id(fieldset_id)[1] - ) + fieldset = Fieldset.objects.get(pk=from_global_id(fieldset_id)[1]) language_model = LanguageModel.objects.get( pk=from_global_id(language_model_id)[1] ) @@ -1416,7 +1411,7 @@ def mutate( instructions=instructions, language_model=language_model, agentic=agentic, - creator=info.context.user + creator=info.context.user, ) column.save() set_permissions_for_obj_to_user( @@ -1445,7 +1440,7 @@ def mutate(root, info, corpus_id, name, fieldset_id): name=name, fieldset=fieldset, owner=info.context.user, - creator=info.context.user + creator=info.context.user, ) extract.save() set_permissions_for_obj_to_user( diff --git a/opencontractserver/tasks/extract_tasks.py b/opencontractserver/tasks/extract_tasks.py index d078782e..c7b14868 100644 --- a/opencontractserver/tasks/extract_tasks.py +++ b/opencontractserver/tasks/extract_tasks.py @@ -1,4 +1,4 @@ -import json +import logging from celery import shared_task from django.db import transaction @@ -8,8 +8,10 @@ from opencontractserver.annotations.models import Annotation from opencontractserver.extracts.models import Extract, Row from opencontractserver.types.enums import PermissionTypes -from opencontractserver.utils.permissioning import set_permissions_for_obj_to_user from opencontractserver.utils.embeddings import calculate_embedding_for_text +from opencontractserver.utils.permissioning import set_permissions_for_obj_to_user + +logger = logging.getLogger(__name__) # Mock these functions for now @@ -18,14 +20,13 @@ def agent_fetch_my_definitions(annot): def extract_for_query(annots, query, output_type): - print(f"Ran extract_for_query") return None @shared_task def run_extract(extract_id, user_id): - print(f"Run extract for extract {extract_id}") + logger.info(f"Run extract for extract {extract_id}") extract = Extract.objects.get(pk=extract_id) @@ -37,29 +38,27 @@ def run_extract(extract_id, user_id): fieldset = extract.fieldset document_ids = corpus.documents.all().values_list("id", flat=True) - print(f"Document ids: {document_ids}") for document_id in document_ids: for column in fieldset.columns.all(): - print(f"Processing column {column} for doc {document_id}") - with transaction.atomic(): row = Row.objects.create( extract=extract, column=column, data_definition=column.output_type, - creator_id=user_id + creator_id=user_id, ) set_permissions_for_obj_to_user(user_id, row, [PermissionTypes.CRUD]) try: - print(f"run_extract() - processing column {column} for {document_id}") + logger.debug( + f"run_extract() - processing column {column} for {document_id}" + ) row.started = timezone.now() row.save() output_type = eval(column.output_type) - print(f"output_type: {output_type}") annotations = Annotation.objects.filter( document_id=document_id, embedding__isnull=False @@ -71,7 +70,6 @@ def run_extract(extract_id, user_id): ) match_text = column.match_text or column.query - print(f"Match_text: {match_text}") if match_text: @@ -84,18 +82,18 @@ def run_extract(extract_id, user_id): )[:5] if column.agentic: - annotations = annotations.union(agent_fetch_my_definitions(annotations)) + annotations = annotations.union( + agent_fetch_my_definitions(annotations) + ) - print(f"Prepare to extract_for_query annotations {annotations} / column {column.query} / {output_type}") val = extract_for_query(annotations, column.query, output_type) - print(f"Extracted value: {val}") row.data = {"data": val} row.completed = timezone.now() row.save() except Exception as e: - print(f"Ran into error: {e}") + logger.error(f"run_extract() - Ran into error: {e}") row.stacktrace = f"Error processing: {e}" row.failed = timezone.now() row.save() diff --git a/opencontractserver/tests/test_extract_mutations.py b/opencontractserver/tests/test_extract_mutations.py index cd25e2c0..94974e8f 100644 --- a/opencontractserver/tests/test_extract_mutations.py +++ b/opencontractserver/tests/test_extract_mutations.py @@ -20,8 +20,7 @@ def __init__(self, user): class ExtractsMutationTestCase(TestCase): def setUp(self): self.user = User.objects.create_user( - username="testuser", - password="testpassword" + username="testuser", password="testpassword" ) self.client = Client(schema, context_value=TestContext(self.user)) @@ -77,12 +76,14 @@ def test_create_fieldset_mutation(self): ) def test_create_column_mutation(self): - language_model = LanguageModel.objects.create(model="TestModel", creator=self.user) + language_model = LanguageModel.objects.create( + model="TestModel", creator=self.user + ) fieldset = Fieldset.objects.create( owner=self.user, name="TestFieldset", description="Test description", - creator=self.user + creator=self.user, ) mutation = """ @@ -121,7 +122,7 @@ def test_start_extract_mutation(self): owner=self.user, name="TestFieldset", description="Test description", - creator=self.user + creator=self.user, ) mutation = """ diff --git a/opencontractserver/tests/test_extract_queries.py b/opencontractserver/tests/test_extract_queries.py index d64eb4e0..abb41850 100644 --- a/opencontractserver/tests/test_extract_queries.py +++ b/opencontractserver/tests/test_extract_queries.py @@ -29,14 +29,13 @@ def setUp(self): self.client = Client(schema, context_value=TestContext(self.user)) self.language_model = LanguageModel.objects.create( - model="TestModel", - creator=self.user + model="TestModel", creator=self.user ) self.fieldset = Fieldset.objects.create( owner=self.user, name="TestFieldset", description="Test description", - creator=self.user + creator=self.user, ) self.column = Column.objects.create( creator=self.user, @@ -52,14 +51,14 @@ def setUp(self): name="TestExtract", fieldset=self.fieldset, owner=self.user, - creator=self.user + creator=self.user, ) self.row = Row.objects.create( extract=self.extract, column=self.column, data={"data": "TestData"}, data_definition="str", - creator=self.user + creator=self.user, ) def test_language_model_query(self): diff --git a/opencontractserver/tests/test_extract_tasks.py b/opencontractserver/tests/test_extract_tasks.py index 17d4caff..a3c71c61 100644 --- a/opencontractserver/tests/test_extract_tasks.py +++ b/opencontractserver/tests/test_extract_tasks.py @@ -1,4 +1,3 @@ -import json from unittest.mock import patch from django.contrib.auth import get_user_model @@ -33,14 +32,13 @@ def setUp(self): ) self.language_model = LanguageModel.objects.create( - model="TestModel", - creator=self.user + model="TestModel", creator=self.user ) self.fieldset = Fieldset.objects.create( owner=self.user, name="TestFieldset", description="Test description", - creator=self.user + creator=self.user, ) self.column = Column.objects.create( fieldset=self.fieldset, @@ -48,7 +46,7 @@ def setUp(self): output_type="str", language_model=self.language_model, agentic=True, - creator=self.user + creator=self.user, ) self.corpus = Corpus.objects.create(title="TestCorpus", creator=self.user) self.extract = Extract.objects.create( @@ -56,7 +54,7 @@ def setUp(self): name="TestExtract", fieldset=self.fieldset, owner=self.user, - creator=self.user + creator=self.user, ) pdf_file = ContentFile( @@ -89,10 +87,7 @@ def test_run_extract_task( self.extract.refresh_from_db() self.assertIsNotNone(self.extract.started) - row = Row.objects.filter( - extract=self.extract, - column=self.column - ).first() + row = Row.objects.filter(extract=self.extract, column=self.column).first() self.assertIsNotNone(row) self.assertEqual(row.data, {"data": "Mocked extracted data"}) self.assertEqual(row.data_definition, "str")