Skip to content
This repository has been archived by the owner on Dec 30, 2024. It is now read-only.

Commit

Permalink
Merge pull request #1 from casavo/feat/operate-on-db-in-one-single-tr…
Browse files Browse the repository at this point in the history
…ansaction

feat: operate on db in one single transaction
  • Loading branch information
alex5995 authored Feb 22, 2023
2 parents 5fcec90 + cc7c973 commit 98de1e9
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 144 deletions.
5 changes: 5 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
PG_HOST=host
PG_PORT=5432
PG_DATABASE=db
PG_USER=user
PG_PASSWORD=password
31 changes: 31 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
![jolteon](https://assets.pokemon.com/assets/cms2/img/pokedex/full/135.png)

Are you a Lightdash user? Have you ever had to change the name of a metric, dimension or model in your dbt project?

If so, you'd know for sure that adapting the Lightdash charts to these changes is a huge pain.

This python package aims to partially solve this issue automatically updating the Lightdash database.

It works pretty well most of the times, but there are still some corner cases when you'll find your charts a little bit different after the migration. Anyway, this package will still save you hours of manual updates.

## How to install Jolteon

```
pip install jolteon
```

## How to use Jolteon

1. Create a `.env` file like the `.env.example` one you find in this repository and fill it with your Lightdash database connection parameters.

2. Create a `config.yaml` file like the `config_example.yaml` one you find in this repository. This file should be structured as follows:

- `old_table` should be filled with the previous name of your dbt model (if you have changed it) or with the current name of it (if you haven't changed it).

- `new_table` should be filled with the current name of your dbt model only when you have changed it, otherwise it should be left empty.

- `fields_raw_mapping` should be filled with the mapping of the metrics and the dimensions you have changed. If you haven't changed any metric or dimension, you can also leave it empty.

- `query_ids` should be filled with the ids of the charts you want to affect when updating the database. If you don't known what are the ids of the charts (and you probably won't the first time), you can run `jolteon get-ids`. You will be presented with a table containing the id, the name and the workspace of all the charts of your Lightdash instance.

3. Run `jolteon update-db config.yaml`.
4 changes: 4 additions & 0 deletions jolteon/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import warnings

from pydantic import BaseModel, validator

warnings.filterwarnings("ignore")


class Config(BaseModel):
old_table: str
Expand Down
35 changes: 5 additions & 30 deletions jolteon/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,23 @@

from jolteon import Config
from jolteon.modules import Updater, print_query_ids
from jolteon.utils import get_pg_connection
from jolteon.utils import get_connection

app = typer.Typer()


@app.command()
def get_ids(
pg_host: str = "localhost",
pg_port: str = "5432",
pg_database: str = "postgres",
pg_user: str = "postgres",
pg_password: str = "postgres",
) -> None:
conn = get_pg_connection(
host=pg_host,
port=pg_port,
database=pg_database,
user=pg_user,
password=pg_password,
)
def get_ids() -> None:
conn = get_connection()
print_query_ids(conn)


@app.command()
def update_db(
config_path: str,
pg_host: str = "localhost",
pg_port: str = "5432",
pg_database: str = "postgres",
pg_user: str = "postgres",
pg_password: str = "postgres",
) -> None:
def update_db(config_path: str) -> None:
with open(config_path) as f:
config = Config(**yaml.safe_load(f))

conn = get_pg_connection(
host=pg_host,
port=pg_port,
database=pg_database,
user=pg_user,
password=pg_password,
)
conn = get_connection()
updater = Updater(config, conn)
updater.overwrite_db()

Expand Down
200 changes: 102 additions & 98 deletions jolteon/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pandas as pd
import psycopg2
from pydantic import BaseModel

from jolteon import Config
from jolteon.utils import get_df_from_query
Expand All @@ -22,6 +23,18 @@ def print_query_ids(conn: psycopg2.extensions.connection) -> None:
)


class UpdateParams(BaseModel):
ids: pd.Series
new_values: pd.Series
table_name: str
field_id_name: str
field_name: str
field_type: str

class Config:
arbitrary_types_allowed = True


class Updater:
def __init__(self, config: Config, conn: psycopg2.extensions.connection) -> None:
self.config = config
Expand All @@ -35,29 +48,22 @@ def get_where_clause(ids: tuple[int, ...]) -> str:
return f"= {ids[0]}"
return f"IN {ids}"

