diff --git a/docs/advanced/crud.md b/docs/advanced/crud.md new file mode 100644 index 0000000..70f0853 --- /dev/null +++ b/docs/advanced/crud.md @@ -0,0 +1,70 @@ + +# Advanced Use of FastCRUD + +FastCRUD offers a flexible and powerful approach to handling CRUD operations in FastAPI applications, leveraging the SQLAlchemy ORM. Beyond basic CRUD functionality, FastCRUD provides advanced features like `allow_multiple` for updates and deletes, and support for advanced filters (e.g., less than, greater than). These features enable more complex and fine-grained data manipulation and querying capabilities. + +## Allow Multiple Updates and Deletes + +One of FastCRUD's advanced features is the ability to update or delete multiple records at once based on specified conditions. This is particularly useful for batch operations where you need to modify or remove several records that match certain criteria. + +### Updating Multiple Records + +To update multiple records, you can set the `allow_multiple=True` parameter in the `update` method. This allows FastCRUD to apply the update to all records matching the given filters. + +```python +# Assuming setup for FastCRUD instance `item_crud` and SQLAlchemy async session `db` + +# Update all items priced below $10 to a new price +await item_crud.update( + db=db, + object={"price": 9.99}, + allow_multiple=True, + price__lt=10 +) +``` + +### Deleting Multiple Records + +Similarly, you can delete multiple records by using the `allow_multiple=True` parameter in the `delete` or `db_delete` method, depending on whether you're performing a soft or hard delete. + +```python +# Soft delete all items not sold in the last year +await item_crud.delete( + db=db, + allow_multiple=True, + last_sold__lt=datetime.datetime.now() - datetime.timedelta(days=365) +) +``` + +## Advanced Filters + +FastCRUD supports advanced filtering options, allowing you to query records using operators such as greater than (`__gt`), less than (`__lt`), and their inclusive counterparts (`__gte`, `__lte`). These filters can be used in any method that retrieves or operates on records, including `get`, `get_multi`, `exists`, `count`, `update`, and `delete`. + +### Using Advanced Filters + +The following examples demonstrate how to use advanced filters for querying and manipulating data: + +#### Fetching Records with Advanced Filters + +```python +# Fetch items priced between $5 and $20 +items = await item_crud.get_multi( + db=db, + price__gte=5, + price__lte=20 +) +``` + +#### Counting Records + +```python +# Count items added in the last month +item_count = await item_crud.count( + db=db, + added_at__gte=datetime.datetime.now() - datetime.timedelta(days=30) +) +``` + +## Conclusion + +The advanced features of FastCRUD, such as `allow_multiple` and support for advanced filters, empower developers to efficiently manage database records with complex conditions. By leveraging these capabilities, you can build more dynamic, robust, and scalable FastAPI applications that effectively interact with your data model. diff --git a/docs/advanced/endpoint.md b/docs/advanced/endpoint.md index c3690e1..022128f 100644 --- a/docs/advanced/endpoint.md +++ b/docs/advanced/endpoint.md @@ -165,6 +165,70 @@ my_router = crud_router( app.include_router(my_router) ``` +## Custom Soft Delete + +To implement custom soft delete columns using `EndpointCreator` and `crud_router` in FastCRUD, you need to specify the names of the columns used for indicating deletion status and the deletion timestamp in your model. FastCRUD provides flexibility in handling soft deletes by allowing you to configure these column names directly when setting up CRUD operations or API endpoints. + +Here's how to specify custom soft delete columns when utilizing `EndpointCreator` and `crud_router`: + +### Defining Models with Custom Soft Delete Columns + +First, ensure your SQLAlchemy model is equipped with the custom soft delete columns. Here's an example model with custom columns for soft deletion: + +```python +from sqlalchemy import Column, Integer, String, DateTime, Boolean +from sqlalchemy.ext.declarative import declarative_base +from datetime import datetime + +Base = declarative_base() + +class MyModel(Base): + __tablename__ = 'my_model' + id = Column(Integer, primary_key=True) + name = Column(String) + archived = Column(Boolean, default=False) # Custom soft delete column + archived_at = Column(DateTime) # Custom timestamp column for soft delete +``` + +### Using `EndpointCreator` and `crud_router` with Custom Soft Delete Columns + +When initializing `crud_router` or creating a custom `EndpointCreator`, you can pass the names of your custom soft delete columns through the `FastCRUD` initialization. This informs FastCRUD which columns to check and update for soft deletion operations. + +Here's an example of using `crud_router` with custom soft delete columns: + +```python +from fastapi import FastAPI +from fastcrud import FastCRUD, crud_router +from sqlalchemy.ext.asyncio import AsyncSession + +app = FastAPI() + +# Assuming async_session is your AsyncSession generator +# and MyModel is your SQLAlchemy model + +# Initialize FastCRUD with custom soft delete columns +my_model_crud = FastCRUD(MyModel, + is_deleted_column='archived', # Custom 'is_deleted' column name + deleted_at_column='archived_at' # Custom 'deleted_at' column name + ) + +# Setup CRUD router with the FastCRUD instance +app.include_router(crud_router( + session=async_session, + model=MyModel, + crud=my_model_crud, + create_schema=CreateMyModelSchema, + update_schema=UpdateMyModelSchema, + delete_schema=DeleteMyModelSchema, + path="/mymodel", + tags=["MyModel"] +)) +``` + +This setup ensures that the soft delete functionality within your application utilizes the `archived` and `archived_at` columns for marking records as deleted, rather than the default `is_deleted` and `deleted_at` fields. + +By specifying custom column names for soft deletion, you can adapt FastCRUD to fit the design of your database models, providing a flexible solution for handling deleted records in a way that best suits your application's needs. + ## Conclusion The `EndpointCreator` class in FastCRUD offers flexibility and control over CRUD operations and custom endpoint creation. By extending this class or using the `included_methods` and `deleted_methods` parameters, you can tailor your API's functionality to your specific requirements, ensuring a more customizable and streamlined experience. diff --git a/docs/advanced/overview.md b/docs/advanced/overview.md new file mode 100644 index 0000000..d820d64 --- /dev/null +++ b/docs/advanced/overview.md @@ -0,0 +1,28 @@ +# Advanced Usage Overview + +The Advanced section of our documentation delves into the sophisticated capabilities and features of our application, tailored for users looking to leverage advanced functionalities. This part of our guide aims to unlock deeper insights and efficiencies through more complex use cases and configurations. + +## Key Topics + +### 1. Advanced Filtering and Searching +Explore how to implement advanced filtering and searching capabilities in your application. This guide covers the use of comparison operators (such as greater than, less than, etc.), pattern matching, and more to perform complex queries. + +- [Advanced Filtering Guide](crud.md#advanced-filters) + +### 2. Bulk Operations and Batch Processing +Learn how to efficiently handle bulk operations and batch processing. This section provides insights into performing mass updates, deletes, and inserts, optimizing performance for large datasets. + +- [Bulk Operations Guide](crud.md#allow-multiple-updates-and-deletes) + +### 3. Soft Delete Mechanisms and Strategies +Understand the implementation of soft delete mechanisms within our application. This guide covers configuring and using custom columns for soft deletes, restoring deleted records, and filtering queries to exclude soft-deleted entries. + +- [Soft Delete Strategies Guide](endpoint.md#custom-soft-delete) + +### 4. Advanced Use of EndpointCreator and crud_router +This topic extends the use of `EndpointCreator` and `crud_router` for advanced endpoint management, including creating custom routes, selective method exposure, and integrating soft delete functionalities. + +- [Advanced Endpoint Management Guide](endpoint.md#advanced-use-of-endpointcreator) + +## Prerequisites +Advanced usage assumes a solid understanding of the basic features and functionalities of our application. Knowledge of FastAPI, SQLAlchemy, and Pydantic is highly recommended to fully grasp the concepts discussed. diff --git a/docs/usage/crud.md b/docs/usage/crud.md index 7350961..b50cceb 100644 --- a/docs/usage/crud.md +++ b/docs/usage/crud.md @@ -65,8 +65,9 @@ new_item = await item_crud.create(db, ItemCreateSchema(name="New Item")) get( db: AsyncSession, schema_to_select: Optional[type[BaseModel]] = None, + return_as_model: bool = False, **kwargs: Any -) -> Optional[dict] +) -> Optional[Union[dict, BaseModel]] ``` **Purpose**: To fetch a single record based on filters, with an option to select specific columns using a Pydantic schema. @@ -136,6 +137,7 @@ items = await item_crud.get_multi(db, offset=10, limit=5) update( db: AsyncSession, object: Union[UpdateSchemaType, dict[str, Any]], + allow_multiple: bool = False, **kwargs: Any ) -> None ``` @@ -153,6 +155,7 @@ await item_crud.update(db, ItemUpdateSchema(description="Updated"), id=item_id) delete( db: AsyncSession, db_row: Optional[Row] = None, + allow_multiple: bool = False, **kwargs: Any ) -> None ``` @@ -169,6 +172,7 @@ await item_crud.delete(db, id=item_id) ```python db_delete( db: AsyncSession, + allow_multiple: bool = False, **kwargs: Any ) -> None ``` @@ -217,7 +221,8 @@ get_joined( join_on: Optional[Union[Join, None]] = None, schema_to_select: Optional[type[BaseModel]] = None, join_schema_to_select: Optional[type[BaseModel]] = None, - join_type: str = "left", **kwargs: Any + join_type: str = "left", + **kwargs: Any ) -> Optional[dict[str, Any]] ``` diff --git a/fastcrud/crud/fast_crud.py b/fastcrud/crud/fast_crud.py index 0737cbf..3caa9d8 100644 --- a/fastcrud/crud/fast_crud.py +++ b/fastcrud/crud/fast_crud.py @@ -2,17 +2,18 @@ from datetime import datetime, timezone from pydantic import BaseModel, ValidationError -import sqlalchemy.sql.selectable -from sqlalchemy import select, update, delete, func, and_, inspect, asc, desc, true -from sqlalchemy.exc import ArgumentError +from sqlalchemy import select, update, delete, func, inspect, asc, desc +from sqlalchemy.exc import ArgumentError, MultipleResultsFound, NoResultFound + from sqlalchemy.sql import Join from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.engine.row import Row from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.sql.elements import BinaryExpression +from sqlalchemy.sql.selectable import Select from .helper import ( _extract_matching_columns_from_schema, - _extract_matching_columns_from_kwargs, _auto_detect_join_condition, _add_column_with_prefix, ) @@ -41,50 +42,42 @@ class FastCRUD( Args: model: The SQLAlchemy model type. + is_deleted_column: Optional column name to use for indicating a soft delete. Defaults to "is_deleted". + deleted_at_column: Optional column name to use for storing the timestamp of a soft delete. Defaults to "deleted_at". Methods: - create(db: AsyncSession, object: CreateSchemaType) -> ModelType: - Creates a new record in the database. The 'object' parameter is a Pydantic schema - containing the data to be saved. + create: + Creates a new record in the database from the provided Pydantic schema. - get(db: AsyncSession, schema_to_select: Optional[Union[type[BaseModel], list]] = None, **kwargs: Any) -> Optional[dict]: - Retrieves a single record based on filters. You can specify a Pydantic schema to - select specific columns, and pass filter conditions as keyword arguments. + get: + Retrieves a single record based on filters. Supports advanced filtering through comparison operators like '__gt', '__lt', etc. - exists(db: AsyncSession, **kwargs: Any) -> bool: - Checks if a record exists based on the provided filters. Returns True if the record - exists, False otherwise. + exists: + Checks if a record exists based on the provided filters. - count(db: AsyncSession, **kwargs: Any) -> int: - Counts the number of records matching the provided filters. Useful for pagination - and analytics. + count: + Counts the number of records matching the provided filters. - get_multi(db: AsyncSession, offset: int = 0, limit: int = 100, schema_to_select: Optional[type[BaseModel]] = None, sort_columns: Optional[Union[str, list[str]]] = None, sort_orders: Optional[Union[str, list[str]]] = None, return_as_model: bool = False, **kwargs: Any) -> dict[str, Any]: + get_multi: Fetches multiple records with optional sorting, pagination, and model conversion. - Filters, sorting, and pagination parameters can be provided. - get_joined(db: AsyncSession, join_model: type[ModelType], join_prefix: Optional[str] = None, join_on: Optional[Union[Join, None]] = None, schema_to_select: Optional[Union[type[BaseModel], list]] = None, join_schema_to_select: Optional[Union[type[BaseModel], list]] = None, join_type: str = "left", **kwargs: Any) -> Optional[dict[str, Any]]: - Performs a join operation with another model. Supports custom join conditions and - selection of specific columns using Pydantic schemas. + get_joined: + Performs a join operation with another model, supporting custom join conditions and selection of specific columns. - get_multi_joined(db: AsyncSession, join_model: type[ModelType], join_prefix: Optional[str] = None, join_on: Optional[Join] = None, schema_to_select: Optional[type[BaseModel]] = None, join_schema_to_select: Optional[type[BaseModel]] = None, join_type: str = "left", offset: int = 0, limit: int = 100, sort_columns: Optional[Union[str, list[str]]] = None, sort_orders: Optional[Union[str, list[str]]] = None, return_as_model: bool = False, **kwargs: Any) -> dict[str, Any]: - Similar to 'get_joined', but for fetching multiple records. Offers pagination and - sorting functionalities for the joined tables. + get_multi_joined: + Fetches multiple records with a join on another model, offering pagination and sorting for the joined tables. - get_multi_by_cursor(db: AsyncSession, cursor: Any = None, limit: int = 100, schema_to_select: Optional[type[BaseModel]] = None, sort_column: str = "id", sort_order: str = "asc", **kwargs: Any) -> dict[str, Any]: - Implements cursor-based pagination for fetching records. Useful for large datasets - and infinite scrolling features. + get_multi_by_cursor: + Implements cursor-based pagination for fetching records, ideal for large datasets and infinite scrolling features. - update(db: AsyncSession, object: Union[UpdateSchemaType, dict[str, Any]], **kwargs: Any) -> None: - Updates an existing record. The 'object' can be a Pydantic schema or dictionary - containing update data. + update: + Updates an existing record or multiple records based on specified filters. - db_delete(db: AsyncSession, **kwargs: Any) -> None: - Hard deletes a record from the database based on provided filters. + db_delete: + Hard deletes a record or multiple records from the database based on provided filters. - delete(db: AsyncSession, db_row: Optional[Row] = None, **kwargs: Any) -> None: - Soft deletes a record if it has an "is_deleted" attribute; otherwise, performs a - hard delete. Filters or an existing database row can be provided for deletion. + delete: + Soft deletes a record if it has an "is_deleted" attribute; otherwise, performs a hard delete. Examples: Example 1: Basic Usage @@ -145,17 +138,58 @@ class FastCRUD( completed_tasks = await task_crud.get_multi(db, status='completed') high_priority_task_count = await task_crud.count(db, priority='high') ``` + + Example 6: Using Custom Column Names for Soft Delete + ---------------------------------------------------- + If your model uses different column names for indicating a soft delete and its timestamp, you can specify these when creating the FastCRUD instance. + ```python + custom_user_crud = FastCRUD(User, UserCreateSchema, UserUpdateSchema, is_deleted_column="archived", deleted_at_column="archived_at") + # Now 'archived' and 'archived_at' will be used for soft delete operations. + ``` """ - def __init__(self, model: type[ModelType]) -> None: + def __init__( + self, + model: type[ModelType], + is_deleted_column: str = "is_deleted", + deleted_at_column: str = "deleted_at", + ) -> None: self.model = model + self.is_deleted_column = is_deleted_column + self.deleted_at_column = deleted_at_column + + def _parse_filters(self, **kwargs) -> list[BinaryExpression]: + filters = [] + for key, value in kwargs.items(): + if "__" in key: + field_name, op = key.rsplit("__", 1) + column = getattr(self.model, field_name, None) + if column is None: + raise ValueError(f"Invalid filter column: {field_name}") + + if op == "gt": + filters.append(column > value) + elif op == "lt": + filters.append(column < value) + elif op == "gte": + filters.append(column >= value) + elif op == "lte": + filters.append(column <= value) + elif op == "ne": + filters.append(column != value) + else: + column = getattr(self.model, key, None) + if column is not None: + filters.append(column == value) + + return filters def _apply_sorting( self, - stmt: sqlalchemy.sql.selectable.Select, + stmt: Select, sort_columns: Union[str, list[str]], sort_orders: Optional[Union[str, list[str]]] = None, - ) -> sqlalchemy.sql.selectable.Select: + ) -> Select: """ Apply sorting to a SQLAlchemy query based on specified column names and sort orders. @@ -246,79 +280,159 @@ async def get( self, db: AsyncSession, schema_to_select: Optional[type[BaseModel]] = None, + return_as_model: bool = False, **kwargs: Any, - ) -> Optional[dict]: + ) -> Optional[Union[dict, BaseModel]]: """ - Fetch a single record based on filters. + Fetches a single record based on specified filters. + This method allows for advanced filtering through comparison operators, enabling queries to be refined beyond simple equality checks. + Supported operators include: + '__gt' (greater than), + '__lt' (less than), + '__gte' (greater than or equal to), + '__lte' (less than or equal to), and + '__ne' (not equal). Args: - db: The SQLAlchemy async session. - schema_to_select: Pydantic schema for selecting specific columns. - Default is None to select all columns. - **kwargs: Filters to apply to the query. + db: The database session to use for the operation. + schema_to_select: Optional Pydantic schema for selecting specific columns. + **kwargs: Filters to apply to the query, using field names for direct matches or appending comparison operators for advanced queries. + + Raises: + ValueError: If return_as_model is True but schema_to_select is not provided. Returns: - The fetched database row or None if not found. + A dictionary or a Pydantic model instance of the fetched database row, or None if no match is found. + + Examples: + Fetch a user by ID: + ```python + user = await crud.get(db, id=1) + ``` + + Fetch a user with an age greater than 30: + ```python + user = await crud.get(db, age__gt=30) + ``` + + Fetch a user with a registration date before Jan 1, 2020: + ```python + user = await crud.get(db, registration_date__lt=datetime(2020, 1, 1)) + ``` + + Fetch a user not equal to a specific username: + ```python + user = await crud.get(db, username__ne='admin') + ``` """ to_select = _extract_matching_columns_from_schema( model=self.model, schema=schema_to_select ) - stmt = select(*to_select).filter_by(**kwargs) + filters = self._parse_filters(**kwargs) + stmt = select(*to_select).filter(*filters) db_row = await db.execute(stmt) result: Row = db_row.first() if result is not None: out: dict = dict(result._mapping) + if return_as_model: + if not schema_to_select: + raise ValueError( + "schema_to_select must be provided when return_as_model is True." + ) + return schema_to_select(**out) return out return None async def exists(self, db: AsyncSession, **kwargs: Any) -> bool: """ - Check if a record exists based on filters. + Checks if any records exist that match the given filter conditions. + This method supports advanced filtering with comparison operators: + '__gt' (greater than), + '__lt' (less than), + '__gte' (greater than or equal to), + '__lte' (less than or equal to), and + '__ne' (not equal). Args: - db: The SQLAlchemy async session. - **kwargs: Filters to apply to the query. + db: The database session to use for the operation. + **kwargs: Filters to apply to the query, supporting both direct matches and advanced comparison operators for refined search criteria. Returns: - True if a record exists, False otherwise. + True if at least one record matches the filter conditions, False otherwise. + + Examples: + Fetch a user by ID exists: + ```python + exists = await crud.exists(db, id=1) + ``` + + Check if any user is older than 30: + ```python + exists = await crud.exists(db, age__gt=30) + ``` + + Check if any user registered before Jan 1, 2020: + ```python + exists = await crud.exists(db, registration_date__lt=datetime(2020, 1, 1)) + ``` + + Check if a username other than 'admin' exists: + ```python + exists = await crud.exists(db, username__ne='admin') + ``` """ - to_select = _extract_matching_columns_from_kwargs( - model=self.model, kwargs=kwargs - ) - stmt = select(*to_select).filter_by(**kwargs).limit(1) + filters = self._parse_filters(**kwargs) + stmt = select(self.model).filter(*filters).limit(1) result = await db.execute(stmt) return result.first() is not None async def count(self, db: AsyncSession, **kwargs: Any) -> int: """ - Count the records based on filters. + Counts records that match specified filters, supporting advanced filtering through comparison operators: + '__gt' (greater than), + '__lt' (less than), + '__gte' (greater than or equal to), + '__lte' (less than or equal to), and + '__ne' (not equal). Args: - db: The SQLAlchemy async session. - **kwargs: Filters to apply to the query. + db: The database session to use for the operation. + **kwargs: Filters to apply for the count, including field names for equality checks or with comparison operators for advanced queries. Returns: - Total count of records that match the applied filters. + The total number of records matching the filter conditions. - Note: - This method provides a quick way to get the count of records without retrieving the actual data. + Examples: + Count users by ID: + ```python + exists = await crud.count(db, id=1) + ``` + + Count users older than 30: + ```python + exists = await crud.count(db, age__gt=30) + ``` + + Count users who registered before Jan 1, 2020: + ```python + exists = await crud.count(db, registration_date__lt=datetime(2020, 1, 1)) + ``` + + Count users with a username other than 'admin': + ```python + exists = await crud.count(db, username__ne='admin') + ``` """ - conditions = [ - getattr(self.model, key) == value for key, value in kwargs.items() - ] - if conditions: - combined_conditions = and_(*conditions) + filters = self._parse_filters(**kwargs) + if filters: + count_query = select(func.count()).select_from(self.model).filter(*filters) else: - combined_conditions = true() + count_query = select(func.count()).select_from(self.model) - count_query = ( - select(func.count()).select_from(self.model).where(combined_conditions) - ) total_count: int = await db.scalar(count_query) - return total_count async def get_multi( @@ -333,20 +447,25 @@ async def get_multi( **kwargs: Any, ) -> dict[str, Any]: """ - Fetch multiple records based on filters, with optional sorting, pagination, and model conversion. + Fetches multiple records based on filters, supporting sorting, pagination, and advanced filtering with comparison operators: + '__gt' (greater than), + '__lt' (less than), + '__gte' (greater than or equal to), + '__lte' (less than or equal to), and + '__ne' (not equal). Args: - db: The SQLAlchemy async session. - offset: Number of rows to skip before fetching. Must be non-negative. - limit: Maximum number of rows to fetch. Must be non-negative. - schema_to_select: Pydantic schema for selecting specific columns. - sort_columns: Single column name or a list of column names for sorting. - sort_orders: Single sort direction ('asc' or 'desc') or a list of directions corresponding to the columns in sort_columns. Defaults to 'asc'. - return_as_model: If True, returns the data as instances of the Pydantic model. - **kwargs: Filters to apply to the query. + db: The database session to use for the operation. + offset: Starting index for records to fetch, useful for pagination. + limit: Maximum number of records to fetch in one call. + schema_to_select: Optional Pydantic schema for selecting specific columns. Required if `return_as_model` is True. + sort_columns: Column names to sort the results by. + sort_orders: Corresponding sort orders ('asc', 'desc') for each column in sort_columns. + return_as_model: If True, returns data as instances of the specified Pydantic model. + **kwargs: Filters to apply to the query, including advanced comparison operators for more detailed querying. Returns: - A dictionary containing the fetched rows under 'data' key and total count under 'total_count'. + A dictionary containing 'data' with fetched records and 'total_count' indicating the total number of records matching the filters. Raises: ValueError: If limit or offset is negative, or if schema_to_select is required but not provided or invalid. @@ -357,11 +476,26 @@ async def get_multi( users = await crud.get_multi(db, 0, 10) ``` - Fetch next 10 users with sorting: + Fetch next 10 users with sorted by username: ```python users = await crud.get_multi(db, 10, 10, sort_columns='username', sort_orders='desc') ``` + Fetch 10 users older than 30, sorted by age in descending order: + ```python + get_multi(db, offset=0, limit=10, age__gt=30, sort_columns='age', sort_orders='desc') + ``` + + Fetch 10 users with a registration date before Jan 1, 2020: + ```python + get_multi(db, offset=0, limit=10, registration_date__lt=datetime(2020, 1, 1)) + ``` + + Fetch 10 users with a username other than 'admin', returning as model instances (ensure appropriate schema is passed): + ```python + get_multi(db, offset=0, limit=10, username__ne='admin', schema_to_select=UserSchema, return_as_model=True) + ``` + Fetch users with filtering and multiple column sorting: ```python users = await crud.get_multi(db, 0, 10, is_active=True, sort_columns=['username', 'email'], sort_orders=['asc', 'desc']) @@ -371,7 +505,8 @@ async def get_multi( raise ValueError("Limit and offset must be non-negative.") to_select = _extract_matching_columns_from_schema(self.model, schema_to_select) - stmt = select(*to_select).filter_by(**kwargs) + filters = self._parse_filters(**kwargs) + stmt = select(*to_select).filter(*filters) if sort_columns: stmt = self._apply_sorting(stmt, sort_columns, sort_orders) @@ -408,7 +543,12 @@ async def get_joined( ) -> Optional[dict[str, Any]]: """ Fetches a single record with a join on another model. If 'join_on' is not provided, the method attempts - to automatically detect the join condition using foreign key relationships. + to automatically detect the join condition using foreign key relationships. Advanced filters supported: + '__gt' (greater than), + '__lt' (less than), + '__gte' (greater than or equal to), + '__lte' (less than or equal to), and + '__ne' (not equal). Args: db: The SQLAlchemy async session. @@ -416,13 +556,13 @@ async def get_joined( join_prefix: Optional prefix to be added to all columns of the joined model. If None, no prefix is added. join_on: SQLAlchemy Join object for specifying the ON clause of the join. If None, the join condition is auto-detected based on foreign keys. - schema_to_select: Pydantic schema for selecting specific columns from the primary model. + schema_to_select: Pydantic schema for selecting specific columns from the primary model. Required if `return_as_model` is True. join_schema_to_select: Pydantic schema for selecting specific columns from the joined model. join_type: Specifies the type of join operation to perform. Can be "left" for a left outer join or "inner" for an inner join. - **kwargs: Filters to apply to the query. + **kwargs: Filters to apply to the primary model query, supporting advanced comparison operators for refined searching. Returns: - The fetched database row or None if not found. + A dictionary representing the joined record, or None if no record matches the criteria. Examples: Simple example: Joining User and Tier models without explicitly providing join_on @@ -435,6 +575,21 @@ async def get_joined( ) ``` + Fetch a user and their associated tier, filtering by user ID: + ```python + get_joined(db, User, Tier, schema_to_select=UserSchema, join_schema_to_select=TierSchema, id=1) + ``` + + Fetch a user and their associated tier, where the user's age is greater than 30: + ```python + get_joined(db, User, Tier, schema_to_select=UserSchema, join_schema_to_select=TierSchema, age__gt=30) + ``` + + Fetch a user and their associated tier, excluding users with the 'admin' username: + ```python + get_joined(db, User, Tier, schema_to_select=UserSchema, join_schema_to_select=TierSchema, username__ne='admin') + ``` + Complex example: Joining with a custom join condition, additional filter parameters, and a prefix ```python from sqlalchemy import and_ @@ -502,9 +657,9 @@ async def get_joined( f"Invalid join type: {join_type}. Only 'left' or 'inner' are valid." ) - for key, value in kwargs.items(): - if hasattr(self.model, key): - stmt = stmt.where(getattr(self.model, key) == value) + filters = self._parse_filters(**kwargs) + if filters: + stmt = stmt.filter(*filters) db_row = await db.execute(stmt) result: Row = db_row.first() @@ -531,14 +686,20 @@ async def get_multi_joined( **kwargs: Any, ) -> dict[str, Any]: """ - Fetch multiple records with a join on another model, allowing for pagination, optional sorting, and model conversion. + Fetch multiple records with a join on another model, allowing for pagination, optional sorting, and model conversion, + supporting advanced filtering with comparison operators: + '__gt' (greater than), + '__lt' (less than), + '__gte' (greater than or equal to), + '__lte' (less than or equal to), and + '__ne' (not equal). Args: db: The SQLAlchemy async session. join_model: The model to join with. join_prefix: Optional prefix to be added to all columns of the joined model. If None, no prefix is added. join_on: SQLAlchemy Join object for specifying the ON clause of the join. If None, the join condition is auto-detected based on foreign keys. - schema_to_select: Pydantic schema for selecting specific columns from the primary model. + schema_to_select: Pydantic schema for selecting specific columns from the primary model. Required if `return_as_model` is True. join_schema_to_select: Pydantic schema for selecting specific columns from the joined model. join_type: Specifies the type of join operation to perform. Can be "left" for a left outer join or "inner" for an inner join. offset: The offset (number of records to skip) for pagination. @@ -546,7 +707,7 @@ async def get_multi_joined( sort_columns: A single column name or a list of column names on which to apply sorting. sort_orders: A single sort order ('asc' or 'desc') or a list of sort orders corresponding to the columns in sort_columns. If not provided, defaults to 'asc' for each column. return_as_model: If True, converts the fetched data to Pydantic models based on schema_to_select. Defaults to False. - **kwargs: Filters to apply to the primary query. + **kwargs: Filters to apply to the primary query, including advanced comparison operators for refined searching. Returns: A dictionary containing the fetched rows under 'data' key and total count under 'total_count'. @@ -556,43 +717,76 @@ async def get_multi_joined( Examples: Fetching multiple User records joined with Tier records, using left join, returning raw data: - >>> users = await crud_user.get_multi_joined( - db=session, - join_model=Tier, - join_prefix="tier_", - schema_to_select=UserSchema, - join_schema_to_select=TierSchema, - offset=0, - limit=10 - ) + ```python + users = await crud_user.get_multi_joined( + db=session, + join_model=Tier, + join_prefix="tier_", + schema_to_select=UserSchema, + join_schema_to_select=TierSchema, + offset=0, + limit=10 + ) + ``` + + Fetch users joined with their tiers, sorted by username, where user's age is greater than 30: + ```python + users = get_multi_joined( + db, + User, + Tier, + schema_to_select=UserSchema, + join_schema_to_select=TierSchema, + age__gt=30, + sort_columns='username', + sort_orders='asc' + ) + ``` + + Fetch users joined with their tiers, excluding users with 'admin' username, returning as model instances: + ```python + users = get_multi_joined( + db, + User, + Tier, + schema_to_select=UserSchema, + join_schema_to_select=TierSchema, + username__ne='admin', + return_as_model=True + ) + ``` Fetching and sorting by username in descending order, returning as Pydantic model: - >>> users = await crud_user.get_multi_joined( - db=session, - join_model=Tier, - join_prefix="tier_", - schema_to_select=UserSchema, - join_schema_to_select=TierSchema, - offset=0, - limit=10, - sort_columns=['username'], - sort_orders=['desc'], - return_as_model=True - ) + ```python + users = await crud_user.get_multi_joined( + db=session, + join_model=Tier, + join_prefix="tier_", + schema_to_select=UserSchema, + join_schema_to_select=TierSchema, + offset=0, + limit=10, + sort_columns=['username'], + sort_orders=['desc'], + return_as_model=True + ) + ``` Fetching with complex conditions and custom join, returning as Pydantic model: - >>> users = await crud_user.get_multi_joined( - db=session, - join_model=Tier, - join_prefix="tier_", - join_on=User.tier_id == Tier.id, - schema_to_select=UserSchema, - join_schema_to_select=TierSchema, - offset=0, - limit=10, - is_active=True, - return_as_model=True - ) + ```python + users = await crud_user.get_multi_joined( + db=session, + join_model=Tier, + join_prefix="tier_", + join_on=User.tier_id == Tier.id, + schema_to_select=UserSchema, + join_schema_to_select=TierSchema, + offset=0, + limit=10, + is_active=True, + return_as_model=True + ) + ``` """ if limit < 0 or offset < 0: raise ValueError("Limit and offset must be non-negative.") @@ -628,23 +822,22 @@ async def get_multi_joined( f"Invalid join type: {join_type}. Only 'left' or 'inner' are valid." ) - for key, value in kwargs.items(): - if hasattr(self.model, key): - stmt = stmt.where(getattr(self.model, key) == value) + filters = self._parse_filters(**kwargs) + if filters: + stmt = stmt.filter(*filters) if sort_columns: stmt = self._apply_sorting(stmt, sort_columns, sort_orders) stmt = stmt.offset(offset).limit(limit) - db_rows = await db.execute(stmt) - data = [dict(row._mapping) for row in db_rows] + result = await db.execute(stmt) + data = result.mappings().all() if return_as_model and schema_to_select: data = [schema_to_select.model_construct(**row) for row in data] total_count = await self.count(db=db, **kwargs) - return {"data": data, "total_count": total_count} async def get_multi_by_cursor( @@ -658,7 +851,13 @@ async def get_multi_by_cursor( **kwargs: Any, ) -> dict[str, Any]: """ - Fetch multiple records based on a cursor for pagination, with optional sorting. + Implements cursor-based pagination for fetching records. This method is designed for efficient data retrieval in large datasets and is ideal for features like infinite scrolling. + It supports advanced filtering with comparison operators: + '__gt' (greater than), + '__lt' (less than), + '__gte' (greater than or equal to), + '__lte' (less than or equal to), and + '__ne' (not equal). Args: db: The SQLAlchemy async session. @@ -667,18 +866,30 @@ async def get_multi_by_cursor( schema_to_select: Pydantic schema for selecting specific columns. sort_column: Column name to use for sorting and cursor pagination. sort_order: Sorting direction, either 'asc' or 'desc'. - **kwargs: Additional filters to apply to the query. + **kwargs: Filters to apply to the query, including advanced comparison operators for detailed querying. Returns: A dictionary containing the fetched rows under 'data' key and the next cursor value under 'next_cursor'. - Usage Examples: - # Fetch the first set of records (e.g., the first page in an infinite scrolling scenario) - >>> first_page = await crud.get_multi_by_cursor(db, limit=10, sort_column='created_at', sort_order='desc') + Examples: + Fetch the first set of records (e.g., the first page in an infinite scrolling scenario) + ```python + first_page = await crud.get_multi_by_cursor(db, limit=10, sort_column='created_at', sort_order='desc') + + Fetch the next set of records using the cursor from the first page + next_cursor = first_page['next_cursor'] + second_page = await crud.get_multi_by_cursor(db, cursor=next_cursor, limit=10, sort_column='created_at', sort_order='desc') + ``` + + Fetch records with age greater than 30 using cursor-based pagination: + ```python + get_multi_by_cursor(db, limit=10, sort_column='age', sort_order='asc', age__gt=30) + ``` - # Fetch the next set of records using the cursor from the first page - >>> next_cursor = first_page['next_cursor'] - >>> second_page = await crud.get_multi_by_cursor(db, cursor=next_cursor, limit=10, sort_column='created_at', sort_order='desc') + Fetch records excluding a specific username using cursor-based pagination: + ```python + get_multi_by_cursor(db, limit=10, sort_column='username', sort_order='asc', username__ne='admin') + ``` Note: This method is designed for efficient pagination in large datasets and is ideal for infinite scrolling features. @@ -689,13 +900,17 @@ async def get_multi_by_cursor( return {"data": [], "next_cursor": None} to_select = _extract_matching_columns_from_schema(self.model, schema_to_select) - stmt = select(*to_select).filter_by(**kwargs) + filters = self._parse_filters(**kwargs) + + stmt = select(*to_select) + if filters: + stmt = stmt.filter(*filters) if cursor: if sort_order == "asc": - stmt = stmt.where(getattr(self.model, sort_column) > cursor) + stmt = stmt.filter(getattr(self.model, sort_column) > cursor) else: - stmt = stmt.where(getattr(self.model, sort_column) < cursor) + stmt = stmt.filter(getattr(self.model, sort_column) < cursor) stmt = stmt.order_by( asc(getattr(self.model, sort_column)) @@ -709,7 +924,10 @@ async def get_multi_by_cursor( next_cursor = None if len(data) == limit: - next_cursor = data[-1][sort_column] + if sort_order == "asc": + next_cursor = data[-1][sort_column] + else: + data[0][sort_column] return {"data": data, "next_cursor": next_cursor} @@ -717,22 +935,53 @@ async def update( self, db: AsyncSession, object: Union[UpdateSchemaType, dict[str, Any]], + allow_multiple: bool = False, **kwargs: Any, ) -> None: """ - Update an existing record in the database. + Updates an existing record or multiple records in the database based on specified filters. This method allows for precise targeting of records to update. + It supports advanced filtering through comparison operators: + '__gt' (greater than), + '__lt' (less than), + '__gte' (greater than or equal to), + '__lte' (less than or equal to), and + '__ne' (not equal). Args: - db: The SQLAlchemy async session. - object: The Pydantic schema or dictionary containing the data to be updated. - **kwargs: Filters for the update. + db: The database session to use for the operation. + object: A Pydantic schema or dictionary containing the update data. + allow_multiple: If True, allows updating multiple records that match the filters. If False, raises an error if more than one record matches the filters. + **kwargs: Filters to identify the record(s) to update, supporting advanced comparison operators for refined querying. Returns: None Raises: + MultipleResultsFound: If `allow_multiple` is False and more than one record matches the filters. ValueError: If extra fields not present in the model are provided in the update data. + + Examples: + Update a user's email based on their ID: + ```python + update(db, {'email': 'new_email@example.com'}, id=1) + ``` + + Update users' statuses to 'inactive' where age is greater than 30 and allow updates to multiple records: + ```python + update(db, {'status': 'inactive'}, allow_multiple=True, age__gt=30) + ``` + + Update a user's username excluding specific user ID and prevent multiple updates: + ```python + update(db, {'username': 'new_username'}, id__ne=1, allow_multiple=False) + ``` """ + total_count = await self.count(db, **kwargs) + if not allow_multiple and total_count > 1: + raise MultipleResultsFound( + f"Expected exactly one record to update, found {total_count}." + ) + if isinstance(object, dict): update_data = object else: @@ -746,53 +995,141 @@ async def update( if extra_fields: raise ValueError(f"Extra fields provided: {extra_fields}") - stmt = update(self.model).filter_by(**kwargs).values(update_data) + filters = self._parse_filters(**kwargs) + stmt = update(self.model).filter(*filters).values(update_data) await db.execute(stmt) await db.commit() - async def db_delete(self, db: AsyncSession, **kwargs: Any) -> None: + async def db_delete( + self, db: AsyncSession, allow_multiple: bool = False, **kwargs: Any + ) -> None: """ - Delete a record in the database. + Deletes a record or multiple records from the database based on specified filters, with support for advanced filtering through comparison operators: + '__gt' (greater than), + '__lt' (less than), + '__gte' (greater than or equal to), + '__lte' (less than or equal to), and + '__ne' (not equal). Args: - db: The SQLAlchemy async session. - **kwargs: Filters for the delete. + db: The database session to use for the operation. + allow_multiple: If True, allows deleting multiple records that match the filters. If False, raises an error if more than one record matches the filters. + **kwargs: Filters to identify the record(s) to delete, including advanced comparison operators for detailed querying. Returns: None + + Raises: + MultipleResultsFound: If `allow_multiple` is False and more than one record matches the filters. + + Examples: + Delete a user based on their ID: + ```python + db_delete(db, id=1) + ``` + + Delete users older than 30 years and allow deletion of multiple records: + ```python + db_delete(db, allow_multiple=True, age__gt=30) + ``` + + Delete a user with a specific username, ensuring only one record is deleted: + ```python + db_delete(db, username='unique_username', allow_multiple=False) + ``` """ - stmt = delete(self.model).filter_by(**kwargs) + total_count = await self.count(db, **kwargs) + if not allow_multiple and total_count > 1: + raise MultipleResultsFound( + f"Expected exactly one record to delete, found {total_count}." + ) + + filters = self._parse_filters(**kwargs) + stmt = delete(self.model).filter(*filters) await db.execute(stmt) await db.commit() async def delete( - self, db: AsyncSession, db_row: Optional[Row] = None, **kwargs: Any + self, + db: AsyncSession, + db_row: Optional[Row] = None, + allow_multiple: bool = False, + **kwargs: Any, ) -> None: """ - Soft delete a record if it has "is_deleted" attribute, otherwise perform a hard delete. + Soft deletes a record or optionally multiple records if it has an "is_deleted" attribute, otherwise performs a hard delete, based on specified filters. + Supports advanced filtering through comparison operators: + '__gt' (greater than), + '__lt' (less than), + '__gte' (greater than or equal to), + '__lte' (less than or equal to), and + '__ne' (not equal). Args: - db: The SQLAlchemy async session. - db_row: Existing database row to delete. If None, it will be fetched based on `kwargs`. Default is None. - **kwargs: Filters for fetching the database row if not provided. + db: The database session to use for the operation. + db_row: Optional existing database row to delete. If provided, the method will attempt to delete this specific row, ignoring other filters. + allow_multiple: If True, allows deleting multiple records that match the filters. If False, raises an error if more than one record matches the filters. + **kwargs: Filters to identify the record(s) to delete, supporting advanced comparison operators for refined querying. + + Raises: + MultipleResultsFound: If `allow_multiple` is False and more than one record matches the filters. + NoResultFound: If no record matches the filters. Returns: None + + Examples: + Soft delete a specific user by ID: + ```python + delete(db, id=1) + ``` + + Hard delete users with account creation dates before 2020, allowing deletion of multiple records: + ```python + delete(db, allow_multiple=True, creation_date__lt=datetime(2020, 1, 1)) + ``` + + Soft delete a user with a specific email, ensuring only one record is deleted: + ```python + delete(db, email='unique@example.com', allow_multiple=False) + ``` """ - db_row = db_row or await self.exists(db=db, **kwargs) + filters = self._parse_filters(**kwargs) if db_row: - if "is_deleted" in self.model.__table__.columns: - object_dict = { - "is_deleted": True, - "deleted_at": datetime.now(timezone.utc), - } - stmt = update(self.model).filter_by(**kwargs).values(object_dict) - - await db.execute(stmt) - await db.commit() + if hasattr(db_row, self.is_deleted_column): + is_deleted_col = getattr(self.model, self.is_deleted_column) + deleted_at_col = getattr(self.model, self.deleted_at_column, None) + update_values = { + is_deleted_col: True, + deleted_at_col: datetime.now(timezone.utc), + } + update_stmt = ( + update(self.model).filter(*filters).values(**update_values) + ) + await db.execute(update_stmt) else: - stmt = delete(self.model).filter_by(**kwargs) - await db.execute(stmt) + await db.delete(db_row) await db.commit() + + total_count = await self.count(db, **kwargs) + if total_count == 0: + raise NoResultFound("No record found to delete.") + if not allow_multiple and total_count > 1: + raise MultipleResultsFound( + f"Expected exactly one record to delete, found {total_count}." + ) + + if self.is_deleted_column in self.model.__table__.columns: + update_stmt = ( + update(self.model) + .filter(*filters) + .values(is_deleted=True, deleted_at=datetime.now(timezone.utc)) + ) + await db.execute(update_stmt) + else: + delete_stmt = delete(self.model).filter(*filters) + await db.execute(delete_stmt) + + await db.commit() diff --git a/fastcrud/endpoint/crud_router.py b/fastcrud/endpoint/crud_router.py index 09ef194..949dca7 100644 --- a/fastcrud/endpoint/crud_router.py +++ b/fastcrud/endpoint/crud_router.py @@ -30,6 +30,8 @@ def crud_router( included_methods: Optional[list[str]] = None, deleted_methods: Optional[list[str]] = None, endpoint_creator: Optional[Type[EndpointCreator]] = None, + is_deleted_column: str = "is_deleted", + deleted_at_column: str = "deleted_at", ) -> APIRouter: """ Creates and configures a FastAPI router with CRUD endpoints for a given model. @@ -57,6 +59,8 @@ def crud_router( included_methods: Optional list of CRUD methods to include. If None, all methods are included. deleted_methods: Optional list of CRUD methods to exclude. endpoint_creator: Optional custom class derived from EndpointCreator for advanced customization. + is_deleted_column: Optional column name to use for indicating a soft delete. Defaults to "is_deleted". + deleted_at_column: Optional column name to use for storing the timestamp of a soft delete. Defaults to "deleted_at". Returns: Configured APIRouter instance with the CRUD endpoints. @@ -210,6 +214,8 @@ async def add_routes_to_router(self, ...): delete_schema=delete_schema, path=path, tags=tags, + is_deleted_column=is_deleted_column, + deleted_at_column=deleted_at_column, ) endpoint_creator_instance.add_routes_to_router( diff --git a/fastcrud/endpoint/endpoint_creator.py b/fastcrud/endpoint/endpoint_creator.py index 3dbe366..f2285ee 100644 --- a/fastcrud/endpoint/endpoint_creator.py +++ b/fastcrud/endpoint/endpoint_creator.py @@ -35,6 +35,8 @@ class EndpointCreator: include_in_schema: Whether to include the created endpoints in the OpenAPI schema. path: Base path for the CRUD endpoints. tags: List of tags for grouping endpoints in the documentation. + is_deleted_column: Optional column name to use for indicating a soft delete. Defaults to "is_deleted". + deleted_at_column: Optional column name to use for storing the timestamp of a soft delete. Defaults to "deleted_at". Raises: ValueError: If both `included_methods` and `deleted_methods` are provided. @@ -122,10 +124,16 @@ def __init__( delete_schema: Optional[Type[DeleteSchemaType]] = None, path: str = "", tags: Optional[list[str]] = None, + is_deleted_column: str = "is_deleted", + deleted_at_column: str = "deleted_at", ) -> None: self.primary_key_name = _get_primary_key(model) self.session = session - self.crud = crud + self.crud = crud or FastCRUD( + model=model, + is_deleted_column=is_deleted_column, + deleted_at_column=deleted_at_column, + ) self.model = model self.create_schema = create_schema self.update_schema = update_schema @@ -134,6 +142,8 @@ def __init__( self.path = path self.tags = tags or [] self.router = APIRouter() + self.is_deleted_column = is_deleted_column + self.deleted_at_column = deleted_at_column def _create_item(self): """Creates an endpoint for creating items in the database.""" diff --git a/mkdocs.yml b/mkdocs.yml index 6b1c945..7deb8cd 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -51,7 +51,9 @@ nav: - Automatic Endpoints: usage/endpoint.md - CRUD Utilities: usage/crud.md - Advanced: + - Overview: advanced/overview.md - Custom Endpoints: advanced/endpoint.md + - Advanced CRUD Usage: advanced/crud.md - API Reference: - Overview: api/overview.md - FastCRUD: api/fastcrud.md diff --git a/pyproject.toml b/pyproject.toml index b812e0d..38bf27d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "fastcrud" -version = "0.4.0" +version = "0.5.0" description = "FastCRUD is a Python package for FastAPI, offering robust async CRUD operations and flexible endpoint creation utilities." authors = ["Igor Benav "] license = "MIT" diff --git a/tests/sqlalchemy/conftest.py b/tests/sqlalchemy/conftest.py index fe9fe9b..820b9e1 100644 --- a/tests/sqlalchemy/conftest.py +++ b/tests/sqlalchemy/conftest.py @@ -40,6 +40,12 @@ class CreateSchemaTest(BaseModel): tier_id: int +class ReadSchemaTest(BaseModel): + id: int + name: str + tier_id: int + + class UpdateSchemaTest(BaseModel): name: str @@ -99,6 +105,7 @@ def test_data() -> list[dict]: {"id": 8, "name": "Hannah", "tier_id": 2}, {"id": 9, "name": "Ivan", "tier_id": 1}, {"id": 10, "name": "Judy", "tier_id": 2}, + {"id": 11, "name": "Alice", "tier_id": 1}, ] @@ -122,6 +129,11 @@ def create_schema(): return CreateSchemaTest +@pytest.fixture +def read_schema(): + return ReadSchemaTest + + @pytest.fixture def tier_schema(): return TierSchemaTest diff --git a/tests/sqlalchemy/crud/test_count.py b/tests/sqlalchemy/crud/test_count.py index d9321f1..051b6ce 100644 --- a/tests/sqlalchemy/crud/test_count.py +++ b/tests/sqlalchemy/crud/test_count.py @@ -34,3 +34,70 @@ async def test_count_no_matching_records(async_session, test_model): count = await crud.count(async_session, **non_existent_filter) assert count == 0 + + +@pytest.mark.asyncio +async def test_count_with_advanced_filters(async_session, test_model, test_data): + for item in test_data: + async_session.add(test_model(**item)) + await async_session.commit() + + crud = FastCRUD(test_model) + + count_gt = await crud.count(async_session, tier_id__gt=1) + assert count_gt == len([item for item in test_data if item["tier_id"] > 1]) + + count_lt = await crud.count(async_session, tier_id__lt=2) + assert count_lt == len([item for item in test_data if item["tier_id"] < 2]) + + count_ne = await crud.count(async_session, name__ne=test_data[0]["name"]) + assert count_ne == len(test_data) - 1 + + +@pytest.mark.asyncio +async def test_update_multiple_records_allow_multiple( + async_session, test_model, test_data +): + for item in test_data: + async_session.add(test_model(**item)) + await async_session.commit() + + crud = FastCRUD(test_model) + + await crud.update( + async_session, {"name": "UpdatedName"}, allow_multiple=True, tier_id=1 + ) + updated_count = await crud.count(async_session, name="UpdatedName") + expected_count = len([item for item in test_data if item["tier_id"] == 1]) + + assert updated_count == expected_count + + +@pytest.mark.asyncio +async def test_soft_delete_custom_columns(async_session, test_model, test_data): + crud = FastCRUD( + test_model, + is_deleted_column="custom_is_deleted", + deleted_at_column="custom_deleted_at", + ) + for item in test_data: + async_session.add(test_model(**item)) + await async_session.commit() + + existing_record = await crud.get(async_session, id=test_data[0]["id"]) + assert existing_record is not None, "Record should exist before deletion" + + await crud.delete(async_session, id=test_data[0]["id"], allow_multiple=False) + + deleted_record = await crud.get(async_session, id=test_data[0]["id"]) + + if deleted_record is None: + assert True, "Record is considered 'deleted' and is not fetched by default" + else: + assert ( + deleted_record.get("custom_is_deleted") is True + ), "Record should be marked as deleted" + assert ( + "custom_deleted_at" in deleted_record + and deleted_record["custom_deleted_at"] is not None + ), "Deletion timestamp should be set" diff --git a/tests/sqlalchemy/crud/test_delete.py b/tests/sqlalchemy/crud/test_delete.py index 74d639a..93e64b4 100644 --- a/tests/sqlalchemy/crud/test_delete.py +++ b/tests/sqlalchemy/crud/test_delete.py @@ -53,3 +53,55 @@ async def test_delete_hard_delete_as_fallback( select(tier_model).where(tier_model.id == some_existing_id) ) assert hard_deleted_record.scalar_one_or_none() is None + + +@pytest.mark.asyncio +async def test_delete_multiple_records(async_session, test_data, test_model): + for item in test_data: + async_session.add(test_model(**item)) + await async_session.commit() + + crud = FastCRUD(test_model) + with pytest.raises(Exception): + await crud.delete(db=async_session, allow_multiple=False, tier_id=1) + + +@pytest.mark.asyncio +async def test_get_with_advanced_filters(async_session, test_data, test_model): + for item in test_data: + async_session.add(test_model(**item)) + await async_session.commit() + + crud = FastCRUD(test_model) + records = await crud.get_multi(db=async_session, id__gt=5) + for record in records["data"]: + assert record["id"] > 5, "All fetched records should have 'id' greater than 5" + + +@pytest.mark.asyncio +async def test_soft_delete_with_custom_columns(async_session, test_data, test_model): + crud = FastCRUD( + test_model, is_deleted_column="is_deleted", deleted_at_column="deleted_at" + ) + some_existing_id = test_data[0]["id"] + + for item in test_data: + async_session.add(test_model(**item)) + await async_session.commit() + + await crud.delete(db=async_session, id=some_existing_id, allow_multiple=False) + + deleted_record = await async_session.execute( + select(test_model) + .where(test_model.id == some_existing_id) + .where(getattr(test_model, "is_deleted") == True) # noqa + ) + deleted_record = deleted_record.scalar_one_or_none() + + assert deleted_record is not None, "Record should exist after soft delete" + assert ( + getattr(deleted_record, "is_deleted") == True # noqa + ), "Record should be marked as soft deleted" + assert ( + getattr(deleted_record, "deleted_at") is not None + ), "Record should have a deletion timestamp" diff --git a/tests/sqlalchemy/crud/test_exists.py b/tests/sqlalchemy/crud/test_exists.py index 6d9007f..a792943 100644 --- a/tests/sqlalchemy/crud/test_exists.py +++ b/tests/sqlalchemy/crud/test_exists.py @@ -21,3 +21,32 @@ async def test_exists_record_not_found(async_session, test_model): exists = await crud.exists(async_session, **non_existent_filter) assert exists is False + + +@pytest.mark.asyncio +async def test_exists_with_advanced_filters(async_session, test_model, test_data): + for item in test_data: + async_session.add(test_model(**item)) + await async_session.commit() + + crud = FastCRUD(test_model) + exists_gt = await crud.exists(db=async_session, id__gt=1) + assert exists_gt is True, "Should find records with ID greater than 1" + + advanced_filter_lt = {"id__lt": max([d["id"] for d in test_data])} + exists_lt = await crud.exists(async_session, **advanced_filter_lt) + assert exists_lt is True, "Should find records with ID less than the max ID" + + +@pytest.mark.asyncio +async def test_exists_multiple_records_match(async_session, test_model, test_data): + for item in test_data: + async_session.add(test_model(**item)) + await async_session.commit() + + duplicate_tier_id = test_data[0]["tier_id"] + crud = FastCRUD(test_model) + exists = await crud.exists(async_session, tier_id=duplicate_tier_id) + assert ( + exists is True + ), "Should return True if multiple records match the filter criteria" diff --git a/tests/sqlalchemy/crud/test_get.py b/tests/sqlalchemy/crud/test_get.py index e9f807c..75b01c6 100644 --- a/tests/sqlalchemy/crud/test_get.py +++ b/tests/sqlalchemy/crud/test_get.py @@ -1,4 +1,6 @@ import pytest +from pydantic import BaseModel + from fastcrud.crud.fast_crud import FastCRUD from ...sqlalchemy.conftest import ModelTest from ...sqlalchemy.conftest import CreateSchemaTest @@ -52,3 +54,62 @@ async def test_get_selecting_columns(async_session, test_data): assert fetched_record is not None assert "name" in fetched_record + + +@pytest.mark.asyncio +async def test_get_with_advanced_filters(async_session, test_data): + for item in test_data: + async_session.add(ModelTest(**item)) + await async_session.commit() + + crud = FastCRUD(ModelTest) + advanced_filter = {"id__gt": 1} + fetched_record_gt = await crud.get(async_session, **advanced_filter) + + assert fetched_record_gt is not None + assert fetched_record_gt["id"] > 1, "Should fetch a record with ID greater than 1" + + ne_filter = {"name__ne": test_data[0]["name"]} + fetched_record_ne = await crud.get(async_session, **ne_filter) + + assert fetched_record_ne is not None + assert ( + fetched_record_ne["name"] != test_data[0]["name"] + ), "Should fetch a record with a different name" + + +@pytest.mark.asyncio +async def test_get_with_schema_selecting_specific_columns(async_session, test_data): + async_session.add(ModelTest(**test_data[0])) + await async_session.commit() + + class PartialSchema(BaseModel): + name: str + + crud = FastCRUD(ModelTest) + fetched_record = await crud.get( + async_session, schema_to_select=PartialSchema, id=test_data[0]["id"] + ) + + assert fetched_record is not None + assert ( + "name" in fetched_record and "tier_id" not in fetched_record + ), "Should only fetch the 'name' column based on the PartialSchema" + + +@pytest.mark.asyncio +async def test_get_return_as_model_instance(async_session, test_data, read_schema): + async_session.add(ModelTest(**test_data[0])) + await async_session.commit() + + crud = FastCRUD(ModelTest) + fetched_record = await crud.get( + async_session, + return_as_model=True, + schema_to_select=read_schema, + id=test_data[0]["id"], + ) + + assert isinstance( + fetched_record, read_schema + ), "The fetched record should be an instance of the ReadSchemaTest Pydantic model" diff --git a/tests/sqlalchemy/crud/test_get_joined.py b/tests/sqlalchemy/crud/test_get_joined.py index acc825e..b0d7380 100644 --- a/tests/sqlalchemy/crud/test_get_joined.py +++ b/tests/sqlalchemy/crud/test_get_joined.py @@ -1,7 +1,12 @@ import pytest from sqlalchemy import and_ from fastcrud.crud.fast_crud import FastCRUD -from ...sqlalchemy.conftest import ModelTest, TierModel, CreateSchemaTest, TierSchemaTest +from ...sqlalchemy.conftest import ( + ModelTest, + TierModel, + CreateSchemaTest, + TierSchemaTest, +) @pytest.mark.asyncio @@ -133,3 +138,39 @@ async def test_get_joined_with_filters(async_session, test_data, test_data_tier) assert result is not None assert result["name"] == "Alice" + + +@pytest.mark.asyncio +async def test_update_multiple_records_allow_multiple( + async_session, test_model, test_data +): + for item in test_data: + async_session.add(test_model(**item)) + await async_session.commit() + + crud = FastCRUD(test_model) + await crud.update( + db=async_session, + object={"name": "Updated Name"}, + allow_multiple=True, + name="Alice", + ) + + updated_records = await crud.get_multi(db=async_session, name="Updated Name") + assert ( + len(updated_records["data"]) > 1 + ), "Should update multiple records when allow_multiple is True" + + +@pytest.mark.asyncio +async def test_count_with_advanced_filters(async_session, test_model, test_data): + for item in test_data: + async_session.add(test_model(**item)) + await async_session.commit() + + crud = FastCRUD(test_model) + count_gt = await crud.count(async_session, id__gt=1) + assert count_gt > 0, "Should count records with ID greater than 1" + + count_lt = await crud.count(async_session, id__lt=10) + assert count_lt > 0, "Should count records with ID less than 10" diff --git a/tests/sqlalchemy/crud/test_get_multi.py b/tests/sqlalchemy/crud/test_get_multi.py index d42aebf..352fb61 100644 --- a/tests/sqlalchemy/crud/test_get_multi.py +++ b/tests/sqlalchemy/crud/test_get_multi.py @@ -112,3 +112,73 @@ async def test_get_multi_return_model( ) assert all(isinstance(item, create_schema) for item in result["data"]) + + +@pytest.mark.asyncio +async def test_get_multi_advanced_filtering(async_session, test_model, test_data): + for item in test_data: + async_session.add(test_model(**item)) + await async_session.commit() + + crud = FastCRUD(test_model) + filtered_results = await crud.get_multi(async_session, id__gt=5) + + assert all( + item["id"] > 5 for item in filtered_results["data"] + ), "Should only include records with ID greater than 5" + + +@pytest.mark.asyncio +async def test_get_multi_multiple_sorting(async_session, test_model, test_data): + for item in test_data: + async_session.add(test_model(**item)) + await async_session.commit() + + crud = FastCRUD(test_model) + result = await crud.get_multi( + async_session, sort_columns=["tier_id", "name"], sort_orders=["asc", "desc"] + ) + + assert len(result["data"]) > 0, "Should fetch sorted records" + + tier_ids = [item["tier_id"] for item in result["data"]] + assert tier_ids == sorted(tier_ids), "tier_id should be sorted in ascending order" + + current_tier_id = None + names_in_current_tier = [] + for item in result["data"]: + if item["tier_id"] != current_tier_id: + if names_in_current_tier: + assert ( + names_in_current_tier == sorted(names_in_current_tier, reverse=True) + ), f"Names within tier_id {current_tier_id} should be sorted in descending order" + current_tier_id = item["tier_id"] + names_in_current_tier = [item["name"]] + else: + names_in_current_tier.append(item["name"]) + + if names_in_current_tier: + assert ( + names_in_current_tier == sorted(names_in_current_tier, reverse=True) + ), f"Names within tier_id {current_tier_id} should be sorted in descending order" + + +@pytest.mark.asyncio +async def test_get_multi_advanced_filtering_return_model( + async_session, test_model, test_data, read_schema +): + for item in test_data: + async_session.add(test_model(**item)) + await async_session.commit() + + crud = FastCRUD(test_model) + result = await crud.get_multi( + async_session, id__lte=5, return_as_model=True, schema_to_select=read_schema + ) + + assert all( + isinstance(item, read_schema) for item in result["data"] + ), "All items should be instances of the schema" + assert all( + item.id <= 5 for item in result["data"] + ), "Should only include records with ID less than or equal to 5" diff --git a/tests/sqlalchemy/crud/test_get_multi_by_cursor.py b/tests/sqlalchemy/crud/test_get_multi_by_cursor.py index 42cb9f3..daa8e14 100644 --- a/tests/sqlalchemy/crud/test_get_multi_by_cursor.py +++ b/tests/sqlalchemy/crud/test_get_multi_by_cursor.py @@ -86,3 +86,55 @@ async def test_get_multi_by_cursor_edge_cases(async_session, test_data): zero_limit_result = await crud.get_multi_by_cursor(db=async_session, limit=0) assert len(zero_limit_result["data"]) == 0 assert zero_limit_result["next_cursor"] is None + + +@pytest.mark.asyncio +async def test_get_multi_by_cursor_with_advanced_filters(async_session, test_data): + for item in test_data: + async_session.add(ModelTest(**item)) + await async_session.commit() + + crud = FastCRUD(ModelTest) + advanced_filter_gt = await crud.get_multi_by_cursor( + db=async_session, limit=5, id__gt=5 + ) + + assert len(advanced_filter_gt["data"]) <= 5 + assert all( + item["id"] > 5 for item in advanced_filter_gt["data"] + ), "All fetched records should have ID greater than 5" + + advanced_filter_lt = await crud.get_multi_by_cursor( + db=async_session, limit=5, id__lt=5 + ) + assert ( + len(advanced_filter_lt["data"]) <= 5 + ), "Should correctly paginate records with ID less than 5" + + +@pytest.mark.asyncio +async def test_get_multi_by_cursor_pagination_integrity(async_session, test_data): + for item in test_data: + async_session.add(ModelTest(**item)) + await async_session.commit() + + crud = FastCRUD(ModelTest) + first_batch = await crud.get_multi_by_cursor(db=async_session, limit=5) + + await crud.update( + db=async_session, + object={"name": "Updated Name"}, + allow_multiple=True, + name="SpecificName", + ) + + second_batch = await crud.get_multi_by_cursor( + db=async_session, cursor=first_batch["next_cursor"], limit=5 + ) + + assert ( + len(second_batch["data"]) == 5 + ), "Pagination should fetch the correct number of records despite updates" + assert ( + first_batch["data"][-1]["id"] < second_batch["data"][0]["id"] + ), "Pagination should maintain order across batches" diff --git a/tests/sqlalchemy/crud/test_get_multi_joined.py b/tests/sqlalchemy/crud/test_get_multi_joined.py index 020be8c..3e15aa5 100644 --- a/tests/sqlalchemy/crud/test_get_multi_joined.py +++ b/tests/sqlalchemy/crud/test_get_multi_joined.py @@ -1,6 +1,12 @@ import pytest from fastcrud.crud.fast_crud import FastCRUD -from ...sqlalchemy.conftest import ModelTest, TierModel, CreateSchemaTest, TierSchemaTest +from ...sqlalchemy.conftest import ( + ModelTest, + TierModel, + CreateSchemaTest, + TierSchemaTest, + ReadSchemaTest, +) @pytest.mark.asyncio @@ -141,6 +147,14 @@ async def test_get_multi_joined_return_model(async_session, test_data, test_data @pytest.mark.asyncio async def test_get_multi_joined_no_results(async_session, test_data, test_data_tier): + for tier_item in test_data_tier: + async_session.add(TierModel(**tier_item)) + await async_session.commit() + + for user_item in test_data: + async_session.add(ModelTest(**user_item)) + await async_session.commit() + crud = FastCRUD(ModelTest) result = await crud.get_multi_joined( db=async_session, @@ -158,6 +172,14 @@ async def test_get_multi_joined_no_results(async_session, test_data, test_data_t @pytest.mark.asyncio async def test_get_multi_joined_large_offset(async_session, test_data, test_data_tier): + for tier_item in test_data_tier: + async_session.add(TierModel(**tier_item)) + await async_session.commit() + + for user_item in test_data: + async_session.add(ModelTest(**user_item)) + await async_session.commit() + crud = FastCRUD(ModelTest) result = await crud.get_multi_joined( db=async_session, @@ -175,6 +197,14 @@ async def test_get_multi_joined_large_offset(async_session, test_data, test_data async def test_get_multi_joined_invalid_limit_offset( async_session, test_data, test_data_tier ): + for tier_item in test_data_tier: + async_session.add(TierModel(**tier_item)) + await async_session.commit() + + for user_item in test_data: + async_session.add(ModelTest(**user_item)) + await async_session.commit() + crud = FastCRUD(ModelTest) with pytest.raises(ValueError): await crud.get_multi_joined( @@ -194,3 +224,35 @@ async def test_get_multi_joined_invalid_limit_offset( offset=0, limit=-1, ) + + +@pytest.mark.asyncio +async def test_get_multi_joined_advanced_filtering( + async_session, test_data, test_data_tier +): + for tier_item in test_data_tier: + async_session.add(TierModel(**tier_item)) + await async_session.commit() + + for user_item in test_data: + async_session.add(ModelTest(**user_item)) + await async_session.commit() + + crud = FastCRUD(ModelTest) + advanced_filter_result = await crud.get_multi_joined( + db=async_session, + join_model=TierModel, + schema_to_select=ReadSchemaTest, + join_schema_to_select=TierSchemaTest, + join_prefix="tier_", + offset=0, + limit=10, + id__gt=5, + ) + + assert ( + len(advanced_filter_result["data"]) > 0 + ), "Should fetch records with ID greater than 5" + assert all( + item["id"] > 5 for item in advanced_filter_result["data"] + ), "All fetched records should meet the advanced filter condition" diff --git a/tests/sqlalchemy/crud/test_update.py b/tests/sqlalchemy/crud/test_update.py index 0760686..ad5bb41 100644 --- a/tests/sqlalchemy/crud/test_update.py +++ b/tests/sqlalchemy/crud/test_update.py @@ -1,6 +1,7 @@ import pytest from sqlalchemy import select +from sqlalchemy.exc import MultipleResultsFound from fastcrud.crud.fast_crud import FastCRUD from ...sqlalchemy.conftest import ModelTest @@ -91,3 +92,61 @@ async def test_update_additional_fields(async_session, test_data): await crud.update(db=async_session, object=updated_data, id=some_existing_id) assert "Extra fields provided" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_update_with_advanced_filters(async_session, test_data): + for item in test_data: + async_session.add(ModelTest(**item)) + await async_session.commit() + + advanced_filter = {"id__gt": 5} + updated_data = {"name": "Updated for Advanced Filter"} + + crud = FastCRUD(ModelTest) + await crud.update( + db=async_session, object=updated_data, allow_multiple=True, **advanced_filter + ) + + updated_records = await async_session.execute( + select(ModelTest).where(ModelTest.id > 5) + ) + assert all( + record.name == "Updated for Advanced Filter" + for record in updated_records.scalars() + ) + + +@pytest.mark.asyncio +async def test_update_multiple_records(async_session, test_data): + for item in test_data: + async_session.add(ModelTest(**item)) + await async_session.commit() + + crud = FastCRUD(ModelTest) + updated_data = {"name": "Updated Multiple"} + await crud.update( + db=async_session, object=updated_data, allow_multiple=True, tier_id=2 + ) + + updated_records = await async_session.execute( + select(ModelTest).where(ModelTest.tier_id == 2) + ) + assert all( + record.name == "Updated Multiple" for record in updated_records.scalars() + ) + + +@pytest.mark.asyncio +async def test_update_multiple_records_restriction(async_session, test_data): + for item in test_data: + async_session.add(ModelTest(**item)) + await async_session.commit() + + crud = FastCRUD(ModelTest) + updated_data = {"name": "Should Fail"} + + with pytest.raises(MultipleResultsFound) as exc_info: + await crud.update(db=async_session, object=updated_data, id__lt=10) + + assert "Expected exactly one record to update" in str(exc_info.value) diff --git a/tests/sqlmodel/conftest.py b/tests/sqlmodel/conftest.py index c97b4ce..89c4538 100644 --- a/tests/sqlmodel/conftest.py +++ b/tests/sqlmodel/conftest.py @@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker -from pydantic import BaseModel, ConfigDict +from pydantic import ConfigDict from sqlmodel import SQLModel, Field, Relationship from fastapi import FastAPI from fastapi.testclient import TestClient @@ -38,6 +38,12 @@ class CreateSchemaTest(SQLModel): tier_id: int +class ReadSchemaTest(SQLModel): + id: int + name: str + tier_id: int + + class UpdateSchemaTest(SQLModel): name: str @@ -97,6 +103,7 @@ def test_data() -> list[dict]: {"id": 8, "name": "Hannah", "tier_id": 2}, {"id": 9, "name": "Ivan", "tier_id": 1}, {"id": 10, "name": "Judy", "tier_id": 2}, + {"id": 11, "name": "Alice", "tier_id": 1}, ] @@ -120,6 +127,11 @@ def create_schema(): return CreateSchemaTest +@pytest.fixture +def read_schema(): + return ReadSchemaTest + + @pytest.fixture def tier_schema(): return TierSchemaTest diff --git a/tests/sqlmodel/crud/test_count.py b/tests/sqlmodel/crud/test_count.py index d9321f1..051b6ce 100644 --- a/tests/sqlmodel/crud/test_count.py +++ b/tests/sqlmodel/crud/test_count.py @@ -34,3 +34,70 @@ async def test_count_no_matching_records(async_session, test_model): count = await crud.count(async_session, **non_existent_filter) assert count == 0 + + +@pytest.mark.asyncio +async def test_count_with_advanced_filters(async_session, test_model, test_data): + for item in test_data: + async_session.add(test_model(**item)) + await async_session.commit() + + crud = FastCRUD(test_model) + + count_gt = await crud.count(async_session, tier_id__gt=1) + assert count_gt == len([item for item in test_data if item["tier_id"] > 1]) + + count_lt = await crud.count(async_session, tier_id__lt=2) + assert count_lt == len([item for item in test_data if item["tier_id"] < 2]) + + count_ne = await crud.count(async_session, name__ne=test_data[0]["name"]) + assert count_ne == len(test_data) - 1 + + +@pytest.mark.asyncio +async def test_update_multiple_records_allow_multiple( + async_session, test_model, test_data +): + for item in test_data: + async_session.add(test_model(**item)) + await async_session.commit() + + crud = FastCRUD(test_model) + + await crud.update( + async_session, {"name": "UpdatedName"}, allow_multiple=True, tier_id=1 + ) + updated_count = await crud.count(async_session, name="UpdatedName") + expected_count = len([item for item in test_data if item["tier_id"] == 1]) + + assert updated_count == expected_count + + +@pytest.mark.asyncio +async def test_soft_delete_custom_columns(async_session, test_model, test_data): + crud = FastCRUD( + test_model, + is_deleted_column="custom_is_deleted", + deleted_at_column="custom_deleted_at", + ) + for item in test_data: + async_session.add(test_model(**item)) + await async_session.commit() + + existing_record = await crud.get(async_session, id=test_data[0]["id"]) + assert existing_record is not None, "Record should exist before deletion" + + await crud.delete(async_session, id=test_data[0]["id"], allow_multiple=False) + + deleted_record = await crud.get(async_session, id=test_data[0]["id"]) + + if deleted_record is None: + assert True, "Record is considered 'deleted' and is not fetched by default" + else: + assert ( + deleted_record.get("custom_is_deleted") is True + ), "Record should be marked as deleted" + assert ( + "custom_deleted_at" in deleted_record + and deleted_record["custom_deleted_at"] is not None + ), "Deletion timestamp should be set" diff --git a/tests/sqlmodel/crud/test_delete.py b/tests/sqlmodel/crud/test_delete.py index 74d639a..93e64b4 100644 --- a/tests/sqlmodel/crud/test_delete.py +++ b/tests/sqlmodel/crud/test_delete.py @@ -53,3 +53,55 @@ async def test_delete_hard_delete_as_fallback( select(tier_model).where(tier_model.id == some_existing_id) ) assert hard_deleted_record.scalar_one_or_none() is None + + +@pytest.mark.asyncio +async def test_delete_multiple_records(async_session, test_data, test_model): + for item in test_data: + async_session.add(test_model(**item)) + await async_session.commit() + + crud = FastCRUD(test_model) + with pytest.raises(Exception): + await crud.delete(db=async_session, allow_multiple=False, tier_id=1) + + +@pytest.mark.asyncio +async def test_get_with_advanced_filters(async_session, test_data, test_model): + for item in test_data: + async_session.add(test_model(**item)) + await async_session.commit() + + crud = FastCRUD(test_model) + records = await crud.get_multi(db=async_session, id__gt=5) + for record in records["data"]: + assert record["id"] > 5, "All fetched records should have 'id' greater than 5" + + +@pytest.mark.asyncio +async def test_soft_delete_with_custom_columns(async_session, test_data, test_model): + crud = FastCRUD( + test_model, is_deleted_column="is_deleted", deleted_at_column="deleted_at" + ) + some_existing_id = test_data[0]["id"] + + for item in test_data: + async_session.add(test_model(**item)) + await async_session.commit() + + await crud.delete(db=async_session, id=some_existing_id, allow_multiple=False) + + deleted_record = await async_session.execute( + select(test_model) + .where(test_model.id == some_existing_id) + .where(getattr(test_model, "is_deleted") == True) # noqa + ) + deleted_record = deleted_record.scalar_one_or_none() + + assert deleted_record is not None, "Record should exist after soft delete" + assert ( + getattr(deleted_record, "is_deleted") == True # noqa + ), "Record should be marked as soft deleted" + assert ( + getattr(deleted_record, "deleted_at") is not None + ), "Record should have a deletion timestamp" diff --git a/tests/sqlmodel/crud/test_exists.py b/tests/sqlmodel/crud/test_exists.py index 6d9007f..a792943 100644 --- a/tests/sqlmodel/crud/test_exists.py +++ b/tests/sqlmodel/crud/test_exists.py @@ -21,3 +21,32 @@ async def test_exists_record_not_found(async_session, test_model): exists = await crud.exists(async_session, **non_existent_filter) assert exists is False + + +@pytest.mark.asyncio +async def test_exists_with_advanced_filters(async_session, test_model, test_data): + for item in test_data: + async_session.add(test_model(**item)) + await async_session.commit() + + crud = FastCRUD(test_model) + exists_gt = await crud.exists(db=async_session, id__gt=1) + assert exists_gt is True, "Should find records with ID greater than 1" + + advanced_filter_lt = {"id__lt": max([d["id"] for d in test_data])} + exists_lt = await crud.exists(async_session, **advanced_filter_lt) + assert exists_lt is True, "Should find records with ID less than the max ID" + + +@pytest.mark.asyncio +async def test_exists_multiple_records_match(async_session, test_model, test_data): + for item in test_data: + async_session.add(test_model(**item)) + await async_session.commit() + + duplicate_tier_id = test_data[0]["tier_id"] + crud = FastCRUD(test_model) + exists = await crud.exists(async_session, tier_id=duplicate_tier_id) + assert ( + exists is True + ), "Should return True if multiple records match the filter criteria" diff --git a/tests/sqlmodel/crud/test_get.py b/tests/sqlmodel/crud/test_get.py index 603752a..75b01c6 100644 --- a/tests/sqlmodel/crud/test_get.py +++ b/tests/sqlmodel/crud/test_get.py @@ -1,7 +1,9 @@ import pytest +from pydantic import BaseModel + from fastcrud.crud.fast_crud import FastCRUD -from ..conftest import ModelTest -from ..conftest import CreateSchemaTest +from ...sqlalchemy.conftest import ModelTest +from ...sqlalchemy.conftest import CreateSchemaTest @pytest.mark.asyncio @@ -52,3 +54,62 @@ async def test_get_selecting_columns(async_session, test_data): assert fetched_record is not None assert "name" in fetched_record + + +@pytest.mark.asyncio +async def test_get_with_advanced_filters(async_session, test_data): + for item in test_data: + async_session.add(ModelTest(**item)) + await async_session.commit() + + crud = FastCRUD(ModelTest) + advanced_filter = {"id__gt": 1} + fetched_record_gt = await crud.get(async_session, **advanced_filter) + + assert fetched_record_gt is not None + assert fetched_record_gt["id"] > 1, "Should fetch a record with ID greater than 1" + + ne_filter = {"name__ne": test_data[0]["name"]} + fetched_record_ne = await crud.get(async_session, **ne_filter) + + assert fetched_record_ne is not None + assert ( + fetched_record_ne["name"] != test_data[0]["name"] + ), "Should fetch a record with a different name" + + +@pytest.mark.asyncio +async def test_get_with_schema_selecting_specific_columns(async_session, test_data): + async_session.add(ModelTest(**test_data[0])) + await async_session.commit() + + class PartialSchema(BaseModel): + name: str + + crud = FastCRUD(ModelTest) + fetched_record = await crud.get( + async_session, schema_to_select=PartialSchema, id=test_data[0]["id"] + ) + + assert fetched_record is not None + assert ( + "name" in fetched_record and "tier_id" not in fetched_record + ), "Should only fetch the 'name' column based on the PartialSchema" + + +@pytest.mark.asyncio +async def test_get_return_as_model_instance(async_session, test_data, read_schema): + async_session.add(ModelTest(**test_data[0])) + await async_session.commit() + + crud = FastCRUD(ModelTest) + fetched_record = await crud.get( + async_session, + return_as_model=True, + schema_to_select=read_schema, + id=test_data[0]["id"], + ) + + assert isinstance( + fetched_record, read_schema + ), "The fetched record should be an instance of the ReadSchemaTest Pydantic model" diff --git a/tests/sqlmodel/crud/test_get_joined.py b/tests/sqlmodel/crud/test_get_joined.py index bdd2eb4..b0d7380 100644 --- a/tests/sqlmodel/crud/test_get_joined.py +++ b/tests/sqlmodel/crud/test_get_joined.py @@ -1,7 +1,12 @@ import pytest from sqlalchemy import and_ from fastcrud.crud.fast_crud import FastCRUD -from ..conftest import ModelTest, TierModel, CreateSchemaTest, TierSchemaTest +from ...sqlalchemy.conftest import ( + ModelTest, + TierModel, + CreateSchemaTest, + TierSchemaTest, +) @pytest.mark.asyncio @@ -133,3 +138,39 @@ async def test_get_joined_with_filters(async_session, test_data, test_data_tier) assert result is not None assert result["name"] == "Alice" + + +@pytest.mark.asyncio +async def test_update_multiple_records_allow_multiple( + async_session, test_model, test_data +): + for item in test_data: + async_session.add(test_model(**item)) + await async_session.commit() + + crud = FastCRUD(test_model) + await crud.update( + db=async_session, + object={"name": "Updated Name"}, + allow_multiple=True, + name="Alice", + ) + + updated_records = await crud.get_multi(db=async_session, name="Updated Name") + assert ( + len(updated_records["data"]) > 1 + ), "Should update multiple records when allow_multiple is True" + + +@pytest.mark.asyncio +async def test_count_with_advanced_filters(async_session, test_model, test_data): + for item in test_data: + async_session.add(test_model(**item)) + await async_session.commit() + + crud = FastCRUD(test_model) + count_gt = await crud.count(async_session, id__gt=1) + assert count_gt > 0, "Should count records with ID greater than 1" + + count_lt = await crud.count(async_session, id__lt=10) + assert count_lt > 0, "Should count records with ID less than 10" diff --git a/tests/sqlmodel/crud/test_get_multi.py b/tests/sqlmodel/crud/test_get_multi.py index d42aebf..352fb61 100644 --- a/tests/sqlmodel/crud/test_get_multi.py +++ b/tests/sqlmodel/crud/test_get_multi.py @@ -112,3 +112,73 @@ async def test_get_multi_return_model( ) assert all(isinstance(item, create_schema) for item in result["data"]) + + +@pytest.mark.asyncio +async def test_get_multi_advanced_filtering(async_session, test_model, test_data): + for item in test_data: + async_session.add(test_model(**item)) + await async_session.commit() + + crud = FastCRUD(test_model) + filtered_results = await crud.get_multi(async_session, id__gt=5) + + assert all( + item["id"] > 5 for item in filtered_results["data"] + ), "Should only include records with ID greater than 5" + + +@pytest.mark.asyncio +async def test_get_multi_multiple_sorting(async_session, test_model, test_data): + for item in test_data: + async_session.add(test_model(**item)) + await async_session.commit() + + crud = FastCRUD(test_model) + result = await crud.get_multi( + async_session, sort_columns=["tier_id", "name"], sort_orders=["asc", "desc"] + ) + + assert len(result["data"]) > 0, "Should fetch sorted records" + + tier_ids = [item["tier_id"] for item in result["data"]] + assert tier_ids == sorted(tier_ids), "tier_id should be sorted in ascending order" + + current_tier_id = None + names_in_current_tier = [] + for item in result["data"]: + if item["tier_id"] != current_tier_id: + if names_in_current_tier: + assert ( + names_in_current_tier == sorted(names_in_current_tier, reverse=True) + ), f"Names within tier_id {current_tier_id} should be sorted in descending order" + current_tier_id = item["tier_id"] + names_in_current_tier = [item["name"]] + else: + names_in_current_tier.append(item["name"]) + + if names_in_current_tier: + assert ( + names_in_current_tier == sorted(names_in_current_tier, reverse=True) + ), f"Names within tier_id {current_tier_id} should be sorted in descending order" + + +@pytest.mark.asyncio +async def test_get_multi_advanced_filtering_return_model( + async_session, test_model, test_data, read_schema +): + for item in test_data: + async_session.add(test_model(**item)) + await async_session.commit() + + crud = FastCRUD(test_model) + result = await crud.get_multi( + async_session, id__lte=5, return_as_model=True, schema_to_select=read_schema + ) + + assert all( + isinstance(item, read_schema) for item in result["data"] + ), "All items should be instances of the schema" + assert all( + item.id <= 5 for item in result["data"] + ), "Should only include records with ID less than or equal to 5" diff --git a/tests/sqlmodel/crud/test_get_multi_by_cursor.py b/tests/sqlmodel/crud/test_get_multi_by_cursor.py index 255fb8f..daa8e14 100644 --- a/tests/sqlmodel/crud/test_get_multi_by_cursor.py +++ b/tests/sqlmodel/crud/test_get_multi_by_cursor.py @@ -1,6 +1,6 @@ import pytest from fastcrud.crud.fast_crud import FastCRUD -from ..conftest import ModelTest +from ...sqlalchemy.conftest import ModelTest @pytest.mark.asyncio @@ -86,3 +86,55 @@ async def test_get_multi_by_cursor_edge_cases(async_session, test_data): zero_limit_result = await crud.get_multi_by_cursor(db=async_session, limit=0) assert len(zero_limit_result["data"]) == 0 assert zero_limit_result["next_cursor"] is None + + +@pytest.mark.asyncio +async def test_get_multi_by_cursor_with_advanced_filters(async_session, test_data): + for item in test_data: + async_session.add(ModelTest(**item)) + await async_session.commit() + + crud = FastCRUD(ModelTest) + advanced_filter_gt = await crud.get_multi_by_cursor( + db=async_session, limit=5, id__gt=5 + ) + + assert len(advanced_filter_gt["data"]) <= 5 + assert all( + item["id"] > 5 for item in advanced_filter_gt["data"] + ), "All fetched records should have ID greater than 5" + + advanced_filter_lt = await crud.get_multi_by_cursor( + db=async_session, limit=5, id__lt=5 + ) + assert ( + len(advanced_filter_lt["data"]) <= 5 + ), "Should correctly paginate records with ID less than 5" + + +@pytest.mark.asyncio +async def test_get_multi_by_cursor_pagination_integrity(async_session, test_data): + for item in test_data: + async_session.add(ModelTest(**item)) + await async_session.commit() + + crud = FastCRUD(ModelTest) + first_batch = await crud.get_multi_by_cursor(db=async_session, limit=5) + + await crud.update( + db=async_session, + object={"name": "Updated Name"}, + allow_multiple=True, + name="SpecificName", + ) + + second_batch = await crud.get_multi_by_cursor( + db=async_session, cursor=first_batch["next_cursor"], limit=5 + ) + + assert ( + len(second_batch["data"]) == 5 + ), "Pagination should fetch the correct number of records despite updates" + assert ( + first_batch["data"][-1]["id"] < second_batch["data"][0]["id"] + ), "Pagination should maintain order across batches" diff --git a/tests/sqlmodel/crud/test_get_multi_joined.py b/tests/sqlmodel/crud/test_get_multi_joined.py index 9777c47..3e15aa5 100644 --- a/tests/sqlmodel/crud/test_get_multi_joined.py +++ b/tests/sqlmodel/crud/test_get_multi_joined.py @@ -1,6 +1,12 @@ import pytest from fastcrud.crud.fast_crud import FastCRUD -from ..conftest import ModelTest, TierModel, CreateSchemaTest, TierSchemaTest +from ...sqlalchemy.conftest import ( + ModelTest, + TierModel, + CreateSchemaTest, + TierSchemaTest, + ReadSchemaTest, +) @pytest.mark.asyncio @@ -141,6 +147,14 @@ async def test_get_multi_joined_return_model(async_session, test_data, test_data @pytest.mark.asyncio async def test_get_multi_joined_no_results(async_session, test_data, test_data_tier): + for tier_item in test_data_tier: + async_session.add(TierModel(**tier_item)) + await async_session.commit() + + for user_item in test_data: + async_session.add(ModelTest(**user_item)) + await async_session.commit() + crud = FastCRUD(ModelTest) result = await crud.get_multi_joined( db=async_session, @@ -158,6 +172,14 @@ async def test_get_multi_joined_no_results(async_session, test_data, test_data_t @pytest.mark.asyncio async def test_get_multi_joined_large_offset(async_session, test_data, test_data_tier): + for tier_item in test_data_tier: + async_session.add(TierModel(**tier_item)) + await async_session.commit() + + for user_item in test_data: + async_session.add(ModelTest(**user_item)) + await async_session.commit() + crud = FastCRUD(ModelTest) result = await crud.get_multi_joined( db=async_session, @@ -175,6 +197,14 @@ async def test_get_multi_joined_large_offset(async_session, test_data, test_data async def test_get_multi_joined_invalid_limit_offset( async_session, test_data, test_data_tier ): + for tier_item in test_data_tier: + async_session.add(TierModel(**tier_item)) + await async_session.commit() + + for user_item in test_data: + async_session.add(ModelTest(**user_item)) + await async_session.commit() + crud = FastCRUD(ModelTest) with pytest.raises(ValueError): await crud.get_multi_joined( @@ -194,3 +224,35 @@ async def test_get_multi_joined_invalid_limit_offset( offset=0, limit=-1, ) + + +@pytest.mark.asyncio +async def test_get_multi_joined_advanced_filtering( + async_session, test_data, test_data_tier +): + for tier_item in test_data_tier: + async_session.add(TierModel(**tier_item)) + await async_session.commit() + + for user_item in test_data: + async_session.add(ModelTest(**user_item)) + await async_session.commit() + + crud = FastCRUD(ModelTest) + advanced_filter_result = await crud.get_multi_joined( + db=async_session, + join_model=TierModel, + schema_to_select=ReadSchemaTest, + join_schema_to_select=TierSchemaTest, + join_prefix="tier_", + offset=0, + limit=10, + id__gt=5, + ) + + assert ( + len(advanced_filter_result["data"]) > 0 + ), "Should fetch records with ID greater than 5" + assert all( + item["id"] > 5 for item in advanced_filter_result["data"] + ), "All fetched records should meet the advanced filter condition" diff --git a/tests/sqlmodel/crud/test_update.py b/tests/sqlmodel/crud/test_update.py index da66281..ad5bb41 100644 --- a/tests/sqlmodel/crud/test_update.py +++ b/tests/sqlmodel/crud/test_update.py @@ -1,9 +1,10 @@ import pytest from sqlalchemy import select +from sqlalchemy.exc import MultipleResultsFound from fastcrud.crud.fast_crud import FastCRUD -from ..conftest import ModelTest +from ...sqlalchemy.conftest import ModelTest @pytest.mark.asyncio @@ -91,3 +92,61 @@ async def test_update_additional_fields(async_session, test_data): await crud.update(db=async_session, object=updated_data, id=some_existing_id) assert "Extra fields provided" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_update_with_advanced_filters(async_session, test_data): + for item in test_data: + async_session.add(ModelTest(**item)) + await async_session.commit() + + advanced_filter = {"id__gt": 5} + updated_data = {"name": "Updated for Advanced Filter"} + + crud = FastCRUD(ModelTest) + await crud.update( + db=async_session, object=updated_data, allow_multiple=True, **advanced_filter + ) + + updated_records = await async_session.execute( + select(ModelTest).where(ModelTest.id > 5) + ) + assert all( + record.name == "Updated for Advanced Filter" + for record in updated_records.scalars() + ) + + +@pytest.mark.asyncio +async def test_update_multiple_records(async_session, test_data): + for item in test_data: + async_session.add(ModelTest(**item)) + await async_session.commit() + + crud = FastCRUD(ModelTest) + updated_data = {"name": "Updated Multiple"} + await crud.update( + db=async_session, object=updated_data, allow_multiple=True, tier_id=2 + ) + + updated_records = await async_session.execute( + select(ModelTest).where(ModelTest.tier_id == 2) + ) + assert all( + record.name == "Updated Multiple" for record in updated_records.scalars() + ) + + +@pytest.mark.asyncio +async def test_update_multiple_records_restriction(async_session, test_data): + for item in test_data: + async_session.add(ModelTest(**item)) + await async_session.commit() + + crud = FastCRUD(ModelTest) + updated_data = {"name": "Should Fail"} + + with pytest.raises(MultipleResultsFound) as exc_info: + await crud.update(db=async_session, object=updated_data, id__lt=10) + + assert "Expected exactly one record to update" in str(exc_info.value)