Skip to content

Commit

Permalink
refactor: [ACI-978] a little more enhancements
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrii committed May 17, 2024
1 parent 20bd0f0 commit 69bcad2
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 66 deletions.
99 changes: 35 additions & 64 deletions credentials/apps/badges/admin_forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from credentials.apps.badges.credly.api_client import CredlyAPIClient
from credentials.apps.badges.credly.exceptions import CredlyAPIError
from credentials.apps.badges.models import AbstractDataRule, BadgePenalty, BadgeRequirement, CredlyOrganization, DataRule, PenaltyDataRule
from credentials.apps.badges.utils import get_event_type_keypaths
from credentials.apps.badges.utils import get_event_type_keypaths, get_event_data_attr_type_by_keypath


class CredlyOrganizationAdminForm(forms.ModelForm):
Expand Down Expand Up @@ -84,29 +84,7 @@ def clean(self):
return cleaned_data


class DataRuleBoolValidationMixin:
"""
Mixin for DataRule form to validate boolean fields.
"""

def clean(self):
"""
Validate boolean fields.
"""

cleaned_data = super().clean()

last_key = cleaned_data.get("data_path").split(".")[-1]
if "is_" in last_key and cleaned_data.get("value") not in AbstractDataRule.BOOL_VALUES:
raise forms.ValidationError(_("Value must be a boolean."))

return cleaned_data


class DataRuleFormSet(forms.BaseInlineFormSet):
"""
Formset for DataRule model.
"""
class ParentMixin:
def get_form_kwargs(self, index):
"""
Pass parent instance to the form.
Expand All @@ -117,15 +95,10 @@ def get_form_kwargs(self, index):
return kwargs


class DataRuleForm(DataRuleBoolValidationMixin, forms.ModelForm):
class DataRuleExtensionsMixin:
"""
Form for DataRule model.
Mixin for DataRule form to extend logic.
"""
class Meta:
model = DataRule
fields = "__all__"

data_path = forms.ChoiceField()

def __init__(self, *args, parent_instance=None, **kwargs):
"""
Expand All @@ -138,14 +111,36 @@ def __init__(self, *args, parent_instance=None, **kwargs):
event_type = self.parent_instance.event_type
self.fields["data_path"].choices = Choices(*get_event_type_keypaths(event_type=event_type))

def clean(self):
"""
Validate boolean fields.
"""

class BadgeRequirementFormSet(forms.BaseInlineFormSet):
def get_form_kwargs(self, index):
kwargs = super().get_form_kwargs(index)
kwargs["parent_instance"] = self.instance
return kwargs
cleaned_data = super().clean()

data_path_type = get_event_data_attr_type_by_keypath(
self.parent_instance.event_type, cleaned_data.get("data_path")
)

if data_path_type == bool and cleaned_data.get("value") not in AbstractDataRule.BOOL_VALUES:
raise forms.ValidationError(_("Value must be a boolean."))

return cleaned_data


class DataRuleFormSet(ParentMixin, forms.BaseInlineFormSet): ...
class DataRuleForm(DataRuleExtensionsMixin, forms.ModelForm):
"""
Form for DataRule model.
"""
class Meta:
model = DataRule
fields = "__all__"

data_path = forms.ChoiceField()


class BadgeRequirementFormSet(ParentMixin, forms.BaseInlineFormSet): ...
class BadgeRequirementForm(forms.ModelForm):
class Meta:
model = BadgeRequirement
Expand All @@ -163,38 +158,14 @@ def __init__(self, *args, parent_instance=None, **kwargs):
self.fields["group"].initial = chr(65 + self.template.requirements.count())


class PenaltyDataRuleFormSet(forms.BaseInlineFormSet):
"""
Formset for PenaltyDataRule model.
"""
def get_form_kwargs(self, index):
"""
Pass parent instance to the form.
"""

kwargs = super().get_form_kwargs(index)
kwargs["parent_instance"] = self.instance
return kwargs


class PenaltyDataRuleForm(DataRuleBoolValidationMixin, forms.ModelForm):
class PenaltyDataRuleFormSet(ParentMixin, forms.BaseInlineFormSet): ...
class PenaltyDataRuleForm(DataRuleExtensionsMixin, forms.ModelForm):
"""
Form for PenaltyDataRule model.
"""

data_path = forms.ChoiceField()

class Meta:
model = PenaltyDataRule
fields = "__all__"

data_path = forms.ChoiceField()

def __init__(self, *args, parent_instance=None, **kwargs):
"""
Load data paths based on the parent instance event type.
"""
self.parent_instance = parent_instance
super().__init__(*args, **kwargs)

if self.parent_instance:
event_type = self.parent_instance.event_type
self.fields["data_path"].choices = Choices(*get_event_type_keypaths(event_type=event_type))
50 changes: 48 additions & 2 deletions credentials/apps/badges/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,21 @@ def extract_payload(public_signal_kwargs: dict) -> attr.s:
return value


def get_event_type_data(event_type: str) -> attr.s:
"""
Extracts the dataclass for a given event type.
Parameters:
- event_type: The event type to extract dataclass for.
Returns:
attr.s: The dataclass for the given event type.
"""

signal = OpenEdxPublicSignal.get_signal_by_type(event_type)
return extract_payload(signal.init_data)


def get_event_type_keypaths(event_type: str) -> list:
"""
Extracts all possible keypaths for a given event type.
Expand All @@ -126,8 +141,7 @@ def get_event_type_keypaths(event_type: str) -> list:
list: A list of all possible keypaths for the given event type.
"""

signal = OpenEdxPublicSignal.get_signal_by_type(event_type)
data = extract_payload(signal.init_data)
data = get_event_type_data(event_type)

def get_data_keypaths(data):
"""
Expand Down Expand Up @@ -157,3 +171,35 @@ def get_data_keypaths(data):
else:
keypaths.append(field.name)
return keypaths


def get_event_data_attr_type_by_keypath(event_type: str, keypath: str):
"""
Extracts the attribute type for a given keypath in the event data.
Parameters:
- event_type: The event type to extract dataclass for.
- keypath: The keypath to extract attribute type for.
Returns:
type: The attribute type for the given keypath in the event data.
"""

data = get_event_type_data(event_type)
data_attrs = attr.fields(data)

def get_attr_type_by_keypath(data_attrs, keypath):
"""
Extracts the attribute type for a given keypath in the dataclass.
"""

keypath_parts = keypath.split(".")
for attr_ in data_attrs:
if attr_.name == keypath_parts[0]:
if len(keypath_parts) == 1:
return attr_.type
elif attr.has(attr_.type):
return get_attr_type_by_keypath(attr.fields(attr_.type), ".".join(keypath_parts[1:]))
return None

return get_attr_type_by_keypath(data_attrs, keypath)

0 comments on commit 69bcad2

Please sign in to comment.