diff --git a/peewee_async.py b/peewee_async.py index fce5f23..6ea8da9 100644 --- a/peewee_async.py +++ b/peewee_async.py @@ -24,13 +24,11 @@ import peewee from importlib.metadata import version from playhouse.db_url import register_database - -IntegrityErrors = (peewee.IntegrityError,) +from peewee_async_compat import Manager, count, execute, prefetch, scalar try: import aiopg import psycopg2 - IntegrityErrors += (psycopg2.IntegrityError,) except ImportError: aiopg = None psycopg2 = None @@ -47,23 +45,24 @@ except AttributeError: asyncio_current_task = asyncio.Task.current_task -__version__ = version("peewee-async") +__version__ = version('peewee-async') __all__ = [ - # High level API ### - - 'Manager', + # TODO: Define new classes here + # ... 'PostgresqlDatabase', 'PooledPostgresqlDatabase', 'MySQLDatabase', 'PooledMySQLDatabase', - # Low level API ### - "execute", + # Compatibility API (deprecated in v1.0 release) + 'Manager', + 'execute', 'count', 'scalar', 'atomic', + 'prefetch', 'transaction', 'savepoint', ] @@ -72,403 +71,6 @@ __log__.addHandler(logging.NullHandler()) -################# -# Async manager # -################# - - -class Manager: - """Async peewee model's manager. - - :param database: (optional) async database driver - - Example:: - - class User(peewee.Model): - username = peewee.CharField(max_length=40, unique=True) - - objects = Manager(PostgresqlDatabase('test')) - - async def my_async_func(): - user0 = await objects.create(User, username='test') - user1 = await objects.get(User, id=user0.id) - user2 = await objects.get(User, username='test') - # All should be the same - print(user1.id, user2.id, user3.id) - - If you don't pass database to constructor, you should define - ``database`` as a class member like that:: - - database = PostgresqlDatabase('test') - - class MyManager(Manager): - database = database - - objects = MyManager() - - """ - #: Async database driver for manager. Must be provided - #: in constructor or as a class member. - database = None - - def __init__(self, database=None): - assert database or self.database, \ - ("Error, database must be provided via " - "argument or class member.") - - self.database = database or self.database - - @property - def is_connected(self): - """Check if database is connected. - """ - return self.database.aio_pool.pool is not None - - async def get(self, source_, *args, **kwargs): - """Get the model instance. - - :param source_: model or base query for lookup - - Example:: - - async def my_async_func(): - obj1 = await objects.get(MyModel, id=1) - obj2 = await objects.get(MyModel, MyModel.id==1) - obj3 = await objects.get(MyModel.select().where(MyModel.id==1)) - - All will return `MyModel` instance with `id = 1` - """ - await self.connect() - - if isinstance(source_, peewee.Query): - query = source_ - model = query.model - else: - query = source_.select() - model = source_ - - conditions = list(args) + [(getattr(model, k) == v) - for k, v in kwargs.items()] - - if conditions: - query = query.where(*conditions) - - try: - result = await self.execute(query) - return list(result)[0] - except IndexError: - raise model.DoesNotExist - - async def create(self, model_, **data): - """Create a new object saved to database. - """ - inst = model_(**data) - query = model_.insert(**dict(inst.__data__)) - - pk = await self.execute(query) - if inst._pk is None: - inst._pk = pk - return inst - - async def get_or_create(self, model_, defaults=None, **kwargs): - """Try to get an object or create it with the specified defaults. - - Return 2-tuple containing the model instance and a boolean - indicating whether the instance was created. - """ - try: - return (await self.get(model_, **kwargs)), False - except model_.DoesNotExist: - data = defaults or {} - data.update({k: v for k, v in kwargs.items() if '__' not in k}) - return (await self.create(model_, **data)), True - - async def get_or_none(self, model_, *args, **kwargs): - """Try to get an object and return None if it doesn't exist.""" - try: - return (await self.get(model_, *args, **kwargs)) - except model_.DoesNotExist: - pass - - async def update(self, obj, only=None): - """Update the object in the database. Optionally, update only - the specified fields. For creating a new object use :meth:`.create()` - - :param only: (optional) the list/tuple of fields or - field names to update - """ - field_dict = dict(obj.__data__) - pk_field = obj._meta.primary_key - - if only: - self._prune_fields(field_dict, only) - - if obj._meta.only_save_dirty: - self._prune_fields(field_dict, obj.dirty_fields) - - if obj._meta.composite_key: - for pk_part_name in pk_field.field_names: - field_dict.pop(pk_part_name, None) - else: - field_dict.pop(pk_field.name, None) - - query = obj.update(**field_dict).where(obj._pk_expr()) - result = await self.execute(query) - obj._dirty.clear() - return result - - async def delete(self, obj, recursive=False, delete_nullable=False): - """Delete object from database. - """ - if recursive: - dependencies = obj.dependencies(delete_nullable) - for cond, fk in reversed(list(dependencies)): - model = fk.model - if fk.null and not delete_nullable: - sq = model.update(**{fk.name: None}).where(cond) - else: - sq = model.delete().where(cond) - await self.execute(sq) - - query = obj.delete().where(obj._pk_expr()) - return (await self.execute(query)) - - async def create_or_get(self, model_, **kwargs): - """Try to create new object with specified data. If object already - exists, then try to get it by unique fields. - """ - try: - return (await self.create(model_, **kwargs)), True - except IntegrityErrors: - query = [] - for field_name, value in kwargs.items(): - field = getattr(model_, field_name) - if field.unique or field.primary_key: - query.append(field == value) - return (await self.get(model_, *query)), False - - async def execute(self, query): - """Execute query asyncronously. - """ - return await self.database.aio_execute(query) - - async def prefetch(self, query, *subqueries, prefetch_type=peewee.PREFETCH_TYPE.JOIN): - """Asynchronous version of the `prefetch()` from peewee. - - :return: Query that has already cached data for subqueries - """ - query = self._swap_database(query) - subqueries = map(self._swap_database, subqueries) - return (await prefetch(query, *subqueries, prefetch_type=prefetch_type)) - - async def count(self, query, clear_limit=False): - """Perform *COUNT* aggregated query asynchronously. - - :return: number of objects in ``select()`` query - """ - query = self._swap_database(query) - return (await count(query, clear_limit=clear_limit)) - - async def scalar(self, query, as_tuple=False): - """Get single value from ``select()`` query, i.e. for aggregation. - - :return: result is the same as after sync ``query.scalar()`` call - """ - query = self._swap_database(query) - return (await scalar(query, as_tuple=as_tuple)) - - async def connect(self): - """Open database async connection if not connected. - """ - await self.database.connect_async() - - async def close(self): - """Close database async connection if connected. - """ - await self.database.close_async() - - def atomic(self): - """Similar to `peewee.Database.atomic()` method, but returns - **asynchronous** context manager. - - Example:: - - async with objects.atomic(): - await objects.create( - PageBlock, key='intro', - text="There are more things in heaven and earth, " - "Horatio, than are dreamt of in your philosophy.") - await objects.create( - PageBlock, key='signature', text="William Shakespeare") - """ - return atomic(self.database) - - def transaction(self): - """Similar to `peewee.Database.transaction()` method, but returns - **asynchronous** context manager. - """ - return transaction(self.database) - - def savepoint(self, sid=None): - """Similar to `peewee.Database.savepoint()` method, but returns - **asynchronous** context manager. - """ - return savepoint(self.database, sid=sid) - - def allow_sync(self): - """Allow sync queries within context. Close the sync - database connection on exit if connected. - - Example:: - - with objects.allow_sync(): - PageBlock.create_table(True) - """ - return self.database.allow_sync() - - def _swap_database(self, query): - """Swap database for query if swappable. Return **new query** - with swapped database. - - This is experimental feature which allows us to have multiple - managers configured against different databases for single model - definition. - - The essential limitation though is that database backend have - to be **the same type** for model and manager! - """ - database = _query_db(query) - - if database == self.database: - return query - - if self._subclassed(peewee.PostgresqlDatabase, database, - self.database): - can_swap = True - elif self._subclassed(peewee.MySQLDatabase, database, - self.database): - can_swap = True - else: - can_swap = False - - if can_swap: - # **Experimental** database swapping! - query = query.clone() - query._database = self.database - return query - - assert False, ( - "Error, query's database and manager's database are " - "different. Query: %s Manager: %s" % (database, self.database) - ) - - return None - - @staticmethod - def _subclassed(base, *classes): - """Check if all classes are subclassed from base. - """ - return all(map(lambda obj: isinstance(obj, base), classes)) - - @staticmethod - def _prune_fields(field_dict, only): - """Filter fields data **in place** with `only` list. - - Example:: - - self._prune_fields(field_dict, ['slug', 'text']) - self._prune_fields(field_dict, [MyModel.slug]) - """ - fields = [(isinstance(f, str) and f or f.name) for f in only] - for f in list(field_dict.keys()): - if f not in fields: - field_dict.pop(f) - return field_dict - - -################# -# Async queries # -################# - - -async def execute(query): - warnings.warn( - "`execute` is deprecated, use `database.aio_execute` method.", - DeprecationWarning - ) - database = _query_db(query) - return await database.aio_execute(query) - - -async def count(query, clear_limit=False): - """Perform *COUNT* aggregated query asynchronously. - - :return: number of objects in ``select()`` query - """ - clone = query.clone() - database = _query_db(query) - if query._distinct or query._group_by or query._limit or query._offset: - if clear_limit: - clone._limit = clone._offset = None - sql, params = clone.sql() - wrapped = 'SELECT COUNT(1) FROM (%s) AS wrapped_select' % sql - async def fetch_results(cursor): - row = await cursor.fetchone() - if row: - return row[0] - else: - return row - result = await database.aio_execute_sql(wrapped, params, fetch_results) - return result or 0 - else: - clone._returning = [peewee.fn.Count(peewee.SQL('*'))] - clone._order_by = None - return (await scalar(clone)) or 0 - - -async def scalar(query, as_tuple=False): - warnings.warn( - "`scalar` is deprecated, use `query.aio_scalar` method.", - DeprecationWarning - ) - return await query.aio_scalar(as_tuple=as_tuple) - - -async def prefetch(sq, *subqueries, prefetch_type): - """Asynchronous version of the `prefetch()` from peewee. - """ - database = _query_db(sq) - if not subqueries: - result = await database.aio_execute(sq) - return result - - fixed_queries = peewee.prefetch_add_subquery(sq, subqueries, prefetch_type) - deps = {} - rel_map = {} - - for pq in reversed(fixed_queries): - query_model = pq.model - if pq.fields: - for rel_model in pq.rel_models: - rel_map.setdefault(rel_model, []) - rel_map[rel_model].append(pq) - - deps[query_model] = {} - id_map = deps[query_model] - has_relations = bool(rel_map.get(query_model)) - database = _query_db(pq.query) - result = await database.aio_execute(pq.query) - - for instance in result: - if pq.fields: - pq.store_instance(instance, id_map) - if has_relations: - for rel in rel_map[query_model]: - rel.populate_instance(instance, deps[rel.model]) - - return result - - ################### # Result wrappers # ################### @@ -586,7 +188,6 @@ def __init__(self, database, **kwargs): **self.connect_params_async ) - def __setattr__(self, name, value): if name == 'allow_sync': warnings.warn( @@ -698,11 +299,13 @@ def execute_sql(self, *args, **kwargs): return super().execute_sql(*args, **kwargs) async def fetch_results(self, query, cursor): + # TODO: Probably we don't need this method at all? + # We might get here if we use older `Manager` interface. if isinstance(query, peewee.BaseModelSelect): return await AsyncQueryWrapper.make_for_all_rows(cursor, query) if isinstance(query, peewee.RawQuery): return await AsyncQueryWrapper.make_for_all_rows(cursor, query) - raise Exception("Unknown type of query") + assert False, "Unsupported type of query '%s', use AioModel instead" % type(query) def connection(self) -> ConnectionContext: return ConnectionContext(self.aio_pool, self._task_data) @@ -727,8 +330,11 @@ async def aio_execute(self, query, fetch_results=None): ``query.execute()`` """ sql, params = query.sql() - default_fetch_results = getattr(query, "fetch_results", functools.partial(self.fetch_results, query)) - return await self.aio_execute_sql(sql, params, fetch_results=fetch_results or default_fetch_results) + if fetch_results is None: + query_fetch_results = getattr(query, 'fetch_results', None) + database_fetch_results = functools.partial(self.fetch_results, query) + fetch_results = query_fetch_results or database_fetch_results + return await self.aio_execute_sql(sql, params, fetch_results=fetch_results) class AioPool(metaclass=abc.ABCMeta): @@ -859,8 +465,7 @@ def init(self, database, **kwargs): register_database(PostgresqlDatabase, 'postgres+async', 'postgresql+async') -class PooledPostgresqlDatabase(AsyncPostgresqlMixin, - peewee.PostgresqlDatabase): +class PooledPostgresqlDatabase(AsyncPostgresqlMixin, peewee.PostgresqlDatabase): """PosgreSQL database driver providing **single drop-in sync** connection and **async connections pool** interface. @@ -887,8 +492,7 @@ def init(self, database, **kwargs): self.init_async() -register_database(PooledPostgresqlDatabase, 'postgres+pool+async', - 'postgresql+pool+async') +register_database(PooledPostgresqlDatabase, 'postgres+pool+async', 'postgresql+pool+async') ######### @@ -1085,13 +689,6 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): # Internal helpers # #################### - -def _query_db(query): - """Get database instance bound to query. This helper - incapsulates internal peewee's access to database. - """ - return query._database - class TaskLocals: """Simple `dict` wrapper to get and set values on per `asyncio` task basis. @@ -1193,7 +790,8 @@ async def fetch_results(self, cursor): @peewee.database_required async def aio_scalar(self, database, as_tuple=False): - """Get single value from ``select()`` query, i.e. for aggregation. + """ + Get single value from ``select()`` query, i.e. for aggregation. :return: result is the same as after sync ``query.scalar()`` call """ @@ -1206,7 +804,6 @@ async def fetch_results(cursor): return await database.aio_execute(self, fetch_results=fetch_results) - async def aio_get(self, database=None): clone = self.paginate(1, 1) try: @@ -1267,10 +864,7 @@ def delete(cls): @classmethod async def aio_get(cls, *query, **filters): - """ - Async version of **peewee.Model.get** - """ - + """Async version of **peewee.Model.get**""" sq = cls.select() if query: if len(query) == 1 and isinstance(query[0], int): @@ -1293,11 +887,9 @@ async def aio_get_or_none(cls, *query, **filters): @classmethod async def aio_create(cls, **data): - """ - INSERT new row into table and return corresponding model instance. - """ - inst = cls(**data) - pk = await cls.insert(**dict(inst.__data__)).aio_execute() - if inst._pk is None: - inst._pk = pk - return inst + """INSERT new row into table and return corresponding model instance.""" + obj = cls(**data) + pk = await cls.insert(**dict(obj.__data__)).aio_execute() + if obj._pk is None: + obj._pk = pk + return obj diff --git a/peewee_async_compat.py b/peewee_async_compat.py new file mode 100644 index 0000000..4fa3b64 --- /dev/null +++ b/peewee_async_compat.py @@ -0,0 +1,442 @@ +""" +Compatibility layer for `peewee_async` to navigate smooth migration towards v1.0. + +In the initial implementation the `Manager` class was introduced to avoid models +subclassing (which is not always possible or just undesirable). The newer interface +relies more to models subclassing, please check the `peewee_async.AioModel`. + +Licensed under The MIT License (MIT) + +Copyright (c) 2024, Alexey Kinëv + +""" +from functools import partial + +import warnings +import peewee + +IntegrityErrors = (peewee.IntegrityError,) + +try: + import aiopg + import psycopg2 + IntegrityErrors += (psycopg2.IntegrityError,) +except ImportError: + aiopg = None + psycopg2 = None + +__all__ = [ + 'Manager', + 'count', + 'prefetch', + 'execute', + 'scalar', +] + + +def _patch_query_with_compat_methods(query, async_query_cls): + """ + Patches original peewee's query with methods from AioQueryMixin, etc. + + This is the central (hacky) place where we glue the new classes with older style + `Manager` interface that operates on original peewee's models and query classes. + + Methods added to peewee's original query: + + - aio_execute + - fetch_results + - make_async_query_wrapper + - aio_get (for SELECT) + - aio_scalar (for SELECT) + """ + from peewee_async import AioModelSelect + + query.aio_execute = partial(async_query_cls.aio_execute, query) + query.fetch_results = partial(async_query_cls.fetch_results, query) + query.make_async_query_wrapper = partial(async_query_cls.make_async_query_wrapper, query) + + if async_query_cls is AioModelSelect: + query.aio_get = partial(async_query_cls.aio_get, query) + query.aio_scalar = partial(async_query_cls.aio_scalar, query) + + +def _query_db(query): + """ + Get database instance bound to query. This helper + incapsulates internal peewee's access to database. + """ + return query._database + + +async def count(query, clear_limit=False): + """ + Perform *COUNT* aggregated query asynchronously. + + :return: number of objects in `select()` query + """ + database = _query_db(query) + clone = query.clone() + if query._distinct or query._group_by or query._limit or query._offset: + if clear_limit: + clone._limit = clone._offset = None + sql, params = clone.sql() + wrapped = 'SELECT COUNT(1) FROM (%s) AS wrapped_select' % sql + async def fetch_results(cursor): + row = await cursor.fetchone() + if row: + return row[0] + else: + return row + result = await database.aio_execute_sql(wrapped, params, fetch_results) + return result or 0 + else: + clone._returning = [peewee.fn.Count(peewee.SQL('*'))] + clone._order_by = None + return (await scalar(clone)) or 0 + + +async def prefetch(sq, *subqueries, prefetch_type): + """Asynchronous version of the `prefetch()` from peewee.""" + from peewee_async import AioModelSelect + + database = _query_db(sq) + if not subqueries: + _patch_query_with_compat_methods(sq, AioModelSelect) + result = await database.aio_execute(sq) + return result + + fixed_queries = peewee.prefetch_add_subquery(sq, subqueries, prefetch_type) + deps = {} + rel_map = {} + + for pq in reversed(fixed_queries): + query_model = pq.model + if pq.fields: + for rel_model in pq.rel_models: + rel_map.setdefault(rel_model, []) + rel_map[rel_model].append(pq) + + deps[query_model] = {} + id_map = deps[query_model] + has_relations = bool(rel_map.get(query_model)) + database = _query_db(pq.query) + _patch_query_with_compat_methods(pq.query, AioModelSelect) + result = await database.aio_execute(pq.query) + + for instance in result: + if pq.fields: + pq.store_instance(instance, id_map) + if has_relations: + for rel in rel_map[query_model]: + rel.populate_instance(instance, deps[rel.model]) + + return result + + +async def execute(query): + warnings.warn( + "`execute` is deprecated, use `database.aio_execute` method.", + DeprecationWarning + ) + database = _query_db(query) + return await database.aio_execute(query) + + +async def scalar(query, as_tuple=False): + from peewee_async import AioModelSelect # noqa + warnings.warn( + "`scalar` is deprecated, use `query.aio_scalar` method.", + DeprecationWarning + ) + _patch_query_with_compat_methods(query, AioModelSelect) + return await query.aio_scalar(as_tuple=as_tuple) + + +class Manager: + """ + Async peewee model's manager. + + :param database: (optional) async database driver + + Example:: + + class User(peewee.Model): + username = peewee.CharField(max_length=40, unique=True) + + objects = Manager(PostgresqlDatabase('test')) + + async def my_async_func(): + user0 = await objects.create(User, username='test') + user1 = await objects.get(User, id=user0.id) + user2 = await objects.get(User, username='test') + # All should be the same + print(user1.id, user2.id, user3.id) + + If you don't pass database to constructor, you should define + ``database`` as a class member like that:: + + database = PostgresqlDatabase('test') + + class MyManager(Manager): + database = database + + objects = MyManager() + + """ + #: Async database driver for manager. Must be provided + #: in constructor or as a class member. + database = None + + def __init__(self, database=None): + assert database or self.database, \ + ("Error, database must be provided via " + "argument or class member.") + + self.database = database or self.database + + @property + def is_connected(self): + """Check if database is connected. + """ + return self.database.aio_pool.pool is not None + + async def get(self, source_, *args, **kwargs): + """Get the model instance. + + :param source_: model or base query for lookup + + Example:: + + async def my_async_func(): + obj1 = await objects.get(MyModel, id=1) + obj2 = await objects.get(MyModel, MyModel.id==1) + obj3 = await objects.get(MyModel.select().where(MyModel.id==1)) + + All will return `MyModel` instance with `id = 1` + """ + from peewee_async import AioModelSelect # noqa + + await self.connect() + + if isinstance(source_, peewee.Query): + query = source_ + model = query.model + else: + query = source_.select() + model = source_ + + conditions = list(args) + [(getattr(model, k) == v) for k, v in kwargs.items()] + + if conditions: + query = query.where(*conditions) + + _patch_query_with_compat_methods(query, AioModelSelect) + + try: + result = await self.execute(query) + return list(result)[0] + except IndexError: + raise model.DoesNotExist + + async def create(self, model_, **data): + """Create a new object saved to database.""" + from peewee_async import AioModelInsert + + obj = model_(**data) + query = model_.insert(**dict(obj.__data__)) + + _patch_query_with_compat_methods(query, AioModelInsert) + + pk = await self.execute(query) + if obj._pk is None: + obj._pk = pk + + return obj + + async def get_or_create(self, model_, defaults=None, **kwargs): + """ + Try to get an object or create it with the specified defaults. + + Return 2-tuple containing the model instance and a boolean + indicating whether the instance was created. + """ + try: + return (await self.get(model_, **kwargs)), False + except model_.DoesNotExist: + data = defaults or {} + data.update({k: v for k, v in kwargs.items() if '__' not in k}) + return (await self.create(model_, **data)), True + + async def get_or_none(self, model_, *args, **kwargs): + """Try to get an object and return None if it doesn't exist.""" + try: + return (await self.get(model_, *args, **kwargs)) + except model_.DoesNotExist: + pass + + async def update(self, obj, only=None): + """ + Update the object in the database. Optionally, update only + the specified fields. For creating a new object use :meth:`.create()` + + :param only: (optional) the list/tuple of fields or + field names to update + """ + from peewee_async import AioModelUpdate # noqa + + field_dict = dict(obj.__data__) + pk_field = obj._meta.primary_key + + if only: + self._prune_fields(field_dict, only) + + if obj._meta.only_save_dirty: + self._prune_fields(field_dict, obj.dirty_fields) + + if obj._meta.composite_key: + for pk_part_name in pk_field.field_names: + field_dict.pop(pk_part_name, None) + else: + field_dict.pop(pk_field.name, None) + + query = obj.update(**field_dict).where(obj._pk_expr()) + + _patch_query_with_compat_methods(query, AioModelUpdate) + + result = await self.execute(query) + obj._dirty.clear() + + return result + + async def delete(self, obj, recursive=False, delete_nullable=False): + """Delete object from database.""" + from peewee_async import AioModelDelete, AioModelUpdate + if recursive: + dependencies = obj.dependencies(delete_nullable) + for cond, fk in reversed(list(dependencies)): + model = fk.model + if fk.null and not delete_nullable: + sq = model.update(**{fk.name: None}).where(cond) + _patch_query_with_compat_methods(sq, AioModelUpdate) + else: + sq = model.delete().where(cond) + _patch_query_with_compat_methods(sq, AioModelDelete) + await self.execute(sq) + + query = obj.delete().where(obj._pk_expr()) + _patch_query_with_compat_methods(query, AioModelDelete) + + return (await self.execute(query)) + + async def create_or_get(self, model_, **kwargs): + """ + Try to create new object with specified data. If object already + exists, then try to get it by unique fields. + """ + try: + return (await self.create(model_, **kwargs)), True + except IntegrityErrors: + query = [] + for field_name, value in kwargs.items(): + field = getattr(model_, field_name) + if field.unique or field.primary_key: + query.append(field == value) + return (await self.get(model_, *query)), False + + async def execute(self, query): + """Execute query asyncronously.""" + return await self.database.aio_execute(query) + + async def prefetch(self, query, *subqueries, prefetch_type=peewee.PREFETCH_TYPE.JOIN): + """ + Asynchronous version of the `prefetch()` from peewee. + + :return: Query that has already cached data for subqueries + """ + return (await prefetch(query, *subqueries, prefetch_type=prefetch_type)) + + async def count(self, query, clear_limit=False): + """ + Perform *COUNT* aggregated query asynchronously. + + :return: number of objects in ``select()`` query + """ + return (await count(query, clear_limit=clear_limit)) + + async def scalar(self, query, as_tuple=False): + """ + Get single value from ``select()`` query, i.e. for aggregation. + + :return: result is the same as after sync ``query.scalar()`` call + """ + return (await scalar(query, as_tuple=as_tuple)) + + async def connect(self): + """Open database async connection if not connected.""" + await self.database.connect_async() + + async def close(self): + """Close database async connection if connected.""" + await self.database.close_async() + + def atomic(self): + """ + Similar to `peewee.Database.atomic()` method, but returns + **asynchronous** context manager. + + Example:: + + async with objects.atomic(): + await objects.create( + PageBlock, key='intro', + text="There are more things in heaven and earth, " + "Horatio, than are dreamt of in your philosophy.") + await objects.create( + PageBlock, key='signature', text="William Shakespeare") + """ + from peewee_async import atomic # noqa + return atomic(self.database) + + def transaction(self): + """ + Similar to `peewee.Database.transaction()` method, but returns + **asynchronous** context manager. + """ + from peewee_async import transaction # noqa + return transaction(self.database) + + def savepoint(self, sid=None): + """ + Similar to `peewee.Database.savepoint()` method, but returns + **asynchronous** context manager. + """ + from peewee_async import savepoint # noqa + return savepoint(self.database, sid=sid) + + def allow_sync(self): + """ + Allow sync queries within context. Close the sync + database connection on exit if connected. + + Example:: + + with objects.allow_sync(): + PageBlock.create_table(True) + """ + return self.database.allow_sync() + + @staticmethod + def _prune_fields(field_dict, only): + """ + Filter fields data **in place** with `only` list. + + Example:: + + self._prune_fields(field_dict, ['slug', 'text']) + self._prune_fields(field_dict, [MyModel.slug]) + """ + fields = [(isinstance(f, str) and f or f.name) for f in only] + for f in list(field_dict.keys()): + if f not in fields: + field_dict.pop(f) + return field_dict diff --git a/pyproject.toml b/pyproject.toml index 7362563..691e8ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,7 @@ authors = ["Alexey Kinev ", "Gorshkov Nikolay(contributor) %s' % (self.__class__.__name__, self.id, self.text) -class TestModelAlpha(AioModel): +class TestModelAlpha(peewee_async.AioModel): __test__ = False text = peewee.CharField() @@ -22,7 +21,7 @@ def __str__(self): return '<%s id=%s> %s' % (self.__class__.__name__, self.id, self.text) -class TestModelBeta(AioModel): +class TestModelBeta(peewee_async.AioModel): __test__ = False alpha = peewee.ForeignKeyField(TestModelAlpha, backref='betas') text = peewee.CharField() @@ -31,7 +30,7 @@ def __str__(self): return '<%s id=%s> %s' % (self.__class__.__name__, self.id, self.text) -class TestModelGamma(AioModel): +class TestModelGamma(peewee_async.AioModel): __test__ = False text = peewee.CharField() beta = peewee.ForeignKeyField(TestModelBeta, backref='gammas') @@ -40,7 +39,7 @@ def __str__(self): return '<%s id=%s> %s' % (self.__class__.__name__, self.id, self.text) -class UUIDTestModel(AioModel): +class UUIDTestModel(peewee_async.AioModel): id = peewee.UUIDField(primary_key=True, default=uuid.uuid4) text = peewee.CharField() @@ -48,7 +47,16 @@ def __str__(self): return '<%s id=%s> %s' % (self.__class__.__name__, self.id, self.text) -class CompositeTestModel(AioModel): +class CompatTestModel(peewee.Model): + id = peewee.UUIDField(primary_key=True, default=uuid.uuid4) + text = peewee.CharField(max_length=100, unique=True) + data = peewee.TextField(default='') + + def __str__(self): + return '<%s id=%s> %s' % (self.__class__.__name__, self.id, self.text) + + +class CompositeTestModel(peewee_async.AioModel): """A simple "through" table for many-to-many relationship.""" uuid = peewee.ForeignKeyField(UUIDTestModel) alpha = peewee.ForeignKeyField(TestModelAlpha) @@ -58,6 +66,6 @@ class Meta: ALL_MODELS = ( - TestModel, UUIDTestModel, TestModelAlpha, - TestModelBeta, TestModelGamma, CompositeTestModel + TestModel, UUIDTestModel, TestModelAlpha, TestModelBeta, TestModelGamma, + CompatTestModel, CompositeTestModel ) diff --git a/tests/test_common.py b/tests/test_common.py index ee07d7c..4348b31 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,7 +1,7 @@ import asyncio import uuid -import peewee as pw +import peewee import pytest import peewee_async @@ -28,7 +28,6 @@ async def test_multiple_iterate_over_result(manager): TestModel.select().order_by(TestModel.text)) assert list(result) == [obj1, obj2] - assert list(result) == [obj1, obj2] @all_dbs @@ -122,7 +121,7 @@ async def test_allow_sync_is_reverted_for_exc(manager): ununique_text = "ununique_text" await manager.create(TestModel, text=ununique_text) await manager.create(TestModel, text=ununique_text) - except pw.IntegrityError: + except peewee.IntegrityError: pass assert manager.database._allow_sync is False @@ -196,11 +195,9 @@ async def test_deferred_init(params, db_cls): [ (DB_DEFAULTS[name], db_cls) for name, db_cls in DB_CLASSES.items() ] - ) async def test_proxy_database(params, db_cls): - - database = pw.Proxy() + database = peewee.Proxy() TestModel._meta.database = database manager = peewee_async.Manager(database) diff --git a/tests/test_compat.py b/tests/test_compat.py new file mode 100644 index 0000000..93457a6 --- /dev/null +++ b/tests/test_compat.py @@ -0,0 +1,41 @@ +import uuid + +from tests.conftest import all_dbs +from tests.models import CompatTestModel + + +@all_dbs +async def test_create_select_compat_mode(manager): + obj1 = await manager.create(CompatTestModel, text="Test 1") + obj2 = await manager.create(CompatTestModel, text="Test 2") + query = CompatTestModel.select().order_by(CompatTestModel.text) + result = await manager.execute(query) + assert list(result) == [obj1, obj2] + + +@all_dbs +async def test_update_compat_mode(manager): + obj_draft = await manager.create(CompatTestModel, text="Draft 1") + obj_draft.text = "Final result" + await manager.update(obj_draft) + obj = await manager.get(CompatTestModel, id=obj_draft.id) + assert obj.text == "Final result" + + +@all_dbs +async def test_count_compat_mode(manager): + obj = await manager.create(CompatTestModel, text="Unique title %s" % uuid.uuid4()) + search = CompatTestModel.select().where(CompatTestModel.text == obj.text) + count = await manager.count(search) + assert count == 1 + + +@all_dbs +async def test_delete_compat_mode(manager): + obj = await manager.create(CompatTestModel, text="Expired item %s" % uuid.uuid4()) + search = CompatTestModel.select().where(CompatTestModel.id == obj.id) + count_before = await manager.count(search) + assert count_before == 1 + await manager.delete(obj) + count_after = await manager.count(search) + assert count_after == 0 diff --git a/tests/test_shortcuts.py b/tests/test_shortcuts.py index ad639a5..5bca9db 100644 --- a/tests/test_shortcuts.py +++ b/tests/test_shortcuts.py @@ -1,7 +1,6 @@ import uuid import peewee -import peewee as pw import pytest from tests.conftest import all_dbs @@ -94,7 +93,7 @@ async def test_scalar_query(manager): text = "Test %s" % uuid.uuid4() await manager.create(TestModel, text=text) - fn = pw.fn.Count(TestModel.id) + fn = peewee.fn.Count(TestModel.id) count = await manager.scalar(TestModel.select(fn)) assert count == 2 diff --git a/tests/test_transaction.py b/tests/test_transaction.py index edab20e..0ed9c2a 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -1,7 +1,5 @@ import asyncio -import pytest - from tests.conftest import all_dbs from tests.models import TestModel