def write_on_postgres(
self,
ids: pd.Series,
new_values: pd.Series,
table_name: str,
field_id_name: str,
field_name: str,
field_type: str,
) -> None:
zipped_vals = zip(ids, new_values)
tuple_to_str = str(tuple(zipped_vals))
entries_to_update = tuple_to_str[1 : len(tuple_to_str) - 1].strip(",")
if "[" in field_type:
entries_to_update = entries_to_update.replace("[", "ARRAY[")
if entries_to_update:
with self.conn.cursor() as cur:
update_sql_query = f"""
UPDATE {table_name} AS t
SET {field_name} = v.value::{field_type}
FROM (VALUES {entries_to_update}) AS v (id, value)
WHERE t.{field_id_name} = v.id
"""
cur.execute(update_sql_query)
def write_on_postgres(self, tasks: list[UpdateParams]) -> None:
with self.conn.cursor() as cur:
for task in tasks:
zipped_vals = zip(task.ids, task.new_values)
tuple_to_str = str(tuple(zipped_vals))
entries_to_update = tuple_to_str[1 : len(tuple_to_str) - 1].strip(",")
if "[" in task.field_type:
entries_to_update = entries_to_update.replace("[", "ARRAY[")
if entries_to_update:
update_sql_query = f"""
UPDATE {task.table_name} AS t
SET {task.field_name} = v.value::{task.field_type}
FROM (VALUES {entries_to_update}) AS v (id, value)
WHERE t.{task.field_id_name} = v.id
"""
cur.execute(update_sql_query)
self.conn.commit()

def get_saved_queries_version_ids(self) -> tuple[int, ...]:
Expand All @@ -75,7 +81,7 @@ def get_saved_queries_version_ids(self) -> tuple[int, ...]:
)["saved_queries_version_id"]
)

def get_fields_to_update(self, ids: tuple[int, ...]) -> pd.DataFrame:
def get_fields_to_update(self, ids: tuple[int, ...]) -> UpdateParams:
fields = get_df_from_query(
f"""
SELECT *
Expand All @@ -90,9 +96,16 @@ def get_fields_to_update(self, ids: tuple[int, ...]) -> pd.DataFrame:
x.replace(self.config.old_table, self.config.target_table),
)
)
return fields
return UpdateParams(
ids=fields["saved_queries_version_field_id"],
new_values=fields["new_name"],
table_name="saved_queries_version_fields",
field_id_name="saved_queries_version_field_id",
field_name="name",
field_type="VARCHAR",
)

def get_calculations_to_update(self, ids: tuple[int, ...]) -> pd.DataFrame:
def get_calculations_to_update(self, ids: tuple[int, ...]) -> UpdateParams:
def apply_mapping(s: str) -> str:
s = s.replace(self.config.old_table, self.config.target_table)
for k, v in self.config.calculations_mapping.items():
Expand All @@ -108,9 +121,16 @@ def apply_mapping(s: str) -> str:
self.conn,
)
calculations["new_calculations"] = calculations["calculation_raw_sql"].apply(apply_mapping)
return calculations
return UpdateParams(
ids=calculations["saved_queries_version_table_calculation_id"],
new_values=calculations["new_calculations"],
table_name="saved_queries_version_table_calculations",
field_id_name="saved_queries_version_table_calculation_id",
field_name="calculation_raw_sql",
field_type="VARCHAR",
)

def get_sorts_to_update(self, ids: tuple[int, ...]) -> pd.DataFrame:
def get_sorts_to_update(self, ids: tuple[int, ...]) -> UpdateParams:
sorts = get_df_from_query(
f"""
SELECT *
Expand All @@ -125,9 +145,16 @@ def get_sorts_to_update(self, ids: tuple[int, ...]) -> pd.DataFrame:
x.replace(self.config.old_table, self.config.target_table),
)
)
return sorts
return UpdateParams(
ids=sorts["saved_queries_version_sort_id"],
new_values=sorts["new_name"],
table_name="saved_queries_version_sorts",
field_id_name="saved_queries_version_sort_id",
field_name="field_name",
field_type="VARCHAR",
)

def get_charts_to_update(self, ids: tuple[int, ...]) -> pd.DataFrame:
def get_charts_to_update(self, ids: tuple[int, ...]) -> UpdateParams:
def apply_mapping(d: dict) -> str:
s = json.dumps(d).replace(self.config.old_table, self.config.target_table)
for k, v in self.config.fields_mapping.items():
Expand All @@ -143,9 +170,16 @@ def apply_mapping(d: dict) -> str:
self.conn,
)
charts["new_chart_config"] = charts["chart_config"].apply(apply_mapping)
return charts
return UpdateParams(
ids=charts["saved_queries_version_id"],
new_values=charts["new_chart_config"],
table_name="saved_queries_versions",
field_id_name="saved_queries_version_id",
field_name="chart_config",
field_type="JSONB",
)

