-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathpeewee.py
4945 lines (4192 loc) · 164 KB
/
peewee.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# May you do good and not evil
# May you find forgiveness for yourself and forgive others
# May you share freely, never taking more than you give. -- SQLite source code
#
# As we enjoy great advantages from the inventions of others, we should be glad
# of an opportunity to serve others by an invention of ours, and this we should
# do freely and generously. -- Ben Franklin
#
# (\
# ( \ /(o)\ caw!
# ( \/ ()/ /)
# ( `;.))'".)
# `(/////.-'
# =====))=))===()
# ///'
# //
# '
import datetime
import decimal
import hashlib
import logging
import operator
import re
import sys
import threading
import uuid
from bisect import bisect_left
from bisect import bisect_right
from collections import deque
from collections import namedtuple
try:
from collections import OrderedDict
except ImportError:
OrderedDict = dict
from copy import deepcopy
from functools import wraps
from inspect import isclass
__version__ = '2.8.1'
__all__ = [
'BareField',
'BigIntegerField',
'BlobField',
'BooleanField',
'CharField',
'Check',
'Clause',
'CompositeKey',
'DatabaseError',
'DataError',
'DateField',
'DateTimeField',
'DecimalField',
'DeferredRelation',
'DoesNotExist',
'DoubleField',
'DQ',
'Field',
'FixedCharField',
'FloatField',
'fn',
'ForeignKeyField',
'ImproperlyConfigured',
'IntegerField',
'IntegrityError',
'InterfaceError',
'InternalError',
'JOIN',
'JOIN_FULL',
'JOIN_INNER',
'JOIN_LEFT_OUTER',
'Model',
'MySQLDatabase',
'NotSupportedError',
'OperationalError',
'Param',
'PostgresqlDatabase',
'prefetch',
'PrimaryKeyField',
'ProgrammingError',
'Proxy',
'R',
'SmallIntegerField',
'SqliteDatabase',
'SQL',
'TextField',
'TimeField',
'Using',
'UUIDField',
'Window',
]
# Set default logging handler to avoid "No handlers could be found for logger
# "peewee"" warnings.
try: # Python 2.7+
from logging import NullHandler
except ImportError:
class NullHandler(logging.Handler):
def emit(self, record):
pass
# All peewee-generated logs are logged to this namespace.
logger = logging.getLogger('peewee')
logger.addHandler(NullHandler())
# Python 2/3 compatibility helpers. These helpers are used internally and are
# not exported.
def with_metaclass(meta, base=object):
return meta("NewBase", (base,), {})
PY2 = sys.version_info[0] == 2
PY3 = sys.version_info[0] == 3
PY26 = sys.version_info[:2] == (2, 6)
if PY3:
import builtins
from collections import Callable
from functools import reduce
callable = lambda c: isinstance(c, Callable)
unicode_type = str
string_type = bytes
basestring = str
print_ = getattr(builtins, 'print')
binary_construct = lambda s: bytes(s.encode('raw_unicode_escape'))
def reraise(tp, value, tb=None):
if value.__traceback__ is not tb:
raise value.with_traceback(tb)
raise value
elif PY2:
unicode_type = unicode
string_type = basestring
binary_construct = buffer
def print_(s):
sys.stdout.write(s)
sys.stdout.write('\n')
exec('def reraise(tp, value, tb=None): raise tp, value, tb')
else:
raise RuntimeError('Unsupported python version.')
# By default, peewee supports Sqlite, MySQL and Postgresql.
try:
from pysqlite2 import dbapi2 as pysq3
except ImportError:
pysq3 = None
try:
import sqlite3
except ImportError:
sqlite3 = pysq3
else:
if pysq3 and pysq3.sqlite_version_info >= sqlite3.sqlite_version_info:
sqlite3 = pysq3
try:
from psycopg2cffi import compat
compat.register()
except ImportError:
pass
try:
import psycopg2
from psycopg2 import extensions as pg_extensions
except ImportError:
psycopg2 = None
try:
import MySQLdb as mysql # prefer the C module.
except ImportError:
try:
import pymysql as mysql
except ImportError:
mysql = None
try:
from playhouse._speedups import format_date_time
from playhouse._speedups import sort_models_topologically
from playhouse._speedups import strip_parens
except ImportError:
def format_date_time(value, formats, post_process=None):
post_process = post_process or (lambda x: x)
for fmt in formats:
try:
return post_process(datetime.datetime.strptime(value, fmt))
except ValueError:
pass
return value
def sort_models_topologically(models):
"""Sort models topologically so that parents will precede children."""
models = set(models)
seen = set()
ordering = []
def dfs(model):
if model in models and model not in seen:
seen.add(model)
for foreign_key in model._meta.reverse_rel.values():
dfs(foreign_key.model_class)
ordering.append(model) # parent will follow descendants
# Order models by name and table initially to guarantee total ordering.
names = lambda m: (m._meta.name, m._meta.db_table)
for m in sorted(models, key=names, reverse=True):
dfs(m)
return list(reversed(ordering))
def strip_parens(s):
# Quick sanity check.
if not s or s[0] != '(':
return s
ct = i = 0
l = len(s)
while i < l:
if s[i] == '(' and s[l - 1] == ')':
ct += 1
i += 1
l -= 1
else:
break
if ct:
# If we ever end up with negatively-balanced parentheses, then we
# know that one of the outer parentheses was required.
unbalanced_ct = 0
required = 0
for i in range(ct, l - ct):
if s[i] == '(':
unbalanced_ct += 1
elif s[i] == ')':
unbalanced_ct -= 1
if unbalanced_ct < 0:
required += 1
unbalanced_ct = 0
if required == ct:
break
ct -= required
if ct > 0:
return s[ct:-ct]
return s
try:
from playhouse._speedups import _DictQueryResultWrapper
from playhouse._speedups import _ModelQueryResultWrapper
from playhouse._speedups import _SortedFieldList
from playhouse._speedups import _TuplesQueryResultWrapper
except ImportError:
_DictQueryResultWrapper = _ModelQueryResultWrapper = _SortedFieldList =\
_TuplesQueryResultWrapper = None
if sqlite3:
sqlite3.register_adapter(decimal.Decimal, str)
sqlite3.register_adapter(datetime.date, str)
sqlite3.register_adapter(datetime.time, str)
DATETIME_PARTS = ['year', 'month', 'day', 'hour', 'minute', 'second']
DATETIME_LOOKUPS = set(DATETIME_PARTS)
# Sqlite does not support the `date_part` SQL function, so we will define an
# implementation in python.
SQLITE_DATETIME_FORMATS = (
'%Y-%m-%d %H:%M:%S',
'%Y-%m-%d %H:%M:%S.%f',
'%Y-%m-%d',
'%H:%M:%S',
'%H:%M:%S.%f',
'%H:%M')
def _sqlite_date_part(lookup_type, datetime_string):
assert lookup_type in DATETIME_LOOKUPS
if not datetime_string:
return
dt = format_date_time(datetime_string, SQLITE_DATETIME_FORMATS)
return getattr(dt, lookup_type)
SQLITE_DATE_TRUNC_MAPPING = {
'year': '%Y',
'month': '%Y-%m',
'day': '%Y-%m-%d',
'hour': '%Y-%m-%d %H',
'minute': '%Y-%m-%d %H:%M',
'second': '%Y-%m-%d %H:%M:%S'}
MYSQL_DATE_TRUNC_MAPPING = SQLITE_DATE_TRUNC_MAPPING.copy()
MYSQL_DATE_TRUNC_MAPPING['minute'] = '%Y-%m-%d %H:%i'
MYSQL_DATE_TRUNC_MAPPING['second'] = '%Y-%m-%d %H:%i:%S'
def _sqlite_date_trunc(lookup_type, datetime_string):
assert lookup_type in SQLITE_DATE_TRUNC_MAPPING
if not datetime_string:
return
dt = format_date_time(datetime_string, SQLITE_DATETIME_FORMATS)
return dt.strftime(SQLITE_DATE_TRUNC_MAPPING[lookup_type])
def _sqlite_regexp(regex, value):
return re.search(regex, value, re.I) is not None
class attrdict(dict):
def __getattr__(self, attr):
return self[attr]
# Operators used in binary expressions.
OP = attrdict(
AND='and',
OR='or',
ADD='+',
SUB='-',
MUL='*',
DIV='/',
BIN_AND='&',
BIN_OR='|',
XOR='^',
MOD='%',
EQ='=',
LT='<',
LTE='<=',
GT='>',
GTE='>=',
NE='!=',
IN='in',
NOT_IN='not in',
IS='is',
IS_NOT='is not',
LIKE='like',
ILIKE='ilike',
BETWEEN='between',
REGEXP='regexp',
CONCAT='||',
)
JOIN = attrdict(
INNER='INNER',
LEFT_OUTER='LEFT OUTER',
RIGHT_OUTER='RIGHT OUTER',
FULL='FULL',
)
JOIN_INNER = JOIN.INNER
JOIN_LEFT_OUTER = JOIN.LEFT_OUTER
JOIN_FULL = JOIN.FULL
RESULTS_NAIVE = 1
RESULTS_MODELS = 2
RESULTS_TUPLES = 3
RESULTS_DICTS = 4
RESULTS_AGGREGATE_MODELS = 5
# To support "django-style" double-underscore filters, create a mapping between
# operation name and operation code, e.g. "__eq" == OP.EQ.
DJANGO_MAP = {
'eq': OP.EQ,
'lt': OP.LT,
'lte': OP.LTE,
'gt': OP.GT,
'gte': OP.GTE,
'ne': OP.NE,
'in': OP.IN,
'is': OP.IS,
'like': OP.LIKE,
'ilike': OP.ILIKE,
'regexp': OP.REGEXP,
}
# Helper functions that are used in various parts of the codebase.
def merge_dict(source, overrides):
merged = source.copy()
merged.update(overrides)
return merged
def returns_clone(func):
"""
Method decorator that will "clone" the object before applying the given
method. This ensures that state is mutated in a more predictable fashion,
and promotes the use of method-chaining.
"""
def inner(self, *args, **kwargs):
clone = self.clone() # Assumes object implements `clone`.
func(clone, *args, **kwargs)
return clone
inner.call_local = func # Provide a way to call without cloning.
return inner
def not_allowed(func):
"""
Method decorator to indicate a method is not allowed to be called. Will
raise a `NotImplementedError`.
"""
def inner(self, *args, **kwargs):
raise NotImplementedError('%s is not allowed on %s instances' % (
func, type(self).__name__))
return inner
class Proxy(object):
"""
Proxy class useful for situations when you wish to defer the initialization
of an object.
"""
__slots__ = ['obj', '_callbacks']
def __init__(self):
self._callbacks = []
self.initialize(None)
def initialize(self, obj):
self.obj = obj
for callback in self._callbacks:
callback(obj)
def attach_callback(self, callback):
self._callbacks.append(callback)
return callback
def __getattr__(self, attr):
if self.obj is None:
raise AttributeError('Cannot use uninitialized Proxy.')
return getattr(self.obj, attr)
def __setattr__(self, attr, value):
if attr not in self.__slots__:
raise AttributeError('Cannot set attribute on proxy.')
return super(Proxy, self).__setattr__(attr, value)
class DeferredRelation(object):
def set_field(self, model_class, field, name):
self.model_class = model_class
self.field = field
self.name = name
def set_model(self, rel_model):
self.field.rel_model = rel_model
self.field.add_to_class(self.model_class, self.name)
class _CDescriptor(object):
def __get__(self, instance, instance_type=None):
if instance is not None:
return Entity(instance._alias)
return self
# Classes representing the query tree.
class Node(object):
"""Base-class for any part of a query which shall be composable."""
c = _CDescriptor()
_node_type = 'node'
def __init__(self):
self._negated = False
self._alias = None
self._bind_to = None
self._ordering = None # ASC or DESC.
@classmethod
def extend(cls, name=None, clone=False):
def decorator(method):
method_name = name or method.__name__
if clone:
method = returns_clone(method)
setattr(cls, method_name, method)
return method
return decorator
def clone_base(self):
return type(self)()
def clone(self):
inst = self.clone_base()
inst._negated = self._negated
inst._alias = self._alias
inst._ordering = self._ordering
inst._bind_to = self._bind_to
return inst
@returns_clone
def __invert__(self):
self._negated = not self._negated
@returns_clone
def alias(self, a=None):
self._alias = a
@returns_clone
def bind_to(self, bt):
"""
Bind the results of an expression to a specific model type. Useful
when adding expressions to a select, where the result of the expression
should be placed on a joined instance.
"""
self._bind_to = bt
@returns_clone
def asc(self):
self._ordering = 'ASC'
@returns_clone
def desc(self):
self._ordering = 'DESC'
def __pos__(self):
return self.asc()
def __neg__(self):
return self.desc()
def _e(op, inv=False):
"""
Lightweight factory which returns a method that builds an Expression
consisting of the left-hand and right-hand operands, using `op`.
"""
def inner(self, rhs):
if inv:
return Expression(rhs, op, self)
return Expression(self, op, rhs)
return inner
__and__ = _e(OP.AND)
__or__ = _e(OP.OR)
__add__ = _e(OP.ADD)
__sub__ = _e(OP.SUB)
__mul__ = _e(OP.MUL)
__div__ = __truediv__ = _e(OP.DIV)
__xor__ = _e(OP.XOR)
__radd__ = _e(OP.ADD, inv=True)
__rsub__ = _e(OP.SUB, inv=True)
__rmul__ = _e(OP.MUL, inv=True)
__rdiv__ = __rtruediv__ = _e(OP.DIV, inv=True)
__rand__ = _e(OP.AND, inv=True)
__ror__ = _e(OP.OR, inv=True)
__rxor__ = _e(OP.XOR, inv=True)
def __eq__(self, rhs):
if rhs is None:
return Expression(self, OP.IS, None)
return Expression(self, OP.EQ, rhs)
def __ne__(self, rhs):
if rhs is None:
return Expression(self, OP.IS_NOT, None)
return Expression(self, OP.NE, rhs)
__lt__ = _e(OP.LT)
__le__ = _e(OP.LTE)
__gt__ = _e(OP.GT)
__ge__ = _e(OP.GTE)
__lshift__ = _e(OP.IN)
__rshift__ = _e(OP.IS)
__mod__ = _e(OP.LIKE)
__pow__ = _e(OP.ILIKE)
bin_and = _e(OP.BIN_AND)
bin_or = _e(OP.BIN_OR)
# Special expressions.
def in_(self, rhs):
return Expression(self, OP.IN, rhs)
def not_in(self, rhs):
return Expression(self, OP.NOT_IN, rhs)
def is_null(self, is_null=True):
if is_null:
return Expression(self, OP.IS, None)
return Expression(self, OP.IS_NOT, None)
def contains(self, rhs):
return Expression(self, OP.ILIKE, '%%%s%%' % rhs)
def startswith(self, rhs):
return Expression(self, OP.ILIKE, '%s%%' % rhs)
def endswith(self, rhs):
return Expression(self, OP.ILIKE, '%%%s' % rhs)
def between(self, low, high):
return Expression(self, OP.BETWEEN, Clause(low, R('AND'), high))
def regexp(self, expression):
return Expression(self, OP.REGEXP, expression)
def concat(self, rhs):
return Expression(self, OP.CONCAT, rhs)
class SQL(Node):
"""An unescaped SQL string, with optional parameters."""
_node_type = 'sql'
def __init__(self, value, *params):
self.value = value
self.params = params
super(SQL, self).__init__()
def clone_base(self):
return SQL(self.value, *self.params)
R = SQL # backwards-compat.
class Entity(Node):
"""A quoted-name or entity, e.g. "table"."column"."""
_node_type = 'entity'
def __init__(self, *path):
super(Entity, self).__init__()
self.path = path
def clone_base(self):
return Entity(*self.path)
def __getattr__(self, attr):
return Entity(*filter(None, self.path + (attr,)))
class Func(Node):
"""An arbitrary SQL function call."""
_node_type = 'func'
def __init__(self, name, *arguments):
self.name = name
self.arguments = arguments
self._coerce = True
super(Func, self).__init__()
@returns_clone
def coerce(self, coerce=True):
self._coerce = coerce
def clone_base(self):
res = Func(self.name, *self.arguments)
res._coerce = self._coerce
return res
def over(self, partition_by=None, order_by=None, window=None):
if isinstance(partition_by, Window) and window is None:
window = partition_by
if window is None:
sql = Window(
partition_by=partition_by, order_by=order_by).__sql__()
else:
sql = SQL(window._alias)
return Clause(self, SQL('OVER'), sql)
def __getattr__(self, attr):
def dec(*args, **kwargs):
return Func(attr, *args, **kwargs)
return dec
# fn is a factory for creating `Func` objects and supports a more friendly
# API. So instead of `Func("LOWER", param)`, `fn.LOWER(param)`.
fn = Func(None)
class Expression(Node):
"""A binary expression, e.g `foo + 1` or `bar < 7`."""
_node_type = 'expression'
def __init__(self, lhs, op, rhs, flat=False):
super(Expression, self).__init__()
self.lhs = lhs
self.op = op
self.rhs = rhs
self.flat = flat
def clone_base(self):
return Expression(self.lhs, self.op, self.rhs, self.flat)
class Param(Node):
"""
Arbitrary parameter passed into a query. Instructs the query compiler to
specifically treat this value as a parameter, useful for `list` which is
special-cased for `IN` lookups.
"""
_node_type = 'param'
def __init__(self, value, conv=None):
self.value = value
self.conv = conv
super(Param, self).__init__()
def clone_base(self):
return Param(self.value, self.conv)
class Passthrough(Param):
_node_type = 'passthrough'
class Clause(Node):
"""A SQL clause, one or more Node objects joined by spaces."""
_node_type = 'clause'
glue = ' '
parens = False
def __init__(self, *nodes, **kwargs):
if 'glue' in kwargs:
self.glue = kwargs['glue']
if 'parens' in kwargs:
self.parens = kwargs['parens']
super(Clause, self).__init__()
self.nodes = list(nodes)
def clone_base(self):
clone = Clause(*self.nodes)
clone.glue = self.glue
clone.parens = self.parens
return clone
class CommaClause(Clause):
"""One or more Node objects joined by commas, no parens."""
glue = ', '
class EnclosedClause(CommaClause):
"""One or more Node objects joined by commas and enclosed in parens."""
parens = True
class Window(Node):
def __init__(self, partition_by=None, order_by=None):
super(Window, self).__init__()
self.partition_by = partition_by
self.order_by = order_by
self._alias = self._alias or 'w'
def __sql__(self):
over_clauses = []
if self.partition_by:
over_clauses.append(Clause(
SQL('PARTITION BY'),
CommaClause(*self.partition_by)))
if self.order_by:
over_clauses.append(Clause(
SQL('ORDER BY'),
CommaClause(*self.order_by)))
return EnclosedClause(Clause(*over_clauses))
def clone_base(self):
return Window(self.partition_by, self.order_by)
def Check(value):
return SQL('CHECK (%s)' % value)
class DQ(Node):
"""A "django-style" filter expression, e.g. {'foo__eq': 'x'}."""
def __init__(self, **query):
super(DQ, self).__init__()
self.query = query
def clone_base(self):
return DQ(**self.query)
class _StripParens(Node):
_node_type = 'strip_parens'
def __init__(self, node):
super(_StripParens, self).__init__()
self.node = node
JoinMetadata = namedtuple('JoinMetadata', (
'src_model', # Source Model class.
'dest_model', # Dest Model class.
'src', # Source, may be Model, ModelAlias
'dest', # Dest, may be Model, ModelAlias, or SelectQuery.
'attr', # Attribute name joined instance(s) should be assigned to.
'primary_key', # Primary key being joined on.
'foreign_key', # Foreign key being joined from.
'is_backref', # Is this a backref, i.e. 1 -> N.
'alias', # Explicit alias given to join expression.
'is_self_join', # Is this a self-join?
'is_expression', # Is the join ON clause an Expression?
))
class Join(namedtuple('_Join', ('src', 'dest', 'join_type', 'on'))):
def get_foreign_key(self, source, dest, field=None):
if isinstance(source, SelectQuery) or isinstance(dest, SelectQuery):
return None, None
fk_field = source._meta.rel_for_model(dest, field)
if fk_field is not None:
return fk_field, False
reverse_rel = source._meta.reverse_rel_for_model(dest, field)
if reverse_rel is not None:
return reverse_rel, True
return None, None
def get_join_type(self):
return self.join_type or JOIN.INNER
def model_from_alias(self, model_or_alias):
if isinstance(model_or_alias, ModelAlias):
return model_or_alias.model_class
elif isinstance(model_or_alias, SelectQuery):
return model_or_alias.model_class
return model_or_alias
def _join_metadata(self):
# Get the actual tables being joined.
src = self.model_from_alias(self.src)
dest = self.model_from_alias(self.dest)
join_alias = isinstance(self.on, Node) and self.on._alias or None
is_expression = isinstance(self.on, (Expression, Func, SQL))
on_field = isinstance(self.on, (Field, FieldProxy)) and self.on or None
if on_field:
fk_field = on_field
is_backref = on_field.name not in src._meta.fields
else:
fk_field, is_backref = self.get_foreign_key(src, dest, self.on)
if fk_field is None and self.on is not None:
fk_field, is_backref = self.get_foreign_key(src, dest)
if fk_field is not None:
primary_key = fk_field.to_field
else:
primary_key = None
if not join_alias:
if fk_field is not None:
if is_backref:
target_attr = dest._meta.db_table
else:
target_attr = fk_field.name
else:
try:
target_attr = self.on.lhs.name
except AttributeError:
target_attr = dest._meta.db_table
else:
target_attr = None
return JoinMetadata(
src_model=src,
dest_model=dest,
src=self.src,
dest=self.dest,
attr=join_alias or target_attr,
primary_key=primary_key,
foreign_key=fk_field,
is_backref=is_backref,
alias=join_alias,
is_self_join=src is dest,
is_expression=is_expression)
@property
def metadata(self):
if not hasattr(self, '_cached_metadata'):
self._cached_metadata = self._join_metadata()
return self._cached_metadata
class FieldDescriptor(object):
# Fields are exposed as descriptors in order to control access to the
# underlying "raw" data.
def __init__(self, field):
self.field = field
self.att_name = self.field.name
def __get__(self, instance, instance_type=None):
if instance is not None:
return instance._data.get(self.att_name)
return self.field
def __set__(self, instance, value):
instance._data[self.att_name] = value
instance._dirty.add(self.att_name)
class Field(Node):
"""A column on a table."""
_field_counter = 0
_order = 0
_node_type = 'field'
db_field = 'unknown'
def __init__(self, null=False, index=False, unique=False,
verbose_name=None, help_text=None, db_column=None,
default=None, choices=None, primary_key=False, sequence=None,
constraints=None, schema=None):
self.null = null
self.index = index
self.unique = unique
self.verbose_name = verbose_name
self.help_text = help_text
self.db_column = db_column
self.default = default
self.choices = choices # Used for metadata purposes, not enforced.
self.primary_key = primary_key
self.sequence = sequence # Name of sequence, e.g. foo_id_seq.
self.constraints = constraints # List of column constraints.
self.schema = schema # Name of schema, e.g. 'public'.
# Used internally for recovering the order in which Fields were defined
# on the Model class.
Field._field_counter += 1
self._order = Field._field_counter
self._sort_key = (self.primary_key and 1 or 2), self._order
self._is_bound = False # Whether the Field is "bound" to a Model.
super(Field, self).__init__()
def clone_base(self, **kwargs):
inst = type(self)(
null=self.null,
index=self.index,
unique=self.unique,
verbose_name=self.verbose_name,
help_text=self.help_text,
db_column=self.db_column,
default=self.default,
choices=self.choices,
primary_key=self.primary_key,
sequence=self.sequence,
constraints=self.constraints,
schema=self.schema,
**kwargs)
if self._is_bound:
inst.name = self.name
inst.model_class = self.model_class
inst._is_bound = self._is_bound
return inst
def add_to_class(self, model_class, name):
"""
Hook that replaces the `Field` attribute on a class with a named
`FieldDescriptor`. Called by the metaclass during construction of the
`Model`.
"""
self.name = name
self.model_class = model_class
self.db_column = self.db_column or self.name
if not self.verbose_name:
self.verbose_name = re.sub('_+', ' ', name).title()
model_class._meta.add_field(self)
setattr(model_class, name, FieldDescriptor(self))
self._is_bound = True
def get_database(self):
return self.model_class._meta.database
def get_column_type(self):
field_type = self.get_db_field()
return self.get_database().compiler().get_column_type(field_type)
def get_db_field(self):
return self.db_field
def get_modifiers(self):
return None
def coerce(self, value):
return value
def db_value(self, value):
"""Convert the python value for storage in the database."""
return value if value is None else self.coerce(value)
def python_value(self, value):
"""Convert the database value to a pythonic value."""
return value if value is None else self.coerce(value)
def as_entity(self, with_table=False):
if with_table:
return Entity(self.model_class._meta.db_table, self.db_column)
return Entity(self.db_column)
def __ddl_column__(self, column_type):
"""Return the column type, e.g. VARCHAR(255) or REAL."""
modifiers = self.get_modifiers()
if modifiers:
return SQL(
'%s(%s)' % (column_type, ', '.join(map(str, modifiers))))
return SQL(column_type)
def __ddl__(self, column_type):
"""Return a list of Node instances that defines the column."""
ddl = [self.as_entity(), self.__ddl_column__(column_type)]
if not self.null:
ddl.append(SQL('NOT NULL'))
if self.primary_key:
ddl.append(SQL('PRIMARY KEY'))
if self.sequence:
ddl.append(SQL("DEFAULT NEXTVAL('%s')" % self.sequence))
if self.constraints:
ddl.extend(self.constraints)
return ddl
def __hash__(self):
return hash(self.name + '.' + self.model_class.__name__)
class BareField(Field):
db_field = 'bare'
def __init__(self, coerce=None, *args, **kwargs):
super(BareField, self).__init__(*args, **kwargs)
if coerce is not None:
self.coerce = coerce
def clone_base(self, **kwargs):
return super(BareField, self).clone_base(coerce=self.coerce, **kwargs)
class IntegerField(Field):
db_field = 'int'
coerce = int
class BigIntegerField(IntegerField):
db_field = 'bigint'
class SmallIntegerField(IntegerField):
db_field = 'smallint'
class PrimaryKeyField(IntegerField):
db_field = 'primary_key'
def __init__(self, *args, **kwargs):
kwargs['primary_key'] = True
super(PrimaryKeyField, self).__init__(*args, **kwargs)
class FloatField(Field):
db_field = 'float'
coerce = float
class DoubleField(FloatField):
db_field = 'double'
class DecimalField(Field):
db_field = 'decimal'