Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: The knowledge base vector model can still be vectorized even after unauthorized use #2080

Merged
merged 1 commit into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions apps/dataset/serializers/dataset_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from dataset.task import sync_web_dataset, sync_replace_web_dataset
from embedding.models import SearchMode
from embedding.task import embedding_by_dataset, delete_embedding_by_dataset
from setting.models import AuthOperate
from setting.models import AuthOperate, Model
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _

Expand Down Expand Up @@ -792,6 +792,15 @@ def delete(self):
def re_embedding(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
dataset_id = self.data.get('id')
dataset = QuerySet(DataSet).filter(id=dataset_id).first()
embedding_model_id = dataset.embedding_mode_id
dataset_user_id = dataset.user_id
embedding_model = QuerySet(Model).filter(id=embedding_model_id).first()
if embedding_model is None:
raise AppApiException(500, _('Model does not exist'))
if embedding_model.permission_type == 'PRIVATE' and dataset_user_id != embedding_model.user_id:
raise AppApiException(500, _('No permission to use this model') + f"{embedding_model.name}")
ListenerManagement.update_status(QuerySet(Document).filter(dataset_id=self.data.get('id')),
TaskType.EMBEDDING,
State.PENDING)
Expand All @@ -801,7 +810,7 @@ def re_embedding(self, with_valid=True):
ListenerManagement.get_aggregation_document_status_by_dataset_id(self.data.get('id'))()
embedding_model_id = get_embedding_model_id_by_dataset_id(self.data.get('id'))
try:
embedding_by_dataset.delay(self.data.get('id'), embedding_model_id)
embedding_by_dataset.delay(dataset_id, embedding_model_id)
except AlreadyQueued as e:
raise AppApiException(500, _('Failed to send the vectorization task, please try again later!'))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the provided code, there are a few areas that need improvement or adjustments:

  1. Imports: The import statement for setting.models was updated to include only necessary models (AuthOperate and Model). However, it's better to keep all imports within their respective sections for clarity.

  2. Method Redundancy: In the delete method, you have repeated operations. Consider extracting these common parts into a separate function to reduce redundancy.

  3. Error Handling: You can improve error handling by using more descriptive exception messages. For example, instead of using _(_("Failed to send the vectorization task")), you could specify which specific component failed (e.g., "Could not initialize the embedding model service").

  4. Null Checks: Ensure that all checked variables (dataset_ID, embedding_model_id) are valid before proceeding. Adding checks like if variable: would be helpful in catching null values gracefully.

  5. Code Duplication: There is some duplication in checking permissions (both in the re-embedding and deletion methods). These should ideally reside in a single method to maintain consistency.

Here’s an improved version of the relevant part:

# Improved Method Implementation

def _initialize_operations(self):
    # Initialization steps here
    pass  # Replace with actual initialization logic

def re_embedding(self, with_valid=True):
    if with_valid:
        self.is_valid(raise_exception=True)

    self._initialize_operations()

    # Fetching dataset and other required objects
    dataset_id = self.data.get('id')
    dataset = QuerySet(DataSet).filter(id=dataset_id).first()
    embedding_model_id = dataset.embedding_mode_id if dataset else None
    dataset_user_id = dataset.user_id if dataset else None
    embedding_model = (
        QuerySet(Model).filter(id=embedding_model_id).first() if embedding_model_id else None
    )

    if not embedding_model:
        raise AppApiException(
            status_code=500,
            message=_("Embedding model does not exist"),
            hint="Check the ID of the data set.",
        )

    if (
        embedding_model.permission_type == "PRIVATE"
        and dataset_user_id != embedding_model.user_id
    ):
        raise AppApiException(
            status_code=500,
            message=_(
                "No permission to use this private model."
            ) + f" {embedding_model.name}",
        )

    ListenerManagement.update_status(
        QuerySet(Document).filter(dataset_id=self.data.get('id')),
        TaskType.EMBEDDING,
        State.PENDING,
    )
    
    ListenerManagement.get_aggregation_document_status_by_dataset_id(self.data.get('id'))()
    
    embedding_model_id = get_embedding_model_id_by_dataset_id(self.data.get('id'))
    
    try:
        embedding_by_dataset.delay(dataset_id, embedding_model_id)
    except AlreadyQueued as e:
        raise AppApiException(
            status_code=500,
            message=_("Failed to enqueue the vectorization task, please retry later."),
        )

This change refactors out initializations into _initialize_operations, ensures that no object is processed without validating its existence, and provides clear exceptions with meaningful hints when invalid conditions are met. Additionally, it uses consistent naming conventions throughout the codebase where applicable.

Expand Down
13 changes: 11 additions & 2 deletions apps/dataset/serializers/document_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from django.db.models import QuerySet, Count
from django.db.models.functions import Substr, Reverse
from django.http import HttpResponse
from django.utils.translation import gettext_lazy as _, gettext
from drf_yasg import openapi
from openpyxl.cell.cell import ILLEGAL_CHARACTERS_RE
from rest_framework import serializers
Expand Down Expand Up @@ -62,8 +63,8 @@
from embedding.task.embedding import embedding_by_document, delete_embedding_by_document_list, \
delete_embedding_by_document, update_embedding_dataset_id, delete_embedding_by_paragraph_ids, \
embedding_by_document_list
from setting.models import Model
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _, gettext

parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle(), ZipParseQAHandle()]
parse_table_handle_list = [CsvSplitTableHandle(), XlsSplitTableHandle(), XlsxSplitTableHandle()]
Expand Down Expand Up @@ -716,6 +717,14 @@ def refresh(self, state_list=None, with_valid=True):
State.REVOKED.value, State.IGNORED.value]
if with_valid:
self.is_valid(raise_exception=True)
dataset = QuerySet(DataSet).filter(id=self.data.get('dataset_id')).first()
embedding_model_id = dataset.embedding_mode_id
dataset_user_id = dataset.user_id
embedding_model = QuerySet(Model).filter(id=embedding_model_id).first()
if embedding_model is None:
raise AppApiException(500, _('Model does not exist'))
if embedding_model.permission_type == 'PRIVATE' and dataset_user_id != embedding_model.user_id:
raise AppApiException(500, _('No permission to use this model') + f"{embedding_model.name}")
document_id = self.data.get("document_id")
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING,
State.PENDING)
Expand All @@ -728,7 +737,7 @@ def refresh(self, state_list=None, with_valid=True):
TaskType.EMBEDDING,
State.PENDING)
ListenerManagement.get_aggregation_document_status(document_id)()
embedding_model_id = get_embedding_model_id_by_dataset_id(dataset_id=self.data.get('dataset_id'))

