From 3484f91acefb9832845a97f11d7c3a445e646465 Mon Sep 17 00:00:00 2001 From: Kunal Tiwary Date: Mon, 31 Jul 2023 11:37:24 +0530 Subject: [PATCH 1/3] Added forgot password API, asr prediction json population API --- backend/functions/tasks.py | 209 ++++++++++++++++++++++++++++++++++- backend/functions/urls.py | 4 + backend/functions/utils.py | 135 +++++++++++++++++++++- backend/functions/views.py | 56 ++++++++++ backend/users/models.py | 54 ++++++++- backend/users/serializers.py | 16 +++ backend/users/urls.py | 5 + backend/users/views.py | 69 ++++++++++++ 8 files changed, 542 insertions(+), 6 deletions(-) diff --git a/backend/functions/tasks.py b/backend/functions/tasks.py index fca736313..20dc5bee3 100644 --- a/backend/functions/tasks.py +++ b/backend/functions/tasks.py @@ -6,6 +6,7 @@ from .utils import ( get_batch_translations, get_batch_ocr_predictions, + get_batch_asr_predictions, ) from django.db import transaction, DataError, IntegrityError from dataset.models import DatasetInstance @@ -361,21 +362,66 @@ def generate_ocr_prediction_json( try: ocr_data_items = dataset_models.OCRDocument.objects.filter( instance_id=dataset_instance_id - ).values_list("id", "image_url", "ocr_prediction_json") + ).values_list( + "id", + "metadata_json", + "draft_data_json", + "file_type", + "file_url", + "image_url", + "page_number", + "language", + "ocr_type", + "ocr_domain", + "ocr_transcribed_json", + "ocr_prediction_json", + "image_details_json", + "parent_data", + ) except Exception as e: ocr_data_items = [] # converting the dataset_instance to pandas dataframe. ocr_data_items_df = pd.DataFrame( ocr_data_items, - columns=["id", "image_url", "ocr_prediction_json"], + columns=[ + "id", + "metadata_json", + "draft_data_json", + "file_type", + "file_url", + "image_url", + "page_number", + "language", + "ocr_type", + "ocr_domain", + "ocr_transcribed_json", + "ocr_prediction_json", + "image_details_json", + "parent_data", + ], ) # Check if the dataframe is empty if ocr_data_items_df.shape[0] == 0: raise Exception("The OCR data is empty.") - required_columns = {"id", "image_url", "ocr_prediction_json"} + required_columns = { + "id", + "metadata_json", + "draft_data_json", + "file_type", + "file_url", + "image_url", + "page_number", + "language", + "ocr_type", + "ocr_domain", + "ocr_transcribed_json", + "ocr_prediction_json", + "image_details_json", + "parent_data", + } if not required_columns.issubset(ocr_data_items_df.columns): missing_columns = required_columns - set(ocr_data_items_df.columns) raise ValueError( @@ -408,8 +454,19 @@ def generate_ocr_prediction_json( ocr_document = dataset_models.OCRDocument( instance_id_id=dataset_instance_id, id=curr_id, + metadata_json=row["metadata_json"], + draft_data_json=row["draft_data_json"], + file_type=row["file_type"], + file_url=row["file_url"], image_url=image_url, + page_number=row["page_number"], + language=row["language"], + ocr_type=row["ocr_type"], + ocr_domain=row["ocr_domain"], + ocr_transcribed_json=row["ocr_transcribed_json"], ocr_prediction_json=ocr_predictions_json, + image_details_json=row["image_details_json"], + parent_data=row["parent_data"], ) with transaction.atomic(): ocr_document.save() @@ -430,6 +487,152 @@ def generate_ocr_prediction_json( return f"{success_count} out of {total_count} populated" +@shared_task(bind=True) +def generate_asr_prediction_json( + self, dataset_instance_id, api_type, automate_missing_data_items +): + """Function to generate ASR prediction data and to save to the same data item. + Args: + dataset_instance_id (int): ID of the dataset instance. + api_type (str): Type of API to be used for translation. (default: dhruva_asr) + Example - [dhruva_asr, indic-trans, google, indic-trans-v2, azure, blank] + automate_missing_data_items (bool): "Boolean to translate only missing data items" + """ + # Fetching the data items for the given dataset instance. + success_count, total_count = 0, 0 + try: + asr_data_items = dataset_models.SpeechConversation.objects.filter( + instance_id=dataset_instance_id + ).values_list( + "id", + "metadata_json", + "draft_data_json", + "domain", + "scenario", + "speaker_count", + "speakers_json", + "language", + "transcribed_json", + "machine_transcribed_json", + "audio_url", + "audio_duration", + "reference_raw_transcript", + "prediction_json", + "parent_data", + ) + except Exception as e: + asr_data_items = [] + + # converting the dataset_instance to pandas dataframe. + asr_data_items_df = pd.DataFrame( + asr_data_items, + columns=[ + "id", + "metadata_json", + "draft_data_json", + "domain", + "scenario", + "speaker_count", + "speakers_json", + "language", + "transcribed_json", + "machine_transcribed_json", + "audio_url", + "audio_duration", + "reference_raw_transcript", + "prediction_json", + "parent_data", + ], + ) + + # Check if the dataframe is empty + if asr_data_items_df.shape[0] == 0: + raise Exception("The ASR data is empty.") + + required_columns = { + "id", + "metadata_json", + "draft_data_json", + "domain", + "scenario", + "speaker_count", + "speakers_json", + "language", + "transcribed_json", + "machine_transcribed_json", + "audio_url", + "audio_duration", + "reference_raw_transcript", + "prediction_json", + "parent_data", + } + if not required_columns.issubset(asr_data_items_df.columns): + missing_columns = required_columns - set(asr_data_items_df.columns) + raise ValueError( + f"The following required columns are missing: {missing_columns}" + ) + + # Update the asr_predictions field for each row in the DataFrame + for index, row in asr_data_items_df.iterrows(): + curr_id = row["id"] + if "audio_url" not in row: + print(f"The ASR item with {curr_id} has missing audio_url.") + continue + audio_url = row["audio_url"] + language = row["language"] + + # Considering the case when we should generate predictions for data items + # which already have asr_predictions or not. + if automate_missing_data_items and row["prediction_json"]: + continue + total_count += 1 + asr_predictions = get_batch_asr_predictions( + curr_id, audio_url, api_type, language + ) + if asr_predictions["status"] == "Success": + success_count += 1 + prediction_json = asr_predictions["output"] + + # Updating the asr_prediction_json column and saving in SpeechConversation dataset with the new asr predictions + try: + asr_data_items_df.at[index, "prediction_json"] = prediction_json + asr_document = dataset_models.SpeechConversation( + instance_id_id=dataset_instance_id, + id=curr_id, + metadata_json=row["metadata_json"], + draft_data_json=row["draft_data_json"], + domain=row["domain"], + scenario=row["scenario"], + speaker_count=row["speaker_count"], + speakers_json=row["speakers_json"], + language=row["language"], + transcribed_json=row["transcribed_json"], + machine_transcribed_json=row["machine_transcribed_json"], + audio_url=audio_url, + audio_duration=row["audio_duration"], + reference_raw_transcript=row["reference_raw_transcript"], + prediction_json=prediction_json, + parent_data=row["parent_data"], + ) + with transaction.atomic(): + asr_document.save() + except IntegrityError as e: + # Handling unique constraint violations or other data integrity issues + print(f"Error while saving dataset id- {curr_id}, IntegrityError: {e}") + except DataError as e: + # Handling data-related issues like incorrect data types, etc. + print(f"Error while saving dataset id- {curr_id}, DataError: {e}") + except Exception as e: + # Handling other unexpected exceptions. + print(f"Error while saving dataset id- {curr_id}, Error message: {e}") + + else: + print( + f"The {api_type} API has not generated predictions for data item with id-{curr_id}" + ) + print(f"{success_count} out of {total_count} populated") + + @shared_task(bind=True) def populate_draft_data_json(self, pk, fields_list): try: diff --git a/backend/functions/urls.py b/backend/functions/urls.py index feb97c14e..f2b582137 100644 --- a/backend/functions/urls.py +++ b/backend/functions/urls.py @@ -23,6 +23,10 @@ "schedule_ocr_prediction_json_population", schedule_ocr_prediction_json_population, ), + path( + "schedule_asr_prediction_json_population", + schedule_asr_prediction_json_population, + ), ] # urlpatterns = format_suffix_patterns(urlpatterns) diff --git a/backend/functions/utils.py b/backend/functions/utils.py index 4a0e19b44..07df9516a 100644 --- a/backend/functions/utils.py +++ b/backend/functions/utils.py @@ -1,5 +1,7 @@ import json import os +import re + import requests from dataset import models as dataset_models from google.cloud import translate_v2 as translate @@ -15,6 +17,7 @@ LANG_NAME_TO_CODE_AZURE, ) from google.cloud import vision +from users.utils import LANG_NAME_TO_CODE_ULCA try: from utils.azure_translate import translator_object @@ -461,7 +464,7 @@ def get_batch_translations( def get_batch_ocr_predictions(id, image_url, api_type): - """Function to get the translation for the input sentences using various APIs. + """Function to get the ocr predictions for the images using various APIs. Args: id (int): id of the dataset instance @@ -469,7 +472,7 @@ def get_batch_ocr_predictions(id, image_url, api_type): api_type (str): Type of API to be used for translation. Returns: - dict: Dictionary containing the translated sentences or error message. + dict: Dictionary containing the predictions or error message. """ # checking the API type if api_type == "google": @@ -557,3 +560,131 @@ def ocr_format_conversion(ocr_prediction): ocr_prediction["height"], ) = (x, y, width, height) return ocr_prediction + + +def get_batch_asr_predictions(id, audio_url, api_type, language): + """Function to get the predictions for the input voice notes using various APIs. + + Args: + id (int): id of the dataset instance + audio_url (str): Image whose text is to be predicted. + api_type (str): Type of API to be used for translation. + + Returns: + dict: Dictionary containing predictions or error message. + """ + # checking the API type + if api_type == "dhruva_asr": + asr_predictions = get_batch_asr_predictions_using_dhruva_asr( + id, audio_url, language + ) + else: + raise ValueError(f"{api_type} is an invalid API type") + + if asr_predictions != "": + return {"status": "Success", "output": asr_predictions} + else: + return {"status": "Failure", "output": asr_predictions} + + +def get_batch_asr_predictions_using_dhruva_asr(cur_id, audio_url, language): + url = os.getenv("ASR_DHRUVA_URL") + header = {"Authorization": os.getenv("ASR_DHRUVA_AUTHORIZATION")} + if language == "Hindi": + serviceId = "ai4bharat/conformer-hi-gpu--t4" + languageCode = LANG_NAME_TO_CODE_ULCA[language] + elif language == "English": + serviceId = "ai4bharat/whisper-medium-en--gpu--t4" + languageCode = LANG_NAME_TO_CODE_ULCA[language] + elif language in [ + "Bengali", + "Gujarati", + "Marathi", + "Odia", + "Punjabi", + "Sanskrit", + "Urdu", + ]: + serviceId = "ai4bharat/conformer-multilingual-indo_aryan-gpu--t4" + languageCode = LANG_NAME_TO_CODE_ULCA[language] + elif language in ["Kannada", "Malayalam", "Tamil", "Telugu"]: + serviceId = "ai4bharat/conformer-multilingual-dravidian-gpu--t4" + languageCode = LANG_NAME_TO_CODE_ULCA[language] + else: + print(f"We don't support predictions for {language} language") + return "" + ds = { + "config": { + "serviceId": serviceId, + "language": {"sourceLanguage": languageCode}, + "transcriptionFormat": {"value": "srt"}, + }, + "audio": [{"audioUri": f"{audio_url}"}], + } + try: + response = requests.post(url, headers=header, json=ds, timeout=180) + response_json = response.json() + input_string = response_json["output"][0]["source"] + except requests.exceptions.Timeout: + print(f"The request took too long and timed out for id- {cur_id}.") + return "" + except requests.exceptions.RequestException as e: + print(f"An error occurred for id- {cur_id}: {e}") + return "" + start_time, end_time, texts = asr_extract_start_end_times_and_texts( + "\n" + input_string + ) + if ( + len(start_time) != len(end_time) + or len(start_time) != len(texts) + or len(end_time) != len(texts) + ): + print(f"Improper predictions for asr data item, id - {cur_id}") + return "" + prediction_json = [] + for i in range(len(start_time)): + prediction_json_for_each_entry = { + "speaker_id": 0, + "start": start_time[i], + "end": end_time[i], + "text": texts[i], + } + prediction_json.append(prediction_json_for_each_entry) + return prediction_json + + +# extracting data from the results obtained +def asr_extract_start_end_times_and_texts(input_str): + input_list = input_str.split("\n") + timestamps = [] + texts = [] + time_idx, text_idx = 2, 3 + while text_idx < len(input_list): + timestamps.append(input_list[time_idx]) + time_idx += 4 + texts.append(input_list[text_idx]) + text_idx += 4 + start_time, end_time = asr_convert_start_end_times(timestamps) + return start_time, end_time, texts + + +# converting starting and ending timings +def asr_convert_start_end_times(timestamps): + formatted_start_times = [] + formatted_end_times = [] + for i in range(len(timestamps)): + short_str = ( + re.split(r"[:,\s]", timestamps[i])[:4] + + re.split(r"[:,\s]", timestamps[i])[5:] + ) + h1, m1, s1, ms1, h2, m2, s2, ms2 = short_str + + # Calculate the start time in seconds with milliseconds + start_time_seconds = int(h1) * 3600 + int(m1) * 60 + int(s1) + int(ms1) / 1000.0 + + # Calculate the end time in seconds with milliseconds + end_time_seconds = int(h2) * 3600 + int(m2) * 60 + int(s2) + int(ms2) / 1000.0 + + formatted_start_times.append(f"{start_time_seconds:.3f}") + formatted_end_times.append(f"{end_time_seconds:.3f}") + return formatted_start_times, formatted_end_times diff --git a/backend/functions/views.py b/backend/functions/views.py index 284111f27..2eaea9259 100644 --- a/backend/functions/views.py +++ b/backend/functions/views.py @@ -22,6 +22,7 @@ sentence_text_translate_and_save_translation_pairs, populate_draft_data_json, generate_ocr_prediction_json, + generate_asr_prediction_json, ) from .utils import ( check_conversation_translation_function_inputs, @@ -563,3 +564,58 @@ def schedule_draft_data_json_population(request): ret_dict = {"message": "draft_data_json population started"} ret_status = status.HTTP_200_OK return Response(ret_dict, status=ret_status) + + +@api_view(["POST"]) +def schedule_asr_prediction_json_population(request): + """ + Schedules a ASR prediction population job for a given dataset instance and API type. + + Request Body + { + "dataset_instance_id": , + "organization_id": , + "api_type": , + "automate_missing_data_items": + } + + Response Body + { + "message": + "result": + "status": DjangoStatusCode + } + """ + # Check if the user is the organization owner + result = check_if_particular_organization_owner(request) + if result["status"] in [status.HTTP_403_FORBIDDEN, status.HTTP_404_NOT_FOUND]: + return Response({"error": result["error"]}, status=result["status"]) + + # Fetching request data + try: + dataset_instance_id = request.data["dataset_instance_id"] + except KeyError: + return Response( + {"error": "Please send a dataset_instance_id"}, + status=status.HTTP_400_BAD_REQUEST, + ) + try: + api_type = request.data["api_type"] + except KeyError: + api_type = "dhruva_asr" + try: + automate_missing_data_items = request.data["automate_missing_data_items"] + except KeyError: + automate_missing_data_items = True + + # Calling a function asynchronously to create ocr predictions. + generate_asr_prediction_json.delay( # add delay + dataset_instance_id=dataset_instance_id, + api_type=api_type, + automate_missing_data_items=automate_missing_data_items, + ) + + # Returning response + ret_dict = {"message": "Generating ASR Predictions"} + ret_status = status.HTTP_200_OK + return Response(ret_dict, status=ret_status) diff --git a/backend/users/models.py b/backend/users/models.py index d9c46eab9..dfdd793e7 100644 --- a/backend/users/models.py +++ b/backend/users/models.py @@ -1,3 +1,15 @@ +import os +from smtplib import ( + SMTPAuthenticationError, + SMTPException, + SMTPRecipientsRefused, + SMTPServerDisconnected, +) +import socket +import jwt +from datetime import datetime, timedelta + +from django.core.mail import send_mail from django.db import models from django.contrib.postgres.fields import ArrayField from django.contrib.auth.base_user import AbstractBaseUser @@ -5,6 +17,8 @@ from django.utils import timezone from organizations.models import Organization +from shoonya_backend import settings +from dotenv import load_dotenv from .utils import hash_upload from .managers import UserManager @@ -36,7 +50,7 @@ ("Telugu", "Telugu"), ("Urdu", "Urdu"), ) - +load_dotenv() # Create your models here. # class Language(models.Model): # language = models.CharField( @@ -226,3 +240,41 @@ def is_organization_owner(self): def is_admin(self): return self.role == User.ADMIN + + def send_mail_to_change_password(self, email, key): + sent_token = self.generate_reset_token(key) + prefix = os.getenv("FRONTEND_URL_FOR_RESET_PASSWORD") + link = f"{prefix}/#/forget-password/confirm/{key}/{sent_token}" + try: + send_mail( + "Reset password link for shoonya", + f"Hello! Please click on the following link to reset your password - {link}", + settings.DEFAULT_FROM_EMAIL, + [email], + ) + except SMTPAuthenticationError: + raise Exception( + "Failed to authenticate with the SMTP server. Check your email settings." + ) + except ( + SMTPException, + socket.gaierror, + SMTPRecipientsRefused, + SMTPServerDisconnected, + ) as e: + raise Exception("Failed to send the email. Please try again later.") + + def generate_reset_token(self, user_id): + # Setting token expiration time (2 hours) + expiration_time = datetime.utcnow() + timedelta(hours=2) + secret_key = os.getenv("SECRET_KEY") + + # Creating the payload containing user ID and expiration time + payload = { + "user_id": user_id, + "exp": expiration_time, + } + + # Signing the payload with a secret key + token = jwt.encode(payload, secret_key, algorithm="HS256") + return token diff --git a/backend/users/serializers.py b/backend/users/serializers.py index 16f7dfd1b..127787242 100644 --- a/backend/users/serializers.py +++ b/backend/users/serializers.py @@ -133,3 +133,19 @@ class UserEmailSerializer(serializers.ModelSerializer): class Meta: model = User fields = ["email"] + + +class ChangePasswordWithoutOldPassword(serializers.Serializer): + new_password = serializers.CharField(max_length=128, write_only=True, required=True) + + def validation_checks(self, instance, data): + try: + password_validation.validate_password(data["new_password"], instance) + except password_validation.ValidationError as e: + return " ".join(e.messages) + return "Validation successful" + + def save(self, instance, validated_data): + instance.set_password(validated_data.get("new_password")) + instance.save() + return instance diff --git a/backend/users/urls.py b/backend/users/urls.py index d8a04fcbe..2b744c2bb 100644 --- a/backend/users/urls.py +++ b/backend/users/urls.py @@ -23,4 +23,9 @@ urlpatterns = [ path("", include(router.urls)), path("auth/jwt/create", AuthViewSet.as_view({"post": "login"}), name="login"), + path( + "auth/reset_password", + AuthViewSet.as_view({"post": "reset_password"}), + name="reset_password", + ), ] diff --git a/backend/users/views.py b/backend/users/views.py index a6d770de3..c92ec813b 100644 --- a/backend/users/views.py +++ b/backend/users/views.py @@ -3,6 +3,8 @@ import secrets import string from wsgiref.util import request_uri +import jwt +from jwt import DecodeError, InvalidSignatureError from rest_framework import viewsets, status import re from rest_framework.response import Response @@ -17,6 +19,7 @@ UserUpdateSerializer, LanguageSerializer, ChangePasswordSerializer, + ChangePasswordWithoutOldPassword, ) from organizations.models import Invite, Organization from organizations.serializers import InviteGenerationSerializer @@ -47,8 +50,10 @@ from workspaces.views import WorkspaceCustomViewSet from .utils import generate_random_string, get_role_name from rest_framework_simplejwt.tokens import RefreshToken +from dotenv import load_dotenv regex = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b" +load_dotenv() class InviteViewSet(viewsets.ViewSet): @@ -323,6 +328,70 @@ def login(self, request, *args, **kwargs): status=status.HTTP_200_OK, ) + @permission_classes([AllowAny]) + @swagger_auto_schema(request_body=ChangePasswordWithoutOldPassword) + @action( + detail=True, + methods=["post"], + url_path="reset_password", + url_name="reset_password", + ) + def reset_password(self, request, *args, **kwargs): + """ + User change password functionality + """ + if not request.data.get("new_password"): + try: + email = request.data.get("email") + user = User.objects.get(email=email) + except User.DoesNotExist: + return Response( + {"message": "Incorrect email, User not found"}, + status=status.HTTP_404_NOT_FOUND, + ) + key = user.id + user.send_mail_to_change_password(email, key) + return Response( + { + "message": "Please check your registered email and click on the link to reset your password.", + }, + status=status.HTTP_200_OK, + ) + else: + try: + received_token = request.data.get("token") + user_id = request.data.get("uid") + new_password = request.data.get("new_password") + except KeyError: + raise Exception("Insufficient details") + user = User.objects.get(id=user_id) + try: + secret_key = os.getenv("SECRET_KEY") + decodedToken = jwt.decode(received_token, secret_key, "HS256") + except InvalidSignatureError: + raise Exception( + "The password reset link has expired. Please request a new link." + ) + except DecodeError: + raise Exception( + "Invalid token. Please make sure the token is correct and try again." + ) + + serializer = ChangePasswordWithoutOldPassword(user, request.data) + serializer.is_valid(raise_exception=True) + + validation_response = serializer.validation_checks( + user, request.data + ) # checks for min_length, whether password is similar to user details etc. + if validation_response != "Validation successful": + return Response( + {"message": validation_response}, + status=status.HTTP_400_BAD_REQUEST, + ) + + user = serializer.save(user, request.data) + return Response({"message": "Password changed."}, status=status.HTTP_200_OK) + class UserViewSet(viewsets.ViewSet): permission_classes = (IsAuthenticated,) From 9922794338277ee880b67b9938fb428e1bbe8403 Mon Sep 17 00:00:00 2001 From: Kunal Tiwary Date: Tue, 1 Aug 2023 14:17:18 +0530 Subject: [PATCH 2/3] Added batch sampling specifications to pull new data items endpoint --- backend/projects/views.py | 66 +++++++++++++++++++++++++-------------- 1 file changed, 42 insertions(+), 24 deletions(-) diff --git a/backend/projects/views.py b/backend/projects/views.py index 8bd47f853..4272ab506 100644 --- a/backend/projects/views.py +++ b/backend/projects/views.py @@ -3565,33 +3565,51 @@ def pull_new_items(self, request, pk=None, *args, **kwargs): """ try: project = Project.objects.get(pk=pk) - if project.sampling_mode != FULL: - ret_dict = {"message": "Sampling Mode is not FULL!"} + if project.sampling_mode != BATCH and project.sampling_mode != FULL: + ret_dict = {"message": "Sampling Mode is neither FULL nor BATCH!"} ret_status = status.HTTP_403_FORBIDDEN + return Response(ret_dict, status=ret_status) + # Get serializer with the project user data + try: + serializer = ProjectUsersSerializer(project, many=False) + except User.DoesNotExist: + ret_dict = {"message": "User does not exist!"} + ret_status = status.HTTP_404_NOT_FOUND + return Response(ret_dict, status=ret_status) + # Get project instance and check how many items to pull + project_type = project.project_type + ids_to_exclude = Task.objects.filter(project_id__exact=project) + items = filter_data_items( + project_type, + list(project.dataset_id.all()), + project.filter_string, + ids_to_exclude, + ) + if not items: + ret_dict = {"message": "No items to pull into the dataset."} + ret_status = status.HTTP_404_NOT_FOUND else: - # Get serializer with the project user data - try: - serializer = ProjectUsersSerializer(project, many=False) - except User.DoesNotExist: - ret_dict = {"message": "User does not exist!"} - ret_status = status.HTTP_404_NOT_FOUND - # Get project instance and check how many items to pull - project_type = project.project_type - ids_to_exclude = Task.objects.filter(project_id__exact=project) - items = filter_data_items( - project_type, - list(project.dataset_id.all()), - project.filter_string, - ids_to_exclude, - ) - if items: - # Pull new data items in to the project asynchronously - add_new_data_items_into_project.delay(project_id=pk, items=items) - ret_dict = {"message": "Adding new tasks to the project."} - ret_status = status.HTTP_200_OK + if project.sampling_mode == BATCH: + try: + batch_size = project.sampling_parameters_json["batch_size"] + batch_number = project.sampling_parameters_json["batch_number"] + except Exception as e: + raise Exception("Sampling parameters are not present") + if not isinstance(batch_number, list): + batch_number = [batch_number] + sampled_items = [] + for batch_num in batch_number: + sampled_items += items[ + batch_size * (batch_num - 1): batch_size * batch_num + ] else: - ret_dict = {"message": "No items to pull into the dataset."} - ret_status = status.HTTP_404_NOT_FOUND + sampled_items = items + # Pull new data items in to the project asynchronously + add_new_data_items_into_project.delay( + project_id=pk, items=sampled_items + ) + ret_dict = {"message": "Adding new tasks to the project."} + ret_status = status.HTTP_200_OK except Project.DoesNotExist: ret_dict = {"message": "Project does not exist!"} ret_status = status.HTTP_404_NOT_FOUND From 5c0bd0255a88a768316526e52705c3a3b1c6a984 Mon Sep 17 00:00:00 2001 From: Kunal Tiwary Date: Tue, 1 Aug 2023 14:19:37 +0530 Subject: [PATCH 3/3] black formatted --- backend/projects/views.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/projects/views.py b/backend/projects/views.py index 4272ab506..31d1bea7f 100644 --- a/backend/projects/views.py +++ b/backend/projects/views.py @@ -3600,7 +3600,7 @@ def pull_new_items(self, request, pk=None, *args, **kwargs): sampled_items = [] for batch_num in batch_number: sampled_items += items[ - batch_size * (batch_num - 1): batch_size * batch_num + batch_size * (batch_num - 1) : batch_size * batch_num ] else: sampled_items = items