From 99f16dfa1e6d4573d8b7c9243384fdcffad3c194 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Odini?= Date: Sat, 19 Oct 2024 01:36:37 +0200 Subject: [PATCH] feat(Prices): Allow adding new prices with a location_id (#527) --- open_prices/api/prices/serializers.py | 3 + open_prices/api/prices/tests.py | 55 ++++++++- open_prices/prices/models.py | 98 ++++++++++++---- open_prices/prices/tests.py | 161 +++++++++++++++++--------- open_prices/proofs/models.py | 15 +-- 5 files changed, 243 insertions(+), 89 deletions(-) diff --git a/open_prices/api/prices/serializers.py b/open_prices/api/prices/serializers.py index 6ee0ff88..970bdee6 100644 --- a/open_prices/api/prices/serializers.py +++ b/open_prices/api/prices/serializers.py @@ -37,6 +37,9 @@ class Meta: class PriceCreateSerializer(serializers.ModelSerializer): + location_id = serializers.PrimaryKeyRelatedField( + queryset=Location.objects.all(), source="location", required=False + ) proof_id = serializers.PrimaryKeyRelatedField( queryset=Proof.objects.all(), source="proof" ) diff --git a/open_prices/api/prices/tests.py b/open_prices/api/prices/tests.py index aab5aa2f..99c946a5 100644 --- a/open_prices/api/prices/tests.py +++ b/open_prices/api/prices/tests.py @@ -3,6 +3,8 @@ from django.test import TestCase from django.urls import reverse +from open_prices.locations import constants as location_constants +from open_prices.locations.factories import LocationFactory from open_prices.locations.models import Location from open_prices.prices import constants as price_constants from open_prices.prices.factories import PriceFactory @@ -29,6 +31,15 @@ "date": "2023-08-30", } +LOCATION_OSM_NODE_652825274 = { + "type": location_constants.TYPE_OSM, + "osm_id": 652825274, + "osm_type": "NODE", + "osm_name": "Monoprix", + "osm_lat": "45.1805534", + "osm_lon": "5.7153387", +} + class PriceListApiTest(TestCase): @classmethod @@ -318,7 +329,9 @@ def setUpTestData(cls): "source": "test", } - def test_price_create(self): + def test_price_create_without_proof(self): + data = self.data.copy() + del data["proof_id"] # anonymous response = self.client.post( self.url, self.data, content_type="application/json" @@ -342,6 +355,8 @@ def test_price_create(self): content_type="application/json", ) self.assertEqual(response.status_code, 400) + + def test_price_create_with_proof(self): # empty proof response = self.client.post( self.url, @@ -382,16 +397,19 @@ def test_price_create(self): self.assertTrue("source" not in response.data) self.assertEqual(response.data["owner"], self.user_session.user.user_id) # with proof, product & location + self.assertTrue("proof_id" in response.data) self.assertEqual(response.data["proof"]["id"], self.user_proof.id) self.assertEqual( response.data["proof"]["price_count"], 0 ) # not yet incremented self.assertEqual(Proof.objects.get(id=self.user_proof.id).price_count, 1) + self.assertTrue("product_id" in response.data) self.assertEqual(response.data["product"]["code"], "8001505005707") self.assertEqual( response.data["product"]["price_count"], 0 ) # not yet incremented self.assertEqual(Product.objects.get(code="8001505005707").price_count, 1) + self.assertTrue("location_id" in response.data) self.assertEqual(response.data["location"]["osm_id"], 652825274) self.assertEqual( response.data["location"]["price_count"], 0 @@ -402,6 +420,41 @@ def test_price_create(self): p = Price.objects.last() self.assertEqual(p.source, "API") # default value + def test_price_create_with_location_id(self): + location_osm = LocationFactory(**LOCATION_OSM_NODE_652825274) + location_online = LocationFactory(type=location_constants.TYPE_ONLINE) + # with location_id, location_osm_id & location_osm_type: OK + response = self.client.post( + self.url, + {**self.data, "location_id": location_osm.id}, + headers={"Authorization": f"Bearer {self.user_session.token}"}, + content_type="application/json", + ) + self.assertEqual(response.status_code, 201) + self.assertEqual(response.data["location"]["id"], location_osm.id) + # with just location_id (OSM): NOK + data = self.data.copy() + del data["location_osm_id"] + del data["location_osm_type"] + response = self.client.post( + self.url, + {**data, "location_id": location_osm.id}, + headers={"Authorization": f"Bearer {self.user_session.token}"}, + content_type="application/json", + ) + self.assertEqual(response.status_code, 400) + # with just location_id (ONLINE): OK + data = self.data.copy() + del data["location_osm_id"] + del data["location_osm_type"] + response = self.client.post( + self.url, + {**data, "location_id": location_online.id}, + headers={"Authorization": f"Bearer {self.user_session.token}"}, + content_type="application/json", + ) + self.assertEqual(response.status_code, 201) + def test_price_create_with_app_name(self): for app_name in ["", "test app"]: response = self.client.post( diff --git a/open_prices/prices/models.py b/open_prices/prices/models.py index 94ed1cc6..7a4bff3c 100644 --- a/open_prices/prices/models.py +++ b/open_prices/prices/models.py @@ -66,8 +66,19 @@ class Price(models.Model): "origins_tags", "location_osm_id", "location_osm_type", + "location_id", # extra field (optional) "proof_id", # extra field ] + DUPLICATE_LOCATION_FIELDS = [ + "location_osm_id", + "location_osm_type", + ] + DUPLICATE_PROOF_FIELDS = [ + "location_osm_id", + "location_osm_type", + "date", + "currency", + ] # "owner" product_code = models.CharField(blank=True, null=True) product_name = models.CharField(blank=True, null=True) @@ -311,33 +322,77 @@ def clean(self, *args, **kwargs): "Should not be in the future", ) # location rules + # - allow passing a location_id # - location_osm_id should be set if location_osm_type is set # - location_osm_type should be set if location_osm_id is set - if self.location_osm_id: - if not self.location_osm_type: - validation_errors = utils.add_validation_error( - validation_errors, - "location_osm_type", - "Should be set if `location_osm_id` is filled", - ) - if self.location_osm_type: - if not self.location_osm_id: - validation_errors = utils.add_validation_error( - validation_errors, - "location_osm_id", - "Should be set if `location_osm_type` is filled", - ) - elif self.location_osm_id in [True, "true", "false", "none", "null"]: + if self.location_id: + location = None + from open_prices.locations.models import Location + + try: + location = Location.objects.get(id=self.location_id) + except Location.DoesNotExist: validation_errors = utils.add_validation_error( validation_errors, - "location_osm_id", - "Should not be a boolean or an invalid string", + "location", + "Location not found", ) + + if location: + if location.type == location_constants.TYPE_ONLINE: + if self.location_osm_id: + validation_errors = utils.add_validation_error( + validation_errors, + "location_osm_id", + "Should not be set if `location_id` is filled", + ) + if self.location_osm_type: + validation_errors = utils.add_validation_error( + validation_errors, + "location_osm_type", + "Should not be set if `location_id` is filled", + ) + elif location.type == location_constants.TYPE_OSM: + for LOCATION_FIELD in Price.DUPLICATE_LOCATION_FIELDS: + location_field_value = getattr( + self.location, LOCATION_FIELD.replace("location_", "") + ) + if location_field_value: + price_field_value = getattr(self, LOCATION_FIELD) + if str(location_field_value) != str(price_field_value): + validation_errors = utils.add_validation_error( + validation_errors, + "location", + f"Location {LOCATION_FIELD} ({location_field_value}) does not match the price {LOCATION_FIELD} ({price_field_value})", + ) + + else: + if self.location_osm_id: + if not self.location_osm_type: + validation_errors = utils.add_validation_error( + validation_errors, + "location_osm_type", + "Should be set if `location_osm_id` is filled", + ) + if self.location_osm_type: + if not self.location_osm_id: + validation_errors = utils.add_validation_error( + validation_errors, + "location_osm_id", + "Should be set if `location_osm_type` is filled", + ) + elif self.location_osm_id in [True, "true", "false", "none", "null"]: + validation_errors = utils.add_validation_error( + validation_errors, + "location_osm_id", + "Should not be a boolean or an invalid string", + ) # proof rules # - proof must exist and belong to the price owner # - some proof fields should be the same as the price fields # - receipt_quantity can only be set for receipts (default to 1) if self.proof_id: + proof = None from open_prices.proofs.models import Proof try: @@ -356,12 +411,9 @@ def clean(self, *args, **kwargs): "proof", "Proof does not belong to the current user", ) - if proof.type in [ - proof_constants.TYPE_RECEIPT, - proof_constants.TYPE_PRICE_TAG, - ]: - for PROOF_FIELD in Proof.DUPLICATE_PRICE_FIELDS: - proof_field_value = getattr(self.proof, PROOF_FIELD) + if proof.type in proof_constants.TYPE_SINGLE_SHOP_LIST: + for PROOF_FIELD in Price.DUPLICATE_PROOF_FIELDS: + proof_field_value = getattr(proof, PROOF_FIELD) if proof_field_value: price_field_value = getattr(self, PROOF_FIELD) if str(proof_field_value) != str(price_field_value): diff --git a/open_prices/prices/tests.py b/open_prices/prices/tests.py index 0a52f271..421b6b4f 100644 --- a/open_prices/prices/tests.py +++ b/open_prices/prices/tests.py @@ -245,6 +245,8 @@ def test_price_date_validation(self): self.assertRaises(ValidationError, PriceFactory, date=DATE_NOT_OK) def test_price_location_validation(self): + location_osm = LocationFactory() + location_online = LocationFactory(type=location_constants.TYPE_ONLINE) # both location_osm_id & location_osm_type not set PriceFactory(location_osm_id=None, location_osm_type=None) # location_osm_id @@ -272,84 +274,116 @@ def test_price_location_validation(self): location_osm_id=652825274, location_osm_type=LOCATION_OSM_TYPE_NOT_OK, ) - # location online - location_online = LocationFactory(type=location_constants.TYPE_ONLINE) - PriceFactory(location_id=location_online.id) + # location unknown + self.assertRaises( + ValidationError, + PriceFactory, + location_id=999, + location_osm_id=None, + location_osm_type=None, + ) + # cannot mix location_id & location_osm_id/type + self.assertRaises( + ValidationError, + PriceFactory, + location_id=location_osm.id, + location_osm_id=None, # needed + location_osm_type=None, # needed + ) + self.assertRaises( + ValidationError, + PriceFactory, + location_id=location_online.id, + location_osm_id=LOCATION_OSM_ID_OK, # should be None + ) + # location_id ok + PriceFactory( + location_id=location_osm.id, + location_osm_id=location_osm.osm_id, + location_osm_type=location_osm.osm_type, + ) + PriceFactory( + location_id=location_online.id, location_osm_id=None, location_osm_type=None + ) def test_price_proof_validation(self): - self.user_session = SessionFactory() - self.user_proof_receipt = ProofFactory( + user_session = SessionFactory() + user_proof_receipt = ProofFactory( type=proof_constants.TYPE_RECEIPT, location_osm_id=652825274, location_osm_type=location_constants.OSM_TYPE_NODE, date="2024-06-30", currency="EUR", - owner=self.user_session.user.user_id, + owner=user_session.user.user_id, ) - self.proof_2 = ProofFactory() + proof_2 = ProofFactory() # proof not set - PriceFactory(proof=None, owner=self.user_proof_receipt.owner) + PriceFactory(proof_id=None, owner=user_proof_receipt.owner) + # proof unknown + self.assertRaises( + ValidationError, PriceFactory, proof_id=999, owner=user_proof_receipt.owner + ) # same price & proof fields PriceFactory( - proof=self.user_proof_receipt, - location_osm_id=self.user_proof_receipt.location_osm_id, - location_osm_type=self.user_proof_receipt.location_osm_type, - date=self.user_proof_receipt.date, - currency=self.user_proof_receipt.currency, - owner=self.user_proof_receipt.owner, + proof_id=user_proof_receipt.id, + location_osm_id=user_proof_receipt.location_osm_id, + location_osm_type=user_proof_receipt.location_osm_type, + date=user_proof_receipt.date, + currency=user_proof_receipt.currency, + owner=user_proof_receipt.owner, ) # different price & proof owner self.assertRaises( ValidationError, PriceFactory, - proof=self.proof_2, # different - location_osm_id=self.user_proof_receipt.location_osm_id, - location_osm_type=self.user_proof_receipt.location_osm_type, - date=self.user_proof_receipt.date, - currency=self.user_proof_receipt.currency, - owner=self.user_proof_receipt.owner, + proof_id=proof_2.id, # different + location_osm_id=user_proof_receipt.location_osm_id, + location_osm_type=user_proof_receipt.location_osm_type, + date=user_proof_receipt.date, + currency=user_proof_receipt.currency, + owner=user_proof_receipt.owner, ) # proof location_osm_id & location_osm_type self.assertRaises( ValidationError, PriceFactory, - proof=self.user_proof_receipt, + proof_id=user_proof_receipt.id, location_osm_id=5, # different location_osm_id - location_osm_type=self.user_proof_receipt.location_osm_type, - date=self.user_proof_receipt.date, - currency=self.user_proof_receipt.currency, - owner=self.user_proof_receipt.owner, + location_osm_type=user_proof_receipt.location_osm_type, + date=user_proof_receipt.date, + currency=user_proof_receipt.currency, + owner=user_proof_receipt.owner, ) self.assertRaises( ValidationError, PriceFactory, - proof=self.user_proof_receipt, - location_osm_id=self.user_proof_receipt.location_osm_id, + proof_id=user_proof_receipt.id, + location_osm_id=user_proof_receipt.location_osm_id, location_osm_type="WAY", # different location_osm_type - date=self.user_proof_receipt.date, - currency=self.user_proof_receipt.currency, - owner=self.user_proof_receipt.owner, + date=user_proof_receipt.date, + currency=user_proof_receipt.currency, + owner=user_proof_receipt.owner, ) # proof date & currency self.assertRaises( ValidationError, PriceFactory, - proof=self.user_proof_receipt, - location_osm_id=self.user_proof_receipt.location_osm_id, - location_osm_type=self.user_proof_receipt.location_osm_type, + proof_id=user_proof_receipt.id, + location_osm_id=user_proof_receipt.location_osm_id, + location_osm_type=user_proof_receipt.location_osm_type, date="2024-07-01", # different date - currency=self.user_proof_receipt.currency, - owner=self.user_proof_receipt.owner, + currency=user_proof_receipt.currency, + owner=user_proof_receipt.owner, ) self.assertRaises( ValidationError, PriceFactory, - proof=self.user_proof_receipt, - location_osm_id=self.user_proof_receipt.location_osm_id, - location_osm_type=self.user_proof_receipt.location_osm_type, - date=self.user_proof_receipt.date, + proof_id=user_proof_receipt.id, + location_osm_id=user_proof_receipt.location_osm_id, + location_osm_type=user_proof_receipt.location_osm_type, + date=user_proof_receipt.date, currency="USD", # different currency - owner=self.user_proof_receipt.owner, + owner=user_proof_receipt.owner, ) # receipt_quantity for RECEIPT_QUANTITY_NOT_OK in [-5, 0]: @@ -357,23 +391,23 @@ def test_price_proof_validation(self): self.assertRaises( ValidationError, PriceFactory, - proof=self.user_proof_receipt, - location_osm_id=self.user_proof_receipt.location_osm_id, - location_osm_type=self.user_proof_receipt.location_osm_type, - date=self.user_proof_receipt.date, - currency=self.user_proof_receipt.currency, - owner=self.user_proof_receipt.owner, + proof_id=user_proof_receipt.id, + location_osm_id=user_proof_receipt.location_osm_id, + location_osm_type=user_proof_receipt.location_osm_type, + date=user_proof_receipt.date, + currency=user_proof_receipt.currency, + owner=user_proof_receipt.owner, receipt_quantity=RECEIPT_QUANTITY_NOT_OK, ) for RECEIPT_QUANTITY_OK in [None, 1, 2]: with self.subTest(RECEIPT_QUANTITY_OK=RECEIPT_QUANTITY_OK): PriceFactory( - proof=self.user_proof_receipt, - location_osm_id=self.user_proof_receipt.location_osm_id, - location_osm_type=self.user_proof_receipt.location_osm_type, - date=self.user_proof_receipt.date, - currency=self.user_proof_receipt.currency, - owner=self.user_proof_receipt.owner, + proof_id=user_proof_receipt.id, + location_osm_id=user_proof_receipt.location_osm_id, + location_osm_type=user_proof_receipt.location_osm_type, + date=user_proof_receipt.date, + currency=user_proof_receipt.currency, + owner=user_proof_receipt.owner, receipt_quantity=RECEIPT_QUANTITY_OK, ) @@ -384,7 +418,7 @@ def test_price_count_increment(self): location = LocationFactory() product = ProductFactory() PriceFactory( - proof=user_proof_1, + proof_id=user_proof_1.id, location_osm_id=location.osm_id, location_osm_type=location.osm_type, product_code=product.code, @@ -397,7 +431,7 @@ def test_price_count_increment(self): self.assertEqual(Location.objects.get(id=location.id).price_count, 1) self.assertEqual(Product.objects.get(id=product.id).price_count, 1) PriceFactory( - proof=user_proof_2, + proof_id=user_proof_2.id, location_osm_id=location.osm_id, location_osm_type=location.osm_type, product_code=product.code, @@ -411,6 +445,23 @@ def test_price_count_increment(self): self.assertEqual(Product.objects.get(id=product.id).price_count, 2) +class PriceModelUpdateTest(TestCase): + def test_price_update(self): + user_session = SessionFactory() + user_proof = ProofFactory(owner=user_session.user.user_id) + location = LocationFactory() + product = ProductFactory() + price = PriceFactory( + proof_id=user_proof.id, + location_osm_id=location.osm_id, + location_osm_type=location.osm_type, + product_code=product.code, + owner=user_session.user.user_id, + ) + price.price = 5 + price.save() + + class PriceModelDeleteTest(TestCase): def test_price_count_decrement(self): user_session = SessionFactory() @@ -418,7 +469,7 @@ def test_price_count_decrement(self): location = LocationFactory() product = ProductFactory() price = PriceFactory( - proof=user_proof, + proof_id=user_proof.id, location_osm_id=location.osm_id, location_osm_type=location.osm_type, product_code=product.code, diff --git a/open_prices/proofs/models.py b/open_prices/proofs/models.py index c03bf2fe..e502fdcc 100644 --- a/open_prices/proofs/models.py +++ b/open_prices/proofs/models.py @@ -48,12 +48,6 @@ class Proof(models.Model): "receipt_price_total", ] CREATE_FIELDS = UPDATE_FIELDS + ["location_osm_id", "location_osm_type"] - DUPLICATE_PRICE_FIELDS = [ - "location_osm_id", - "location_osm_type", - "date", - "currency", - ] # "owner" FIX_PRICE_FIELDS = ["location", "date", "currency"] file_path = models.CharField(blank=True, null=True) @@ -212,14 +206,15 @@ def update_location(self, location_osm_id, location_osm_type): self.location_osm_id = location_osm_id self.location_osm_type = location_osm_type self.save() + self.refresh_from_db() + new_location = self.location # update proof's prices location for price in self.prices.all(): - price.location_osm_id = location_osm_id - price.location_osm_type = location_osm_type + price.location = self.location + price.location_osm_id = self.location_osm_id + price.location_osm_type = self.location_osm_type price.save() # update old & new location price counts - self.refresh_from_db() - new_location = self.location if old_location: old_location.update_price_count() if new_location: