Skip to content

Commit

Permalink
feat(Prices): Allow adding new prices with a location_id (#527)
Browse files Browse the repository at this point in the history
  • Loading branch information
raphodn authored Oct 18, 2024
1 parent 1822b8d commit 99f16df
Show file tree
Hide file tree
Showing 5 changed files with 243 additions and 89 deletions.
3 changes: 3 additions & 0 deletions open_prices/api/prices/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class Meta:


class PriceCreateSerializer(serializers.ModelSerializer):
location_id = serializers.PrimaryKeyRelatedField(
queryset=Location.objects.all(), source="location", required=False
)
proof_id = serializers.PrimaryKeyRelatedField(
queryset=Proof.objects.all(), source="proof"
)
Expand Down
55 changes: 54 additions & 1 deletion open_prices/api/prices/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from django.test import TestCase
from django.urls import reverse

from open_prices.locations import constants as location_constants
from open_prices.locations.factories import LocationFactory
from open_prices.locations.models import Location
from open_prices.prices import constants as price_constants
from open_prices.prices.factories import PriceFactory
Expand All @@ -29,6 +31,15 @@
"date": "2023-08-30",
}

LOCATION_OSM_NODE_652825274 = {
"type": location_constants.TYPE_OSM,
"osm_id": 652825274,
"osm_type": "NODE",
"osm_name": "Monoprix",
"osm_lat": "45.1805534",
"osm_lon": "5.7153387",
}


class PriceListApiTest(TestCase):
@classmethod
Expand Down Expand Up @@ -318,7 +329,9 @@ def setUpTestData(cls):
"source": "test",
}

def test_price_create(self):
def test_price_create_without_proof(self):
data = self.data.copy()
del data["proof_id"]
# anonymous
response = self.client.post(
self.url, self.data, content_type="application/json"
Expand All @@ -342,6 +355,8 @@ def test_price_create(self):
content_type="application/json",
)
self.assertEqual(response.status_code, 400)