try:
embedding_by_document.delay(document_id, embedding_model_id, state_list)
except AlreadyQueued as e:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The provided code looks generally clean and well-organized. However, there are several points that can be improved or clarified:

  1. Use of gettext_lazy: It's good practice to use gettext_lazy for translations, especially in Django views. This ensures that translations are handled correctly when using the application during multiple requests.

  2. Code DRY (Don't Repeat Yourself): The function refresh() does similar tasks repeatedly. Consider extracting common logic into helper methods or functions to make the code more maintainable.

  3. Validation Checks: Ensure all required fields are validated before proceeding with further operations. For example, you might want to add checks for document_id, state_list, etc.

  4. Error Handling: Improve error handling by providing more informative messages and ensuring consistent formatting across exception blocks.

  5. Logging: Add logging statements at appropriate places to help debug issues if necessary.

Here's a revised version of the refresh() method with some improvements:

def refresh(self, state_list=None, with_valid=True):
    if not self.data.get('dataset_id'):
        raise AppApiException(400, _("Dataset ID is required"))
    
    if not self.data.get('document_id'):
        raise AppApiException(400, _("Document ID is required"))

    states = [State.APPROVED.value, State.REVOKED.value, State.IGNORED.value]
    if with_valid:
        self.is_valid(raise_exception=True)

    dataset = QuerySet(DataSet).filter(id=self.data.get('dataset_id')).first()
    if not dataset:
        raise AppApiException(500, _('Dataset does not exist'))

    embedding_model_id = dataset.embedding_mode_id
    embedding_model_name = dataset.embedding_model.id
    user_id = self.data.get("user_id") or ""

    dataset_user_id = dataset.user_id

    try:
        embedding_model_instance = QuerySet(Model).get(id=embedding_model_id)
    except Model.DoesNotExist:
        raise AppApiException(500, _('Model does not exist'))

    # Check permission type specifically for 'PRIVATE'
    if embedding_model_instance.permission_type == 'PRIVATE' and int(user_id) != embedding_model_instance.user_id:
        raise AppApiException(
            500,
            _(
                "No permission to use this model"
            ).format(name=embedding_model_instance.name)
        )

    ListenerManagement.update_status(QuerySet(Document).filter(id=self.data.get("document_id")), 
                                   TaskType.EMBEDDING, State.PENDING)
    ListenerManagement.get_aggregation_document_status(self.data.get("document_id"))()

    try:
        embedding_by_document.delay(self.data.get("document_id"), embedding_model_instance.id, state_list)
    except AlreadyQueued as e:
        raiseAppApiException(400, _(self.translator.e_message.format(key='embeddings', msg=str(e)))

# Example of additional utility methods for better readability and separation
def validate_data(data):
    # Implement validation rules here
    if not data.get('dataset_id'):
        raise Exception(_("Dataset ID is required"))
    if not data.get('document_id'):
        raise Exception(_("Document ID is required"))

def get_embedding_model_id_by_dataset_id(dataset_id):
    return QuerySet(DataSet).filter(id=dataset_id).values('embedding_mode_id')[0]['embedding_mode_id']

Key Points Made:

  1. Required Field Validation:

    • Added validations for dataset_id and document_id.
  2. Consistent Error Messages:

    • Used _(...) for translation strings, which is recommended for internationalization.
  3. Improved Code Readability:

    • Extracted common logic into a separate method (validate_data) to improve modularity.
    • Simplified error message construction where applicable.

These changes should help ensure the code is cleaner and more robust while maintaining its functionality.

Expand Down
Loading