Skip to content

Commit

Permalink
Implement secret santa with backtracking
Browse files Browse the repository at this point in the history
Prevent flaky failures from invalid combinations.
  • Loading branch information
mariocj89 committed Dec 29, 2024
1 parent cae3a08 commit 48c3560
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 42 deletions.
36 changes: 36 additions & 0 deletions eas/api/secret_santa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import random


def _is_valid_assignment(source, target, exclusions):
if source == target:
return False
if (source, target) in exclusions:
return False
return True


def _backtrack_assignments(participants, exclusions, assignments, targets):
if not participants:
return assignments

source = participants[0]
random.shuffle(targets) # Shuffle targets to ensure randomness
for target in targets:
if _is_valid_assignment(source, target, exclusions):
new_assignments = assignments + [(source, target)]
new_targets = targets[:]
new_targets.remove(target)
result = _backtrack_assignments(
participants[1:], exclusions, new_assignments, new_targets
)
if result:
return result

return None


def resolve_secret_santa(
participants,
exclusions,
):
return _backtrack_assignments(participants, exclusions, [], participants[:])
58 changes: 16 additions & 42 deletions eas/api/views.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import datetime as dt
import logging
import random

import requests.exceptions
from django.http import Http404
Expand All @@ -11,7 +10,7 @@
from rest_framework.exceptions import APIException, ValidationError
from rest_framework.response import Response

from . import amazonsqs, instagram, models, paypal, serializers, tiktok
from . import amazonsqs, instagram, models, paypal, secret_santa, serializers, tiktok

LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -199,25 +198,6 @@ class LinkViewSet(BaseDrawViewSet):
queryset = MODEL.objects.all()


def _ss_find_target(targets, exclusions):
potential_targets = set(targets) - exclusions
if not potential_targets:
return
return random.choice(list(potential_targets))


def _ss_build_results(participants, exclusions_map):
results = []
targets = list(participants)
for source in participants:
target = _ss_find_target(targets, exclusions_map[source])
if not target:
return
targets.remove(target)
results.append((source, target))
return results


class SecretSantaSet(
mixins.CreateModelMixin, mixins.RetrieveModelMixin, viewsets.GenericViewSet
):
Expand All @@ -226,28 +206,22 @@ def create(self, request, *args, **kwargs):
serializer = serializers.SecretSantaSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
data = serializer.validated_data
emails_map = {
p["name"]: p["email"] for p in data["participants"] if p.get("email")
}
phones_map = {
p["name"]: p["phone_number"]
for p in data["participants"]
if p.get("phone_number")
}
exclusions_map = {
p["name"]: set(p.get("exclusions") or []) for p in data["participants"]
}
LOG.info("Using exclusion map: %r", exclusions_map)
for participant, exclusions in exclusions_map.items():
exclusions.add(participant)
participants = {p["name"] for p in data["participants"]}
for _ in range(min(50, len(participants))):
results = _ss_build_results(participants, exclusions_map)
if results is not None:
break
else:
exclusions = []
participants = []
phones_map = {}
emails_map = {}
for p in data["participants"]:
participant = p["name"]
participants.append(participant)
if p.get("phone_number"):
phones_map[participant] = p["phone_number"]
if p.get("email"):
emails_map[participant] = p["email"]
for e in p.get("exclusions", []):
exclusions.append((participant, e))
results = secret_santa.resolve_secret_santa(participants, exclusions)
if not results:
raise ValidationError("Unable to match participants")
LOG.info("Sending %s secret santa emails", len(results))
draw = models.SecretSanta()
draw.save()
emails = []
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[tool.isort]
line_length = 88

0 comments on commit 48c3560

Please sign in to comment.