diff --git a/main.tf b/main.tf index 6a8e88c..76d5d80 100644 --- a/main.tf +++ b/main.tf @@ -51,9 +51,9 @@ module "lambda" { DYNAMODB_TIME_TO_LIVE = var.dynamodb_time_to_live DYNAMODB_TABLE_NAME = try(module.cloudtrail_to_slack_dynamodb_table[0].dynamodb_table_id, "") + USE_DEFAULT_RULES = var.use_default_rules + PUSH_ACCESS_DENIED_CLOUDWATCH_METRICS = var.push_access_denied_cloudwatch_metrics }, - var.use_default_rules ? { USE_DEFAULT_RULES = "True" } : {}, - var.push_access_denied_cloudwatch_metrics ? { PUSH_ACCESS_DENIED_CLOUDWATCH_METRICS = "True" } : {} ) memory_size = var.lambda_memory_size diff --git a/src/config.py b/src/config.py index 1292e79..ae2cdcb 100644 --- a/src/config.py +++ b/src/config.py @@ -61,12 +61,12 @@ def __init__(self): # noqa: ANN101 ANN204 self.rules_separator: str = os.environ.get("RULES_SEPARATOR", ",") self.user_rules: List[str] = self.parse_rules_from_string(os.environ.get("RULES"), self.rules_separator) # noqa: E501 self.ignore_rules: List[str] = self.parse_rules_from_string(os.environ.get("IGNORE_RULES"), self.rules_separator) # noqa: E501 - self.use_default_rules: bool = os.environ.get("USE_DEFAULT_RULES", True) # type: ignore # noqa: PGH003 + self.use_default_rules: bool = self.get_bool_from_env_var("USE_DEFAULT_RULES") self.events_to_track: str | None = os.environ.get("EVENTS_TO_TRACK") self.dynamodb_table_name: str | None = os.environ.get("DYNAMODB_TABLE_NAME") self.dynamodb_time_to_live: int = int(os.environ.get("DYNAMODB_TIME_TO_LIVE", 900)) - self.push_access_denied_cloudwatch_metrics: bool = os.environ.get("PUSH_ACCESS_DENIED_CLOUDWATCH_METRICS") # type: ignore # noqa: PGH003, E501 + self.push_access_denied_cloudwatch_metrics: bool = self.get_bool_from_env_var("PUSH_ACCESS_DENIED_CLOUDWATCH_METRICS") self.rules = [] if self.use_default_rules: @@ -87,6 +87,10 @@ def parse_rules_from_string(rules_as_string: str | None, rules_separator: str) - # make sure there are no empty strings in the list return [x for x in rules_as_list if x] + @staticmethod + def get_bool_from_env_var(env_var_name: str) -> bool: + return os.environ.get(env_var_name, "").lower() in ["true", "1"] + class JsonFormatter(logging.Formatter):