diff --git a/Include/internal/pycore_intersectionobject.h b/Include/internal/pycore_intersectionobject.h new file mode 100644 index 000000000000000..14f2c422d054d20 --- /dev/null +++ b/Include/internal/pycore_intersectionobject.h @@ -0,0 +1,23 @@ +#ifndef Py_INTERNAL_INTERSECTIONOBJECT_H +#define Py_INTERNAL_INTERSECTIONOBJECT_H +#ifdef __cplusplus +extern "C" { +#endif + +#ifndef Py_BUILD_CORE +# error "this header requires Py_BUILD_CORE define" +#endif + +extern PyTypeObject _PyIntersection_Type; +#define _PyIntersection_Check(op) Py_IS_TYPE((op), &_PyIntersection_Type) +extern PyObject *_Py_intersection_type_and(PyObject *, PyObject *); + +#define _PyGenericAlias_Check(op) PyObject_TypeCheck((op), &Py_GenericAliasType) +extern PyObject *_Py_subs_parameters(PyObject *, PyObject *, PyObject *, PyObject *); +extern PyObject *_Py_make_parameters(PyObject *); +extern PyObject *_Py_intersection_args(PyObject *self); + +#ifdef __cplusplus +} +#endif +#endif /* !Py_INTERNAL_INTESECTIONOBJECT_H */ diff --git a/Lib/types.py b/Lib/types.py index 6110e6e1de7249e..d2d300f33716b79 100644 --- a/Lib/types.py +++ b/Lib/types.py @@ -328,6 +328,7 @@ def wrapped(*args, **kwargs): GenericAlias = type(list[int]) UnionType = type(int | str) +IntersectionType = type(int & str) EllipsisType = type(Ellipsis) NoneType = type(None) diff --git a/Lib/typing.py b/Lib/typing.py index 387b4c5ad5284b7..58fcd6677f4bd29 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -317,7 +317,7 @@ def _deduplicate(params): return params -def _remove_dups_flatten(parameters): +def _remove_dups_flatten_union(parameters): """Internal helper for Union creation and substitution. Flatten Unions among parameters, then remove duplicates. @@ -333,6 +333,22 @@ def _remove_dups_flatten(parameters): return tuple(_deduplicate(params)) +def _remove_dups_flatten_intersection(parameters): + """Internal helper for Intersection creation and substitution. + + Flatten Intersections among parameters, then remove duplicates. + """ + # Flatten out Intersection[Intersection[...], ...]. + params = [] + for p in parameters: + if isinstance(p, (_IntersectionGenericAlias, types.IntersectionType)): + params.extend(p.__args__) + else: + params.append(p) + + return tuple(_deduplicate(params)) + + def _flatten_literal_params(parameters): """Internal helper for Literal creation: flatten Literals among parameters.""" params = [] @@ -474,6 +490,12 @@ def __or__(self, other): def __ror__(self, other): return Union[other, self] + def __and__(self, other): + return Intersection[self, other] + + def __rand__(self, other): + return Intersection[other, self] + def __instancecheck__(self, obj): raise TypeError(f"{self} cannot be used with isinstance()") @@ -696,7 +718,7 @@ def Union(self, parameters): parameters = (parameters,) msg = "Union[arg, ...]: each arg must be a type." parameters = tuple(_type_check(p, msg) for p in parameters) - parameters = _remove_dups_flatten(parameters) + parameters = _remove_dups_flatten_union(parameters) if len(parameters) == 1: return parameters[0] if len(parameters) == 2 and type(None) in parameters: @@ -712,6 +734,47 @@ def _make_union(left, right): """ return Union[left, right] +@_SpecialForm +def Intersection(self, parameters): + """Intersection type; Intersection[X, Y] means both X and Y. + + On Python 3.XX and higher, the & operator + can also be used to denote intersections; + X & Y means the same thing to the type checker as Intersection[X, Y]. + + To define an intersection, use e.g. Intersection[int, str]. Details: + - The arguments must be types and there must be at least one. + - None as an argument is a special case and is replaced by + type(None). + - Intersections of intersections are flattened, e.g.:: + + assert Intersection[Intersection[int, str], float] == Intersection[int, str, float] + + - Intersections of a single argument vanish, e.g.:: + + assert Intersection[int] == int # The constructor actually returns int + + - Redundant arguments are skipped, e.g.:: + + assert Intersection[int, str, int] == Intersection[int, str] + + - When comparing intersections, the argument order is ignored, e.g.:: + + assert Intersection[int, str] == Intersection[str, int] + + - You cannot subclass or instantiate an intersection. + """ + if parameters == (): + raise TypeError("Cannot take a Intersection of no types.") + if not isinstance(parameters, tuple): + parameters = (parameters,) + msg = "Intersection[arg, ...]: each arg must be a type." + parameters = tuple(_type_check(p, msg) for p in parameters) + parameters = _remove_dups_flatten_intersection(parameters) + if len(parameters) == 1: + return parameters[0] + return _IntersectionGenericAlias(self, parameters) + @_SpecialForm def Optional(self, parameters): """Optional[X] is equivalent to Union[X, None].""" @@ -922,6 +985,12 @@ def __or__(self, other): def __ror__(self, other): return Union[other, self] + def __and__(self, other): + return Intersection[self, other] + + def __rand__(self, other): + return Intersection[other, self] + def __repr__(self): if self.__forward_module__ is None: module_repr = '' @@ -1234,6 +1303,12 @@ def __or__(self, right): def __ror__(self, left): return Union[left, self] + def __and__(self, other): + return Intersection[self, other] + + def __rand__(self, other): + return Intersection[other, self] + @_tp_cache def __getitem__(self, args): # Parameterizes an already-parameterized object. @@ -1448,6 +1523,11 @@ def __or__(self, right): def __ror__(self, left): return Union[left, self] + def __and__(self, other): + return Intersection[self, other] + + def __rand__(self, other): + return Intersection[other, self] class _DeprecatedGenericAlias(_SpecialGenericAlias, _root=True): def __init__( @@ -1562,6 +1642,34 @@ def __reduce__(self): return func, (Union, args) +class _IntersectionGenericAlias(_NotIterable, _GenericAlias, _root=True): + def copy_with(self, params): + return Intersection[params] + + def __eq__(self, other): + if not isinstance(other, (_IntersectionGenericAlias, types.IntersectionType)): + return NotImplemented + return set(self.__args__) == set(other.__args__) + + def __hash__(self): + return hash(frozenset(self.__args__)) + + def __repr__(self): + return super().__repr__() + + def __instancecheck__(self, obj): + return self.__subclasscheck__(type(obj)) + + def __subclasscheck__(self, cls): + for arg in self.__args__: + if issubclass(cls, arg): + return True + + def __reduce__(self): + func, (origin, args) = super().__reduce__() + return func, (Intersection, args) + + def _value_and_type_iter(parameters): return ((p, type(p)) for p in parameters) @@ -3094,6 +3202,11 @@ def __or__(self, other): def __ror__(self, other): return Union[other, self] + def __and__(self, other): + return Intersection[self, other] + + def __rand__(self, other): + return Intersection[other, self] # Python-version-specific alias (Python 2: unicode; Python 3: str) Text = str diff --git a/Makefile.pre.in b/Makefile.pre.in index 3725feaca66ce3c..47a711f99ed4eb9 100644 --- a/Makefile.pre.in +++ b/Makefile.pre.in @@ -486,6 +486,7 @@ OBJECT_OBJS= \ Objects/unicodeobject.o \ Objects/unicodectype.o \ Objects/unionobject.o \ + Objects/intersectionobject.o \ Objects/weakrefobject.o \ @PERF_TRAMPOLINE_OBJ@ diff --git a/Objects/intersectionobject.c b/Objects/intersectionobject.c new file mode 100644 index 000000000000000..705824b01065a21 --- /dev/null +++ b/Objects/intersectionobject.c @@ -0,0 +1,407 @@ +// types.IntersectionType -- used to represent e.g. Intersection[int, str], int | str +#include "Python.h" +#include "pycore_object.h" // _PyObject_GC_TRACK/UNTRACK +#include "pycore_typevarobject.h" // _PyTypeAlias_Type +#include "pycore_intersectionobject.h" +#include "structmember.h" + + +static PyObject *make_intersection(PyObject *); + + +typedef struct { + PyObject_HEAD + PyObject *args; + PyObject *parameters; +} intersectionobject; + +static void +intersectionobject_dealloc(PyObject *self) +{ + intersectionobject *alias = (intersectionobject *)self; + + _PyObject_GC_UNTRACK(self); + + Py_XDECREF(alias->args); + Py_XDECREF(alias->parameters); + Py_TYPE(self)->tp_free(self); +} + +static int +intersection_traverse(PyObject *self, visitproc visit, void *arg) +{ + intersectionobject *alias = (intersectionobject *)self; + Py_VISIT(alias->args); + Py_VISIT(alias->parameters); + return 0; +} + +static Py_hash_t +intersection_hash(PyObject *self) +{ + intersectionobject *alias = (intersectionobject *)self; + PyObject *args = PyFrozenSet_New(alias->args); + if (args == NULL) { + return (Py_hash_t)-1; + } + Py_hash_t hash = PyObject_Hash(args); + Py_DECREF(args); + return hash; +} + +static PyObject * +intersection_richcompare(PyObject *a, PyObject *b, int op) +{ + if (!_PyIntersection_Check(b) || (op != Py_EQ && op != Py_NE)) { + Py_RETURN_NOTIMPLEMENTED; + } + + PyObject *a_set = PySet_New(((intersectionobject*)a)->args); + if (a_set == NULL) { + return NULL; + } + PyObject *b_set = PySet_New(((intersectionobject*)b)->args); + if (b_set == NULL) { + Py_DECREF(a_set); + return NULL; + } + PyObject *result = PyObject_RichCompare(a_set, b_set, op); + Py_DECREF(b_set); + Py_DECREF(a_set); + return result; +} + +static int +is_same(PyObject *left, PyObject *right) +{ + int is_ga = _PyGenericAlias_Check(left) && _PyGenericAlias_Check(right); + return is_ga ? PyObject_RichCompareBool(left, right, Py_EQ) : left == right; +} + +static int +contains(PyObject **items, Py_ssize_t size, PyObject *obj) +{ + for (int i = 0; i < size; i++) { + int is_duplicate = is_same(items[i], obj); + if (is_duplicate) { // -1 or 1 + return is_duplicate; + } + } + return 0; +} + +static PyObject * +merge(PyObject **items1, Py_ssize_t size1, + PyObject **items2, Py_ssize_t size2) +{ + PyObject *tuple = NULL; + Py_ssize_t pos = 0; + + for (int i = 0; i < size2; i++) { + PyObject *arg = items2[i]; + int is_duplicate = contains(items1, size1, arg); + if (is_duplicate < 0) { + Py_XDECREF(tuple); + return NULL; + } + if (is_duplicate) { + continue; + } + + if (tuple == NULL) { + tuple = PyTuple_New(size1 + size2 - i); + if (tuple == NULL) { + return NULL; + } + for (; pos < size1; pos++) { + PyObject *a = items1[pos]; + PyTuple_SET_ITEM(tuple, pos, Py_NewRef(a)); + } + } + PyTuple_SET_ITEM(tuple, pos, Py_NewRef(arg)); + pos++; + } + + if (tuple) { + (void) _PyTuple_Resize(&tuple, pos); + } + return tuple; +} + +static PyObject ** +get_types(PyObject **obj, Py_ssize_t *size) +{ + if (*obj == Py_None) { + *obj = (PyObject *)&_PyNone_Type; + } + if (_PyIntersection_Check(*obj)) { + PyObject *args = ((intersectionobject *) *obj)->args; + *size = PyTuple_GET_SIZE(args); + return &PyTuple_GET_ITEM(args, 0); + } + else { + *size = 1; + return obj; + } +} + +static int +is_intersectionable(PyObject *obj) +{ + if (obj == Py_None || + PyType_Check(obj) || + _PyGenericAlias_Check(obj) || + _PyIntersection_Check(obj) || + Py_IS_TYPE(obj, &_PyTypeAlias_Type)) { + return 1; + } + return 0; +} + +PyObject * +_Py_intersection_type_and(PyObject* self, PyObject* other) +{ + if (!is_intersectionable(self) || !is_intersectionable(other)) { + Py_RETURN_NOTIMPLEMENTED; + } + + Py_ssize_t size1, size2; + PyObject **items1 = get_types(&self, &size1); + PyObject **items2 = get_types(&other, &size2); + PyObject *tuple = merge(items1, size1, items2, size2); + if (tuple == NULL) { + if (PyErr_Occurred()) { + return NULL; + } + return Py_NewRef(self); + } + + PyObject *new_intersection = make_intersection(tuple); + Py_DECREF(tuple); + return new_intersection; +} + +static int +intersection_repr_item(_PyUnicodeWriter *writer, PyObject *p) +{ + PyObject *qualname = NULL; + PyObject *module = NULL; + PyObject *tmp; + PyObject *r = NULL; + int err; + + if (p == (PyObject *)&_PyNone_Type) { + return _PyUnicodeWriter_WriteASCIIString(writer, "None", 4); + } + + if (PyObject_GetOptionalAttr(p, &_Py_ID(__origin__), &tmp) < 0) { + goto exit; + } + + if (tmp) { + Py_DECREF(tmp); + if (PyObject_GetOptionalAttr(p, &_Py_ID(__args__), &tmp) < 0) { + goto exit; + } + if (tmp) { + // It looks like a GenericAlias + Py_DECREF(tmp); + goto use_repr; + } + } + + if (PyObject_GetOptionalAttr(p, &_Py_ID(__qualname__), &qualname) < 0) { + goto exit; + } + if (qualname == NULL) { + goto use_repr; + } + if (PyObject_GetOptionalAttr(p, &_Py_ID(__module__), &module) < 0) { + goto exit; + } + if (module == NULL || module == Py_None) { + goto use_repr; + } + + // Looks like a class + if (PyUnicode_Check(module) && + _PyUnicode_EqualToASCIIString(module, "builtins")) + { + // builtins don't need a module name + r = PyObject_Str(qualname); + goto exit; + } + else { + r = PyUnicode_FromFormat("%S.%S", module, qualname); + goto exit; + } + +use_repr: + r = PyObject_Repr(p); +exit: + Py_XDECREF(qualname); + Py_XDECREF(module); + if (r == NULL) { + return -1; + } + err = _PyUnicodeWriter_WriteStr(writer, r); + Py_DECREF(r); + return err; +} + +static PyObject * +intersection_repr(PyObject *self) +{ + intersectionobject *alias = (intersectionobject *)self; + Py_ssize_t len = PyTuple_GET_SIZE(alias->args); + + _PyUnicodeWriter writer; + _PyUnicodeWriter_Init(&writer); + for (Py_ssize_t i = 0; i < len; i++) { + if (i > 0 && _PyUnicodeWriter_WriteASCIIString(&writer, " & ", 3) < 0) { + goto error; + } + PyObject *p = PyTuple_GET_ITEM(alias->args, i); + if (intersection_repr_item(&writer, p) < 0) { + goto error; + } + } + return _PyUnicodeWriter_Finish(&writer); +error: + _PyUnicodeWriter_Dealloc(&writer); + return NULL; +} + +static PyMemberDef intersection_members[] = { + {"__args__", T_OBJECT, offsetof(intersectionobject, args), READONLY}, + {0} +}; + +static PyObject * +intersection_getitem(PyObject *self, PyObject *item) +{ + intersectionobject *alias = (intersectionobject *)self; + // Populate __parameters__ if needed. + if (alias->parameters == NULL) { + alias->parameters = _Py_make_parameters(alias->args); + if (alias->parameters == NULL) { + return NULL; + } + } + + PyObject *newargs = _Py_subs_parameters(self, alias->args, alias->parameters, item); + if (newargs == NULL) { + return NULL; + } + + PyObject *res; + Py_ssize_t nargs = PyTuple_GET_SIZE(newargs); + if (nargs == 0) { + res = make_intersection(newargs); + } + else { + res = Py_NewRef(PyTuple_GET_ITEM(newargs, 0)); + for (Py_ssize_t iarg = 1; iarg < nargs; iarg++) { + PyObject *arg = PyTuple_GET_ITEM(newargs, iarg); + Py_SETREF(res, PyNumber_Or(res, arg)); + if (res == NULL) { + break; + } + } + } + Py_DECREF(newargs); + return res; +} + +static PyMappingMethods intersection_as_mapping = { + .mp_subscript = intersection_getitem, +}; + +static PyObject * +intersection_parameters(PyObject *self, void *Py_UNUSED(unused)) +{ + intersectionobject *alias = (intersectionobject *)self; + if (alias->parameters == NULL) { + alias->parameters = _Py_make_parameters(alias->args); + if (alias->parameters == NULL) { + return NULL; + } + } + return Py_NewRef(alias->parameters); +} + +static PyGetSetDef intersection_properties[] = { + {"__parameters__", intersection_parameters, (setter)NULL, "Type variables in the types.IntersectionType.", NULL}, + {0} +}; + +static PyNumberMethods intersection_as_number = { + .nb_and = _Py_intersection_type_and, // Add __and__ function +}; + +static const char* const cls_attrs[] = { + "__module__", // Required for compatibility with typing module + NULL, +}; + +static PyObject * +intersection_getattro(PyObject *self, PyObject *name) +{ + intersectionobject *alias = (intersectionobject *)self; + if (PyUnicode_Check(name)) { + for (const char * const *p = cls_attrs; ; p++) { + if (*p == NULL) { + break; + } + if (_PyUnicode_EqualToASCIIString(name, *p)) { + return PyObject_GetAttr((PyObject *) Py_TYPE(alias), name); + } + } + } + return PyObject_GenericGetAttr(self, name); +} + +PyObject * +_Py_intersection_args(PyObject *self) +{ + assert(_PyIntersection_Check(self)); + return ((intersectionobject *) self)->args; +} + +PyTypeObject _PyIntersection_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + .tp_name = "types.IntersectionType", + .tp_doc = PyDoc_STR("Represent a PEP 604 intersection type\n" + "\n" + "E.g. for int & str"), + .tp_basicsize = sizeof(intersectionobject), + .tp_dealloc = intersectionobject_dealloc, + .tp_alloc = PyType_GenericAlloc, + .tp_free = PyObject_GC_Del, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, + .tp_traverse = intersection_traverse, + .tp_hash = intersection_hash, + .tp_getattro = intersection_getattro, + .tp_members = intersection_members, + .tp_richcompare = intersection_richcompare, + .tp_as_mapping = &intersection_as_mapping, + .tp_as_number = &intersection_as_number, + .tp_repr = intersection_repr, + .tp_getset = intersection_properties, +}; + +static PyObject * +make_intersection(PyObject *args) +{ + assert(PyTuple_CheckExact(args)); + + intersectionobject *result = PyObject_GC_New(intersectionobject, &_PyIntersection_Type); + if (result == NULL) { + return NULL; + } + + result->parameters = NULL; + result->args = Py_NewRef(args); + _PyObject_GC_TRACK(result); + return (PyObject*)result; +} diff --git a/Objects/typeobject.c b/Objects/typeobject.c index 7e5282cabd1bfb7..3d8c67638c705cb 100644 --- a/Objects/typeobject.c +++ b/Objects/typeobject.c @@ -1,24 +1,25 @@ /* Type object implementation */ #include "Python.h" -#include "pycore_abstract.h" // _PySequence_IterSearch() -#include "pycore_call.h" // _PyObject_VectorcallTstate() -#include "pycore_code.h" // CO_FAST_FREE -#include "pycore_dict.h" // _PyDict_KeysSize() -#include "pycore_frame.h" // _PyInterpreterFrame -#include "pycore_long.h" // _PyLong_IsNegative() -#include "pycore_memoryobject.h" // _PyMemoryView_FromBufferProc() -#include "pycore_modsupport.h" // _PyArg_NoKwnames() -#include "pycore_moduleobject.h" // _PyModule_GetDef() -#include "pycore_object.h" // _PyType_HasFeature() -#include "pycore_pyerrors.h" // _PyErr_Occurred() -#include "pycore_pystate.h" // _PyThreadState_GET() -#include "pycore_symtable.h" // _Py_Mangle() -#include "pycore_typeobject.h" // struct type_cache -#include "pycore_unionobject.h" // _Py_union_type_or -#include "pycore_weakref.h" // _PyWeakref_GET_REF() -#include "opcode.h" // MAKE_CELL -#include "structmember.h" // PyMemberDef +#include "pycore_abstract.h" // _PySequence_IterSearch() +#include "pycore_call.h" // _PyObject_VectorcallTstate() +#include "pycore_code.h" // CO_FAST_FREE +#include "pycore_dict.h" // _PyDict_KeysSize() +#include "pycore_frame.h" // _PyInterpreterFrame +#include "pycore_long.h" // _PyLong_IsNegative() +#include "pycore_memoryobject.h" // _PyMemoryView_FromBufferProc() +#include "pycore_modsupport.h" // _PyArg_NoKwnames() +#include "pycore_moduleobject.h" // _PyModule_GetDef() +#include "pycore_object.h" // _PyType_HasFeature() +#include "pycore_pyerrors.h" // _PyErr_Occurred() +#include "pycore_pystate.h" // _PyThreadState_GET() +#include "pycore_symtable.h" // _Py_Mangle() +#include "pycore_typeobject.h" // struct type_cache +#include "pycore_unionobject.h" // _Py_union_type_or +#include "pycore_intersectionobject.h" // _Py_intersection_type_and +#include "pycore_weakref.h" // _PyWeakref_GET_REF() +#include "opcode.h" // MAKE_CELL +#include "structmember.h" // PyMemberDef #include #include // ptrdiff_t @@ -5333,6 +5334,7 @@ type_is_gc(PyTypeObject *type) static PyNumberMethods type_as_number = { .nb_or = _Py_union_type_or, // Add __or__ function + .nb_and = _Py_intersection_type_and, // Add __and__ function }; PyTypeObject PyType_Type = {