Skip to content

Commit

Permalink
feat: test
Browse files Browse the repository at this point in the history
  • Loading branch information
talboren committed Oct 14, 2024
1 parent 1e86599 commit eeaf6a2
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 29 deletions.
41 changes: 21 additions & 20 deletions keep/contextmanager/contextmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@

class ContextManager:
def __init__(self, tenant_id, workflow_id=None, workflow_execution_id=None):

self.logger = logging.getLogger(__name__)
self.logger_adapter = WorkflowLoggerAdapter(
self.logger, self, tenant_id, workflow_id, workflow_execution_id
)
self.workflow_id = workflow_id
self.workflow_execution_id = workflow_execution_id
self.tenant_id = tenant_id
self.logger = None
self.logger_adapter = None
self.set_logger(logging.getLogger(__name__))
self.steps_context = {}
self.steps_context_size = 0
self.providers_context = {}
Expand Down Expand Up @@ -54,6 +54,7 @@ def __init__(self, tenant_id, workflow_id=None, workflow_execution_id=None):
self.dependencies = set()
self.workflow_execution_id = None
self._api_key = None
self.__loggers = []

@property
def api_key(self):
Expand All @@ -74,23 +75,23 @@ def api_key(self):
def set_execution_context(self, workflow_execution_id):
self.workflow_execution_id = workflow_execution_id
self.logger_adapter.workflow_execution_id = workflow_execution_id

def set_logger_by_name(self, logger_name):
logger = logging.getLogger(logger_name)
log_level_string = os.environ.get("KEEP_PROVIDER_{}_LOG_LEVEL".format(logger_name.upper()), None)
if log_level_string:
log_level = logging.getLevelName(log_level_string)
logger.setLevel(log_level)
self.set_logger(logger)

def set_logger(self, logger):
self.logger = logger
self.logger_adapter = WorkflowLoggerAdapter(
self.logger, self, self.tenant_id, self.workflow_id, self.workflow_execution_id
for logger in self.__loggers:
logger.workflow_execution_id = workflow_execution_id

def get_logger(self, name=None):
if not name:
return self.logger_adapter

logger = logging.getLogger(name)
logger_adapter = WorkflowLoggerAdapter(
logger,
self,
self.tenant_id,
self.workflow_id,
self.workflow_execution_id,
)

def get_logger(self):
return self.logger_adapter
self.__loggers.append(logger_adapter)
return logger_adapter

def set_event_context(self, event):
self.event_context = event
Expand Down
3 changes: 1 addition & 2 deletions keep/providers/base/base_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ def __init__(
self.webhook_markdown = webhook_markdown
self.provider_description = provider_description
self.context_manager = context_manager
context_manager.set_logger_by_name("provider_{}".format(self.provider_id))
self.logger = context_manager.get_logger()
self.logger = context_manager.get_logger(self.__class__.__name__)
self.validate_config()
self.logger.debug(
"Base provider initalized", extra={"provider": self.__class__.__name__}
Expand Down
62 changes: 55 additions & 7 deletions tests/test_workflow_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,33 @@
"""


workflow_definition_with_two_providers = """workflow:
id: susu-and-sons
description: Just to test the logs of 2 providers
triggers:
- type: alert
filters:
- key: name
value: "server-is-hamburger"
steps:
- name: keep_step
provider:
type: keep
with:
filters:
- key: status
value: open
actions:
- name: console_action
provider:
type: console
with:
message: |
"Tier 1 Alert: {{ alert.name }} - {{ alert.description }}
Alert details: {{ alert }}"
"""


@pytest.fixture(scope="module")
def workflow_manager():
"""
Expand Down Expand Up @@ -77,6 +104,25 @@ def setup_workflow(db_session):
db_session.commit()


@pytest.fixture
def setup_workflow_with_two_providers(db_session):
"""
Fixture to set up a workflow in the database before each test.
It creates a Workflow object with the predefined workflow definition and adds it to the database.
"""
workflow = Workflow(
id="susu-and-sons",
name="susu-and-sons",
tenant_id=SINGLE_TENANT_UUID,
description="some stuff for unit testing",
created_by="[email protected]",
interval=0,
workflow_raw=workflow_definition_with_two_providers,
)
db_session.add(workflow)
db_session.commit()


@pytest.mark.parametrize(
"test_app, test_case, alert_statuses, expected_tier, db_session",
[
Expand Down Expand Up @@ -807,7 +853,7 @@ def test_workflow_execution_logs(
db_session,
test_app,
create_alert,
setup_workflow,
setup_workflow_with_two_providers,
workflow_manager,
test_case,
alert_statuses,
Expand All @@ -828,7 +874,7 @@ def test_workflow_execution_logs(
current_alert = AlertDto(
id="grafana-1",
source=["grafana"],
name="server-is-down",
name="server-is-hamburger",
status=AlertStatus.FIRING,
severity="critical",
fingerprint="fp1",
Expand All @@ -843,7 +889,7 @@ def test_workflow_execution_logs(
status = None
while workflow_execution is None and count < 30 and status != "success":
workflow_execution = get_last_workflow_execution_by_workflow_id(
SINGLE_TENANT_UUID, "alert-time-check"
SINGLE_TENANT_UUID, "susu-and-sons"
)
if workflow_execution is not None:
status = workflow_execution.status
Expand All @@ -854,8 +900,10 @@ def test_workflow_execution_logs(
assert workflow_execution is not None
assert workflow_execution.status == "success"

logs = db_session.query(WorkflowExecutionLog).filter(
WorkflowExecutionLog.workflow_execution_id == workflow_execution.id
).all()
logs = (
db_session.query(WorkflowExecutionLog)
.filter(WorkflowExecutionLog.workflow_execution_id == workflow_execution.id)
.all()
)

assert len(logs) == 4
assert len(logs) == 15

0 comments on commit eeaf6a2

Please sign in to comment.