def test_price_create_with_proof(self):
# empty proof
response = self.client.post(
self.url,
Expand Down Expand Up @@ -382,16 +397,19 @@ def test_price_create(self):
self.assertTrue("source" not in response.data)
self.assertEqual(response.data["owner"], self.user_session.user.user_id)
# with proof, product & location
self.assertTrue("proof_id" in response.data)
self.assertEqual(response.data["proof"]["id"], self.user_proof.id)
self.assertEqual(
response.data["proof"]["price_count"], 0
) # not yet incremented
self.assertEqual(Proof.objects.get(id=self.user_proof.id).price_count, 1)
self.assertTrue("product_id" in response.data)
self.assertEqual(response.data["product"]["code"], "8001505005707")
self.assertEqual(
response.data["product"]["price_count"], 0
) # not yet incremented
self.assertEqual(Product.objects.get(code="8001505005707").price_count, 1)
self.assertTrue("location_id" in response.data)
self.assertEqual(response.data["location"]["osm_id"], 652825274)
self.assertEqual(
response.data["location"]["price_count"], 0
Expand All @@ -402,6 +420,41 @@ def test_price_create(self):
p = Price.objects.last()
self.assertEqual(p.source, "API") # default value

def test_price_create_with_location_id(self):
location_osm = LocationFactory(**LOCATION_OSM_NODE_652825274)
location_online = LocationFactory(type=location_constants.TYPE_ONLINE)
# with location_id, location_osm_id & location_osm_type: OK
response = self.client.post(
self.url,
{**self.data, "location_id": location_osm.id},
headers={"Authorization": f"Bearer {self.user_session.token}"},
content_type="application/json",
)
self.assertEqual(response.status_code, 201)
self.assertEqual(response.data["location"]["id"], location_osm.id)
# with just location_id (OSM): NOK
data = self.data.copy()
del data["location_osm_id"]
del data["location_osm_type"]
response = self.client.post(
self.url,
{**data, "location_id": location_osm.id},
headers={"Authorization": f"Bearer {self.user_session.token}"},
content_type="application/json",
)
self.assertEqual(response.status_code, 400)
# with just location_id (ONLINE): OK
data = self.data.copy()
del data["location_osm_id"]
del data["location_osm_type"]
response = self.client.post(
self.url,
{**data, "location_id": location_online.id},
headers={"Authorization": f"Bearer {self.user_session.token}"},
content_type="application/json",
)
self.assertEqual(response.status_code, 201)

def test_price_create_with_app_name(self):
for app_name in ["", "test app"]:
response = self.client.post(
Expand Down
98 changes: 75 additions & 23 deletions open_prices/prices/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,19 @@ class Price(models.Model):
"origins_tags",
"location_osm_id",
"location_osm_type",
"location_id", # extra field (optional)
"proof_id", # extra field
]
DUPLICATE_LOCATION_FIELDS = [
"location_osm_id",
"location_osm_type",
]
DUPLICATE_PROOF_FIELDS = [
"location_osm_id",
"location_osm_type",
"date",
"currency",
] # "owner"

product_code = models.CharField(blank=True, null=True)
product_name = models.CharField(blank=True, null=True)
Expand Down Expand Up @@ -311,33 +322,77 @@ def clean(self, *args, **kwargs):
"Should not be in the future",
)
# location rules
# - allow passing a location_id
# - location_osm_id should be set if location_osm_type is set
# - location_osm_type should be set if location_osm_id is set
if self.location_osm_id:
if not self.location_osm_type:
validation_errors = utils.add_validation_error(
validation_errors,
"location_osm_type",
"Should be set if `location_osm_id` is filled",
)
if self.location_osm_type:
if not self.location_osm_id:
validation_errors = utils.add_validation_error(
validation_errors,
"location_osm_id",
"Should be set if `location_osm_type` is filled",
)
elif self.location_osm_id in [True, "true", "false", "none", "null"]:
if self.location_id:
location = None
from open_prices.locations.models import Location

try:
location = Location.objects.get(id=self.location_id)
except Location.DoesNotExist:
validation_errors = utils.add_validation_error(
validation_errors,
"location_osm_id",
"Should not be a boolean or an invalid string",
"location",
"Location not found",
)

if location:
if location.type == location_constants.TYPE_ONLINE:
if self.location_osm_id:
validation_errors = utils.add_validation_error(
validation_errors,
"location_osm_id",
"Should not be set if `location_id` is filled",
)
if self.location_osm_type:
validation_errors = utils.add_validation_error(
validation_errors,
"location_osm_type",
"Should not be set if `location_id` is filled",
)
elif location.type == location_constants.TYPE_OSM:
for LOCATION_FIELD in Price.DUPLICATE_LOCATION_FIELDS:
location_field_value = getattr(
self.location, LOCATION_FIELD.replace("location_", "")
)
if location_field_value:
price_field_value = getattr(self, LOCATION_FIELD)
if str(location_field_value) != str(price_field_value):
validation_errors = utils.add_validation_error(
validation_errors,
"location",
f"Location {LOCATION_FIELD} ({location_field_value}) does not match the price {LOCATION_FIELD} ({price_field_value})",
)

else:
if self.location_osm_id:
if not self.location_osm_type:
validation_errors = utils.add_validation_error(
validation_errors,
"location_osm_type",
"Should be set if `location_osm_id` is filled",
)
if self.location_osm_type:
if not self.location_osm_id:
validation_errors = utils.add_validation_error(
validation_errors,
"location_osm_id",
"Should be set if `location_osm_type` is filled",
)
elif self.location_osm_id in [True, "true", "false", "none", "null"]:
validation_errors = utils.add_validation_error(
validation_errors,
"location_osm_id",
"Should not be a boolean or an invalid string",
)
# proof rules
# - proof must exist and belong to the price owner
# - some proof fields should be the same as the price fields
# - receipt_quantity can only be set for receipts (default to 1)
if self.proof_id:
proof = None
from open_prices.proofs.models import Proof

try:
Expand All @@ -356,12 +411,9 @@ def clean(self, *args, **kwargs):
"proof",
"Proof does not belong to the current user",
)
if proof.type in [
proof_constants.TYPE_RECEIPT,
proof_constants.TYPE_PRICE_TAG,
]:
for PROOF_FIELD in Proof.DUPLICATE_PRICE_FIELDS:
proof_field_value = getattr(self.proof, PROOF_FIELD)
if proof.type in proof_constants.TYPE_SINGLE_SHOP_LIST:
for PROOF_FIELD in Price.DUPLICATE_PROOF_FIELDS:
proof_field_value = getattr(proof, PROOF_FIELD)
if proof_field_value:
price_field_value = getattr(self, PROOF_FIELD)
if str(proof_field_value) != str(price_field_value):
Expand Down
Loading

0 comments on commit 99f16df

Please sign in to comment.