Skip to content

Commit

Permalink
Add the IntersectionType
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasr8 committed Jul 22, 2023
1 parent f8b7fe2 commit 2cce357
Show file tree
Hide file tree
Showing 5 changed files with 531 additions and 0 deletions.
23 changes: 23 additions & 0 deletions Include/internal/pycore_intersectionobject.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef Py_INTERNAL_INTERSECTIONOBJECT_H
#define Py_INTERNAL_INTERSECTIONOBJECT_H
#ifdef __cplusplus
extern "C" {
#endif

#ifndef Py_BUILD_CORE
# error "this header requires Py_BUILD_CORE define"
#endif

extern PyTypeObject _PyIntersection_Type;
#define _PyIntersection_Check(op) Py_IS_TYPE((op), &_PyIntersection_Type)
extern PyObject *_Py_intersection_type_and(PyObject *, PyObject *);

#define _PyGenericAlias_Check(op) PyObject_TypeCheck((op), &Py_GenericAliasType)
extern PyObject *_Py_subs_parameters(PyObject *, PyObject *, PyObject *, PyObject *);
extern PyObject *_Py_make_parameters(PyObject *);
extern PyObject *_Py_intersection_args(PyObject *self);

#ifdef __cplusplus
}
#endif
#endif /* !Py_INTERNAL_INTESECTIONOBJECT_H */
98 changes: 98 additions & 0 deletions Lib/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,12 @@ def __or__(self, other):
def __ror__(self, other):
return Union[other, self]

def __and__(self, other):
return Intersection[self, other]

def __rand__(self, other):
return Intersection[other, self]

def __instancecheck__(self, obj):
raise TypeError(f"{self} cannot be used with isinstance()")

Expand Down Expand Up @@ -712,6 +718,48 @@ def _make_union(left, right):
"""
return Union[left, right]

@_SpecialForm
def Intersection(self, parameters):
"""Union type; Union[X, Y] means either X or Y.
On Python 3.10 and higher, the | operator
can also be used to denote unions;
X | Y means the same thing to the type checker as Union[X, Y].
To define a union, use e.g. Union[int, str]. Details:
- The arguments must be types and there must be at least one.
- None as an argument is a special case and is replaced by
type(None).
- Unions of unions are flattened, e.g.::
assert Union[Union[int, str], float] == Union[int, str, float]
- Unions of a single argument vanish, e.g.::
assert Union[int] == int # The constructor actually returns int
- Redundant arguments are skipped, e.g.::
assert Union[int, str, int] == Union[int, str]
- When comparing unions, the argument order is ignored, e.g.::
assert Union[int, str] == Union[str, int]
- You cannot subclass or instantiate a union.
- You can use Optional[X] as a shorthand for Union[X, None].
"""
if parameters == ():
raise TypeError("Cannot take a Intersection of no types.")
if not isinstance(parameters, tuple):
parameters = (parameters,)
msg = "Intersection[arg, ...]: each arg must be a type."
parameters = tuple(_type_check(p, msg) for p in parameters)
parameters = _remove_dups_flatten(parameters)
if len(parameters) == 1:
return parameters[0]
return _UnionGenericAlias(self, parameters)

@_SpecialForm
def Optional(self, parameters):
"""Optional[X] is equivalent to Union[X, None]."""
Expand Down Expand Up @@ -921,6 +969,12 @@ def __or__(self, other):

def __ror__(self, other):
return Union[other, self]

def __and__(self, other):
return Intersection[self, other]

def __rand__(self, other):
return Intersection[other, self]

def __repr__(self):
if self.__forward_module__ is None:
Expand Down Expand Up @@ -1234,6 +1288,12 @@ def __or__(self, right):
def __ror__(self, left):
return Union[left, self]

def __and__(self, other):
return Intersection[self, other]

def __rand__(self, other):
return Intersection[other, self]

@_tp_cache
def __getitem__(self, args):
# Parameterizes an already-parameterized object.
Expand Down Expand Up @@ -1448,6 +1508,11 @@ def __or__(self, right):
def __ror__(self, left):
return Union[left, self]

def __and__(self, other):
return Intersection[self, other]

def __rand__(self, other):
return Intersection[other, self]

class _DeprecatedGenericAlias(_SpecialGenericAlias, _root=True):
def __init__(
Expand Down Expand Up @@ -1562,6 +1627,34 @@ def __reduce__(self):
return func, (Union, args)


class _IntersectionGenericAlias(_NotIterable, _GenericAlias, _root=True):
def copy_with(self, params):
return Intersection[params]

def __eq__(self, other):
if not isinstance(other, (_IntersectionGenericAlias, types.IntersectionType)):
return NotImplemented
return set(self.__args__) == set(other.__args__)

def __hash__(self):
return hash(frozenset(self.__args__))

def __repr__(self):
return super().__repr__()

def __instancecheck__(self, obj):
return self.__subclasscheck__(type(obj))

def __subclasscheck__(self, cls):
for arg in self.__args__:
if issubclass(cls, arg):
return True

def __reduce__(self):
func, (origin, args) = super().__reduce__()
return func, (Union, args)


def _value_and_type_iter(parameters):
return ((p, type(p)) for p in parameters)

Expand Down Expand Up @@ -3094,6 +3187,11 @@ def __or__(self, other):
def __ror__(self, other):
return Union[other, self]

def __and__(self, other):
return Intersection[self, other]

def __rand__(self, other):
return Intersection[other, self]

# Python-version-specific alias (Python 2: unicode; Python 3: str)
Text = str
Expand Down
1 change: 1 addition & 0 deletions Makefile.pre.in
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ OBJECT_OBJS= \
Objects/unicodeobject.o \
Objects/unicodectype.o \
Objects/unionobject.o \
Objects/intersectionobject.o \
Objects/weakrefobject.o \
@PERF_TRAMPOLINE_OBJ@

Expand Down
Loading

0 comments on commit 2cce357

Please sign in to comment.