Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(Proofs): on update, also update proof prices #538

Merged
merged 2 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 28 additions & 24 deletions open_prices/prices/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class Price(models.Model):
"location_osm_type",
]
DUPLICATE_PROOF_FIELDS = [
# "location_id",
"location_osm_id",
"location_osm_type",
"date",
Expand Down Expand Up @@ -325,6 +326,7 @@ def clean(self, *args, **kwargs):
# - allow passing a location_id
# - location_osm_id should be set if location_osm_type is set
# - location_osm_type should be set if location_osm_id is set
# - some location fields should match the price fields (on create)
if self.location_id:
location = None
from open_prices.locations.models import Location
Expand Down Expand Up @@ -353,18 +355,19 @@ def clean(self, *args, **kwargs):
"Can only be set if location type is OSM",
)
elif location.type == location_constants.TYPE_OSM:
for LOCATION_FIELD in Price.DUPLICATE_LOCATION_FIELDS:
location_field_value = getattr(
self.location, LOCATION_FIELD.replace("location_", "")
)
if location_field_value:
price_field_value = getattr(self, LOCATION_FIELD)
if str(location_field_value) != str(price_field_value):
validation_errors = utils.add_validation_error(
validation_errors,
"location",
f"Location {LOCATION_FIELD} ({location_field_value}) does not match the price {LOCATION_FIELD} ({price_field_value})",
)
if not self.id: # skip these checks on update
for LOCATION_FIELD in Price.DUPLICATE_LOCATION_FIELDS:
location_field_value = getattr(
self.location, LOCATION_FIELD.replace("location_", "")
)
if location_field_value:
price_field_value = getattr(self, LOCATION_FIELD)
if str(location_field_value) != str(price_field_value):
validation_errors = utils.add_validation_error(
validation_errors,
"location",
f"Location {LOCATION_FIELD} ({location_field_value}) does not match the price {LOCATION_FIELD} ({price_field_value})",
)

else:
if self.location_osm_id:
Expand All @@ -389,7 +392,7 @@ def clean(self, *args, **kwargs):
)
# proof rules
# - proof must exist and belong to the price owner
# - some proof fields should be the same as the price fields
# - some proof fields should match the price fields (on create)
# - receipt_quantity can only be set for receipts (default to 1)
if self.proof_id:
proof = None
Expand All @@ -411,17 +414,18 @@ def clean(self, *args, **kwargs):
"proof",
"Proof does not belong to the current user",
)
if proof.type in proof_constants.TYPE_SINGLE_SHOP_LIST:
for PROOF_FIELD in Price.DUPLICATE_PROOF_FIELDS:
proof_field_value = getattr(proof, PROOF_FIELD)
if proof_field_value:
price_field_value = getattr(self, PROOF_FIELD)
if str(proof_field_value) != str(price_field_value):
validation_errors = utils.add_validation_error(
validation_errors,
"proof",
f"Proof {PROOF_FIELD} ({proof_field_value}) does not match the price {PROOF_FIELD} ({price_field_value})",
)
if not self.id: # skip these checks on update
if proof.type in proof_constants.TYPE_SINGLE_SHOP_LIST:
for PROOF_FIELD in Price.DUPLICATE_PROOF_FIELDS:
proof_field_value = getattr(proof, PROOF_FIELD)
if proof_field_value:
price_field_value = getattr(self, PROOF_FIELD)
if str(proof_field_value) != str(price_field_value):
validation_errors = utils.add_validation_error(
validation_errors,
"proof",
f"Proof {PROOF_FIELD} ({proof_field_value}) does not match the price {PROOF_FIELD} ({price_field_value})",
)
if proof.type in proof_constants.TYPE_SHOPPING_SESSION_LIST:
if not self.receipt_quantity:
self.receipt_quantity = 1
Expand Down
48 changes: 30 additions & 18 deletions open_prices/proofs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def with_stats(self):
class Proof(models.Model):
FILE_FIELDS = ["file_path", "mimetype", "image_thumb_path"]
UPDATE_FIELDS = [
# "location_osm_id",
# "location_osm_type",
"type",
"currency",
"date",
Expand Down Expand Up @@ -133,6 +135,7 @@ def clean(self, *args, **kwargs):
# - allow passing a location_id
# - location_osm_id should be set if location_osm_type is set
# - location_osm_type should be set if location_osm_id is set
# - some location fields should match the proof fields (on create)
if self.location_id:
location = None
from open_prices.locations.models import Location
Expand Down Expand Up @@ -161,18 +164,19 @@ def clean(self, *args, **kwargs):
"Can only be set if location type is OSM",
)
elif location.type == location_constants.TYPE_OSM:
for LOCATION_FIELD in Proof.DUPLICATE_LOCATION_FIELDS:
location_field_value = getattr(
self.location, LOCATION_FIELD.replace("location_", "")
)
if location_field_value:
proof_field_value = getattr(self, LOCATION_FIELD)
if str(location_field_value) != str(proof_field_value):
validation_errors = utils.add_validation_error(
validation_errors,
"location",
f"Location {LOCATION_FIELD} ({location_field_value}) does not match the proof {LOCATION_FIELD} ({proof_field_value})",
)
if not self.id: # skip these checks on update
for LOCATION_FIELD in Proof.DUPLICATE_LOCATION_FIELDS:
location_field_value = getattr(
self.location, LOCATION_FIELD.replace("location_", "")
)
if location_field_value:
proof_field_value = getattr(self, LOCATION_FIELD)
if str(location_field_value) != str(proof_field_value):
validation_errors = utils.add_validation_error(
validation_errors,
"location",
f"Location {LOCATION_FIELD} ({location_field_value}) does not match the proof {LOCATION_FIELD} ({proof_field_value})",
)
else:
if self.location_osm_id:
if not self.location_osm_type:
Expand Down Expand Up @@ -259,12 +263,8 @@ def update_location(self, location_osm_id, location_osm_type):
self.save()
self.refresh_from_db()
new_location = self.location
# update proof's prices location
for price in self.prices.all():
price.location = new_location
price.location_osm_id = new_location.osm_id
price.location_osm_type = new_location.osm_type
price.save()
# update proof's prices location?
# # done in post_save signal
# update old & new location price counts
if old_location:
old_location.update_price_count()
Expand Down Expand Up @@ -301,6 +301,18 @@ def set_missing_fields_from_prices(self):
self.save()


