From 5ebcad7cde83ffe30327df8b73e3f591d2d1bd35 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Thu, 23 Jan 2025 10:53:12 +0800 Subject: [PATCH] fix: The knowledge base vector model can still be vectorized even after unauthorized use (#2080) --- apps/dataset/serializers/dataset_serializers.py | 13 +++++++++++-- apps/dataset/serializers/document_serializers.py | 13 +++++++++++-- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py index fcf126e7cb4..3f51915adc8 100644 --- a/apps/dataset/serializers/dataset_serializers.py +++ b/apps/dataset/serializers/dataset_serializers.py @@ -47,7 +47,7 @@ from dataset.task import sync_web_dataset, sync_replace_web_dataset from embedding.models import SearchMode from embedding.task import embedding_by_dataset, delete_embedding_by_dataset -from setting.models import AuthOperate +from setting.models import AuthOperate, Model from smartdoc.conf import PROJECT_DIR from django.utils.translation import gettext_lazy as _ @@ -792,6 +792,15 @@ def delete(self): def re_embedding(self, with_valid=True): if with_valid: self.is_valid(raise_exception=True) + dataset_id = self.data.get('id') + dataset = QuerySet(DataSet).filter(id=dataset_id).first() + embedding_model_id = dataset.embedding_mode_id + dataset_user_id = dataset.user_id + embedding_model = QuerySet(Model).filter(id=embedding_model_id).first() + if embedding_model is None: + raise AppApiException(500, _('Model does not exist')) + if embedding_model.permission_type == 'PRIVATE' and dataset_user_id != embedding_model.user_id: + raise AppApiException(500, _('No permission to use this model') + f"{embedding_model.name}") ListenerManagement.update_status(QuerySet(Document).filter(dataset_id=self.data.get('id')), TaskType.EMBEDDING, State.PENDING) @@ -801,7 +810,7 @@ def re_embedding(self, with_valid=True): ListenerManagement.get_aggregation_document_status_by_dataset_id(self.data.get('id'))() embedding_model_id = get_embedding_model_id_by_dataset_id(self.data.get('id')) try: - embedding_by_dataset.delay(self.data.get('id'), embedding_model_id) + embedding_by_dataset.delay(dataset_id, embedding_model_id) except AlreadyQueued as e: raise AppApiException(500, _('Failed to send the vectorization task, please try again later!')) diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index a4ba688abbd..2a664882850 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -23,6 +23,7 @@ from django.db.models import QuerySet, Count from django.db.models.functions import Substr, Reverse from django.http import HttpResponse +from django.utils.translation import gettext_lazy as _, gettext from drf_yasg import openapi from openpyxl.cell.cell import ILLEGAL_CHARACTERS_RE from rest_framework import serializers @@ -62,8 +63,8 @@ from embedding.task.embedding import embedding_by_document, delete_embedding_by_document_list, \ delete_embedding_by_document, update_embedding_dataset_id, delete_embedding_by_paragraph_ids, \ embedding_by_document_list +from setting.models import Model from smartdoc.conf import PROJECT_DIR -from django.utils.translation import gettext_lazy as _, gettext parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle(), ZipParseQAHandle()] parse_table_handle_list = [CsvSplitTableHandle(), XlsSplitTableHandle(), XlsxSplitTableHandle()] @@ -716,6 +717,14 @@ def refresh(self, state_list=None, with_valid=True): State.REVOKED.value, State.IGNORED.value] if with_valid: self.is_valid(raise_exception=True) + dataset = QuerySet(DataSet).filter(id=self.data.get('dataset_id')).first() + embedding_model_id = dataset.embedding_mode_id + dataset_user_id = dataset.user_id + embedding_model = QuerySet(Model).filter(id=embedding_model_id).first() + if embedding_model is None: + raise AppApiException(500, _('Model does not exist')) + if embedding_model.permission_type == 'PRIVATE' and dataset_user_id != embedding_model.user_id: + raise AppApiException(500, _('No permission to use this model') + f"{embedding_model.name}") document_id = self.data.get("document_id") ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING, State.PENDING) @@ -728,7 +737,7 @@ def refresh(self, state_list=None, with_valid=True): TaskType.EMBEDDING, State.PENDING) ListenerManagement.get_aggregation_document_status(document_id)() - embedding_model_id = get_embedding_model_id_by_dataset_id(dataset_id=self.data.get('dataset_id')) + try: embedding_by_document.delay(document_id, embedding_model_id, state_list) except AlreadyQueued as e: