Skip to content

Commit

Permalink
Add the IntersectionType
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasr8 committed Jul 24, 2023
1 parent f8b7fe2 commit c369029
Show file tree
Hide file tree
Showing 6 changed files with 567 additions and 20 deletions.
23 changes: 23 additions & 0 deletions Include/internal/pycore_intersectionobject.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef Py_INTERNAL_INTERSECTIONOBJECT_H
#define Py_INTERNAL_INTERSECTIONOBJECT_H
#ifdef __cplusplus
extern "C" {
#endif

#ifndef Py_BUILD_CORE
# error "this header requires Py_BUILD_CORE define"
#endif

extern PyTypeObject _PyIntersection_Type;
#define _PyIntersection_Check(op) Py_IS_TYPE((op), &_PyIntersection_Type)
extern PyObject *_Py_intersection_type_and(PyObject *, PyObject *);

#define _PyGenericAlias_Check(op) PyObject_TypeCheck((op), &Py_GenericAliasType)
extern PyObject *_Py_subs_parameters(PyObject *, PyObject *, PyObject *, PyObject *);
extern PyObject *_Py_make_parameters(PyObject *);
extern PyObject *_Py_intersection_args(PyObject *self);

#ifdef __cplusplus
}
#endif
#endif /* !Py_INTERNAL_INTESECTIONOBJECT_H */
1 change: 1 addition & 0 deletions Lib/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ def wrapped(*args, **kwargs):

GenericAlias = type(list[int])
UnionType = type(int | str)
IntersectionType = type(int & str)

EllipsisType = type(Ellipsis)
NoneType = type(None)
Expand Down
117 changes: 115 additions & 2 deletions Lib/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def _deduplicate(params):
return params


def _remove_dups_flatten(parameters):
def _remove_dups_flatten_union(parameters):
"""Internal helper for Union creation and substitution.
Flatten Unions among parameters, then remove duplicates.
Expand All @@ -333,6 +333,22 @@ def _remove_dups_flatten(parameters):
return tuple(_deduplicate(params))


def _remove_dups_flatten_intersection(parameters):
"""Internal helper for Intersection creation and substitution.
Flatten Intersections among parameters, then remove duplicates.
"""
# Flatten out Intersection[Intersection[...], ...].
params = []
for p in parameters:
if isinstance(p, (_IntersectionGenericAlias, types.IntersectionType)):
params.extend(p.__args__)
else:
params.append(p)

return tuple(_deduplicate(params))


def _flatten_literal_params(parameters):
"""Internal helper for Literal creation: flatten Literals among parameters."""
params = []
Expand Down Expand Up @@ -474,6 +490,12 @@ def __or__(self, other):
def __ror__(self, other):
return Union[other, self]

def __and__(self, other):
return Intersection[self, other]

def __rand__(self, other):
return Intersection[other, self]

def __instancecheck__(self, obj):
raise TypeError(f"{self} cannot be used with isinstance()")

Expand Down Expand Up @@ -696,7 +718,7 @@ def Union(self, parameters):
parameters = (parameters,)
msg = "Union[arg, ...]: each arg must be a type."
parameters = tuple(_type_check(p, msg) for p in parameters)
parameters = _remove_dups_flatten(parameters)
parameters = _remove_dups_flatten_union(parameters)
if len(parameters) == 1:
return parameters[0]
if len(parameters) == 2 and type(None) in parameters:
Expand All @@ -712,6 +734,47 @@ def _make_union(left, right):
"""
return Union[left, right]

@_SpecialForm
def Intersection(self, parameters):
"""Intersection type; Intersection[X, Y] means both X and Y.
On Python 3.XX and higher, the & operator
can also be used to denote intersections;
X & Y means the same thing to the type checker as Intersection[X, Y].
To define an intersection, use e.g. Intersection[int, str]. Details:
- The arguments must be types and there must be at least one.
- None as an argument is a special case and is replaced by
type(None).
- Intersections of intersections are flattened, e.g.::
assert Intersection[Intersection[int, str], float] == Intersection[int, str, float]
- Intersections of a single argument vanish, e.g.::
assert Intersection[int] == int # The constructor actually returns int
- Redundant arguments are skipped, e.g.::
assert Intersection[int, str, int] == Intersection[int, str]
- When comparing intersections, the argument order is ignored, e.g.::
assert Intersection[int, str] == Intersection[str, int]
- You cannot subclass or instantiate an intersection.
"""
if parameters == ():
raise TypeError("Cannot take a Intersection of no types.")
if not isinstance(parameters, tuple):
parameters = (parameters,)
msg = "Intersection[arg, ...]: each arg must be a type."
parameters = tuple(_type_check(p, msg) for p in parameters)
parameters = _remove_dups_flatten_intersection(parameters)
if len(parameters) == 1:
return parameters[0]
return _IntersectionGenericAlias(self, parameters)

@_SpecialForm
def Optional(self, parameters):
"""Optional[X] is equivalent to Union[X, None]."""
Expand Down Expand Up @@ -922,6 +985,12 @@ def __or__(self, other):
def __ror__(self, other):
return Union[other, self]

def __and__(self, other):
return Intersection[self, other]

def __rand__(self, other):
return Intersection[other, self]

def __repr__(self):
if self.__forward_module__ is None:
module_repr = ''
Expand Down Expand Up @@ -1234,6 +1303,12 @@ def __or__(self, right):
def __ror__(self, left):
return Union[left, self]

def __and__(self, other):
return Intersection[self, other]

def __rand__(self, other):
return Intersection[other, self]

@_tp_cache
def __getitem__(self, args):
# Parameterizes an already-parameterized object.
Expand Down Expand Up @@ -1448,6 +1523,11 @@ def __or__(self, right):
def __ror__(self, left):
return Union[left, self]

def __and__(self, other):
return Intersection[self, other]

def __rand__(self, other):
return Intersection[other, self]

class _DeprecatedGenericAlias(_SpecialGenericAlias, _root=True):
def __init__(
Expand Down Expand Up @@ -1562,6 +1642,34 @@ def __reduce__(self):
return func, (Union, args)


class _IntersectionGenericAlias(_NotIterable, _GenericAlias, _root=True):
def copy_with(self, params):
return Intersection[params]

def __eq__(self, other):
if not isinstance(other, (_IntersectionGenericAlias, types.IntersectionType)):
return NotImplemented
return set(self.__args__) == set(other.__args__)

def __hash__(self):
return hash(frozenset(self.__args__))

def __repr__(self):
return super().__repr__()

def __instancecheck__(self, obj):
return self.__subclasscheck__(type(obj))

def __subclasscheck__(self, cls):
for arg in self.__args__:
if issubclass(cls, arg):
return True

def __reduce__(self):
func, (origin, args) = super().__reduce__()
return func, (Intersection, args)


def _value_and_type_iter(parameters):
return ((p, type(p)) for p in parameters)

Expand Down Expand Up @@ -3094,6 +3202,11 @@ def __or__(self, other):
def __ror__(self, other):
return Union[other, self]

def __and__(self, other):
return Intersection[self, other]

def __rand__(self, other):
return Intersection[other, self]

# Python-version-specific alias (Python 2: unicode; Python 3: str)
Text = str
Expand Down
1 change: 1 addition & 0 deletions Makefile.pre.in
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ OBJECT_OBJS= \
Objects/unicodeobject.o \
Objects/unicodectype.o \
Objects/unionobject.o \
Objects/intersectionobject.o \
Objects/weakrefobject.o \
@PERF_TRAMPOLINE_OBJ@

Expand Down
Loading

0 comments on commit c369029

Please sign in to comment.