Skip to content

Commit

Permalink
Revert "Add update override arg to upsert_multi"
Browse files Browse the repository at this point in the history
  • Loading branch information
igorbenav authored Sep 6, 2024
1 parent a5458a8 commit c32275d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 91 deletions.
26 changes: 5 additions & 21 deletions fastcrud/crud/fast_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,6 @@ async def upsert_multi(
return_columns: Optional[list[str]] = None,
schema_to_select: Optional[type[BaseModel]] = None,
return_as_model: bool = False,
update_override: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Optional[Dict[str, Any]]:
"""
Expand All @@ -790,7 +789,6 @@ async def upsert_multi(
return_columns: Optional list of column names to return after the upsert operation.
schema_to_select: Optional Pydantic schema for selecting specific columns. Required if return_as_model is True.
return_as_model: If True, returns data as instances of the specified Pydantic model.
update_override: Optional dictionary to override the update values for the upsert operation.
**kwargs: Filters to identify the record(s) to update on conflict, supporting advanced comparison operators for refined querying.
Returns:
Expand All @@ -800,18 +798,12 @@ async def upsert_multi(
ValueError: If the MySQL dialect is used with filters, return_columns, schema_to_select, or return_as_model.
NotImplementedError: If the database dialect is not supported for upsert multi.
"""
if update_override is None:
update_override = {}
filters = self._parse_filters(**kwargs)

if db.bind.dialect.name == "postgresql":
statement, params = await self._upsert_multi_postgresql(
instances, filters, update_override
)
statement, params = await self._upsert_multi_postgresql(instances, filters)
elif db.bind.dialect.name == "sqlite":
statement, params = await self._upsert_multi_sqlite(
instances, filters, update_override
)
statement, params = await self._upsert_multi_sqlite(instances, filters)
elif db.bind.dialect.name in ["mysql", "mariadb"]:
if filters:
raise ValueError(
Expand All @@ -821,9 +813,7 @@ async def upsert_multi(
raise ValueError(
"MySQL does not support the returning clause for insert operations."
)
statement, params = await self._upsert_multi_mysql(
instances, update_override
)
statement, params = await self._upsert_multi_mysql(instances)
else: # pragma: no cover
raise NotImplementedError(
f"Upsert multi is not implemented for {db.bind.dialect.name}"
Expand All @@ -848,7 +838,6 @@ async def _upsert_multi_postgresql(
self,
instances: list[Union[UpdateSchemaType, CreateSchemaType]],
filters: list[ColumnElement],
update_set_override: dict[str, Any],
) -> tuple[Insert, list[dict]]:
statement = postgresql.insert(self.model)
statement = statement.on_conflict_do_update(
Expand All @@ -857,8 +846,7 @@ async def _upsert_multi_postgresql(
column.name: getattr(statement.excluded, column.name)
for column in self.model.__table__.columns
if not column.primary_key and not column.unique
}
| update_set_override,
},
where=and_(*filters) if filters else None,
)
params = [
Expand All @@ -870,7 +858,6 @@ async def _upsert_multi_sqlite(
self,
instances: list[Union[UpdateSchemaType, CreateSchemaType]],
filters: list[ColumnElement],
update_set_override: dict[str, Any],
) -> tuple[Insert, list[dict]]:
statement = sqlite.insert(self.model)
statement = statement.on_conflict_do_update(
Expand All @@ -879,8 +866,7 @@ async def _upsert_multi_sqlite(
column.name: getattr(statement.excluded, column.name)
for column in self.model.__table__.columns
if not column.primary_key and not column.unique
}
| update_set_override,
},
where=and_(*filters) if filters else None,
)
params = [
Expand All @@ -891,7 +877,6 @@ async def _upsert_multi_sqlite(
async def _upsert_multi_mysql(
self,
instances: list[Union[UpdateSchemaType, CreateSchemaType]],
update_set_override: dict[str, Any],
) -> tuple[Insert, list[dict]]:
statement = mysql.insert(self.model)
statement = statement.on_duplicate_key_update(
Expand All @@ -902,7 +887,6 @@ async def _upsert_multi_mysql(
and not column.unique
and column.name != self.deleted_at_column
}
| update_set_override,
)
params = [
self.model(**instance.model_dump()).__dict__ for instance in instances
Expand Down
70 changes: 0 additions & 70 deletions tests/sqlalchemy/crud/test_upsert.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,35 +58,6 @@ async def test_upsert_successful(async_session, test_model, read_schema):
marks=pytest.mark.dialect("postgresql"),
id="postgresql-dict",
),
pytest.param(
{
"kwargs": {"return_columns": ["id", "name"]},
"expected_result": {
"data": [
{
"id": 1,
"name": "New Record",
}
]
},
},
{
"kwargs": {
"return_columns": ["id", "name"],
"update_override": {"name": "New"},
},
"expected_result": {
"data": [
{
"id": 1,
"name": "New",
}
]
},
},
marks=pytest.mark.dialect("postgresql"),
id="postgresql-dict-update-override",
),
pytest.param(
{
"kwargs": {"return_columns": ["id", "name"]},
Expand Down Expand Up @@ -175,35 +146,6 @@ async def test_upsert_successful(async_session, test_model, read_schema):
marks=pytest.mark.dialect("sqlite"),
id="sqlite-dict",
),
pytest.param(
{
"kwargs": {"return_columns": ["id", "name"]},
"expected_result": {
"data": [
{
"id": 1,
"name": "New Record",
}
]
},
},
{
"kwargs": {
"return_columns": ["id", "name"],
"update_override": {"name": "New"},
},
"expected_result": {
"data": [
{
"id": 1,
"name": "New",
}
]
},
},
marks=pytest.mark.dialect("sqlite"),
id="sqlite-dict-update-override",
),
pytest.param(
{
"kwargs": {"return_columns": ["id", "name"]},
Expand Down Expand Up @@ -266,18 +208,6 @@ async def test_upsert_successful(async_session, test_model, read_schema):
marks=pytest.mark.dialect("mysql"),
id="mysql-none",
),
pytest.param(
{
"kwargs": {},
"expected_result": None,
},
{
"kwargs": {"update_override": {"name": "New"}},
"expected_result": None,
},
marks=pytest.mark.dialect("mysql"),
id="mysql-dict-update-override",
),
],
)
@pytest.mark.asyncio
Expand Down

0 comments on commit c32275d

Please sign in to comment.