From 8eb176db861e5b9ac573c9b0f2106b1eb4e5abc5 Mon Sep 17 00:00:00 2001 From: Serge Smertin Date: Fri, 15 Mar 2024 14:50:34 +0100 Subject: [PATCH] Added checks for UC-incompatible task clusters in Apache Airflow DAGs --- src/databricks/labs/pylint/airflow.py | 14 +++++++++----- tests/test_airflow.py | 28 +++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/src/databricks/labs/pylint/airflow.py b/src/databricks/labs/pylint/airflow.py index d1c21cb..fd543e4 100644 --- a/src/databricks/labs/pylint/airflow.py +++ b/src/databricks/labs/pylint/airflow.py @@ -30,7 +30,8 @@ def _check_tasks(self, value: astroid.NodeNG): for task in self._infer_value(inferred): if "new_cluster" not in task: continue - raise ValueError("new_cluster is missing data_security_mode") + if "data_security_mode" not in task["new_cluster"]: + self.add_message("missing-data-security-mode", node=value, args=(task["task_key"],)) def _check_job_clusters(self, value: astroid.NodeNG): for inferred in value.infer(): @@ -38,23 +39,26 @@ def _check_job_clusters(self, value: astroid.NodeNG): if "new_cluster" not in job_cluster: continue # add message that this job cluster is missing data_security_mode - self.add_message("missing-data-security-mode", node=value, args=(job_cluster["job_cluster_key"],)) + if "data_security_mode" not in job_cluster["new_cluster"]: + self.add_message("missing-data-security-mode", node=value, args=(job_cluster["job_cluster_key"],)) def _infer_value(self, value: astroid.NodeNG): - if isinstance(value, (str, int, bool, list, dict, type(None))): - return value if isinstance(value, astroid.Dict): return self._infer_dict(value) if isinstance(value, astroid.List): return self._infer_list(value) if isinstance(value, astroid.Const): return value.value + if isinstance(value, astroid.Tuple): + return tuple(self._infer_value(elem) for elem in value.elts) + if isinstance(value, astroid.DictUnpack): + return {self._infer_value(key): self._infer_value(value) for key, value in value.items} raise ValueError(f"Unsupported type {type(value)}") def _infer_dict(self, in_dict: astroid.Dict): out_dict = {} for in_key, in_value in in_dict.items: - out_key = self._infer_value(in_key.value) + out_key = self._infer_value(in_key) out_value = self._infer_value(in_value) out_dict[out_key] = out_value return out_dict diff --git a/tests/test_airflow.py b/tests/test_airflow.py index f7a036c..414852c 100644 --- a/tests/test_airflow.py +++ b/tests/test_airflow.py @@ -34,3 +34,31 @@ def test_missing_data_security_mode_in_job_clusters(lint_with): "[missing-data-security-mode] banana cluster missing 'data_security_mode' " "required for Unity Catalog compatibility" ) in messages + + +def test_missing_data_security_mode_in_task_clusters(lint_with): + messages = ( + lint_with(AirflowChecker) + << """from airflow.providers.databricks.operators.databricks import DatabricksSubmitRunOperator +tasks = [ + { + "task_key": "banana", + "notebook_task": { + "notebook_path": "/Shared/test", + }, + 'new_cluster': { + "spark_version": "7.3.x-scala2.12", + "node_type_id": "i3.xlarge", + "num_workers": 2, + }, + }, +] +DatabricksCreateJobsOperator( #@ + task_id="jobs_create_named", + tasks=tasks +)""" + ) + assert ( + "[missing-data-security-mode] banana cluster missing 'data_security_mode' " + "required for Unity Catalog compatibility" + ) in messages