Skip to content

Commit

Permalink
Test extract mutations working again and added tests for the new add …
Browse files Browse the repository at this point in the history
…and remove contract mutations.
  • Loading branch information
JSv4 committed Jun 2, 2024
1 parent 4d923f0 commit 7dae89e
Show file tree
Hide file tree
Showing 23 changed files with 883 additions and 69 deletions.
144 changes: 132 additions & 12 deletions config/graphql/mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,11 @@ def mutate(root, info, document_id, corpus_id):
ok = False
export = None

return StartDocumentExport(ok=ok, message=message, export=export)
return StartDocumentExport(
ok=ok,
message=message,
export=export
)


class UploadAnnotatedDocument(graphene.Mutation):
Expand Down Expand Up @@ -1362,7 +1366,6 @@ class Arguments:
@login_required
def mutate(root, info, name, description):
fieldset = Fieldset(
owner=info.context.user,
name=name,
description=description,
creator=info.context.user,
Expand Down Expand Up @@ -1405,6 +1408,7 @@ class Arguments:
instructions = graphene.String()
language_model_id = graphene.ID(required=True)
agentic = graphene.Boolean(required=True)
name = graphene.String(required=True)

ok = graphene.Boolean()
message = graphene.String()
Expand All @@ -1415,6 +1419,7 @@ class Arguments:
def mutate(
root,
info,
name,
fieldset_id,
query,
output_type,
Expand Down Expand Up @@ -1452,48 +1457,67 @@ class IOSettings:
lookup_field = "id"

class Arguments:
id = graphene.String(required=True)
id = graphene.ID(required=True)


class StartExtract(graphene.Mutation):
class Arguments:
extract_global_id = graphene.ID(required=True)
extract_id = graphene.ID(required=True)

ok = graphene.Boolean()
message = graphene.String()

@staticmethod
@login_required
def mutate(root, info, extract_global_id):
extract_id = from_global_id(extract_global_id)[1]
def mutate(root, info, extract_id):

# Start celery task to process extract
run_extract.s(extract_id, info.context.user.id).apply_async()
run_extract.s(from_global_id(extract_id)[1], info.context.user.id).apply_async()

return StartExtract(ok=True, message="STARTED!")


class CreateExtract(graphene.Mutation):
"""
Create a new extract. If fieldset_id is provided, attach existing fieldset.
Otherwise, a new fieldset is created. If no name is provided, fieldset name has
form "[Extract name] Fieldset"
"""
class Arguments:
corpus_id = graphene.ID(required=True)
name = graphene.String(required=True)
fieldset_id = graphene.ID(required=True)
fieldset_id = graphene.ID(required=False)
fieldset_name = graphene.String(required=False)
fieldset_description = graphene.String(required=False)

ok = graphene.Boolean()
msg = graphene.String()
obj = graphene.Field(ExtractType)

@staticmethod
@login_required
def mutate(root, info, corpus_id, name, fieldset_id):
def mutate(root, info, corpus_id, name, fieldset_id=None, fieldset_name=None, fieldset_description=None):

corpus = Corpus.objects.get(pk=from_global_id(corpus_id)[1])
fieldset = Fieldset.objects.get(pk=from_global_id(fieldset_id)[1])

if fieldset_id is not None:
fieldset = Fieldset.objects.get(pk=from_global_id(fieldset_id)[1])
else:
if fieldset_name is None:
fieldset_name = f"{name} Fieldset"
fieldset = Fieldset.objects.create(
name=fieldset_name,
description=fieldset_description,
creator=info.context.user,
)
set_permissions_for_obj_to_user(
info.context.user, fieldset, [PermissionTypes.CRUD]
)

extract = Extract(
corpus=corpus,
name=name,
fieldset=fieldset,
owner=info.context.user,
creator=info.context.user,
)
extract.save()
Expand All @@ -1507,7 +1531,7 @@ def mutate(root, info, corpus_id, name, fieldset_id):
class UpdateExtractMutation(DRFMutation):
class IOSettings:
lookup_field = "id"
pk_fields = ["corpus", "fieldset", "owner"]
pk_fields = ["corpus", "fieldset", "creator"]
serializer = ExtractSerializer
model = Extract
graphene_model = ExtractType
Expand All @@ -1520,6 +1544,100 @@ class Arguments:
label_set = graphene.String(required=False)


class AddDocumentsToExtract(DRFMutation):

class Arguments:
document_ids = graphene.List(
graphene.ID,
required=True,
description="List of ids of the documents to add to extract.",
)
extract_id = graphene.ID(required=True, description="Id of corpus to add docs to.")

ok = graphene.Boolean()
message = graphene.String()

@login_required
def mutate(root, info, extract_id, document_ids):

ok = False

try:
user = info.context.user

extract = Extract.objects.get(
Q(pk=from_global_id(extract_id)[1])
& (Q(creator=user) | Q(is_public=True))
)

if extract.finished is not None:
raise ValueError(f"Extract {extract_id} already finished... it cannot be edited.")

doc_pks = list(
map(lambda graphene_id: from_global_id(graphene_id)[1], document_ids)
)
doc_objs = Document.objects.filter(
Q(pk__in=doc_pks) & (Q(creator=user) | Q(is_public=True))
)

extract.documents.add(*doc_objs)

ok = True
message = "Success"

except Exception as e:
message = f"Error assigning docs to corpus: {e}"

return AddDocumentsToExtract(message=message, ok=ok)


class RemoveDocumentsFromExtract(graphene.Mutation):
class Arguments:
extract_id = graphene.ID(
required=True, description="ID of extract to remove documents from."
)
document_ids_to_remove = graphene.List(
graphene.ID,
required=True,
description="List of ids of the docs to remove from extract.",
)

ok = graphene.Boolean()
message = graphene.String()

@login_required
def mutate(root, info, extract_id, document_ids_to_remove):

ok = False

try:
user = info.context.user
extract = Extract.objects.get(
Q(pk=from_global_id(extract_id)[1])
& (Q(creator=user) | Q(is_public=True))
)

if extract.finished is not None:
raise ValueError(f"Extract {extract_id} already finished... it cannot be edited.")

doc_pks = list(
map(
lambda graphene_id: from_global_id(graphene_id)[1],
document_ids_to_remove,
)
)

extract_docs = extract.documents.filter(pk__in=doc_pks)
extract.documents.remove(*extract_docs)
ok = True
message = "Success"

except Exception as e:
message = f"Error on removing docs: {e}"

return RemoveDocumentsFromExtract(message=message, ok=ok)


class DeleteExtract(DRFDeletion):
class IOSettings:
model = Extract
Expand Down Expand Up @@ -1608,3 +1726,5 @@ class Mutation(graphene.ObjectType):
start_extract = StartExtract.Field()
delete_extract = DeleteExtract.Field() # TODO - test
update_extract = UpdateExtractMutation.Field()
add_docs_to_extract = AddDocumentsToExtract.Field()
remove_docs_from_extract = RemoveDocumentsFromExtract.Field()
16 changes: 8 additions & 8 deletions config/graphql/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ def resolve_fieldset(self, info, **kwargs):
return Fieldset.objects.get(Q(id=django_pk) & Q(is_public=True))
else:
return Fieldset.objects.get(
Q(id=django_pk) & (Q(owner=info.context.user) | Q(is_public=True))
Q(id=django_pk) & (Q(creator=info.context.user) | Q(is_public=True))
)

fieldsets = DjangoFilterConnectionField(
Expand All @@ -692,7 +692,7 @@ def resolve_fieldsets(self, info, **kwargs):
return Fieldset.objects.filter(Q(is_public=True))
else:
return Fieldset.objects.filter(
Q(owner=info.context.user) | Q(is_public=True)
Q(creator=info.context.user) | Q(is_public=True)
)

column = relay.Node.Field(ColumnType)
Expand All @@ -707,7 +707,7 @@ def resolve_column(self, info, **kwargs):
else:
return Column.objects.get(
Q(id=django_pk)
& (Q(fieldset__owner=info.context.user) | Q(is_public=True))
& (Q(fieldset__creator=info.context.user) | Q(is_public=True))
)

columns = DjangoFilterConnectionField(ColumnType, filterset_class=ColumnFilter)
Expand All @@ -720,7 +720,7 @@ def resolve_columns(self, info, **kwargs):
return Column.objects.filter(Q(is_public=True))
else:
return Column.objects.filter(
Q(fieldset__owner=info.context.user) | Q(is_public=True)
Q(fieldset__creator=info.context.user) | Q(is_public=True)
)

extract = relay.Node.Field(ExtractType)
Expand All @@ -734,7 +734,7 @@ def resolve_extract(self, info, **kwargs):
return Extract.objects.get(Q(id=django_pk) & Q(is_public=True))
else:
return Extract.objects.get(
Q(id=django_pk) & (Q(owner=info.context.user) | Q(is_public=True))
Q(id=django_pk) & (Q(creator=info.context.user) | Q(is_public=True))
)

extracts = DjangoFilterConnectionField(ExtractType, filterset_class=ExtractFilter)
Expand All @@ -747,7 +747,7 @@ def resolve_extracts(self, info, **kwargs):
return Extract.objects.filter(Q(is_public=True))
else:
return Extract.objects.filter(
Q(owner=info.context.user) | Q(is_public=True)
Q(creator=info.context.user) | Q(is_public=True)
)

datacell = relay.Node.Field(DatacellType)
Expand All @@ -762,7 +762,7 @@ def resolve_datacell(self, info, **kwargs):
else:
return Datacell.objects.get(
Q(id=django_pk)
& (Q(extract__owner=info.context.user) | Q(is_public=True))
& (Q(extract__creator=info.context.user) | Q(is_public=True))
)

datacells = DjangoFilterConnectionField(
Expand All @@ -777,5 +777,5 @@ def resolve_datacells(self, info, **kwargs):
return Datacell.objects.filter(Q(is_public=True))
else:
return Datacell.objects.filter(
Q(extract__owner=info.context.user) | Q(is_public=True)
Q(extract__creator=info.context.user) | Q(is_public=True)
)
3 changes: 1 addition & 2 deletions config/graphql/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ class Meta:
"corpus",
"name",
"fieldset",
"owner",
"owner_id",
"creator",
"creator_id",
"created",
Expand All @@ -61,6 +59,7 @@ class Meta:
model = Column
fields = [
"id",
"name",
"fieldset",
"fieldset_id",
"language_model",
Expand Down
Loading

0 comments on commit 7dae89e

Please sign in to comment.