diff --git a/c3s_eqc_automatic_quality_control/diagnostics.py b/c3s_eqc_automatic_quality_control/diagnostics.py index 4799392..5059d1a 100644 --- a/c3s_eqc_automatic_quality_control/diagnostics.py +++ b/c3s_eqc_automatic_quality_control/diagnostics.py @@ -137,13 +137,22 @@ def time_weighted_linear_trend( ) output["linear_trend"] *= 1.0e9 # 1/ns to 1/s - def attrs_func(attrs: dict[str, Any]) -> dict[str, Any]: + def attrs_func_linear(attrs: dict[str, Any]) -> dict[str, Any]: return { "long_name": f"Linear trend of {attrs.get('long_name', '')}", "units": f"{attrs.get('units', '')} s-1", } - output["linear_trend"] = _apply_attrs_func(output["linear_trend"], obj, attrs_func) + def attrs_func_rmse(attrs: dict[str, Any]) -> dict[str, Any]: + return { + "units": f"{attrs.get('units', '')}", + } + + output["linear_trend"] = _apply_attrs_func( + output["linear_trend"], obj, attrs_func_linear + ) + if "rmse" in output: + output["rmse"] = _apply_attrs_func(output["rmse"], obj, attrs_func_rmse) return output["linear_trend"] if not (p_value or rmse) else output