@receiver(signals.post_save, sender=Proof)
def proof_post_save_update_prices(sender, instance, created, **kwargs):
if not created:
if instance.is_type_single_shop and instance.prices.exists():
from open_prices.prices.models import Price

for price in instance.prices.all():
for field in Price.DUPLICATE_PROOF_FIELDS:
setattr(price, field, getattr(instance, field))
price.save()


@receiver(signals.post_delete, sender=Proof)
def proof_post_delete_remove_images(sender, instance, **kwargs):
import os
Expand Down
79 changes: 51 additions & 28 deletions open_prices/proofs/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
"osm_type": location_constants.OSM_TYPE_NODE,
"osm_name": "Monoprix",
}
# LOCATION_OSM_NODE_6509705997 = {
# "osm_id": 6509705997,
# "osm_type": location_constants.OSM_TYPE_NODE,
# "osm_name": "Carrefour",
# }
LOCATION_OSM_NODE_6509705997 = {
"osm_id": 6509705997,
"osm_type": location_constants.OSM_TYPE_NODE,
"osm_name": "Carrefour",
}


class ProofModelSaveTest(TestCase):
Expand Down Expand Up @@ -168,11 +168,12 @@ def test_with_stats(self):
class ProofPropertyTest(TestCase):
@classmethod
def setUpTestData(cls):
cls.location = LocationFactory(**LOCATION_OSM_NODE_652825274)
cls.location_osm_1 = LocationFactory(**LOCATION_OSM_NODE_652825274)
cls.location_osm_2 = LocationFactory(**LOCATION_OSM_NODE_6509705997)
cls.proof_price_tag = ProofFactory(
type=proof_constants.TYPE_PRICE_TAG,
location_osm_id=cls.location.osm_id,
location_osm_type=cls.location.osm_type,
location_osm_id=cls.location_osm_1.osm_id,
location_osm_type=cls.location_osm_1.osm_type,
)
PriceFactory(
proof_id=cls.proof_price_tag.id,
Expand All @@ -189,8 +190,8 @@ def setUpTestData(cls):
cls.proof_receipt = ProofFactory(type=proof_constants.TYPE_RECEIPT)
PriceFactory(
proof_id=cls.proof_receipt.id,
location_osm_id=cls.location.osm_id,
location_osm_type=cls.location.osm_type,
location_osm_id=cls.location_osm_1.osm_id,
location_osm_type=cls.location_osm_1.osm_type,
price=2.0,
currency="EUR",
date="2024-06-30",
Expand All @@ -212,32 +213,37 @@ def test_update_price_count(self):
def test_update_location(self):
# existing
self.proof_price_tag.refresh_from_db()
self.location.refresh_from_db()
self.location_osm_1.refresh_from_db()
self.assertEqual(self.proof_price_tag.price_count, 2)
self.assertEqual(self.proof_price_tag.location.id, self.location.id)
self.assertEqual(self.location.price_count, 2 + 1)
self.assertEqual(self.proof_price_tag.location.id, self.location_osm_1.id)
self.assertEqual(self.location_osm_1.price_count, 2 + 1)
# update location
self.proof_price_tag.update_location(
location_osm_id=6509705997,
location_osm_type=location_constants.OSM_TYPE_NODE,
location_osm_id=self.location_osm_2.osm_id,
location_osm_type=self.location_osm_2.osm_type,
)
# check changes
self.proof_price_tag.refresh_from_db()
self.location.refresh_from_db()
new_location = self.proof_price_tag.location
self.assertNotEqual(self.location, new_location)
self.assertEqual(self.proof_price_tag.price_count, 2)
self.assertEqual(new_location.price_count, 2)
self.assertEqual(self.location.price_count, 3 - 2)
self.location_osm_1.refresh_from_db()
self.location_osm_2.refresh_from_db()
self.assertEqual(self.proof_price_tag.location, self.location_osm_2)
self.assertEqual(self.proof_price_tag.price_count, 2) # same
self.assertEqual(self.proof_price_tag.location.price_count, 2)
self.assertEqual(self.location_osm_1.price_count, 3 - 2)
self.assertEqual(self.location_osm_2.price_count, 2)
# update again, same location
self.proof_price_tag.update_location(
location_osm_id=6509705997,
location_osm_type=location_constants.OSM_TYPE_NODE,
location_osm_id=self.location_osm_2.osm_id,
location_osm_type=self.location_osm_2.osm_type,
)
self.proof_price_tag.refresh_from_db()
self.location.refresh_from_db()
self.location_osm_1.refresh_from_db()
self.location_osm_2.refresh_from_db()
self.assertEqual(self.proof_price_tag.location, self.location_osm_2)
self.assertEqual(self.proof_price_tag.price_count, 2)
self.assertEqual(self.proof_price_tag.location.price_count, 2)
self.assertEqual(self.location_osm_1.price_count, 1)
self.assertEqual(self.location_osm_2.price_count, 2)

def test_set_missing_fields_from_prices(self):
self.proof_receipt.refresh_from_db()
Expand All @@ -246,7 +252,7 @@ def test_set_missing_fields_from_prices(self):
self.assertTrue(self.proof_receipt.currency is None)
self.assertEqual(self.proof_receipt.price_count, 1)
self.proof_receipt.set_missing_fields_from_prices()
self.assertEqual(self.proof_receipt.location, self.location)
self.assertEqual(self.proof_receipt.location, self.location_osm_1)
self.assertEqual(
self.proof_receipt.date, self.proof_receipt.prices.first().date
)
Expand All @@ -258,11 +264,12 @@ def test_set_missing_fields_from_prices(self):
class ProofModelUpdateTest(TestCase):
@classmethod
def setUpTestData(cls):
cls.location = LocationFactory(**LOCATION_OSM_NODE_652825274)
cls.location_osm_1 = LocationFactory(**LOCATION_OSM_NODE_652825274)
cls.location_osm_2 = LocationFactory(**LOCATION_OSM_NODE_6509705997)
cls.proof_price_tag = ProofFactory(
type=proof_constants.TYPE_PRICE_TAG,
location_osm_id=cls.location.osm_id,
location_osm_type=cls.location.osm_type,
location_osm_id=cls.location_osm_1.osm_id,
location_osm_type=cls.location_osm_1.osm_type,
currency="EUR",
date="2024-06-30",
)
Expand All @@ -276,5 +283,21 @@ def setUpTestData(cls):
)

def test_proof_update(self):
# currency
self.assertEqual(self.proof_price_tag.prices.count(), 1)
self.proof_price_tag.currency = "USD"
self.proof_price_tag.save()
self.assertEqual(self.proof_price_tag.prices.first().currency, "USD")
# date
self.proof_price_tag.date = "2024-07-01"
self.proof_price_tag.save()
self.assertEqual(str(self.proof_price_tag.prices.first().date), "2024-07-01")
# location
self.proof_price_tag.location_osm_id = self.location_osm_2.osm_id
self.proof_price_tag.location_osm_type = self.location_osm_2.osm_type
self.proof_price_tag.save()
self.proof_price_tag.refresh_from_db()
self.assertEqual(self.proof_price_tag.location, self.location_osm_2)
self.assertEqual(
self.proof_price_tag.prices.first().location, self.location_osm_2
)