-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: The knowledge base vector model can still be vectorized even after unauthorized use #2080
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ | |
from django.db.models import QuerySet, Count | ||
from django.db.models.functions import Substr, Reverse | ||
from django.http import HttpResponse | ||
from django.utils.translation import gettext_lazy as _, gettext | ||
from drf_yasg import openapi | ||
from openpyxl.cell.cell import ILLEGAL_CHARACTERS_RE | ||
from rest_framework import serializers | ||
|
@@ -62,8 +63,8 @@ | |
from embedding.task.embedding import embedding_by_document, delete_embedding_by_document_list, \ | ||
delete_embedding_by_document, update_embedding_dataset_id, delete_embedding_by_paragraph_ids, \ | ||
embedding_by_document_list | ||
from setting.models import Model | ||
from smartdoc.conf import PROJECT_DIR | ||
from django.utils.translation import gettext_lazy as _, gettext | ||
|
||
parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle(), ZipParseQAHandle()] | ||
parse_table_handle_list = [CsvSplitTableHandle(), XlsSplitTableHandle(), XlsxSplitTableHandle()] | ||
|
@@ -716,6 +717,14 @@ def refresh(self, state_list=None, with_valid=True): | |
State.REVOKED.value, State.IGNORED.value] | ||
if with_valid: | ||
self.is_valid(raise_exception=True) | ||
dataset = QuerySet(DataSet).filter(id=self.data.get('dataset_id')).first() | ||
embedding_model_id = dataset.embedding_mode_id | ||
dataset_user_id = dataset.user_id | ||
embedding_model = QuerySet(Model).filter(id=embedding_model_id).first() | ||
if embedding_model is None: | ||
raise AppApiException(500, _('Model does not exist')) | ||
if embedding_model.permission_type == 'PRIVATE' and dataset_user_id != embedding_model.user_id: | ||
raise AppApiException(500, _('No permission to use this model') + f"{embedding_model.name}") | ||
document_id = self.data.get("document_id") | ||
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING, | ||
State.PENDING) | ||
|
@@ -728,7 +737,7 @@ def refresh(self, state_list=None, with_valid=True): | |
TaskType.EMBEDDING, | ||
State.PENDING) | ||
ListenerManagement.get_aggregation_document_status(document_id)() | ||
embedding_model_id = get_embedding_model_id_by_dataset_id(dataset_id=self.data.get('dataset_id')) | ||
|
||
try: | ||
embedding_by_document.delay(document_id, embedding_model_id, state_list) | ||
except AlreadyQueued as e: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The provided code looks generally clean and well-organized. However, there are several points that can be improved or clarified:
Here's a revised version of the def refresh(self, state_list=None, with_valid=True):
if not self.data.get('dataset_id'):
raise AppApiException(400, _("Dataset ID is required"))
if not self.data.get('document_id'):
raise AppApiException(400, _("Document ID is required"))
states = [State.APPROVED.value, State.REVOKED.value, State.IGNORED.value]
if with_valid:
self.is_valid(raise_exception=True)
dataset = QuerySet(DataSet).filter(id=self.data.get('dataset_id')).first()
if not dataset:
raise AppApiException(500, _('Dataset does not exist'))
embedding_model_id = dataset.embedding_mode_id
embedding_model_name = dataset.embedding_model.id
user_id = self.data.get("user_id") or ""
dataset_user_id = dataset.user_id
try:
embedding_model_instance = QuerySet(Model).get(id=embedding_model_id)
except Model.DoesNotExist:
raise AppApiException(500, _('Model does not exist'))
# Check permission type specifically for 'PRIVATE'
if embedding_model_instance.permission_type == 'PRIVATE' and int(user_id) != embedding_model_instance.user_id:
raise AppApiException(
500,
_(
"No permission to use this model"
).format(name=embedding_model_instance.name)
)
ListenerManagement.update_status(QuerySet(Document).filter(id=self.data.get("document_id")),
TaskType.EMBEDDING, State.PENDING)
ListenerManagement.get_aggregation_document_status(self.data.get("document_id"))()
try:
embedding_by_document.delay(self.data.get("document_id"), embedding_model_instance.id, state_list)
except AlreadyQueued as e:
raiseAppApiException(400, _(self.translator.e_message.format(key='embeddings', msg=str(e)))
# Example of additional utility methods for better readability and separation
def validate_data(data):
# Implement validation rules here
if not data.get('dataset_id'):
raise Exception(_("Dataset ID is required"))
if not data.get('document_id'):
raise Exception(_("Document ID is required"))
def get_embedding_model_id_by_dataset_id(dataset_id):
return QuerySet(DataSet).filter(id=dataset_id).values('embedding_mode_id')[0]['embedding_mode_id'] Key Points Made:
These changes should help ensure the code is cleaner and more robust while maintaining its functionality. |
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the provided code, there are a few areas that need improvement or adjustments:
Imports: The import statement for
setting.models
was updated to include only necessary models (AuthOperate
andModel
). However, it's better to keep all imports within their respective sections for clarity.Method Redundancy: In the
delete
method, you have repeated operations. Consider extracting these common parts into a separate function to reduce redundancy.Error Handling: You can improve error handling by using more descriptive exception messages. For example, instead of using
_(_("Failed to send the vectorization task"))
, you could specify which specific component failed (e.g., "Could not initialize the embedding model service").Null Checks: Ensure that all checked variables (
dataset_ID
,embedding_model_id
) are valid before proceeding. Adding checks likeif variable:
would be helpful in catching null values gracefully.Code Duplication: There is some duplication in checking permissions (both in the re-embedding and deletion methods). These should ideally reside in a single method to maintain consistency.
Here’s an improved version of the relevant part:
This change refactors out initializations into
_initialize_operations
, ensures that no object is processed without validating its existence, and provides clear exceptions with meaningful hints when invalid conditions are met. Additionally, it uses consistent naming conventions throughout the codebase where applicable.