def get_filters_to_update(self, ids: tuple[int, ...]) -> pd.DataFrame:
def get_filters_to_update(self, ids: tuple[int, ...]) -> UpdateParams:
filters = get_df_from_query(
f"""
SELECT saved_queries_version_id, filters
Expand All @@ -157,9 +191,16 @@ def get_filters_to_update(self, ids: tuple[int, ...]) -> pd.DataFrame:
filters["new_filters"] = filters["filters"].apply(
lambda x: json.dumps(x).replace(self.config.old_table, self.config.target_table)
)
return filters
return UpdateParams(
ids=filters["saved_queries_version_id"],
new_values=filters["new_filters"],
table_name="saved_queries_versions",
field_id_name="saved_queries_version_id",
field_name="filters",
field_type="JSONB",
)

def get_explore_names_to_update(self, ids: tuple[int, ...]) -> pd.DataFrame:
def get_explore_names_to_update(self, ids: tuple[int, ...]) -> UpdateParams:
explore_names = get_df_from_query(
f"""
SELECT saved_queries_version_id, explore_name
Expand All @@ -171,9 +212,16 @@ def get_explore_names_to_update(self, ids: tuple[int, ...]) -> pd.DataFrame:
explore_names["new_explore_name"] = explore_names["explore_name"].apply(
lambda x: self.config.target_table if x == self.config.old_table else x
)
return explore_names
return UpdateParams(
ids=explore_names["saved_queries_version_id"],
new_values=explore_names["new_explore_name"],
table_name="saved_queries_versions",
field_id_name="saved_queries_version_id",
field_name="explore_name",
field_type="VARCHAR",
)

def get_pivot_dimensions_to_update(self, ids: tuple[int, ...]) -> pd.DataFrame:
def get_pivot_dimensions_to_update(self, ids: tuple[int, ...]) -> UpdateParams:
pivot_dimensions = get_df_from_query(
f"""
SELECT saved_queries_version_id, pivot_dimensions
Expand All @@ -193,70 +241,26 @@ def get_pivot_dimensions_to_update(self, ids: tuple[int, ...]) -> pd.DataFrame:
if l is not None
else None
)
return pivot_dimensions.dropna(subset="new_pivot_dimensions")
pivot_dimensions.dropna(subset="new_pivot_dimensions", inplace=True)
return UpdateParams(
ids=pivot_dimensions["saved_queries_version_id"],
new_values=pivot_dimensions["new_pivot_dimensions"],
table_name="saved_queries_versions",
field_id_name="saved_queries_version_id",
field_name="pivot_dimensions",
field_type="VARCHAR[]",
)

def overwrite_db(self) -> None:
saved_queries_version_ids = self.get_saved_queries_version_ids()
fields = self.get_fields_to_update(saved_queries_version_ids)
self.write_on_postgres(
fields["saved_queries_version_field_id"],
fields["new_name"],
"saved_queries_version_fields",
"saved_queries_version_field_id",
"name",
"VARCHAR",
)
calculations = self.get_calculations_to_update(saved_queries_version_ids)
self.write_on_postgres(
calculations["saved_queries_version_table_calculation_id"],
calculations["new_calculations"],
"saved_queries_version_table_calculations",
"saved_queries_version_table_calculation_id",
"calculation_raw_sql",
"VARCHAR",
)
sorts = self.get_sorts_to_update(saved_queries_version_ids)
self.write_on_postgres(
sorts["saved_queries_version_sort_id"],
sorts["new_name"],
"saved_queries_version_sorts",
"saved_queries_version_sort_id",
"field_name",
"VARCHAR",
)
charts = self.get_charts_to_update(saved_queries_version_ids)
self.write_on_postgres(
charts["saved_queries_version_id"],
charts["new_chart_config"],
"saved_queries_versions",
"saved_queries_version_id",
"chart_config",
"JSONB",
)
filters = self.get_filters_to_update(saved_queries_version_ids)
self.write_on_postgres(
filters["saved_queries_version_id"],
filters["new_filters"],
"saved_queries_versions",
"saved_queries_version_id",
"filters",
"JSONB",
)
explore_names = self.get_explore_names_to_update(saved_queries_version_ids)
self.write_on_postgres(
explore_names["saved_queries_version_id"],
explore_names["new_explore_name"],
"saved_queries_versions",
"saved_queries_version_id",
"explore_name",
"VARCHAR",
)
pivot_dimensions = self.get_pivot_dimensions_to_update(saved_queries_version_ids)
self.write_on_postgres(
pivot_dimensions["saved_queries_version_id"],
pivot_dimensions["new_pivot_dimensions"],
"saved_queries_versions",
"saved_queries_version_id",
"pivot_dimensions",
"VARCHAR[]",
[
self.get_fields_to_update(saved_queries_version_ids),
self.get_calculations_to_update(saved_queries_version_ids),
self.get_sorts_to_update(saved_queries_version_ids),
self.get_charts_to_update(saved_queries_version_ids),
self.get_filters_to_update(saved_queries_version_ids),
self.get_explore_names_to_update(saved_queries_version_ids),
self.get_pivot_dimensions_to_update(saved_queries_version_ids),
]
)
Loading

0 comments on commit 98de1e9

Please sign in to comment.