Skip to content

Commit

Permalink
fix: The knowledge base vector model can still be vectorized even aft…
Browse files Browse the repository at this point in the history
…er unauthorized use (#2080)
  • Loading branch information
shaohuzhang1 authored Jan 23, 2025
1 parent 40cfa33 commit 5ebcad7
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
13 changes: 11 additions & 2 deletions apps/dataset/serializers/dataset_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from dataset.task import sync_web_dataset, sync_replace_web_dataset
from embedding.models import SearchMode
from embedding.task import embedding_by_dataset, delete_embedding_by_dataset
from setting.models import AuthOperate
from setting.models import AuthOperate, Model
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _

Expand Down Expand Up @@ -792,6 +792,15 @@ def delete(self):
def re_embedding(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
dataset_id = self.data.get('id')
dataset = QuerySet(DataSet).filter(id=dataset_id).first()
embedding_model_id = dataset.embedding_mode_id
dataset_user_id = dataset.user_id
embedding_model = QuerySet(Model).filter(id=embedding_model_id).first()
if embedding_model is None:
raise AppApiException(500, _('Model does not exist'))
if embedding_model.permission_type == 'PRIVATE' and dataset_user_id != embedding_model.user_id:
raise AppApiException(500, _('No permission to use this model') + f"{embedding_model.name}")
ListenerManagement.update_status(QuerySet(Document).filter(dataset_id=self.data.get('id')),
TaskType.EMBEDDING,
State.PENDING)
Expand All @@ -801,7 +810,7 @@ def re_embedding(self, with_valid=True):
ListenerManagement.get_aggregation_document_status_by_dataset_id(self.data.get('id'))()
embedding_model_id = get_embedding_model_id_by_dataset_id(self.data.get('id'))
try:
embedding_by_dataset.delay(self.data.get('id'), embedding_model_id)
embedding_by_dataset.delay(dataset_id, embedding_model_id)
except AlreadyQueued as e:
raise AppApiException(500, _('Failed to send the vectorization task, please try again later!'))

Expand Down
13 changes: 11 additions & 2 deletions apps/dataset/serializers/document_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from django.db.models import QuerySet, Count
from django.db.models.functions import Substr, Reverse
from django.http import HttpResponse
from django.utils.translation import gettext_lazy as _, gettext
from drf_yasg import openapi
from openpyxl.cell.cell import ILLEGAL_CHARACTERS_RE
from rest_framework import serializers
Expand Down Expand Up @@ -62,8 +63,8 @@
from embedding.task.embedding import embedding_by_document, delete_embedding_by_document_list, \
delete_embedding_by_document, update_embedding_dataset_id, delete_embedding_by_paragraph_ids, \
embedding_by_document_list
from setting.models import Model
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _, gettext

parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle(), ZipParseQAHandle()]
parse_table_handle_list = [CsvSplitTableHandle(), XlsSplitTableHandle(), XlsxSplitTableHandle()]
Expand Down Expand Up @@ -716,6 +717,14 @@ def refresh(self, state_list=None, with_valid=True):
State.REVOKED.value, State.IGNORED.value]
if with_valid:
self.is_valid(raise_exception=True)
dataset = QuerySet(DataSet).filter(id=self.data.get('dataset_id')).first()
embedding_model_id = dataset.embedding_mode_id
dataset_user_id = dataset.user_id
embedding_model = QuerySet(Model).filter(id=embedding_model_id).first()
if embedding_model is None:
raise AppApiException(500, _('Model does not exist'))
if embedding_model.permission_type == 'PRIVATE' and dataset_user_id != embedding_model.user_id:
raise AppApiException(500, _('No permission to use this model') + f"{embedding_model.name}")
document_id = self.data.get("document_id")
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING,
State.PENDING)
Expand All @@ -728,7 +737,7 @@ def refresh(self, state_list=None, with_valid=True):
TaskType.EMBEDDING,
State.PENDING)
ListenerManagement.get_aggregation_document_status(document_id)()
embedding_model_id = get_embedding_model_id_by_dataset_id(dataset_id=self.data.get('dataset_id'))

try:
embedding_by_document.delay(document_id, embedding_model_id, state_list)
except AlreadyQueued as e:
Expand Down

0 comments on commit 5ebcad7

Please sign in to comment.