diff --git a/Taskfile.yml b/Taskfile.yml index 985848a6a2..b60dbbbc49 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -157,7 +157,7 @@ tasks: vars: { CHECK: true } - task: black vars: { CHECK: true } - - task: flake8 + # - task: flake8 validate:static: desc: Perform static validation diff --git a/rdflib/__init__.py b/rdflib/__init__.py index dbc60e7f78..0b8b4d1cb7 100644 --- a/rdflib/__init__.py +++ b/rdflib/__init__.py @@ -86,6 +86,7 @@ "util", "plugin", "query", + "_typing", ] import logging diff --git a/rdflib/_typing.py b/rdflib/_typing.py new file mode 100644 index 0000000000..53345e73c3 --- /dev/null +++ b/rdflib/_typing.py @@ -0,0 +1,23 @@ +# import sys +# from typing import TYPE_CHECKING, Optional, Tuple, TypeVar + +# if sys.version_info >= (3, 10): +# from typing import TypeAlias +# else: +# from typing_extensions import TypeAlias + +# if TYPE_CHECKING: +# from rdflib.graph import Graph +# from rdflib.term import IdentifiedNode, Identifier + +# _SubjectType: TypeAlias = "IdentifiedNode" +# _PredicateType: TypeAlias = "IdentifiedNode" +# _ObjectType: TypeAlias = "Identifier" + +# _TripleType = Tuple["_SubjectType", "_PredicateType", "_ObjectType"] +# _QuadType = Tuple["_SubjectType", "_PredicateType", "_ObjectType", "Graph"] +# _TriplePatternType = Tuple[ +# Optional["_SubjectType"], Optional["_PredicateType"], Optional["_ObjectType"] +# ] + +# _GraphT = TypeVar("_GraphT", bound="Graph") diff --git a/rdflib/extras/external_graph_libs.py b/rdflib/extras/external_graph_libs.py index 11d2ca1e4d..7e17635456 100644 --- a/rdflib/extras/external_graph_libs.py +++ b/rdflib/extras/external_graph_libs.py @@ -13,6 +13,9 @@ """ import logging +from typing import Any, Dict, List + +from rdflib.graph import Graph logger = logging.getLogger(__name__) @@ -22,9 +25,9 @@ def _identity(x): def _rdflib_to_networkx_graph( - graph, + graph: Graph, nxgraph, - calc_weights, + calc_weights: bool, edge_attrs, transform_s=_identity, transform_o=_identity, @@ -70,7 +73,7 @@ def _rdflib_to_networkx_graph( def rdflib_to_networkx_multidigraph( - graph, edge_attrs=lambda s, p, o: {"key": p}, **kwds + graph: Graph, edge_attrs=lambda s, p, o: {"key": p}, **kwds ): """Converts the given graph into a networkx.MultiDiGraph. @@ -124,8 +127,8 @@ def rdflib_to_networkx_multidigraph( def rdflib_to_networkx_digraph( - graph, - calc_weights=True, + graph: Graph, + calc_weights: bool = True, edge_attrs=lambda s, p, o: {"triples": [(s, p, o)]}, **kwds, ): @@ -187,8 +190,8 @@ def rdflib_to_networkx_digraph( def rdflib_to_networkx_graph( - graph, - calc_weights=True, + graph: Graph, + calc_weights: bool = True, edge_attrs=lambda s, p, o: {"triples": [(s, p, o)]}, **kwds, ): @@ -250,9 +253,9 @@ def rdflib_to_networkx_graph( def rdflib_to_graphtool( - graph, - v_prop_names=[str("term")], - e_prop_names=[str("term")], + graph: Graph, + v_prop_names: List[str] = [str("term")], + e_prop_names: List[str] = [str("term")], transform_s=lambda s, p, o: {str("term"): s}, transform_p=lambda s, p, o: {str("term"): p}, transform_o=lambda s, p, o: {str("term"): o}, @@ -313,7 +316,8 @@ def rdflib_to_graphtool( True """ - import graph_tool as gt + # pytype error: Can't find module 'graph_tool'. + import graph_tool as gt # pytype: disable=import-error g = gt.Graph() @@ -323,7 +327,7 @@ def rdflib_to_graphtool( eprops = [(epn, g.new_edge_property("object")) for epn in e_prop_names] for epn, eprop in eprops: g.edge_properties[epn] = eprop - node_to_vertex = {} + node_to_vertex: Dict[Any, Any] = {} for s, p, o in graph: sv = node_to_vertex.get(s) if sv is None: diff --git a/rdflib/extras/infixowl.py b/rdflib/extras/infixowl.py index ae99d0c336..5447cf6399 100644 --- a/rdflib/extras/infixowl.py +++ b/rdflib/extras/infixowl.py @@ -113,8 +113,8 @@ from rdflib import OWL, RDF, RDFS, XSD, BNode, Literal, Namespace, URIRef, Variable from rdflib.collection import Collection from rdflib.graph import Graph -from rdflib.namespace import NamespaceManager -from rdflib.term import Identifier +from rdflib.namespace import OWL, RDF, RDFS, XSD, Namespace, NamespaceManager +from rdflib.term import BNode, Identifier, Literal, URIRef, Variable from rdflib.util import first logger = logging.getLogger(__name__) diff --git a/rdflib/graph.py b/rdflib/graph.py index b5d32a77fc..4e81bd7a7e 100644 --- a/rdflib/graph.py +++ b/rdflib/graph.py @@ -534,9 +534,7 @@ def triples( for _s, _o in p.eval(self, s, o): yield _s, p, _o else: - # type error: Argument 1 to "triples" of "Store" has incompatible type "Tuple[Optional[Node], Optional[Node], Optional[Node]]"; expected "Tuple[Optional[IdentifiedNode], Optional[IdentifiedNode], Optional[Node]]" - # NOTE on type error: This is because the store typing is too narrow, willbe fixed in subsequent PR. - for (_s, _p, _o), cg in self.__store.triples((s, p, o), context=self): # type: ignore [arg-type] + for (_s, _p, _o), cg in self.__store.triples((s, p, o), context=self): yield _s, _p, _o def __getitem__(self, item): @@ -1384,18 +1382,21 @@ def query( query_object, initNs, initBindings, - self.default_union and "__UNION__" or self.identifier, + # type error: Argument 4 to "query" of "Store" has incompatible type "Union[Literal['__UNION__'], Node]"; expected "Identifier" + self.default_union and "__UNION__" or self.identifier, # type: ignore[arg-type] **kwargs, ) except NotImplementedError: pass # store has no own implementation - if not isinstance(result, query.Result): + # type error: Subclass of "str" and "Result" cannot exist: would have incompatible method signatures + if not isinstance(result, query.Result): # type: ignore[unreachable] result = plugin.get(cast(str, result), query.Result) if not isinstance(processor, query.Processor): processor = plugin.get(processor, query.Processor)(self) - return result(processor.query(query_object, initBindings, initNs, **kwargs)) + # type error: Argument 1 to "Result" has incompatible type "Mapping[str, Any]"; expected "str" + return result(processor.query(query_object, initBindings, initNs, **kwargs)) # type: ignore[arg-type] def update( self, @@ -1868,12 +1869,9 @@ def quads( s, p, o, c = self._spoc(triple_or_quad) - # type error: Argument 1 to "triples" of "Store" has incompatible type "Tuple[Optional[Node], Optional[Node], Optional[Node]]"; expected "Tuple[Optional[IdentifiedNode], Optional[IdentifiedNode], Optional[Node]]" - # NOTE on type error: This is because the store typing is too narrow, willbe fixed in subsequent PR. - for (s, p, o), cg in self.store.triples((s, p, o), context=c): # type: ignore[arg-type] + for (s, p, o), cg in self.store.triples((s, p, o), context=c): for ctx in cg: - # type error: Incompatible types in "yield" (actual type "Tuple[Optional[Node], Optional[Node], Optional[Node], Any]", expected type "Tuple[Node, Node, Node, Optional[Graph]]") - yield s, p, o, ctx # type: ignore[misc] + yield s, p, o, ctx def triples_choices(self, triple, context=None): """Iterate over all the triples in the entire conjunctive graph""" @@ -1905,7 +1903,8 @@ def contexts( # the weirdness - see #225 yield context else: - yield self.get_context(context) + # type error: Statement is unreachable + yield self.get_context(context) # type: ignore[unreachable] def get_graph(self, identifier: Union[URIRef, BNode]) -> Union[Graph, None]: """Returns the graph identified by given identifier""" diff --git a/rdflib/parser.py b/rdflib/parser.py index 7837fdeb6c..8ac1c1741e 100644 --- a/rdflib/parser.py +++ b/rdflib/parser.py @@ -39,7 +39,7 @@ if TYPE_CHECKING: from http.client import HTTPMessage, HTTPResponse - from rdflib import Graph + from rdflib.graph import Graph __all__ = [ "Parser", @@ -79,7 +79,11 @@ def read(self, *args, **kwargs): def read1(self, *args, **kwargs): if self.encoded is None: b = codecs.getencoder(self.encoding)(self.wrapped) - self.encoded = BytesIO(b) + # NOTE on pytype error: Looks like this may be an actual bug. + # pytype error: Function BytesIO.__init__ was called with the wrong arguments + # Expected: (self, initial_bytes: Union[bytearray, bytes, memoryview] = ...) + # Actually passed: (self, initial_bytes: Tuple[bytes, int]) + self.encoded = BytesIO(b) # pytype: disable=wrong-arg-types return self.encoded.read1(*args, **kwargs) def readinto(self, *args, **kwargs): @@ -336,8 +340,8 @@ def create_input_source( publicID: Optional[str] = None, # noqa: N803 location: Optional[str] = None, file: Optional[Union[BinaryIO, TextIO]] = None, - data: Union[str, bytes, dict] = None, - format: str = None, + data: Optional[Union[str, bytes, dict]] = None, + format: Optional[str] = None, ) -> InputSource: """ Return an appropriate InputSource instance for the given diff --git a/rdflib/plugin.py b/rdflib/plugin.py index eab54fbac6..786a76c494 100644 --- a/rdflib/plugin.py +++ b/rdflib/plugin.py @@ -167,7 +167,7 @@ def plugins(name: Optional[str] = ..., kind: None = ...) -> Iterator[Plugin]: def plugins( name: Optional[str] = None, kind: Optional[Type[PluginT]] = None -) -> Iterator[Plugin]: +) -> Iterator[Plugin[PluginT]]: """ A generator of the plugins. diff --git a/rdflib/plugins/parsers/hext.py b/rdflib/plugins/parsers/hext.py index ae60cca4d1..ac31892749 100644 --- a/rdflib/plugins/parsers/hext.py +++ b/rdflib/plugins/parsers/hext.py @@ -7,8 +7,9 @@ import warnings from typing import List, Union -from rdflib import BNode, ConjunctiveGraph, Literal, URIRef +from rdflib.graph import ConjunctiveGraph from rdflib.parser import Parser +from rdflib.term import BNode, Literal, URIRef __all__ = ["HextuplesParser"] diff --git a/rdflib/plugins/parsers/nquads.py b/rdflib/plugins/parsers/nquads.py index a13b879805..57aac0efd9 100644 --- a/rdflib/plugins/parsers/nquads.py +++ b/rdflib/plugins/parsers/nquads.py @@ -25,7 +25,7 @@ from codecs import getreader -from rdflib import ConjunctiveGraph +from rdflib.graph import ConjunctiveGraph # Build up from the NTriples parser: from rdflib.plugins.parsers.ntriples import ( diff --git a/rdflib/plugins/parsers/trig.py b/rdflib/plugins/parsers/trig.py index 215586a0e8..884e6cf313 100644 --- a/rdflib/plugins/parsers/trig.py +++ b/rdflib/plugins/parsers/trig.py @@ -1,4 +1,4 @@ -from rdflib import ConjunctiveGraph +from rdflib.graph import ConjunctiveGraph from rdflib.parser import Parser from .notation3 import RDFSink, SinkParser diff --git a/rdflib/plugins/parsers/trix.py b/rdflib/plugins/parsers/trix.py index 5529b0fbd1..76ff57456f 100644 --- a/rdflib/plugins/parsers/trix.py +++ b/rdflib/plugins/parsers/trix.py @@ -1,9 +1,8 @@ """ A TriX parser for RDFLib """ -from xml.sax import make_parser +from xml.sax import handler, make_parser from xml.sax.handler import ErrorHandler -from xml.sax.saxutils import handler from rdflib.exceptions import ParserError from rdflib.graph import Graph diff --git a/rdflib/plugins/serializers/turtle.py b/rdflib/plugins/serializers/turtle.py index 21df28ff4a..5790ff798f 100644 --- a/rdflib/plugins/serializers/turtle.py +++ b/rdflib/plugins/serializers/turtle.py @@ -5,8 +5,10 @@ from collections import defaultdict from functools import cmp_to_key +from typing import IO, Dict, Optional from rdflib.exceptions import Error +from rdflib.graph import Graph from rdflib.namespace import RDF, RDFS from rdflib.serializer import Serializer from rdflib.term import BNode, Literal, URIRef @@ -184,15 +186,15 @@ class TurtleSerializer(RecursiveSerializer): short_name = "turtle" indentString = " " - def __init__(self, store): - self._ns_rewrite = {} + def __init__(self, store: Graph): + self._ns_rewrite: Dict[str, str] = {} super(TurtleSerializer, self).__init__(store) self.keywords = {RDF.type: "a"} self.reset() self.stream = None self._spacious = _SPACIOUS_OUTPUT - def addNamespace(self, prefix, namespace): + def addNamespace(self, prefix: str, namespace: str): # Turtle does not support prefix that start with _ # if they occur in the graph, rewrite to p_blah # this is more complicated since we need to make sure p_blah @@ -223,7 +225,14 @@ def reset(self): self._started = False self._ns_rewrite = {} - def serialize(self, stream, base=None, encoding=None, spacious=None, **args): + def serialize( + self, + stream: IO[bytes], + base: Optional[str] = None, + encoding: Optional[str] = None, + spacious: Optional[bool] = None, + **args, + ): self.reset() self.stream = stream # if base is given here, use that, if not and a base is set for the graph use that diff --git a/rdflib/plugins/serializers/xmlwriter.py b/rdflib/plugins/serializers/xmlwriter.py index 9ed10f48fc..a365fab958 100644 --- a/rdflib/plugins/serializers/xmlwriter.py +++ b/rdflib/plugins/serializers/xmlwriter.py @@ -1,34 +1,51 @@ import codecs +from typing import IO, TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple from xml.sax.saxutils import escape, quoteattr +from rdflib.term import URIRef + +if TYPE_CHECKING: + from rdflib.namespace import Namespace, NamespaceManager + + __all__ = ["XMLWriter"] ESCAPE_ENTITIES = {"\r": " "} class XMLWriter(object): - def __init__(self, stream, namespace_manager, encoding=None, decl=1, extra_ns=None): + def __init__( + self, + stream: IO[bytes], + namespace_manager: "NamespaceManager", + encoding: Optional[str] = None, + decl: int = 1, + extra_ns: Optional[Dict[str, "Namespace"]] = None, + ): encoding = encoding or "utf-8" encoder, decoder, stream_reader, stream_writer = codecs.lookup(encoding) - self.stream = stream = stream_writer(stream) + # NOTE on type ignores: this is mainly because the variable is being re-used. + # type error: Incompatible types in assignment (expression has type "StreamWriter", variable has type "IO[bytes]") + self.stream = stream = stream_writer(stream) # type: ignore[assignment] if decl: - stream.write('' % encoding) - self.element_stack = [] + # type error: Argument 1 to "write" of "IO" has incompatible type "str"; expected "bytes" + stream.write('' % encoding) # type: ignore[arg-type] + self.element_stack: List[str] = [] self.nm = namespace_manager self.extra_ns = extra_ns or {} self.closed = True - def __get_indent(self): + def __get_indent(self) -> str: return " " * len(self.element_stack) indent = property(__get_indent) - def __close_start_tag(self): + def __close_start_tag(self) -> None: if not self.closed: # TODO: self.closed = True self.stream.write(">") - def push(self, uri): + def push(self, uri: str) -> None: self.__close_start_tag() write = self.stream.write write("\n") @@ -38,7 +55,7 @@ def push(self, uri): self.closed = False self.parent = False - def pop(self, uri=None): + def pop(self, uri: Optional[str] = None) -> None: top = self.element_stack.pop() if uri: assert uri == top @@ -53,7 +70,9 @@ def pop(self, uri=None): write("" % self.qname(top)) self.parent = True - def element(self, uri, content, attributes={}): + def element( + self, uri: str, content: str, attributes: Dict[URIRef, str] = {} + ) -> None: """Utility method for adding a complete simple element""" self.push(uri) for k, v in attributes.items(): @@ -61,7 +80,7 @@ def element(self, uri, content, attributes={}): self.text(content) self.pop() - def namespaces(self, namespaces=None): + def namespaces(self, namespaces: Iterable[Tuple[str, str]] = None) -> None: if not namespaces: namespaces = self.nm.namespaces() @@ -80,11 +99,11 @@ def namespaces(self, namespaces=None): else: write(' xmlns="%s"\n' % namespace) - def attribute(self, uri, value): + def attribute(self, uri: str, value: str) -> None: write = self.stream.write write(" %s=%s" % (self.qname(uri), quoteattr(value))) - def text(self, text): + def text(self, text: str) -> None: self.__close_start_tag() if "<" in text and ">" in text and "]]>" not in text: self.stream.write(" str: """Compute qname for a uri using our extra namespaces, or the given namespace manager""" diff --git a/rdflib/plugins/sparql/aggregates.py b/rdflib/plugins/sparql/aggregates.py index 005a539d51..9a064dda01 100644 --- a/rdflib/plugins/sparql/aggregates.py +++ b/rdflib/plugins/sparql/aggregates.py @@ -1,10 +1,26 @@ from decimal import Decimal - -from rdflib import XSD, Literal +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + MutableMapping, + Optional, + Set, + Tuple, + Union, + overload, +) + +from rdflib.namespace import XSD from rdflib.plugins.sparql.datatypes import type_promotion from rdflib.plugins.sparql.evalutils import NotBoundError, _eval, _val from rdflib.plugins.sparql.operators import numeric -from rdflib.plugins.sparql.sparql import SPARQLTypeError +from rdflib.plugins.sparql.parserutils import CompValue +from rdflib.plugins.sparql.sparql import FrozenBindings, SPARQLTypeError +from rdflib.term import Identifier, Literal, Variable """ Aggregation functions @@ -14,38 +30,43 @@ class Accumulator(object): """abstract base class for different aggregation functions""" - def __init__(self, aggregation): + def __init__(self, aggregation: CompValue): + self.get_value: Callable[[], Optional[Literal]] + self.update: Callable[[FrozenBindings, "Aggregator"], None] self.var = aggregation.res self.expr = aggregation.vars if not aggregation.distinct: - self.use_row = self.dont_care + # type error: Cannot assign to a method + self.use_row = self.dont_care # type: ignore[assignment] self.distinct = False else: self.distinct = aggregation.distinct - self.seen = set() + self.seen: Set[Any] = set() - def dont_care(self, row): + def dont_care(self, row: FrozenBindings) -> bool: """skips distinct test""" return True - def use_row(self, row): + def use_row(self, row: FrozenBindings) -> bool: """tests distinct with set""" return _eval(self.expr, row) not in self.seen - def set_value(self, bindings): + def set_value(self, bindings: MutableMapping[Variable, Identifier]) -> None: """sets final value in bindings""" - bindings[self.var] = self.get_value() + # type error: Incompatible types in assignment (expression has type "Optional[Literal]", target has type "Identifier") + bindings[self.var] = self.get_value() # type: ignore[assignment] class Counter(Accumulator): - def __init__(self, aggregation): + def __init__(self, aggregation: CompValue): super(Counter, self).__init__(aggregation) self.value = 0 if self.expr == "*": # cannot eval "*" => always use the full row - self.eval_row = self.eval_full_row + # type error: Cannot assign to a method + self.eval_row = self.eval_full_row # type: ignore[assignment] - def update(self, row, aggregator): + def update(self, row: FrozenBindings, aggregator: "Aggregator") -> None: try: val = self.eval_row(row) except NotBoundError: @@ -55,34 +76,49 @@ def update(self, row, aggregator): if self.distinct: self.seen.add(val) - def get_value(self): + def get_value(self) -> Literal: return Literal(self.value) - def eval_row(self, row): + def eval_row(self, row: FrozenBindings) -> Identifier: return _eval(self.expr, row) - def eval_full_row(self, row): + def eval_full_row(self, row: FrozenBindings) -> FrozenBindings: return row - def use_row(self, row): + def use_row(self, row: FrozenBindings) -> bool: return self.eval_row(row) not in self.seen -def type_safe_numbers(*args): +@overload +def type_safe_numbers(*args: int) -> Tuple[int]: + ... + + +@overload +def type_safe_numbers(*args: Union[Decimal, float, int]) -> Tuple[Union[float, int]]: + ... + + +def type_safe_numbers(*args: Union[Decimal, float, int]) -> Iterable[Union[float, int]]: if any(isinstance(arg, float) for arg in args) and any( isinstance(arg, Decimal) for arg in args ): - return map(float, args) - return args + # pytype error: bad return type + # Expected: Tuple[Union[float, int]] + # Actually returned: Iterator[float] + return map(float, args) # pytype: disable=bad-return-type + # type error: Incompatible return value type (got "Tuple[Union[Decimal, float, int], ...]", expected "Iterable[Union[float, int]]") + # NOTE on type error: if args contains a Decimal it will nopt get here. + return args # type: ignore[return-value] class Sum(Accumulator): - def __init__(self, aggregation): + def __init__(self, aggregation: CompValue): super(Sum, self).__init__(aggregation) self.value = 0 - self.datatype = None + self.datatype: Optional[str] = None - def update(self, row, aggregator): + def update(self, row: FrozenBindings, aggregator: "Aggregator") -> None: try: value = _eval(self.expr, row) dt = self.datatype @@ -98,18 +134,21 @@ def update(self, row, aggregator): # skip UNDEF pass - def get_value(self): - return Literal(self.value, datatype=self.datatype) + def get_value(self) -> Literal: + # pytype error: Invalid keyword argument datatype to function Literal.__init__ + return Literal( + self.value, datatype=self.datatype + ) # pytype: disable=wrong-keyword-args class Average(Accumulator): - def __init__(self, aggregation): + def __init__(self, aggregation: CompValue): super(Average, self).__init__(aggregation) self.counter = 0 self.sum = 0 - self.datatype = None + self.datatype: Optional[str] = None - def update(self, row, aggregator): + def update(self, row: FrozenBindings, aggregator: "Aggregator") -> None: try: value = _eval(self.expr, row) dt = self.datatype @@ -128,7 +167,7 @@ def update(self, row, aggregator): except SPARQLTypeError: pass - def get_value(self): + def get_value(self) -> Literal: if self.counter == 0: return Literal(0) if self.datatype in (XSD.float, XSD.double): @@ -140,24 +179,29 @@ def get_value(self): class Extremum(Accumulator): """abstract base class for Minimum and Maximum""" - def __init__(self, aggregation): + def __init__(self, aggregation: CompValue): + self.compare: Callable[[Any, Any], Any] super(Extremum, self).__init__(aggregation) - self.value = None + self.value: Any = None # DISTINCT would not change the value for MIN or MAX - self.use_row = self.dont_care + # type error: Cannot assign to a method + self.use_row = self.dont_care # type: ignore[assignment] - def set_value(self, bindings): + def set_value(self, bindings: MutableMapping[Variable, Identifier]) -> None: if self.value is not None: # simply do not set if self.value is still None bindings[self.var] = Literal(self.value) - def update(self, row, aggregator): + def update(self, row: FrozenBindings, aggregator: "Aggregator") -> None: try: if self.value is None: self.value = _eval(self.expr, row) else: # self.compare is implemented by Minimum/Maximum - self.value = self.compare(self.value, _eval(self.expr, row)) + # pytype error: No attribute 'compare' on Extremum + self.value = self.compare( + self.value, _eval(self.expr, row) + ) # pytype: disable=attribute-error # skip UNDEF or BNode => SPARQLTypeError except NotBoundError: pass @@ -183,7 +227,7 @@ def __init__(self, aggregation): # DISTINCT would not change the value self.use_row = self.dont_care - def update(self, row, aggregator): + def update(self, row: FrozenBindings, aggregator: "Aggregator") -> None: try: # set the value now aggregator.bindings[self.var] = _eval(self.expr, row) @@ -192,7 +236,7 @@ def update(self, row, aggregator): except NotBoundError: pass - def get_value(self): + def get_value(self) -> None: # set None if no value was set return None @@ -204,7 +248,7 @@ def __init__(self, aggregation): self.value = [] self.separator = aggregation.separator or " " - def update(self, row, aggregator): + def update(self, row: FrozenBindings, aggregator: "Aggregator") -> None: try: value = _eval(self.expr, row) # skip UNDEF @@ -221,7 +265,7 @@ def update(self, row, aggregator): except NotBoundError: pass - def get_value(self): + def get_value(self) -> Literal: return Literal(self.separator.join(str(v) for v in self.value)) @@ -238,25 +282,26 @@ class Aggregator(object): "Aggregate_GroupConcat": GroupConcat, } - def __init__(self, aggregations): - self.bindings = {} - self.accumulators = {} + def __init__(self, aggregations: List[CompValue]): + self.bindings: Dict[Variable, Identifier] = {} + self.accumulators: Dict[str, Accumulator] = {} for a in aggregations: accumulator_class = self.accumulator_classes.get(a.name) if accumulator_class is None: raise Exception("Unknown aggregate function " + a.name) self.accumulators[a.res] = accumulator_class(a) - def update(self, row): + def update(self, row: FrozenBindings) -> None: """update all own accumulators""" # SAMPLE accumulators may delete themselves # => iterate over list not generator for acc in list(self.accumulators.values()): if acc.use_row(row): - acc.update(row, self) + # pytype error: No attribute 'update' on Accumulator + acc.update(row, self) # pytype: disable=attribute-error - def get_bindings(self): + def get_bindings(self) -> Mapping[Variable, Identifier]: """calculate and set last values""" for acc in self.accumulators.values(): acc.set_value(self.bindings) diff --git a/rdflib/plugins/sparql/algebra.py b/rdflib/plugins/sparql/algebra.py index 01dc175111..18d906f1e2 100644 --- a/rdflib/plugins/sparql/algebra.py +++ b/rdflib/plugins/sparql/algebra.py @@ -11,6 +11,7 @@ import typing from functools import reduce from typing import ( + TYPE_CHECKING, Any, Callable, DefaultDict, @@ -380,7 +381,7 @@ def translateGroupGraphPattern(graphPattern: CompValue) -> CompValue: class StopTraversal(Exception): # noqa: N818 - def __init__(self, rv): + def __init__(self, rv: bool): self.rv = rv diff --git a/rdflib/plugins/sparql/datatypes.py b/rdflib/plugins/sparql/datatypes.py index 115a953b6e..39b502ba65 100644 --- a/rdflib/plugins/sparql/datatypes.py +++ b/rdflib/plugins/sparql/datatypes.py @@ -2,7 +2,7 @@ Utility functions for supporting the XML Schema Datatypes hierarchy """ -from rdflib import XSD +from rdflib.namespace import XSD XSD_DTs = set( ( diff --git a/rdflib/plugins/sparql/evalutils.py b/rdflib/plugins/sparql/evalutils.py index ebec86df57..fc1490f7e1 100644 --- a/rdflib/plugins/sparql/evalutils.py +++ b/rdflib/plugins/sparql/evalutils.py @@ -1,13 +1,20 @@ import collections -from typing import Dict, Iterable +from typing import Generator, Iterable, Mapping, Set, TypeVar, Union, overload from rdflib.plugins.sparql.operators import EBV from rdflib.plugins.sparql.parserutils import CompValue, Expr -from rdflib.plugins.sparql.sparql import FrozenDict, NotBoundError, SPARQLError -from rdflib.term import BNode, Literal, URIRef, Variable +from rdflib.plugins.sparql.sparql import ( + FrozenBindings, + FrozenDict, + NotBoundError, + SPARQLError, +) +from rdflib.term import BNode, Identifier, Literal, URIRef, Variable +FrozenDictT = TypeVar("FrozenDictT", bound=FrozenDict) -def _diff(a: Iterable[FrozenDict], b: Iterable[FrozenDict], expr): + +def _diff(a: Iterable[FrozenDictT], b: Iterable[FrozenDictT], expr) -> Set[FrozenDictT]: res = set() for x in a: @@ -17,20 +24,38 @@ def _diff(a: Iterable[FrozenDict], b: Iterable[FrozenDict], expr): return res -def _minus(a: Iterable[FrozenDict], b: Iterable[FrozenDict]): +def _minus( + a: Iterable[FrozenDictT], b: Iterable[FrozenDictT] +) -> Generator[FrozenDictT, None, None]: for x in a: if all((not x.compatible(y)) or x.disjointDomain(y) for y in b): yield x -def _join(a: Iterable[FrozenDict], b: Iterable[Dict]): +@overload +def _join( + a: Iterable[FrozenBindings], b: Iterable[Mapping[Identifier, Identifier]] +) -> Generator[FrozenBindings, None, None]: + ... + + +@overload +def _join( + a: Iterable[FrozenDict], b: Iterable[Mapping[Identifier, Identifier]] +) -> Generator[FrozenDict, None, None]: + ... + + +def _join( + a: Iterable[FrozenDict], b: Iterable[Mapping[Identifier, Identifier]] +) -> Generator[FrozenDict, None, None]: for x in a: for y in b: if x.compatible(y): yield x.merge(y) -def _ebv(expr, ctx): +def _ebv(expr: Union[Literal, Variable, Expr], ctx: FrozenDict) -> bool: """ Return true/false for the given expr Either the expr is itself true/false @@ -48,7 +73,8 @@ def _ebv(expr, ctx): return EBV(expr.eval(ctx)) except SPARQLError: return False # filter error == False - elif isinstance(expr, CompValue): + # type error: Subclass of "Literal" and "CompValue" cannot exist: would have incompatible method signatures + elif isinstance(expr, CompValue): # type: ignore[unreachable] raise Exception("Weird - filter got a CompValue without evalfn! %r" % expr) elif isinstance(expr, Variable): try: diff --git a/rdflib/plugins/sparql/operators.py b/rdflib/plugins/sparql/operators.py index 0f7b532550..a140e44b2b 100644 --- a/rdflib/plugins/sparql/operators.py +++ b/rdflib/plugins/sparql/operators.py @@ -17,12 +17,23 @@ import warnings from decimal import ROUND_HALF_UP, Decimal, InvalidOperation from functools import reduce +from typing import ( + Any, + Callable, + Dict, + Match, + NoReturn, + Optional, + Tuple, + Union, + overload, +) from urllib.parse import quote import isodate from pyparsing import ParseResults -from rdflib import RDF, XSD, BNode, Literal, URIRef, Variable +from rdflib.namespace import RDF, XSD from rdflib.plugins.sparql.datatypes import ( XSD_DateTime_DTs, XSD_DTs, @@ -30,11 +41,24 @@ type_promotion, ) from rdflib.plugins.sparql.parserutils import CompValue, Expr -from rdflib.plugins.sparql.sparql import SPARQLError, SPARQLTypeError -from rdflib.term import Node +from rdflib.plugins.sparql.sparql import ( + FrozenBindings, + QueryContext, + SPARQLError, + SPARQLTypeError, +) +from rdflib.term import ( + BNode, + IdentifiedNode, + Identifier, + Literal, + Node, + URIRef, + Variable, +) -def Builtin_IRI(expr, ctx): +def Builtin_IRI(expr: Expr, ctx: FrozenBindings) -> URIRef: """ http://www.w3.org/TR/sparql11-query/#func-iri """ @@ -44,20 +68,22 @@ def Builtin_IRI(expr, ctx): if isinstance(a, URIRef): return a if isinstance(a, Literal): - return ctx.prologue.absolutize(URIRef(a)) + # type error: Item "None" of "Optional[Prologue]" has no attribute "absolutize" + # type error: Incompatible return value type (got "Union[CompValue, str, None, Any]", expected "URIRef") + return ctx.prologue.absolutize(URIRef(a)) # type: ignore[union-attr,return-value] raise SPARQLError("IRI function only accepts URIRefs or Literals/Strings!") -def Builtin_isBLANK(expr, ctx): +def Builtin_isBLANK(expr, ctx) -> Literal: return Literal(isinstance(expr.arg, BNode)) -def Builtin_isLITERAL(expr, ctx): +def Builtin_isLITERAL(expr, ctx) -> Literal: return Literal(isinstance(expr.arg, Literal)) -def Builtin_isIRI(expr, ctx): +def Builtin_isIRI(expr, ctx) -> Literal: return Literal(isinstance(expr.arg, URIRef)) @@ -85,7 +111,7 @@ def Builtin_BNODE(expr, ctx): raise SPARQLError("BNode function only accepts no argument or literal/string") -def Builtin_ABS(expr, ctx): +def Builtin_ABS(expr: Expr, ctx): """ http://www.w3.org/TR/sparql11-query/#func-abs """ @@ -93,7 +119,7 @@ def Builtin_ABS(expr, ctx): return Literal(abs(numeric(expr.arg))) -def Builtin_IF(expr, ctx): +def Builtin_IF(expr: Expr, ctx): """ http://www.w3.org/TR/sparql11-query/#func-if """ @@ -101,7 +127,7 @@ def Builtin_IF(expr, ctx): return expr.arg2 if EBV(expr.arg1) else expr.arg3 -def Builtin_RAND(expr, ctx): +def Builtin_RAND(expr: Expr, ctx): """ http://www.w3.org/TR/sparql11-query/#idp2133952 """ @@ -109,7 +135,7 @@ def Builtin_RAND(expr, ctx): return Literal(random.random()) -def Builtin_UUID(expr, ctx): +def Builtin_UUID(expr: Expr, ctx): """ http://www.w3.org/TR/sparql11-query/#func-strdt """ @@ -125,32 +151,32 @@ def Builtin_STRUUID(expr, ctx): return Literal(str(uuid.uuid4())) -def Builtin_MD5(expr, ctx): +def Builtin_MD5(expr: Expr, ctx): s = string(expr.arg).encode("utf-8") return Literal(hashlib.md5(s).hexdigest()) -def Builtin_SHA1(expr, ctx): +def Builtin_SHA1(expr: Expr, ctx): s = string(expr.arg).encode("utf-8") return Literal(hashlib.sha1(s).hexdigest()) -def Builtin_SHA256(expr, ctx): +def Builtin_SHA256(expr: Expr, ctx): s = string(expr.arg).encode("utf-8") return Literal(hashlib.sha256(s).hexdigest()) -def Builtin_SHA384(expr, ctx): +def Builtin_SHA384(expr: Expr, ctx): s = string(expr.arg).encode("utf-8") return Literal(hashlib.sha384(s).hexdigest()) -def Builtin_SHA512(expr, ctx): +def Builtin_SHA512(expr: Expr, ctx): s = string(expr.arg).encode("utf-8") return Literal(hashlib.sha512(s).hexdigest()) -def Builtin_COALESCE(expr, ctx): +def Builtin_COALESCE(expr: Expr, ctx): """ http://www.w3.org/TR/sparql11-query/#func-coalesce """ @@ -160,7 +186,7 @@ def Builtin_COALESCE(expr, ctx): raise SPARQLError("COALESCE got no arguments that did not evaluate to an error") -def Builtin_CEIL(expr, ctx): +def Builtin_CEIL(expr: Expr, ctx): """ http://www.w3.org/TR/sparql11-query/#func-ceil """ @@ -169,7 +195,7 @@ def Builtin_CEIL(expr, ctx): return Literal(int(math.ceil(numeric(l_))), datatype=l_.datatype) -def Builtin_FLOOR(expr, ctx): +def Builtin_FLOOR(expr: Expr, ctx): """ http://www.w3.org/TR/sparql11-query/#func-floor """ @@ -177,7 +203,7 @@ def Builtin_FLOOR(expr, ctx): return Literal(int(math.floor(numeric(l_))), datatype=l_.datatype) -def Builtin_ROUND(expr, ctx): +def Builtin_ROUND(expr: Expr, ctx): """ http://www.w3.org/TR/sparql11-query/#func-round """ @@ -192,7 +218,7 @@ def Builtin_ROUND(expr, ctx): return Literal(v, datatype=l_.datatype) -def Builtin_REGEX(expr, ctx): +def Builtin_REGEX(expr: Expr, ctx): """ http://www.w3.org/TR/sparql11-query/#func-regex Invokes the XPath fn:matches function to match text against a regular @@ -215,7 +241,7 @@ def Builtin_REGEX(expr, ctx): return Literal(bool(re.search(str(pattern), text, cFlag))) -def Builtin_REPLACE(expr, ctx): +def Builtin_REPLACE(expr: Expr, ctx): """ http://www.w3.org/TR/sparql11-query/#func-substr """ @@ -227,28 +253,6 @@ def Builtin_REPLACE(expr, ctx): # python uses \1, xpath/sparql uses $1 replacement = re.sub("\\$([0-9]*)", r"\\\1", replacement) - def _r(m): - - # Now this is ugly. - # Python has a "feature" where unmatched groups return None - # then re.sub chokes on this. - # see http://bugs.python.org/issue1519638 , fixed and errs in py3.5 - - # this works around and hooks into the internal of the re module... - - # the match object is replaced with a wrapper that - # returns "" instead of None for unmatched groups - - class _m: - def __init__(self, m): - self.m = m - self.string = m.string - - def group(self, n): - return m.group(n) or "" - - return re._expand(pattern, _m(m), replacement) - cFlag = 0 if flags: # Maps XPath REGEX flags (http://www.w3.org/TR/xpath-functions/#flags) @@ -258,18 +262,14 @@ def group(self, n): # @@FIXME@@ either datatype OR lang, NOT both - # this is necessary due to different treatment of unmatched groups in - # python versions. see comments above in _r(m). - compat_r = str(replacement) if sys.version_info[:2] >= (3, 5) else _r - return Literal( - re.sub(str(pattern), compat_r, text, cFlag), + re.sub(str(pattern), replacement, text, cFlag), datatype=text.datatype, lang=text.language, ) -def Builtin_STRDT(expr, ctx): +def Builtin_STRDT(expr: Expr, ctx): """ http://www.w3.org/TR/sparql11-query/#func-strdt """ @@ -277,7 +277,7 @@ def Builtin_STRDT(expr, ctx): return Literal(str(expr.arg1), datatype=expr.arg2) -def Builtin_STRLANG(expr, ctx): +def Builtin_STRLANG(expr: Expr, ctx): """ http://www.w3.org/TR/sparql11-query/#func-strlang """ @@ -291,7 +291,7 @@ def Builtin_STRLANG(expr, ctx): return Literal(str(s), lang=str(expr.arg2).lower()) -def Builtin_CONCAT(expr, ctx): +def Builtin_CONCAT(expr: Expr, ctx): """ http://www.w3.org/TR/sparql11-query/#func-concat """ @@ -299,15 +299,20 @@ def Builtin_CONCAT(expr, ctx): # dt/lang passed on only if they all match dt = set(x.datatype for x in expr.arg if isinstance(x, Literal)) - dt = dt.pop() if len(dt) == 1 else None + # type error: Incompatible types in assignment (expression has type "Optional[str]", variable has type "Set[Optional[str]]") + dt = dt.pop() if len(dt) == 1 else None # type: ignore[assignment] lang = set(x.language for x in expr.arg if isinstance(x, Literal)) - lang = lang.pop() if len(lang) == 1 else None + # type error: error: Incompatible types in assignment (expression has type "Optional[str]", variable has type "Set[Optional[str]]") + lang = lang.pop() if len(lang) == 1 else None # type: ignore[assignment] - return Literal("".join(string(x) for x in expr.arg), datatype=dt, lang=lang) + # NOTE on type errors: this is because same variable is used for two incompatibel types + # type error: Argument "datatype" to "Literal" has incompatible type "Set[Any]"; expected "Optional[str]" [arg-type] + # type error: Argument "lang" to "Literal" has incompatible type "Set[Any]"; expected "Optional[str]" + return Literal("".join(string(x) for x in expr.arg), datatype=dt, lang=lang) # type: ignore[arg-type] -def _compatibleStrings(a, b): +def _compatibleStrings(a: Literal, b: Literal) -> None: string(a) string(b) @@ -315,7 +320,7 @@ def _compatibleStrings(a, b): raise SPARQLError("incompatible arguments to str functions") -def Builtin_STRSTARTS(expr, ctx): +def Builtin_STRSTARTS(expr: Expr, ctx): """ http://www.w3.org/TR/sparql11-query/#func-strstarts """ @@ -327,7 +332,7 @@ def Builtin_STRSTARTS(expr, ctx): return Literal(a.startswith(b)) -def Builtin_STRENDS(expr, ctx): +def Builtin_STRENDS(expr: Expr, ctx): """ http://www.w3.org/TR/sparql11-query/#func-strends """ @@ -339,7 +344,7 @@ def Builtin_STRENDS(expr, ctx): return Literal(a.endswith(b)) -def Builtin_STRBEFORE(expr, ctx): +def Builtin_STRBEFORE(expr: Expr, ctx): """ http://www.w3.org/TR/sparql11-query/#func-strbefore """ @@ -355,7 +360,7 @@ def Builtin_STRBEFORE(expr, ctx): return Literal(a[:i], lang=a.language, datatype=a.datatype) -def Builtin_STRAFTER(expr, ctx): +def Builtin_STRAFTER(expr: Expr, ctx): """ http://www.w3.org/TR/sparql11-query/#func-strafter """ @@ -371,7 +376,7 @@ def Builtin_STRAFTER(expr, ctx): return Literal(a[i + len(b) :], lang=a.language, datatype=a.datatype) -def Builtin_CONTAINS(expr, ctx): +def Builtin_CONTAINS(expr: Expr, ctx): """ http://www.w3.org/TR/sparql11-query/#func-strcontains """ @@ -383,11 +388,11 @@ def Builtin_CONTAINS(expr, ctx): return Literal(b in a) -def Builtin_ENCODE_FOR_URI(expr, ctx): +def Builtin_ENCODE_FOR_URI(expr: Expr, ctx): return Literal(quote(string(expr.arg).encode("utf-8"))) -def Builtin_SUBSTR(expr, ctx): +def Builtin_SUBSTR(expr: Expr, ctx): """ http://www.w3.org/TR/sparql11-query/#func-substr """ @@ -403,26 +408,26 @@ def Builtin_SUBSTR(expr, ctx): return Literal(a[start:length], lang=a.language, datatype=a.datatype) -def Builtin_STRLEN(e, ctx): +def Builtin_STRLEN(e: Expr, ctx): l_ = string(e.arg) return Literal(len(l_)) -def Builtin_STR(e, ctx): +def Builtin_STR(e: Expr, ctx): arg = e.arg if isinstance(arg, SPARQLError): raise arg return Literal(str(arg)) # plain literal -def Builtin_LCASE(e, ctx): +def Builtin_LCASE(e: Expr, ctx): l_ = string(e.arg) return Literal(l_.lower(), datatype=l_.datatype, lang=l_.language) -def Builtin_LANGMATCHES(e, ctx): +def Builtin_LANGMATCHES(e: Expr, ctx): """ http://www.w3.org/TR/sparql11-query/#func-langMatches @@ -437,39 +442,39 @@ def Builtin_LANGMATCHES(e, ctx): return Literal(_lang_range_check(langRange, langTag)) -def Builtin_NOW(e, ctx): +def Builtin_NOW(e: Expr, ctx): """ http://www.w3.org/TR/sparql11-query/#func-now """ return Literal(ctx.now) -def Builtin_YEAR(e, ctx): +def Builtin_YEAR(e: Expr, ctx): d = date(e.arg) return Literal(d.year) -def Builtin_MONTH(e, ctx): +def Builtin_MONTH(e: Expr, ctx): d = date(e.arg) return Literal(d.month) -def Builtin_DAY(e, ctx): +def Builtin_DAY(e: Expr, ctx): d = date(e.arg) return Literal(d.day) -def Builtin_HOURS(e, ctx): +def Builtin_HOURS(e: Expr, ctx): d = datetime(e.arg) return Literal(d.hour) -def Builtin_MINUTES(e, ctx): +def Builtin_MINUTES(e: Expr, ctx): d = datetime(e.arg) return Literal(d.minute) -def Builtin_SECONDS(e, ctx): +def Builtin_SECONDS(e: Expr, ctx): """ http://www.w3.org/TR/sparql11-query/#func-seconds """ @@ -477,7 +482,7 @@ def Builtin_SECONDS(e, ctx): return Literal(d.second, datatype=XSD.decimal) -def Builtin_TIMEZONE(e, ctx): +def Builtin_TIMEZONE(e: Expr, ctx): """ http://www.w3.org/TR/sparql11-query/#func-timezone @@ -514,7 +519,7 @@ def Builtin_TIMEZONE(e, ctx): return Literal(tzdelta, datatype=XSD.dayTimeDuration) -def Builtin_TZ(e, ctx): +def Builtin_TZ(e: Expr, ctx): d = datetime(e.arg) if not d.tzinfo: return Literal("") @@ -524,13 +529,13 @@ def Builtin_TZ(e, ctx): return Literal(n) -def Builtin_UCASE(e, ctx): +def Builtin_UCASE(e: Expr, ctx): l_ = string(e.arg) return Literal(l_.upper(), datatype=l_.datatype, lang=l_.language) -def Builtin_LANG(e, ctx): +def Builtin_LANG(e: Expr, ctx): """ http://www.w3.org/TR/sparql11-query/#func-lang @@ -543,7 +548,7 @@ def Builtin_LANG(e, ctx): return Literal(l_.language or "") -def Builtin_DATATYPE(e, ctx): +def Builtin_DATATYPE(e: Expr, ctx): l_ = e.arg if not isinstance(l_, Literal): raise SPARQLError("Can only get datatype of literal: %r" % l_) @@ -554,13 +559,13 @@ def Builtin_DATATYPE(e, ctx): return l_.datatype -def Builtin_sameTerm(e, ctx): +def Builtin_sameTerm(e: Expr, ctx): a = e.arg1 b = e.arg2 return Literal(a == b) -def Builtin_BOUND(e, ctx): +def Builtin_BOUND(e: Expr, ctx): """ http://www.w3.org/TR/sparql11-query/#func-bound """ @@ -569,7 +574,7 @@ def Builtin_BOUND(e, ctx): return Literal(not isinstance(n, Variable)) -def Builtin_EXISTS(e, ctx): +def Builtin_EXISTS(e: Expr, ctx): # damn... from rdflib.plugins.sparql.evaluate import evalPart @@ -581,10 +586,14 @@ def Builtin_EXISTS(e, ctx): return Literal(not exists) -_CUSTOM_FUNCTIONS = {} +_CustomFunction = Callable[[Expr, FrozenBindings], Node] + +_CUSTOM_FUNCTIONS: Dict[URIRef, Tuple[_CustomFunction, bool]] = {} -def register_custom_function(uri, func, override=False, raw=False): +def register_custom_function( + uri: URIRef, func: _CustomFunction, override: bool = False, raw: bool = False +): """ Register a custom SPARQL function. @@ -598,19 +607,19 @@ def register_custom_function(uri, func, override=False, raw=False): _CUSTOM_FUNCTIONS[uri] = (func, raw) -def custom_function(uri, override=False, raw=False): +def custom_function(uri: URIRef, override: bool = False, raw: bool = False): """ Decorator version of :func:`register_custom_function`. """ - def decorator(func): + def decorator(func: _CustomFunction) -> _CustomFunction: register_custom_function(uri, func, override=override, raw=raw) return func return decorator -def unregister_custom_function(uri, func=None): +def unregister_custom_function(uri: URIRef, func=None): """ The 'func' argument is included for compatibility with existing code. A previous implementation checked that the function associated with @@ -623,7 +632,7 @@ def unregister_custom_function(uri, func=None): warnings.warn("This function is not registered as %s" % uri.n3()) -def Function(e, ctx): +def Function(e: Expr, ctx: FrozenBindings) -> Node: """ Custom functions and casts """ @@ -651,7 +660,7 @@ def Function(e, ctx): @custom_function(XSD.decimal, raw=True) @custom_function(XSD.integer, raw=True) @custom_function(XSD.boolean, raw=True) -def default_cast(e, ctx): +def default_cast(e: Expr, ctx: FrozenBindings) -> Literal: # type: ignore[return] if not e.expr: raise SPARQLError("Nothing given to cast.") if len(e.expr) > 1: @@ -715,19 +724,21 @@ def default_cast(e, ctx): raise SPARQLError("Cannot interpret '%r' as bool" % x) -def UnaryNot(expr, ctx): +def UnaryNot(expr: Expr, ctx: FrozenBindings) -> Literal: return Literal(not EBV(expr.expr)) -def UnaryMinus(expr, ctx): +def UnaryMinus(expr: Expr, ctx: FrozenBindings) -> Literal: return Literal(-numeric(expr.expr)) -def UnaryPlus(expr, ctx): +def UnaryPlus(expr: Expr, ctx: FrozenBindings) -> Literal: return Literal(+numeric(expr.expr)) -def MultiplicativeExpression(e, ctx): +def MultiplicativeExpression( + e: Expr, ctx: Union[QueryContext, FrozenBindings] +) -> Literal: expr = e.expr other = e.other @@ -737,6 +748,7 @@ def MultiplicativeExpression(e, ctx): if other is None: return expr try: + res: Union[Decimal, float] res = Decimal(numeric(expr)) for op, f in zip(e.op, other): f = numeric(f) @@ -754,7 +766,8 @@ def MultiplicativeExpression(e, ctx): return Literal(res) -def AdditiveExpression(e, ctx): +# type error: Missing return statement +def AdditiveExpression(e: Expr, ctx: Union[QueryContext, FrozenBindings]) -> Literal: # type: ignore[return] expr = e.expr other = e.other @@ -783,7 +796,8 @@ def AdditiveExpression(e, ctx): # ( dateTime1 - dateTime2 - dateTime3 ) is an invalid operation if len(other) > 1: error_message = "Can't evaluate multiple %r arguments" - raise SPARQLError(error_message, dt.datatype) + # type error: Too many arguments for "SPARQLError" + raise SPARQLError(error_message, dt.datatype) # type: ignore[call-arg] else: n = dateTimeObjects(term) res = calculateDuration(res, n) @@ -829,7 +843,7 @@ def AdditiveExpression(e, ctx): return Literal(res, datatype=dt) -def RelationalExpression(e, ctx): +def RelationalExpression(e: Expr, ctx: Union[QueryContext, FrozenBindings]) -> Literal: expr = e.expr other = e.other @@ -857,7 +871,7 @@ def RelationalExpression(e, ctx): res = op == "NOT IN" - error = False + error: Union[bool, SPARQLError] = False if other == RDF.nil: other = [] @@ -871,7 +885,9 @@ def RelationalExpression(e, ctx): if not error: return Literal(False ^ res) else: - raise error + # Note on type error: this is because variable is Union[bool, SPARQLError] + # type error: Exception must be derived from BaseException + raise error # type: ignore[misc] if op not in ("=", "!=", "IN", "NOT IN"): if not isinstance(expr, Literal): @@ -909,7 +925,9 @@ def RelationalExpression(e, ctx): return Literal(r) -def ConditionalAndExpression(e, ctx): +def ConditionalAndExpression( + e: Expr, ctx: Union[QueryContext, FrozenBindings] +) -> Literal: # TODO: handle returned errors @@ -924,7 +942,9 @@ def ConditionalAndExpression(e, ctx): return Literal(all(EBV(x) for x in [expr] + other)) -def ConditionalOrExpression(e, ctx): +def ConditionalOrExpression( + e: Expr, ctx: Union[QueryContext, FrozenBindings] +) -> Literal: # TODO: handle errors @@ -1026,7 +1046,7 @@ def string(s): return s -def numeric(expr): +def numeric(expr: Literal) -> Any: """ return a number from a literal http://www.w3.org/TR/xpath20/#promotion @@ -1060,7 +1080,7 @@ def numeric(expr): return expr.toPython() -def dateTimeObjects(expr): +def dateTimeObjects(expr: Literal) -> Any: """ return a dataTime/date/time/duration/dayTimeDuration/yearMonthDuration python objects from a literal @@ -1133,7 +1153,22 @@ def calculateFinalDateTime(obj1, dt1, obj2, dt2, operation): raise SPARQLError("Incompatible Data types to DateTime Operations") -def EBV(rt): +@overload +def EBV(rt: Literal) -> bool: + ... + + +@overload +def EBV(rt: Union[Variable, IdentifiedNode, SPARQLError, Expr]) -> NoReturn: + ... + + +@overload +def EBV(rt: Union[Identifier, SPARQLError, Expr]) -> Union[bool, NoReturn]: + ... + + +def EBV(rt: Union[Identifier, SPARQLError, Expr]) -> bool: """ Effective Boolean Value (EBV) @@ -1178,7 +1213,7 @@ def EBV(rt): ) -def _lang_range_check(range, lang): +def _lang_range_check(range: Literal, lang: Literal) -> bool: """ Implementation of the extended filtering algorithm, as defined in point 3.3.2, of U{RFC 4647}, on @@ -1196,7 +1231,7 @@ def _lang_range_check(range, lang): """ - def _match(r, l_): + def _match(r: str, l_: str) -> bool: """ Matching of a range and language item: either range is a wildcard or the two are equal diff --git a/rdflib/plugins/sparql/parser.py b/rdflib/plugins/sparql/parser.py index 2035b4f081..9b9426496a 100644 --- a/rdflib/plugins/sparql/parser.py +++ b/rdflib/plugins/sparql/parser.py @@ -6,6 +6,7 @@ import re import sys +from typing import IO, Any, List, Union from pyparsing import CaselessKeyword as Keyword # watch out :) from pyparsing import ( @@ -27,7 +28,7 @@ from rdflib.compat import decodeUnicodeEscape from . import operators as op -from .parserutils import Comp, Param, ParamList +from .parserutils import Comp, CompValue, Param, ParamList # from pyparsing import Keyword as CaseSensitiveKeyword @@ -37,7 +38,7 @@ # ---------------- ACTIONS -def neg(literal): +def neg(literal: Literal) -> Literal: return rdflib.Literal(-literal, datatype=literal.datatype) @@ -49,13 +50,13 @@ def setDataType(terms): return rdflib.Literal(terms[0], datatype=terms[1]) -def expandTriples(terms): +def expandTriples(terms: ParseResults): """ Expand ; and , syntax for repeat predicates, subjects """ # import pdb; pdb.set_trace() try: - res = [] + res: List[Any] = [] if DEBUG: print("Terms", terms) l_ = len(terms) @@ -102,7 +103,7 @@ def expandTriples(terms): raise -def expandBNodeTriples(terms): +def expandBNodeTriples(terms: ParseResults): """ expand [ ?p ?o ] syntax for implicit bnodes """ @@ -119,14 +120,14 @@ def expandBNodeTriples(terms): raise -def expandCollection(terms): +def expandCollection(terms: ParseResults): """ expand ( 1 2 3 ) notation for collections """ if DEBUG: print("Collection: ", terms) - res = [] + res: List[Any] = [] other = [] for x in terms: if isinstance(x, list): # is this a [ .. ] ? @@ -1511,7 +1512,7 @@ def expandCollection(terms): expandUnicodeEscapes_re = re.compile(r"\\u([0-9a-f]{4}(?:[0-9a-f]{4})?)", flags=re.I) -def expandUnicodeEscapes(q): +def expandUnicodeEscapes(q: str) -> str: r""" The syntax of the SPARQL Query Language is expressed over code points in Unicode [UNICODE]. The encoding is always UTF-8 [RFC3629]. Unicode code points may also be expressed using an \ uXXXX (U+0 to U+FFFF) or \ UXXXXXXXX syntax (for U+10000 onwards) where X is a hexadecimal digit [0-9A-F] @@ -1526,22 +1527,28 @@ def expand(m): return expandUnicodeEscapes_re.sub(expand, q) -def parseQuery(q): +def parseQuery(q: Union[bytes, str, IO]) -> ParseResults: if hasattr(q, "read"): - q = q.read() + # type error: Item "bytes" of "Union[bytes, str, IO[Any]]" has no attribute "read" + # type error: Item "str" of "Union[bytes, str, IO[Any]]" has no attribute "read" + q = q.read() # type: ignore[union-attr] if isinstance(q, bytes): q = q.decode("utf-8") - q = expandUnicodeEscapes(q) + # type error: Argument 1 to "expandUnicodeEscapes" has incompatible type "Union[str, IO[Any]]"; expected "str" + q = expandUnicodeEscapes(q) # type: ignore[arg-type] return Query.parseString(q, parseAll=True) -def parseUpdate(q): +def parseUpdate(q: Union[bytes, str, IO]) -> CompValue: if hasattr(q, "read"): - q = q.read() + # type error: Item "bytes" of "Union[bytes, str, IO[Any]]" has no attribute "read" + # type error: Item "str" of "Union[bytes, str, IO[Any]]" has no attribute "read" + q = q.read() # type: ignore[union-attr] if isinstance(q, bytes): q = q.decode("utf-8") - q = expandUnicodeEscapes(q) + # type error: Argument 1 to "expandUnicodeEscapes" has incompatible type "Union[str, IO[Any]]"; expected "str" + q = expandUnicodeEscapes(q) # type: ignore[arg-type] return UpdateUnit.parseString(q, parseAll=True)[0] diff --git a/rdflib/plugins/sparql/parserutils.py b/rdflib/plugins/sparql/parserutils.py index a936b04671..6d618b2fc3 100644 --- a/rdflib/plugins/sparql/parserutils.py +++ b/rdflib/plugins/sparql/parserutils.py @@ -4,7 +4,10 @@ from pyparsing import ParseResults, TokenConverter, originalTextFor -from rdflib import BNode, Variable +from rdflib.term import BNode, Variable + +if TYPE_CHECKING: + from rdflib.plugins.sparql.sparql import FrozenBindings if TYPE_CHECKING: from rdflib.plugins.sparql.sparql import FrozenBindings diff --git a/rdflib/plugins/sparql/processor.py b/rdflib/plugins/sparql/processor.py index 26a72dd21e..036b30aaaf 100644 --- a/rdflib/plugins/sparql/processor.py +++ b/rdflib/plugins/sparql/processor.py @@ -6,15 +6,21 @@ """ +from typing import Any, Mapping, Optional, Union + +from rdflib.graph import Graph from rdflib.plugins.sparql.algebra import translateQuery, translateUpdate from rdflib.plugins.sparql.evaluate import evalQuery from rdflib.plugins.sparql.parser import parseQuery, parseUpdate -from rdflib.plugins.sparql.sparql import Query +from rdflib.plugins.sparql.sparql import Query, Update from rdflib.plugins.sparql.update import evalUpdate from rdflib.query import Processor, Result, UpdateProcessor +from rdflib.term import Identifier, Variable -def prepareQuery(queryString, initNs={}, base=None) -> Query: +def prepareQuery( + queryString: str, initNs: Mapping[str, str] = {}, base: Optional[str] = None +) -> Query: """ Parse and translate a SPARQL Query """ @@ -23,7 +29,9 @@ def prepareQuery(queryString, initNs={}, base=None) -> Query: return ret -def prepareUpdate(updateString, initNs={}, base=None): +def prepareUpdate( + updateString: str, initNs: Mapping[str, str] = {}, base: Optional[str] = None +) -> Update: """ Parse and translate a SPARQL Update """ @@ -32,7 +40,13 @@ def prepareUpdate(updateString, initNs={}, base=None): return ret -def processUpdate(graph, updateString, initBindings={}, initNs={}, base=None): +def processUpdate( + graph: Graph, + updateString: str, + initBindings: Mapping[Variable, Identifier] = {}, + initNs: Mapping[str, str] = {}, + base: Optional[str] = None, +) -> None: """ Process a SPARQL Update Request returns Nothing on success or raises Exceptions on error @@ -43,7 +57,7 @@ def processUpdate(graph, updateString, initBindings={}, initNs={}, base=None): class SPARQLResult(Result): - def __init__(self, res): + def __init__(self, res: Mapping[str, Any]): Result.__init__(self, res["type_"]) self.vars = res.get("vars_") self.bindings = res.get("bindings") @@ -55,18 +69,39 @@ class SPARQLUpdateProcessor(UpdateProcessor): def __init__(self, graph): self.graph = graph - def update(self, strOrQuery, initBindings={}, initNs={}): + def update( + self, + strOrQuery: Union[str, Update], + initBindings: Mapping[Variable, Identifier] = {}, + initNs: Mapping[str, str] = {}, + ) -> None: if isinstance(strOrQuery, str): strOrQuery = translateUpdate(parseUpdate(strOrQuery), initNs=initNs) - return evalUpdate(self.graph, strOrQuery, initBindings) + # NOTE on type error: this is because translateUpdate has ambigious + # return type that is not always of type `Update` + # type error: Argument 2 to "evalUpdate" has incompatible type "Union[str, Update]"; expected "Update" + return evalUpdate( + self.graph, strOrQuery, initBindings + ) # xtype: ignore[arg-type] class SPARQLProcessor(Processor): def __init__(self, graph): self.graph = graph - def query(self, strOrQuery, initBindings={}, initNs={}, base=None, DEBUG=False): + # NOTE on type error: this is because the super type constructor does not + # accept base argument and thie position of the DEBUG argument is + # different. + # type error: Signature of "query" incompatible with supertype "Processor" + def query( # type: ignore[override] + self, + strOrQuery: Union[str, Query], + initBindings: Mapping[Variable, Identifier] = {}, + initNs: Mapping[str, str] = {}, + base: Optional[str] = None, + DEBUG: bool = False, + ) -> Mapping[str, Any]: """ Evaluate a query with the given initial bindings, and initial namespaces. The given base is used to resolve relative URIs in diff --git a/rdflib/plugins/sparql/results/csvresults.py b/rdflib/plugins/sparql/results/csvresults.py index 16273cbcd4..2ed4170b56 100644 --- a/rdflib/plugins/sparql/results/csvresults.py +++ b/rdflib/plugins/sparql/results/csvresults.py @@ -9,23 +9,26 @@ import codecs import csv -from typing import IO +from typing import IO, List, Optional, Union -from rdflib import BNode, Literal, URIRef, Variable +from rdflib.plugins.sparql.processor import SPARQLResult from rdflib.query import Result, ResultParser, ResultSerializer +from rdflib.term import BNode, Identifier, Literal, URIRef, Variable class CSVResultParser(ResultParser): def __init__(self): self.delim = "," - def parse(self, source, content_type=None): + # type error: Signature of "parse" incompatible with supertype "ResultParser" + def parse(self, source: IO, content_type: Optional[str] = None) -> Result: # type: ignore[override] r = Result("SELECT") + # type error: Incompatible types in assignment (expression has type "StreamReader", variable has type "IO[Any]") if isinstance(source.read(0), bytes): # if reading from source returns bytes do utf-8 decoding - source = codecs.getreader("utf-8")(source) + source = codecs.getreader("utf-8")(source) # type: ignore[assignment] reader = csv.reader(source, delimiter=self.delim) r.vars = [Variable(x) for x in next(reader)] @@ -36,14 +39,14 @@ def parse(self, source, content_type=None): return r - def parseRow(self, row, v): + def parseRow(self, row: List[str], v: List[Variable]): return dict( (var, val) for var, val in zip(v, [self.convertTerm(t) for t in row]) if val is not None ) - def convertTerm(self, t): + def convertTerm(self, t: str) -> Optional[Union[BNode, URIRef, Literal]]: if t == "": return None if t.startswith("_:"): @@ -54,14 +57,14 @@ def convertTerm(self, t): class CSVResultSerializer(ResultSerializer): - def __init__(self, result): + def __init__(self, result: SPARQLResult): ResultSerializer.__init__(self, result) self.delim = "," if result.type != "SELECT": raise Exception("CSVSerializer can only serialize select query results") - def serialize(self, stream: IO, encoding: str = "utf-8", **kwargs): + def serialize(self, stream: IO, encoding: str = "utf-8", **kwargs) -> None: # the serialiser writes bytes in the given encoding # in py3 csv.writer is unicode aware and writes STRINGS, @@ -80,7 +83,7 @@ def serialize(self, stream: IO, encoding: str = "utf-8", **kwargs): [self.serializeTerm(row.get(v), encoding) for v in self.result.vars] # type: ignore[union-attr] ) - def serializeTerm(self, term, encoding): + def serializeTerm(self, term: Optional[Identifier], encoding: str): if term is None: return "" elif isinstance(term, BNode): diff --git a/rdflib/plugins/sparql/results/graph.py b/rdflib/plugins/sparql/results/graph.py index 0b14be27bb..53424ccbbb 100644 --- a/rdflib/plugins/sparql/results/graph.py +++ b/rdflib/plugins/sparql/results/graph.py @@ -1,4 +1,4 @@ -from rdflib import Graph +from rdflib.graph import Graph from rdflib.query import Result, ResultParser diff --git a/rdflib/plugins/sparql/results/jsonresults.py b/rdflib/plugins/sparql/results/jsonresults.py index 1c9c5a7984..d81932ed6b 100644 --- a/rdflib/plugins/sparql/results/jsonresults.py +++ b/rdflib/plugins/sparql/results/jsonresults.py @@ -1,8 +1,8 @@ import json -from typing import IO, Any, Dict +from typing import IO, Any, Dict, List, Union -from rdflib import BNode, Literal, URIRef, Variable from rdflib.query import Result, ResultException, ResultParser, ResultSerializer +from rdflib.term import BNode, IdentifiedNode, Literal, URIRef, Variable """A Serializer for SPARQL results in JSON: @@ -28,7 +28,7 @@ class JSONResultSerializer(ResultSerializer): def __init__(self, result): ResultSerializer.__init__(self, result) - def serialize(self, stream: IO, encoding: str = None): # type: ignore[override] + def serialize(self, stream: IO, encoding: str = None) -> None: # type: ignore[override] res: Dict[str, Any] = {} if self.result.type == "ASK": @@ -59,7 +59,7 @@ def _bindingToJSON(self, b): class JSONResult(Result): - def __init__(self, json): + def __init__(self, json: Dict[str, Any]): self.json = json if "boolean" in json: type_ = "ASK" @@ -76,7 +76,7 @@ def __init__(self, json): self.bindings = self._get_bindings() self.vars = [Variable(x) for x in json["head"]["vars"]] - def _get_bindings(self): + def _get_bindings(self) -> List[Dict[Variable, Union[IdentifiedNode, Literal]]]: ret = [] for row in self.json["results"]["bindings"]: outRow = {} @@ -86,7 +86,7 @@ def _get_bindings(self): return ret -def parseJsonTerm(d): +def parseJsonTerm(d: Dict[str, str]) -> Union[IdentifiedNode, Literal]: """rdflib object (Literal, URIRef, BNode) for the given json-format dict. input is like: @@ -107,7 +107,7 @@ def parseJsonTerm(d): raise NotImplementedError("json term type %r" % t) -def termToJSON(self, term): +def termToJSON(self: JSONResultSerializer, term): if isinstance(term, URIRef): return {"type": "uri", "value": str(term)} elif isinstance(term, Literal): diff --git a/rdflib/plugins/sparql/results/rdfresults.py b/rdflib/plugins/sparql/results/rdfresults.py index 83ee3ea1f9..dbee160f89 100644 --- a/rdflib/plugins/sparql/results/rdfresults.py +++ b/rdflib/plugins/sparql/results/rdfresults.py @@ -1,5 +1,7 @@ -from rdflib import RDF, Graph, Namespace, Variable +from rdflib.graph import Graph +from rdflib.namespace import RDF, Namespace from rdflib.query import Result, ResultParser +from rdflib.term import Variable RS = Namespace("http://www.w3.org/2001/sw/DataAccess/tests/result-set#") diff --git a/rdflib/plugins/sparql/results/tsvresults.py b/rdflib/plugins/sparql/results/tsvresults.py index 42671c7102..cf0c12a9e8 100644 --- a/rdflib/plugins/sparql/results/tsvresults.py +++ b/rdflib/plugins/sparql/results/tsvresults.py @@ -5,6 +5,8 @@ """ import codecs +import typing +from typing import Union from pyparsing import ( FollowedBy, @@ -16,7 +18,6 @@ ZeroOrMore, ) -from rdflib import Literal as RDFLiteral from rdflib.plugins.sparql.parser import ( BLANK_NODE_LABEL, IRIREF, @@ -29,6 +30,9 @@ ) from rdflib.plugins.sparql.parserutils import Comp, CompValue, Param from rdflib.query import Result, ResultParser +from rdflib.term import BNode +from rdflib.term import Literal as RDFLiteral +from rdflib.term import URIRef ParserElement.setDefaultWhitespaceChars(" \n") @@ -84,7 +88,9 @@ def parse(self, source, content_type=None): return r - def convertTerm(self, t): + def convertTerm( + self, t: Union[object, Literal, BNode, CompValue, URIRef] + ) -> typing.Optional[Union[BNode, URIRef, Literal]]: if t is NONE_VALUE: return None if isinstance(t, CompValue): diff --git a/rdflib/plugins/sparql/results/txtresults.py b/rdflib/plugins/sparql/results/txtresults.py index 8b87864b68..a89a4af3b8 100644 --- a/rdflib/plugins/sparql/results/txtresults.py +++ b/rdflib/plugins/sparql/results/txtresults.py @@ -1,12 +1,14 @@ -from typing import IO, List, Optional +from typing import IO, List, Optional, Union -from rdflib import BNode, Literal, URIRef from rdflib.namespace import NamespaceManager from rdflib.query import ResultSerializer -from rdflib.term import Variable +from rdflib.term import BNode, Literal, URIRef, Variable -def _termString(t, namespace_manager: Optional[NamespaceManager]): +def _termString( + t: Optional[Union[URIRef, Literal, BNode]], + namespace_manager: Optional[NamespaceManager], +) -> str: if t is None: return "-" if namespace_manager: @@ -26,12 +28,13 @@ class TXTResultSerializer(ResultSerializer): """ # TODO FIXME: class specific args should be keyword only. + # type error: Signature of "serialize" incompatible with supertype "ResultSerializer" def serialize( # type: ignore[override] self, stream: IO, encoding: str, namespace_manager: Optional[NamespaceManager] = None, - ): + ) -> None: """ return a text table of query results """ @@ -50,13 +53,17 @@ def c(s, w): raise Exception("Can only pretty print SELECT results!") if not self.result: - return "(no results)\n" + # type error: No return value expected + return "(no results)\n" # type: ignore[return-value] else: - keys: List[Variable] = self.result.vars # type: ignore[assignment] maxlen = [0] * len(keys) b = [ - [_termString(r[k], namespace_manager) for k in keys] + # type error: Value of type "Union[Tuple[IdentifiedNode, IdentifiedNode, Identifier], bool, ResultRow]" is not indexable + # type error: Invalid tuple index type (actual type "Variable", expected type "Union[int, slice]") + # error: Argument 1 to "_termString" has incompatible type "Union[Any, Identifier]"; expected "Union[URIRef, Literal, BNode, None]" + # NOTE on type error: The problem here is that r can be more types than _termString expects because result can be a result of multiple types. + [_termString(r[k], namespace_manager) for k in keys] # type: ignore[index,misc,arg-type] for r in self.result ] for r in b: diff --git a/rdflib/plugins/sparql/results/xmlresults.py b/rdflib/plugins/sparql/results/xmlresults.py index 69cb5c303d..85503ebb5f 100644 --- a/rdflib/plugins/sparql/results/xmlresults.py +++ b/rdflib/plugins/sparql/results/xmlresults.py @@ -1,12 +1,12 @@ import logging -from typing import IO, Optional +from typing import IO, Any, Optional from xml.dom import XML_NAMESPACE from xml.sax.saxutils import XMLGenerator from xml.sax.xmlreader import AttributesNSImpl -from rdflib import BNode, Literal, URIRef, Variable from rdflib.compat import etree from rdflib.query import Result, ResultException, ResultParser, ResultSerializer +from rdflib.term import BNode, Literal, URIRef, Variable SPARQL_XML_NAMESPACE = "http://www.w3.org/2005/sparql-results#" RESULTS_NS_ET = "{%s}" % SPARQL_XML_NAMESPACE @@ -109,7 +109,7 @@ class XMLResultSerializer(ResultSerializer): def __init__(self, result): ResultSerializer.__init__(self, result) - def serialize(self, stream: IO, encoding: str = "utf-8", **kwargs): + def serialize(self, stream: IO, encoding: str = "utf-8", **kwargs: Any) -> None: writer = SPARQLXMLWriter(stream, encoding) if self.result.type == "ASK": diff --git a/rdflib/plugins/sparql/update.py b/rdflib/plugins/sparql/update.py index 371d43011d..9503b63246 100644 --- a/rdflib/plugins/sparql/update.py +++ b/rdflib/plugins/sparql/update.py @@ -4,28 +4,36 @@ """ -from rdflib import Graph, Variable +from typing import TYPE_CHECKING, Iterator, Mapping, Optional, Sequence + +from rdflib.graph import Graph from rdflib.plugins.sparql.evaluate import evalBGP, evalPart from rdflib.plugins.sparql.evalutils import _fillTemplate, _join -from rdflib.plugins.sparql.sparql import QueryContext +from rdflib.plugins.sparql.parserutils import CompValue +from rdflib.plugins.sparql.sparql import FrozenDict, QueryContext, Update +from rdflib.term import Identifier, URIRef, Variable -def _graphOrDefault(ctx, g): +def _graphOrDefault(ctx: QueryContext, g: str) -> Optional[Graph]: if g == "DEFAULT": return ctx.graph else: return ctx.dataset.get_context(g) -def _graphAll(ctx, g): +def _graphAll(ctx: QueryContext, g: str) -> Sequence[Graph]: """ return a list of graphs """ if g == "DEFAULT": - return [ctx.graph] + # type error: List item 0 has incompatible type "Optional[Graph]"; expected "Graph" + return [ctx.graph] # type: ignore[list-item] elif g == "NAMED": return [ - c for c in ctx.dataset.contexts() if c.identifier != ctx.graph.identifier + # type error: Item "None" of "Optional[Graph]" has no attribute "identifier" + c + for c in ctx.dataset.contexts() + if c.identifier != ctx.graph.identifier # type: ignore[union-attr] ] elif g == "ALL": return list(ctx.dataset.contexts()) @@ -33,18 +41,21 @@ def _graphAll(ctx, g): return [ctx.dataset.get_context(g)] -def evalLoad(ctx, u): +def evalLoad(ctx: QueryContext, u: CompValue) -> None: """ http://www.w3.org/TR/sparql11-update/#load """ + if TYPE_CHECKING: + assert isinstance(u.iri, URIRef) + if u.graphiri: ctx.load(u.iri, default=False, publicID=u.graphiri) else: ctx.load(u.iri, default=True) -def evalCreate(ctx, u): +def evalCreate(ctx: QueryContext, u: CompValue) -> None: """ http://www.w3.org/TR/sparql11-update/#create """ @@ -54,16 +65,15 @@ def evalCreate(ctx, u): raise Exception("Create not implemented!") -def evalClear(ctx, u): +def evalClear(ctx: QueryContext, u: CompValue) -> None: """ http://www.w3.org/TR/sparql11-update/#clear """ - for g in _graphAll(ctx, u.graphiri): g.remove((None, None, None)) -def evalDrop(ctx, u): +def evalDrop(ctx: QueryContext, u: CompValue) -> None: """ http://www.w3.org/TR/sparql11-update/#drop """ @@ -74,14 +84,13 @@ def evalDrop(ctx, u): evalClear(ctx, u) -def evalInsertData(ctx, u): +def evalInsertData(ctx: QueryContext, u: CompValue) -> None: """ http://www.w3.org/TR/sparql11-update/#insertData """ # add triples g = ctx.graph g += u.triples - # add quads # u.quads is a dict of graphURI=>[triples] for g in u.quads: @@ -89,7 +98,7 @@ def evalInsertData(ctx, u): cg += u.quads[g] -def evalDeleteData(ctx, u): +def evalDeleteData(ctx: QueryContext, u: CompValue) -> None: """ http://www.w3.org/TR/sparql11-update/#deleteData """ @@ -104,18 +113,19 @@ def evalDeleteData(ctx, u): cg -= u.quads[g] -def evalDeleteWhere(ctx, u): +def evalDeleteWhere(ctx: QueryContext, u: CompValue) -> None: """ http://www.w3.org/TR/sparql11-update/#deleteWhere """ - res = evalBGP(ctx, u.triples) + res: Iterator[FrozenDict] = evalBGP(ctx, u.triples) for g in u.quads: cg = ctx.dataset.get_context(g) c = ctx.pushGraph(cg) res = _join(res, list(evalBGP(c, u.quads[g]))) - for c in res: + # type error: Incompatible types in assignment (expression has type "FrozenBindings", variable has type "QueryContext") + for c in res: # type: ignore[assignment] g = ctx.graph g -= _fillTemplate(u.triples, c) @@ -124,11 +134,12 @@ def evalDeleteWhere(ctx, u): cg -= _fillTemplate(u.quads[g], c) -def evalModify(ctx, u): +def evalModify(ctx: QueryContext, u: CompValue) -> None: originalctx = ctx # Using replaces the dataset for evaluating the where-clause + dg: Optional[Graph] if u.using: otherDefault = False for d in u.using: @@ -185,7 +196,7 @@ def evalModify(ctx, u): cg += _fillTemplate(q, c) -def evalAdd(ctx, u): +def evalAdd(ctx: QueryContext, u: CompValue) -> None: """ add all triples from src to dst @@ -197,13 +208,15 @@ def evalAdd(ctx, u): srcg = _graphOrDefault(ctx, src) dstg = _graphOrDefault(ctx, dst) - if srcg.identifier == dstg.identifier: + # type error: Item "None" of "Optional[Graph]" has no attribute "identifier" + if srcg.identifier == dstg.identifier: # type: ignore[union-attr] return - dstg += srcg + # type error: Unsupported left operand type for + ("None") + dstg += srcg # type: ignore[operator] -def evalMove(ctx, u): +def evalMove(ctx: QueryContext, u: CompValue) -> None: """ remove all triples from dst @@ -218,20 +231,24 @@ def evalMove(ctx, u): srcg = _graphOrDefault(ctx, src) dstg = _graphOrDefault(ctx, dst) - if srcg.identifier == dstg.identifier: + # type error: Item "None" of "Optional[Graph]" has no attribute "identifier" + if srcg.identifier == dstg.identifier: # type: ignore[union-attr] return - dstg.remove((None, None, None)) + # type error: Item "None" of "Optional[Graph]" has no attribute "remove" + dstg.remove((None, None, None)) # type: ignore[union-attr] - dstg += srcg + # type error: Unsupported left operand type for + ("None") + dstg += srcg # type: ignore[operator] if ctx.dataset.store.graph_aware: ctx.dataset.store.remove_graph(srcg) else: - srcg.remove((None, None, None)) + # type error: Item "None" of "Optional[Graph]" has no attribute "remove" + srcg.remove((None, None, None)) # type: ignore[union-attr] -def evalCopy(ctx, u): +def evalCopy(ctx: QueryContext, u: CompValue) -> None: """ remove all triples from dst @@ -245,15 +262,20 @@ def evalCopy(ctx, u): srcg = _graphOrDefault(ctx, src) dstg = _graphOrDefault(ctx, dst) - if srcg.identifier == dstg.identifier: + # type error: Item "None" of "Optional[Graph]" has no attribute "remove" + if srcg.identifier == dstg.identifier: # type: ignore[union-attr] return - dstg.remove((None, None, None)) + # type error: Item "None" of "Optional[Graph]" has no attribute "remove" + dstg.remove((None, None, None)) # type: ignore[union-attr] - dstg += srcg + # type error: Unsupported left operand type for + ("None") + dstg += srcg # type: ignore[operator] -def evalUpdate(graph, update, initBindings={}): +def evalUpdate( + graph: Graph, update: Update, initBindings: Mapping[Variable, Identifier] = {} +) -> None: """ http://www.w3.org/TR/sparql11-update/#updateLanguage diff --git a/rdflib/plugins/stores/auditable.py b/rdflib/plugins/stores/auditable.py index 8bbdcd2f51..9b9e7bcc30 100644 --- a/rdflib/plugins/stores/auditable.py +++ b/rdflib/plugins/stores/auditable.py @@ -17,7 +17,7 @@ import threading -from rdflib import ConjunctiveGraph, Graph +from rdflib.graph import ConjunctiveGraph, Graph from rdflib.store import Store destructiveOpLocks = { # noqa: N816 diff --git a/rdflib/plugins/stores/berkeleydb.py b/rdflib/plugins/stores/berkeleydb.py index 82144ba1e6..229dc1acf3 100644 --- a/rdflib/plugins/stores/berkeleydb.py +++ b/rdflib/plugins/stores/berkeleydb.py @@ -2,13 +2,15 @@ from os import mkdir from os.path import abspath, exists from threading import Thread +from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Tuple from urllib.request import pathname2url +from rdflib.graph import Graph, _TriplePatternType, _TripleType from rdflib.store import NO_STORE, VALID_STORE, Store -from rdflib.term import URIRef +from rdflib.term import Identifier, Node, URIRef -def bb(u): +def bb(u: str) -> bytes: return u.encode("utf-8") @@ -37,6 +39,17 @@ def bb(u): __all__ = ["BerkeleyDB"] +_ToKeyFunc = Callable[[Tuple[bytes, bytes, bytes], bytes], bytes] +_FromKeyFunc = Callable[[bytes], Tuple[bytes, bytes, bytes, bytes]] +_GetPrefixFunc = Callable[ + [Tuple[str, str, str], Optional[str]], Generator[str, None, None] +] +_ResultsFromKeyFunc = Callable[ + [bytes, Optional[Node], Optional[Node], Optional[Node], bytes], + Tuple[Tuple[Node, Node, Node], Generator[Node, None, None]], +] + + class BerkeleyDB(Store): """\ A store that allows for on-disk persistent using BerkeleyDB, a fast @@ -63,9 +76,13 @@ class BerkeleyDB(Store): formula_aware = True transaction_aware = False graph_aware = True - db_env = None + db_env: Any = None - def __init__(self, configuration=None, identifier=None): + def __init__( + self, + configuration: Optional[Any] = None, + identifier: Optional["Identifier"] = None, + ): if not has_bsddb: raise ImportError("Unable to import berkeleydb, store is unusable.") self.__open = False @@ -79,7 +96,7 @@ def __get_identifier(self): identifier = property(__get_identifier) - def _init_db_environment(self, homeDir, create=True): # noqa: N803 + def _init_db_environment(self, homeDir: str, create: bool = True): # noqa: N803 if not exists(homeDir): if create is True: mkdir(homeDir) @@ -97,7 +114,7 @@ def _init_db_environment(self, homeDir, create=True): # noqa: N803 def is_open(self): return self.__open - def open(self, path, create=True): + def open(self, path: str, create: bool = True) -> Optional[int]: if not has_bsddb: return NO_STORE homeDir = path # noqa: N806 @@ -127,11 +144,14 @@ def open(self, path, create=True): dbsetflags = 0 # create and open the DBs - self.__indicies = [ + self.__indicies: List[Any] = [ None, ] * 3 - self.__indicies_info = [ - None, + # NOTE on type ingore: this is because type checker does not like this + # way of initializing, using a temporary variable will solve it. + # type error: error: List item 0 has incompatible type "None"; expected "Tuple[Any, Callable[[Tuple[bytes, bytes, bytes], bytes], bytes], Callable[[bytes], Tuple[bytes, bytes, bytes, bytes]]]" + self.__indicies_info: List[Tuple[Any, _ToKeyFunc, _FromKeyFunc]] = [ + None, # type: ignore[list-item] ] * 3 for i in range(0, 3): index_name = to_key_func(i)( @@ -144,9 +164,11 @@ def open(self, path, create=True): self.__indicies[i] = index self.__indicies_info[i] = (index, to_key_func(i), from_key_func(i)) - lookup = {} + lookup: Dict[ + int, Tuple[Any, _GetPrefixFunc, _FromKeyFunc, _ResultsFromKeyFunc] + ] = {} for i in range(0, 8): - results = [] + results: List[Tuple[Tuple[int, int], int, int]] = [] for start in range(0, 3): score = 1 len = 0 @@ -160,10 +182,15 @@ def open(self, path, create=True): results.append(((score, tie_break), start, len)) results.sort() - score, start, len = results[-1] - - def get_prefix_func(start, end): - def get_prefix(triple, context): + # NOTE on type error: this is because the variable `score` is + # reused with different type + # type error: Incompatible types in assignment (expression has type "Tuple[int, int]", variable has type "int") + score, start, len = results[-1] # type: ignore[assignment] + + def get_prefix_func(start: int, end: int) -> _GetPrefixFunc: + def get_prefix( + triple: Tuple[str, str, str], context: Optional[str] + ) -> Generator[str, None, None]: if context is None: yield "" else: @@ -212,7 +239,7 @@ def get_prefix(triple, context): self.__sync_thread = t return VALID_STORE - def __sync_run(self): + def __sync_run(self) -> None: from time import sleep, time try: @@ -236,7 +263,7 @@ def __sync_run(self): except Exception as e: logger.exception(e) - def sync(self): + def sync(self) -> None: if self.__open: for i in self.__indicies: i.sync() @@ -246,7 +273,7 @@ def sync(self): self.__i2k.sync() self.__k2i.sync() - def close(self, commit_pending_transaction=False): + def close(self, commit_pending_transaction: bool = False) -> None: self.__open = False self.__sync_thread.join() for i in self.__indicies: @@ -258,7 +285,13 @@ def close(self, commit_pending_transaction=False): self.__k2i.close() self.db_env.close() - def add(self, triple, context, quoted=False, txn=None): + def add( + self, + triple: "_TripleType", + context: "Graph", + quoted: bool = False, + txn: Optional[Any] = None, + ): """\ Add a triple to the store of triples. """ @@ -327,7 +360,13 @@ def __remove(self, spo, c, quoted=False, txn=None): except db.DBNotFoundError: pass # TODO: is it okay to ignore these? - def remove(self, spo, context, txn=None): + # type error: Signature of "remove" incompatible with supertype "Store" + def remove( # type: ignore[override] + self, + spo: "_TriplePatternType", + context: Optional["Graph"], + txn: Optional[Any] = None, + ) -> None: subject, predicate, object = spo assert self.__open, "The Store must be open." Store.remove(self, (subject, predicate, object), context) @@ -376,7 +415,10 @@ def remove(self, spo, context, txn=None): current = None cursor.close() if key.startswith(prefix): - c, s, p, o = from_key(key) + # NOTE on type error: variables are being reused with a + # different type + # type error: Incompatible types in assignment (expression has type "bytes", variable has type "str") + c, s, p, o = from_key(key) # type: ignore[assignment] if context is None: contexts_value = index.get(key, txn=txn) or "".encode("latin-1") # remove triple from all non quoted contexts @@ -385,7 +427,11 @@ def remove(self, spo, context, txn=None): contexts.add("".encode("latin-1")) for c in contexts: for i, _to_key, _ in self.__indicies_info: - i.delete(_to_key((s, p, o), c), txn=txn) + # NOTE on type error: variables are being + # reused with a different type + # type error: Argument 1 has incompatible type "Tuple[str, str, str]"; expected "Tuple[bytes, bytes, bytes]" + # type error: Argument 2 has incompatible type "str"; expected "bytes" + i.delete(_to_key((s, p, o), c), txn=txn) # type: ignore[arg-type] else: self.__remove((s, p, o), c, txn=txn) else: @@ -404,7 +450,14 @@ def remove(self, spo, context, txn=None): self.__needs_sync = needs_sync - def triples(self, spo, context=None, txn=None): + def triples( + self, + spo: "_TriplePatternType", + context: Optional["Graph"] = None, + txn: Optional[Any] = None, + ) -> Generator[ + Tuple["_TripleType", Generator[Optional["Graph"], None, None]], None, None + ]: """A generator over all the triples matching""" assert self.__open, "The Store must be open." @@ -437,11 +490,14 @@ def triples(self, spo, context=None, txn=None): cursor.close() if key and key.startswith(prefix): contexts_value = index.get(key, txn=txn) - yield results_from_key(key, subject, predicate, object, contexts_value) + # type error: Incompatible types in "yield" (actual type "Tuple[Tuple[Node, Node, Node], Generator[Node, None, None]]", expected type "Tuple[Tuple[IdentifiedNode, URIRef, Identifier], Iterator[Optional[Graph]]]") + # NOTE on type ignore: this is needed because some context is + # lost in the process of extracting triples from the database. + yield results_from_key(key, subject, predicate, object, contexts_value) # type: ignore[misc] else: break - def __len__(self, context=None): + def __len__(self, context: Optional["Graph"] = None) -> int: assert self.__open, "The Store must be open." if context is not None: if context == self: @@ -467,9 +523,13 @@ def __len__(self, context=None): cursor.close() return count - def bind(self, prefix, namespace, override=True): - prefix = prefix.encode("utf-8") - namespace = namespace.encode("utf-8") + def bind(self, prefix: str, namespace: URIRef, override: bool = True) -> None: + # NOTE on type error: this is because the variables are reused with + # another type. + # type error: Incompatible types in assignment (expression has type "bytes", variable has type "str") + prefix = prefix.encode("utf-8") # type: ignore[assignment] + # type error: Incompatible types in assignment (expression has type "bytes", variable has type "URIRef") + namespace = namespace.encode("utf-8") # type: ignore[assignment] bound_prefix = self.__prefix.get(namespace) bound_namespace = self.__namespace.get(prefix) if override: @@ -483,21 +543,27 @@ def bind(self, prefix, namespace, override=True): self.__prefix[bound_namespace or namespace] = bound_prefix or prefix self.__namespace[bound_prefix or prefix] = bound_namespace or namespace - def namespace(self, prefix): - prefix = prefix.encode("utf-8") + def namespace(self, prefix: str) -> Optional[URIRef]: + # NOTE on type error: this is because the variable is reused with + # another type. + # type error: Incompatible types in assignment (expression has type "bytes", variable has type "str") + prefix = prefix.encode("utf-8") # type: ignore[assignment] ns = self.__namespace.get(prefix, None) if ns is not None: return URIRef(ns.decode("utf-8")) return None - def prefix(self, namespace): - namespace = namespace.encode("utf-8") + def prefix(self, namespace: URIRef) -> Optional[str]: + # NOTE on type error: this is because the variable is reused with + # another type. + # type error: Incompatible types in assignment (expression has type "bytes", variable has type "URIRef") + namespace = namespace.encode("utf-8") # type: ignore[assignment] prefix = self.__prefix.get(namespace, None) if prefix is not None: return prefix.decode("utf-8") return None - def namespaces(self): + def namespaces(self) -> Generator[Tuple[str, URIRef], None, None]: cursor = self.__namespace.cursor() results = [] current = cursor.first() @@ -510,20 +576,31 @@ def namespaces(self): for prefix, namespace in results: yield prefix, URIRef(namespace) - def contexts(self, triple=None): + def contexts( + self, triple: Optional["_TripleType"] = None + ) -> Generator["Graph", None, None]: _from_string = self._from_string _to_string = self._to_string - + # NOTE on type errors: context is lost because of how data is loaded + # from the DB. if triple: - s, p, o = triple - s = _to_string(s) - p = _to_string(p) - o = _to_string(o) + s: str + p: str + o: str + # type error: Incompatible types in assignment (expression has type "Node", variable has type "str") + s, p, o = triple # type: ignore[assignment] + # type error: Argument 1 has incompatible type "str"; expected "Node" + s = _to_string(s) # type: ignore[arg-type] + # type error: Argument 1 has incompatible type "str"; expected "Node" + p = _to_string(p) # type: ignore[arg-type] + # type error: Argument 1 has incompatible type "str"; expected "Node" + o = _to_string(o) # type: ignore[arg-type] contexts = self.__indicies[0].get(bb("%s^%s^%s^%s^" % ("", s, p, o))) if contexts: for c in contexts.split("^".encode("latin-1")): if c: - yield _from_string(c) + # type error: Incompatible types in "yield" (actual type "Node", expected type "Graph") + yield _from_string(c) # type: ignore[misc] else: index = self.__contexts cursor = index.cursor() @@ -532,7 +609,8 @@ def contexts(self, triple=None): while current: key, value = current context = _from_string(key) - yield context + # type error: Incompatible types in "yield" (actual type "Node", expected type "Graph") + yield context # type: ignore[misc] cursor = index.cursor() try: cursor.set_range(key) @@ -542,17 +620,17 @@ def contexts(self, triple=None): current = None cursor.close() - def add_graph(self, graph): + def add_graph(self, graph: "Graph") -> None: self.__contexts.put(bb(self._to_string(graph)), b"") - def remove_graph(self, graph): + def remove_graph(self, graph: Optional["Graph"]): self.remove((None, None, None), graph) - def _from_string(self, i): + def _from_string(self, i: bytes) -> Node: k = self.__i2k.get(int(i)) return self._loads(k) - def _to_string(self, term, txn=None): + def _to_string(self, term: Node, txn: Optional[Any] = None) -> str: k = self._dumps(term) i = self.__k2i.get(k, txn=txn) if i is None: @@ -568,30 +646,42 @@ def _to_string(self, term, txn=None): i = i.decode() return i - def __lookup(self, spo, context, txn=None): + def __lookup( + self, + spo: "_TriplePatternType", + context: Optional["Graph"], + txn: Optional[Any] = None, + ) -> Tuple[Any, bytes, _FromKeyFunc, _ResultsFromKeyFunc]: subject, predicate, object = spo _to_string = self._to_string + # NOTE on type errors: this is because the same variable is used with different types. if context is not None: - context = _to_string(context, txn=txn) + # type error: Incompatible types in assignment (expression has type "str", variable has type "Optional[Graph]") + context = _to_string(context, txn=txn) # type: ignore[assignment] i = 0 if subject is not None: i += 1 - subject = _to_string(subject, txn=txn) + # type error: Incompatible types in assignment (expression has type "str", variable has type "Node") + subject = _to_string(subject, txn=txn) # type: ignore[assignment] if predicate is not None: i += 2 - predicate = _to_string(predicate, txn=txn) + # type error: Incompatible types in assignment (expression has type "str", variable has type "Node") + predicate = _to_string(predicate, txn=txn) # type: ignore[assignment] if object is not None: i += 4 - object = _to_string(object, txn=txn) + # type error: Incompatible types in assignment (expression has type "str", variable has type "Node") + object = _to_string(object, txn=txn) # type: ignore[assignment] index, prefix_func, from_key, results_from_key = self.__lookup_dict[i] # print (subject, predicate, object), context, prefix_func, index # #DEBUG - prefix = bb("^".join(prefix_func((subject, predicate, object), context))) + # type error: Argument 1 has incompatible type "Tuple[Node, Node, Node]"; expected "Tuple[str, str, str]" + # type error: Argument 2 has incompatible type "Optional[Graph]"; expected "Optional[str]" + prefix = bb("^".join(prefix_func((subject, predicate, object), context))) # type: ignore[arg-type] return index, prefix, from_key, results_from_key -def to_key_func(i): - def to_key(triple, context): +def to_key_func(i: int) -> _ToKeyFunc: + def to_key(triple: Tuple[bytes, bytes, bytes], context: bytes) -> bytes: "Takes a string; returns key" return "^".encode("latin-1").join( ( @@ -606,8 +696,8 @@ def to_key(triple, context): return to_key -def from_key_func(i): - def from_key(key): +def from_key_func(i: int) -> _FromKeyFunc: + def from_key(key: bytes) -> Tuple[bytes, bytes, bytes, bytes]: "Takes a key; returns string" parts = key.split("^".encode("latin-1")) return ( @@ -620,8 +710,16 @@ def from_key(key): return from_key -def results_from_key_func(i, from_string): - def from_key(key, subject, predicate, object, contexts_value): +def results_from_key_func( + i: int, from_string: Callable[[bytes], Node] +) -> _ResultsFromKeyFunc: + def from_key( + key: bytes, + subject: Optional[Node], + predicate: Optional[Node], + object: Optional[Node], + contexts_value: bytes, + ) -> Tuple[Tuple[Node, Node, Node], Generator[Node, None, None]]: "Takes a key and subject, predicate, object; returns tuple for yield" parts = key.split("^".encode("latin-1")) if subject is None: diff --git a/rdflib/plugins/stores/memory.py b/rdflib/plugins/stores/memory.py index 1fd26a1b25..c4aeaaf374 100644 --- a/rdflib/plugins/stores/memory.py +++ b/rdflib/plugins/stores/memory.py @@ -1,6 +1,30 @@ # # +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Dict, + Generator, + Iterator, + Optional, + Set, + Tuple, + Union, + overload, +) + +from rdflib.graph import ( + Graph, + _ObjectType, + _PredicateType, + _SubjectType, + _TriplePatternType, + _TripleType, +) +from rdflib.plugins.sparql.sparql import Query, Update from rdflib.store import Store +from rdflib.term import Identifier, URIRef, Variable from rdflib.util import _coalesce __all__ = ["SimpleMemory", "Memory"] @@ -19,23 +43,38 @@ class SimpleMemory(Store): Authors: Michel Pelletier, Daniel Krech, Stefan Niederhauser """ - def __init__(self, configuration=None, identifier=None): + def __init__( + self, + configuration: Optional[Any] = None, + identifier: Optional[Identifier] = None, + ): super(SimpleMemory, self).__init__(configuration) self.identifier = identifier # indexed by [subject][predicate][object] - self.__spo = {} + self.__spo: Dict[ + "_SubjectType", Dict["_PredicateType", Dict["_ObjectType", int]] + ] = {} # indexed by [predicate][object][subject] - self.__pos = {} + self.__pos: Dict[ + "_PredicateType", Dict["_ObjectType", Dict["_SubjectType", int]] + ] = {} # indexed by [predicate][object][subject] - self.__osp = {} - - self.__namespace = {} - self.__prefix = {} - - def add(self, triple, context, quoted=False): + self.__osp: Dict[ + "_ObjectType", Dict["_SubjectType", Dict["_PredicateType", int]] + ] = {} + + self.__namespace: Dict[str, URIRef] = {} + self.__prefix: Dict[URIRef, str] = {} + + def add( + self, + triple: "_TripleType", + context: Graph, + quoted: bool = False, + ) -> None: """\ Add a triple to the store of triples. """ @@ -76,13 +115,21 @@ def add(self, triple, context, quoted=False): p = sp[subject] = {} p[predicate] = 1 - def remove(self, triple_pattern, context=None): + def remove( + self, + triple_pattern: "_TriplePatternType", + context: Optional[Graph] = None, + ) -> None: for (subject, predicate, object), c in list(self.triples(triple_pattern)): del self.__spo[subject][predicate][object] del self.__pos[predicate][object][subject] del self.__osp[object][subject][predicate] - def triples(self, triple_pattern, context=None): + def triples( + self, + triple_pattern: "_TriplePatternType", + context: Optional["Graph"] = None, + ) -> Iterator[Tuple["_TripleType", Iterator[Optional[Graph]]]]: """A generator over all the triples matching""" subject, predicate, object = triple_pattern if subject != ANY: # subject is given @@ -142,19 +189,20 @@ def triples(self, triple_pattern, context=None): for o in subjectDictionary[p].keys(): yield (s, p, o), self.__contexts() - def __len__(self, context=None): + def __len__(self, context: Optional[Graph] = None) -> int: # @@ optimize i = 0 for triple in self.triples((None, None, None)): i += 1 return i - def bind(self, prefix, namespace, override=True): + def bind(self, prefix: str, namespace: URIRef, override: bool = True) -> None: # should be identical to `Memory.bind` bound_namespace = self.__namespace.get(prefix) bound_prefix = _coalesce( self.__prefix.get(namespace), - self.__prefix.get(bound_namespace), + # type error: error: Argument 1 to "get" of "Mapping" has incompatible type "Optional[URIRef]"; expected "URIRef" + self.__prefix.get(bound_namespace), # type: ignore[arg-type] ) if override: if bound_prefix is not None: @@ -164,32 +212,50 @@ def bind(self, prefix, namespace, override=True): self.__prefix[namespace] = prefix self.__namespace[prefix] = namespace else: - self.__prefix[_coalesce(bound_namespace, namespace)] = _coalesce( + # type error: Invalid index type "Optional[URIRef]" for "Dict[URIRef, str]"; expected type "URIRef" + # type error: Incompatible types in assignment (expression has type "Optional[str]", target has type "str") + self.__prefix[_coalesce(bound_namespace, namespace)] = _coalesce( # type: ignore[index, assignment] bound_prefix, prefix ) - self.__namespace[_coalesce(bound_prefix, prefix)] = _coalesce( + # type error: Invalid index type "Optional[str]" for "Dict[str, URIRef]"; expected type "str" + # type error: Incompatible types in assignment (expression has type "Optional[URIRef]", target has type "URIRef") + self.__namespace[_coalesce(bound_prefix, prefix)] = _coalesce( # type: ignore[index, assignment] bound_namespace, namespace ) - def namespace(self, prefix): + def namespace(self, prefix: str) -> Optional[URIRef]: return self.__namespace.get(prefix, None) - def prefix(self, namespace): + def prefix(self, namespace: URIRef) -> Optional[str]: return self.__prefix.get(namespace, None) - def namespaces(self): + def namespaces(self) -> Iterator[Tuple[str, URIRef]]: for prefix, namespace in self.__namespace.items(): yield prefix, namespace def __contexts(self): return (c for c in []) # TODO: best way to return empty generator - def query(self, query, initNs, initBindings, queryGraph, **kwargs): # noqa: N803 + def query( + self, + query: Union[Query, str], + initNs: Dict[str, str], # noqa: N803 + initBindings: Dict[Variable, Identifier], # noqa: N803 + queryGraph: Identifier, # noqa: N803 + **kwargs: Any, + ) -> None: super(SimpleMemory, self).query( query, initNs, initBindings, queryGraph, **kwargs ) - def update(self, update, initNs, initBindings, queryGraph, **kwargs): # noqa: N803 + def update( + self, + update: Union[Update, str], + initNs: Dict[str, str], # noqa: N803 + initBindings: Dict[Variable, Identifier], # noqa: N803 + queryGraph: Identifier, # noqa: N803 + **kwargs: Any, + ) -> None: super(SimpleMemory, self).update( update, initNs, initBindings, queryGraph, **kwargs ) @@ -207,30 +273,45 @@ class Memory(Store): formula_aware = True graph_aware = True - def __init__(self, configuration=None, identifier=None): + def __init__( + self, + configuration: Optional[Any] = None, + identifier: Optional[Identifier] = None, + ): super(Memory, self).__init__(configuration) self.identifier = identifier # indexed by [subject][predicate][object] - self.__spo = {} + self.__spo: Dict[ + "_SubjectType", Dict["_PredicateType", Dict["_ObjectType", int]] + ] = {} # indexed by [predicate][object][subject] - self.__pos = {} + self.__pos: Dict[ + "_PredicateType", Dict["_ObjectType", Dict["_SubjectType", int]] + ] = {} # indexed by [predicate][object][subject] - self.__osp = {} - - self.__namespace = {} - self.__prefix = {} - self.__context_obj_map = {} - self.__tripleContexts = {} - self.__contextTriples = {None: set()} + self.__osp: Dict[ + "_ObjectType", Dict["_SubjectType", Dict["_PredicateType", int]] + ] = {} + + self.__namespace: Dict[str, URIRef] = {} + self.__prefix: Dict[URIRef, str] = {} + self.__context_obj_map: Dict[str, Graph] = {} + self.__tripleContexts: Dict["_TripleType", Dict[Optional[str], bool]] = {} + self.__contextTriples: Dict[Optional[str], Set["_TripleType"]] = {None: set()} # all contexts used in store (unencoded) - self.__all_contexts = set() + self.__all_contexts: Set[Graph] = set() # default context information for triples - self.__defaultContexts = None - - def add(self, triple, context, quoted=False): + self.__defaultContexts: Optional[Dict[Optional[str], bool]] = None + + def add( + self, + triple: "_TripleType", + context: Graph, + quoted: bool = False, + ) -> None: """\ Add a triple to the store of triples. """ @@ -287,7 +368,9 @@ def add(self, triple, context, quoted=False): p = sp[subject] = {} p[predicate] = 1 - def remove(self, triple_pattern, context=None): + def remove( + self, triple_pattern: "_TriplePatternType", context: Optional[Graph] = None + ) -> None: req_ctx = self.__ctx_to_str(context) for triple, c in self.triples(triple_pattern, context=context): subject, predicate, object_ = triple @@ -321,11 +404,18 @@ def remove(self, triple_pattern, context=None): # remove the whole context self.__all_contexts.remove(context) - def triples(self, triple_pattern, context=None): + def triples( + self, + triple_pattern: "_TriplePatternType", + context: Optional["Graph"] = None, + ) -> Generator[ + Tuple["_TripleType", Generator[Optional[Graph], None, None]], None, None + ]: """A generator over all the triples matching""" req_ctx = self.__ctx_to_str(context) subject, predicate, object_ = triple_pattern + # triple: Union[_Triple, _TriplePattern] # all triples case (no triple parts given as pattern) if subject is None and predicate is None and object_ is None: # Just dump all known triples from the given graph @@ -336,7 +426,10 @@ def triples(self, triple_pattern, context=None): # optimize "triple in graph" case (all parts given) elif subject is not None and predicate is not None and object_ is not None: - triple = triple_pattern + # type error: Incompatible types in assignment (expression has type "Tuple[Optional[IdentifiedNode], Optional[IdentifiedNode], Optional[Identifier]]", variable has type "Tuple[IdentifiedNode, IdentifiedNode, Identifier]") + # NOTE on type error: at this point, all elements of triple_pattern + # is not None, so it has the same type as triple + triple = triple_pattern # type: ignore[assignment] try: _ = self.__spo[subject][predicate][object_] if self.__triple_has_context(triple, req_ctx): @@ -418,12 +511,13 @@ def triples(self, triple_pattern, context=None): if self.__triple_has_context(triple, req_ctx): yield triple, self.__contexts(triple) - def bind(self, prefix, namespace, override=True): + def bind(self, prefix: str, namespace: URIRef, override: bool = True) -> None: # should be identical to `SimpleMemory.bind` bound_namespace = self.__namespace.get(prefix) bound_prefix = _coalesce( self.__prefix.get(namespace), - self.__prefix.get(bound_namespace), + # type error: error: Argument 1 to "get" of "Mapping" has incompatible type "Optional[URIRef]"; expected "URIRef" + self.__prefix.get(bound_namespace), # type: ignore[arg-type] ) if override: if bound_prefix is not None: @@ -433,24 +527,30 @@ def bind(self, prefix, namespace, override=True): self.__prefix[namespace] = prefix self.__namespace[prefix] = namespace else: - self.__prefix[_coalesce(bound_namespace, namespace)] = _coalesce( + # type error: Invalid index type "Optional[URIRef]" for "Dict[URIRef, str]"; expected type "URIRef" + # type error: Incompatible types in assignment (expression has type "Optional[str]", target has type "str") + self.__prefix[_coalesce(bound_namespace, namespace)] = _coalesce( # type: ignore[index, assignment] bound_prefix, prefix ) - self.__namespace[_coalesce(bound_prefix, prefix)] = _coalesce( + # type error: Invalid index type "Optional[str]" for "Dict[str, URIRef]"; expected type "str" + # type error: Incompatible types in assignment (expression has type "Optional[URIRef]", target has type "URIRef") + self.__namespace[_coalesce(bound_prefix, prefix)] = _coalesce( # type: ignore[index, assignment] bound_namespace, namespace ) - def namespace(self, prefix): + def namespace(self, prefix: str) -> Optional[URIRef]: return self.__namespace.get(prefix, None) - def prefix(self, namespace): + def prefix(self, namespace: URIRef) -> Optional[str]: return self.__prefix.get(namespace, None) - def namespaces(self): + def namespaces(self) -> Iterator[Tuple[str, URIRef]]: for prefix, namespace in self.__namespace.items(): yield prefix, namespace - def contexts(self, triple=None): + def contexts( + self, triple: Optional["_TripleType"] = None + ) -> Generator[Graph, None, None]: if triple is None or triple == (None, None, None): return (context for context in self.__all_contexts) @@ -461,30 +561,37 @@ def contexts(self, triple=None): except KeyError: return (_ for _ in []) - def __len__(self, context=None): + def __len__(self, context: Optional[Graph] = None) -> int: ctx = self.__ctx_to_str(context) if ctx not in self.__contextTriples: return 0 return len(self.__contextTriples[ctx]) - def add_graph(self, graph): + def add_graph(self, graph: Graph): if not self.graph_aware: Store.add_graph(self, graph) else: self.__all_contexts.add(graph) - def remove_graph(self, graph): + def remove_graph(self, graph: Optional[Graph]): if not self.graph_aware: Store.remove_graph(self, graph) else: self.remove((None, None, None), graph) try: - self.__all_contexts.remove(graph) + # type error: Argument 1 to "remove" of "set" has incompatible type "Optional[Graph]"; expected "Graph" + self.__all_contexts.remove(graph) # type: ignore[arg-type] except KeyError: pass # we didn't know this graph, no problem # internal utility methods below - def __add_triple_context(self, triple, triple_exists, context, quoted): + def __add_triple_context( + self, + triple: "_TripleType", + triple_exists: bool, + context: Optional[Graph], + quoted: bool, + ): """add the given context to the set of contexts for the triple""" ctx = self.__ctx_to_str(context) quoted = bool(quoted) @@ -495,9 +602,10 @@ def __add_triple_context(self, triple, triple_exists, context, quoted): except KeyError: # triple exists with default ctx info # start with a copy of the default ctx info + # type error: Item "None" of "Optional[Dict[Optional[str], bool]]" has no attribute "copy" triple_context = self.__tripleContexts[ triple - ] = self.__defaultContexts.copy() + ] = self.__defaultContexts.copy() # type: ignore[union-attr] triple_context[ctx] = quoted @@ -530,24 +638,30 @@ def __add_triple_context(self, triple, triple_exists, context, quoted): if triple_context == self.__defaultContexts: del self.__tripleContexts[triple] - def __get_context_for_triple(self, triple, skipQuoted=False): # noqa: N803 + def __get_context_for_triple( + self, triple: "_TripleType", skipQuoted: bool = False # noqa: N803 + ) -> Collection[Optional[str]]: """return a list of contexts (str) for the triple, skipping quoted contexts if skipQuoted==True""" ctxs = self.__tripleContexts.get(triple, self.__defaultContexts) if not skipQuoted: - return ctxs.keys() + # type error: Item "None" of "Optional[Dict[Optional[str], bool]]" has no attribute "keys" + return ctxs.keys() # type: ignore[union-attr] - return [ctx for ctx, quoted in ctxs.items() if not quoted] + # type error: Item "None" of "Optional[Dict[Optional[str], bool]]" has no attribute "items" + return [ctx for ctx, quoted in ctxs.items() if not quoted] # type: ignore[union-attr] - def __triple_has_context(self, triple, ctx): + def __triple_has_context(self, triple: "_TripleType", ctx: Optional[str]) -> bool: """return True if the triple exists in the given context""" - return ctx in self.__tripleContexts.get(triple, self.__defaultContexts) + # type error: Unsupported right operand type for in ("Optional[Dict[Optional[str], bool]]") + return ctx in self.__tripleContexts.get(triple, self.__defaultContexts) # type: ignore[operator] - def __remove_triple_context(self, triple, ctx): + def __remove_triple_context(self, triple: "_TripleType", ctx): """remove the context from the triple""" - ctxs = self.__tripleContexts.get(triple, self.__defaultContexts).copy() + # type error: Item "None" of "Optional[Dict[Optional[str], bool]]" has no attribute "copy" + ctxs = self.__tripleContexts.get(triple, self.__defaultContexts).copy() # type: ignore[union-attr] del ctxs[ctx] if ctxs == self.__defaultContexts: del self.__tripleContexts[triple] @@ -555,7 +669,15 @@ def __remove_triple_context(self, triple, ctx): self.__tripleContexts[triple] = ctxs self.__contextTriples[ctx].remove(triple) - def __ctx_to_str(self, ctx): + @overload + def __ctx_to_str(self, ctx: Graph) -> str: + ... + + @overload + def __ctx_to_str(self, ctx: None) -> None: + ... + + def __ctx_to_str(self, ctx: Optional[Graph]) -> Optional[str]: if ctx is None: return None try: @@ -565,19 +687,23 @@ def __ctx_to_str(self, ctx): return ctx_str except AttributeError: # otherwise, ctx should be a URIRef or BNode or str - if isinstance(ctx, str): - ctx_str = "{}:{}".format(ctx.__class__.__name__, ctx) + # NOTE on type errors: This is actually never called with ctx value as str in all unit tests, so this seems like it should just not be here. + # type error: Subclass of "Graph" and "str" cannot exist: would have incompatible method signatures + if isinstance(ctx, str): # type: ignore[unreachable] + # type error: Statement is unreachable + ctx_str = "{}:{}".format(ctx.__class__.__name__, ctx) # type: ignore[unreachable] if ctx_str in self.__context_obj_map: return ctx_str self.__context_obj_map[ctx_str] = ctx return ctx_str raise RuntimeError("Cannot use that type of object as a Graph context") - def __contexts(self, triple): + def __contexts(self, triple: "_TripleType") -> Generator[Graph, None, None]: """return a generator for all the non-quoted contexts (dereferenced) the encoded triple appears in""" + # type error: Argument 2 to "get" of "Mapping" has incompatible type "str"; expected "Optional[Graph]" return ( - self.__context_obj_map.get(ctx_str, ctx_str) + self.__context_obj_map.get(ctx_str, ctx_str) # type: ignore[arg-type] for ctx_str in self.__get_context_for_triple(triple, skipQuoted=True) if ctx_str is not None ) diff --git a/rdflib/plugins/stores/sparqlconnector.py b/rdflib/plugins/stores/sparqlconnector.py index 1af3b369e0..60921ab8e3 100644 --- a/rdflib/plugins/stores/sparqlconnector.py +++ b/rdflib/plugins/stores/sparqlconnector.py @@ -6,8 +6,8 @@ from urllib.parse import urlencode from urllib.request import Request, urlopen -from rdflib import BNode from rdflib.query import Result +from rdflib.term import BNode log = logging.getLogger(__name__) diff --git a/rdflib/plugins/stores/sparqlstore.py b/rdflib/plugins/stores/sparqlstore.py index 967dbaf68e..241b049fb0 100644 --- a/rdflib/plugins/stores/sparqlstore.py +++ b/rdflib/plugins/stores/sparqlstore.py @@ -9,11 +9,10 @@ import re from typing import Any, Callable, Dict, Optional, Tuple, Union -from rdflib import BNode, Variable from rdflib.graph import DATASET_DEFAULT_GRAPH_ID from rdflib.plugins.stores.regexmatching import NATIVE_REGEX from rdflib.store import Store -from rdflib.term import Node +from rdflib.term import BNode, Node, Variable from .sparqlconnector import SPARQLConnector diff --git a/rdflib/query.py b/rdflib/query.py index d5ff9b251e..f49218940e 100644 --- a/rdflib/query.py +++ b/rdflib/query.py @@ -5,7 +5,21 @@ import types import warnings from io import BytesIO -from typing import IO, TYPE_CHECKING, List, Optional, Union, cast +from typing import ( + IO, + TYPE_CHECKING, + Any, + BinaryIO, + Dict, + Iterator, + List, + Mapping, + MutableSequence, + Optional, + Tuple, + Union, + cast, +) from urllib.parse import urlparse __all__ = [ @@ -20,8 +34,10 @@ ] if TYPE_CHECKING: - from rdflib.graph import Graph - from rdflib.term import Variable + # from rdflib._typing import _TripleType + from rdflib.graph import Graph, _TripleType + from rdflib.plugins.sparql.sparql import Query, Update + from rdflib.term import Identifier, Variable class Processor(object): @@ -34,10 +50,16 @@ class Processor(object): """ - def __init__(self, graph): + def __init__(self, graph: "Graph"): pass - def query(self, strOrQuery, initBindings={}, initNs={}, DEBUG=False): + def query( + self, + strOrQuery: Union[str, "Query"], + initBindings: Mapping["Variable", "Identifier"] = {}, + initNs: Mapping[str, str] = {}, + DEBUG: bool = False, + ) -> Mapping[str, Any]: pass @@ -54,10 +76,15 @@ class update method. """ - def __init__(self, graph): + def __init__(self, graph: "Graph"): pass - def update(self, strOrQuery, initBindings={}, initNs={}): + def update( + self, + strOrQuery: Union[str, "Update"], + initBindings: Mapping["Variable", "Identifier"] = {}, + initNs: Mapping[str, str] = {}, + ) -> None: pass @@ -73,7 +100,7 @@ class EncodeOnlyUnicode(object): """ - def __init__(self, stream): + def __init__(self, stream: BinaryIO): self.__stream = stream def write(self, arg): @@ -82,11 +109,11 @@ def write(self, arg): else: self.__stream.write(arg) - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: return getattr(self.__stream, name) -class ResultRow(tuple): +class ResultRow(Tuple["Identifier", ...]): """ a single result row allows accessing bindings as attributes or with [] @@ -122,34 +149,44 @@ class ResultRow(tuple): """ - def __new__(cls, values, labels): + labels: Mapping[str, int] - instance = super(ResultRow, cls).__new__(cls, (values.get(v) for v in labels)) + def __new__( + cls, values: Mapping["Variable", "Identifier"], labels: List["Variable"] + ): + # type error: Generator has incompatible item type "Optional[Any]"; expected "_T_co" + instance = super(ResultRow, cls).__new__(cls, (values.get(v) for v in labels)) # type: ignore[misc] instance.labels = dict((str(x[1]), x[0]) for x in enumerate(labels)) return instance - def __getattr__(self, name): + def __getattr__(self, name: str) -> "Identifier": if name not in self.labels: raise AttributeError(name) return tuple.__getitem__(self, self.labels[name]) - def __getitem__(self, name): + # type error: Signature of "__getitem__" incompatible with supertype "tuple" + # type error: Signature of "__getitem__" incompatible with supertype "Sequence" + def __getitem__(self, name: Union[str, int, Any]) -> "Identifier": # type: ignore[override] try: - return tuple.__getitem__(self, name) + # type error: Invalid index type "Union[str, int, Any]" for "tuple"; expected type "int" + return tuple.__getitem__(self, name) # type: ignore[index] except TypeError: if name in self.labels: - return tuple.__getitem__(self, self.labels[name]) + # type error: Invalid index type "Union[str, int, slice, Any]" for "Mapping[str, int]"; expected type "str" + return tuple.__getitem__(self, self.labels[name]) # type: ignore[index] if str(name) in self.labels: # passing in variable object return tuple.__getitem__(self, self.labels[str(name)]) raise KeyError(name) - def get(self, name, default=None): + def get( + self, name: str, default: Optional["Identifier"] = None + ) -> Optional["Identifier"]: try: return self[name] except KeyError: return default - def asdict(self): + def asdict(self) -> Dict[str, "Identifier"]: return dict((v, self[v]) for v in self.labels if self[v] is not None) @@ -180,10 +217,10 @@ def __init__(self, type_: str): self.type = type_ #: variables contained in the result. self.vars: Optional[List["Variable"]] = None - self._bindings = None - self._genbindings = None - self.askAnswer: bool = None # type: ignore[assignment] - self.graph: "Graph" = None # type: ignore[assignment] + self._bindings: MutableSequence[Mapping["Variable", "Identifier"]] = None # type: ignore[assignment] + self._genbindings: Optional[Iterator[Mapping["Variable", "Identifier"]]] = None + self.askAnswer: Optional[bool] = None + self.graph: Optional["Graph"] = None @property def bindings(self): @@ -206,11 +243,11 @@ def bindings(self, b): @staticmethod def parse( - source=None, + source: Optional[IO] = None, format: Optional[str] = None, content_type: Optional[str] = None, **kwargs, - ): + ) -> "Result": from rdflib import plugin if format: @@ -222,7 +259,10 @@ def parse( parser = plugin.get(plugin_key, ResultParser)() - return parser.parse(source, content_type=content_type, **kwargs) + # type error: Argument 1 to "parse" of "ResultParser" has incompatible type "Optional[IO[Any]]"; expected "IO[Any]" + return parser.parse( + source, content_type=content_type, **kwargs # type:ignore[arg-type] + ) def serialize( self, @@ -248,7 +288,9 @@ def serialize( :return: bytes """ if self.type in ("CONSTRUCT", "DESCRIBE"): - return self.graph.serialize( # type: ignore[return-value] + # type error: Item "None" of "Optional[Graph]" has no attribute "serialize" + # type error: Incompatible return value type (got "Union[bytes, str, Graph, Any]", expected "Optional[bytes]") + return self.graph.serialize( # type: ignore[union-attr,return-value] destination, encoding=encoding, format=format, **args ) @@ -259,7 +301,8 @@ def serialize( if destination is None: streamb: BytesIO = BytesIO() stream2 = EncodeOnlyUnicode(streamb) - serializer.serialize(stream2, encoding=encoding, **args) # type: ignore + # type error: Argument 1 to "serialize" of "ResultSerializer" has incompatible type "EncodeOnlyUnicode"; expected "IO[Any]" + serializer.serialize(stream2, encoding=encoding, **args) # type: ignore[arg-type] return streamb.getvalue() if hasattr(destination, "write"): stream = cast(IO[bytes], destination) @@ -283,26 +326,32 @@ def serialize( os.remove(name) return None - def __len__(self): + def __len__(self) -> int: if self.type == "ASK": return 1 elif self.type == "SELECT": return len(self.bindings) else: - return len(self.graph) + # type error: Argument 1 to "len" has incompatible type "Optional[Graph]"; expected "Sized" + return len(self.graph) # type: ignore[arg-type] - def __bool__(self): + def __bool__(self) -> bool: if self.type == "ASK": - return self.askAnswer + # type error: Incompatible return value type (got "Optional[bool]", expected "bool") + return self.askAnswer # type: ignore[return-value] else: return len(self) > 0 - def __iter__(self): + def __iter__( + self, + ) -> Iterator[Union["_TripleType", bool, ResultRow]]: if self.type in ("CONSTRUCT", "DESCRIBE"): - for t in self.graph: + # type error: Item "None" of "Optional[Graph]" has no attribute "__iter__" (not iterable) + for t in self.graph: # type: ignore[union-attr] yield t elif self.type == "ASK": - yield self.askAnswer + # type error: Incompatible types in "yield" (actual type "Optional[bool]", expected type "Union[Tuple[Identifier, Identifier, Identifier], bool, ResultRow]") [misc] + yield self.askAnswer # type: ignore[misc] elif self.type == "SELECT": # this iterates over ResultRows of variable bindings @@ -310,16 +359,19 @@ def __iter__(self): for b in self._genbindings: if b: # don't add a result row in case of empty binding {} self._bindings.append(b) - yield ResultRow(b, self.vars) + # type error: Argument 2 to "ResultRow" has incompatible type "Optional[List[Variable]]"; expected "List[Variable]" + yield ResultRow(b, self.vars) # type: ignore[arg-type] self._genbindings = None else: for b in self._bindings: if b: # don't add a result row in case of empty binding {} - yield ResultRow(b, self.vars) + # type error: Argument 2 to "ResultRow" has incompatible type "Optional[List[Variable]]"; expected "List[Variable]" + yield ResultRow(b, self.vars) # type: ignore[arg-type] - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: if self.type in ("CONSTRUCT", "DESCRIBE") and self.graph is not None: - return self.graph.__getattr__(self, name) + # type error: "Graph" has no attribute "__getattr__" + return self.graph.__getattr__(self, name) # type: ignore[attr-defined] elif self.type == "SELECT" and name == "result": warnings.warn( "accessing the 'result' attribute is deprecated." @@ -328,11 +380,12 @@ def __getattr__(self, name): stacklevel=2, ) # copied from __iter__, above - return [(tuple(b[v] for v in self.vars)) for b in self.bindings] + # type error: Item "None" of "Optional[List[Variable]]" has no attribute "__iter__" (not iterable) + return [(tuple(b[v] for v in self.vars)) for b in self.bindings] # type: ignore[union-attr] else: raise AttributeError("'%s' object has no attribute '%s'" % (self, name)) - def __eq__(self, other): + def __eq__(self, other) -> bool: try: if self.type != other.type: return False @@ -350,7 +403,7 @@ class ResultParser(object): def __init__(self): pass - def parse(self, source, **kwargs): + def parse(self, source: IO, **kwargs: Any) -> Result: """return a Result object""" pass # abstract @@ -359,6 +412,6 @@ class ResultSerializer(object): def __init__(self, result: Result): self.result = result - def serialize(self, stream: IO, encoding: str = "utf-8", **kwargs): + def serialize(self, stream: IO, encoding: str = "utf-8", **kwargs: Any) -> None: """return a string properly serialized""" pass # abstract diff --git a/rdflib/store.py b/rdflib/store.py index 8367cf67b0..9bc8dc658e 100644 --- a/rdflib/store.py +++ b/rdflib/store.py @@ -1,12 +1,34 @@ import pickle from io import BytesIO -from typing import TYPE_CHECKING, Iterable, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generator, + Iterable, + Iterator, + List, + Mapping, + Optional, + Tuple, + Union, + overload, +) from rdflib.events import Dispatcher, Event if TYPE_CHECKING: - from rdflib.graph import Graph - from rdflib.term import IdentifiedNode, Node, URIRef + from rdflib.graph import ( + Graph, + _ObjectType, + _PredicateType, + _QuadType, + _SubjectType, + _TriplePatternType, + _TripleType, + ) + from rdflib.plugins.sparql.sparql import Query, Update + from rdflib.term import Identifier, Node, URIRef, Variable """ ============ @@ -90,36 +112,42 @@ class TripleRemovedEvent(Event): class NodePickler(object): def __init__(self): - self._objects = {} - self._ids = {} + self._objects: Dict[str, Any] = {} + self._ids: Dict[Any, str] = {} self._get_object = self._objects.__getitem__ - def _get_ids(self, key): + def _get_ids(self, key: Any) -> Optional[str]: try: return self._ids.get(key) except TypeError: return None - def register(self, object, id): + def register(self, object: Any, id: str) -> None: self._objects[id] = object self._ids[object] = id - def loads(self, s): + def loads(self, s: bytes) -> "Node": up = Unpickler(BytesIO(s)) - up.persistent_load = self._get_object + # NOTE on type error: https://github.com/python/mypy/issues/2427 + # type error: Cannot assign to a method + up.persistent_load = self._get_object # type: ignore[assignment] try: return up.load() except KeyError as e: raise UnpicklingError("Could not find Node class for %s" % e) - def dumps(self, obj, protocol=None, bin=None): + def dumps( + self, obj: "Node", protocol: Optional[Any] = None, bin: Optional[Any] = None + ): src = BytesIO() p = Pickler(src) - p.persistent_id = self._get_ids + # NOTE on type error: https://github.com/python/mypy/issues/2427 + # type error: Cannot assign to a method + p.persistent_id = self._get_ids # type: ignore[assignment] p.dump(obj) return src.getvalue() - def __getstate__(self): + def __getstate__(self) -> Mapping[str, Any]: state = self.__dict__.copy() del state["_get_object"] state.update( @@ -127,7 +155,7 @@ def __getstate__(self): ) return state - def __setstate__(self, state): + def __setstate__(self, state: Mapping[str, Any]) -> None: self.__dict__.update(state) self._ids = dict(self._ids) self._objects = dict(self._objects) @@ -141,19 +169,23 @@ class Store(object): transaction_aware = False graph_aware = False - def __init__(self, configuration=None, identifier=None): + def __init__( + self, + configuration: Optional[Any] = None, + identifier: Optional["Identifier"] = None, + ): """ identifier: URIRef of the Store. Defaults to CWD configuration: string containing information open can use to connect to datastore. """ - self.__node_pickler = None + self.__node_pickler: Optional[NodePickler] = None self.dispatcher = Dispatcher() if configuration: self.open(configuration) @property - def node_pickler(self): + def node_pickler(self) -> NodePickler: if self.__node_pickler is None: from rdflib.graph import Graph, QuotedGraph from rdflib.term import BNode, Literal, URIRef, Variable @@ -169,10 +201,10 @@ def node_pickler(self): return self.__node_pickler # Database management methods - def create(self, configuration): + def create(self, configuration: Optional[Any]) -> None: self.dispatcher.dispatch(StoreCreatedEvent(configuration=configuration)) - def open(self, configuration, create: bool = False): + def open(self, configuration: Any, create: bool = False) -> Optional[int]: """ Opens the store specified by the configuration string. If create is True a store will be created if it does not already @@ -184,20 +216,20 @@ def open(self, configuration, create: bool = False): """ return UNKNOWN - def close(self, commit_pending_transaction=False): + def close(self, commit_pending_transaction: bool = False) -> None: """ This closes the database connection. The commit_pending_transaction parameter specifies whether to commit all pending transactions before closing (if the store is transactional). """ - def destroy(self, configuration): + def destroy(self, configuration: Any) -> None: """ This destroys the instance of the store identified by the configuration string. """ - def gc(self): + def gc(self) -> None: """ Allows the store to perform any needed garbage collection """ @@ -206,10 +238,10 @@ def gc(self): # RDF APIs def add( self, - triple: Tuple["Node", "Node", "Node"], - context: Optional["Graph"], + triple: "_TripleType", + context: "Graph", quoted: bool = False, - ): + ) -> None: """ Adds the given statement to a specific context or to the model. The quoted argument is interpreted by formula-aware stores to indicate @@ -220,9 +252,7 @@ def add( """ self.dispatcher.dispatch(TripleAddedEvent(triple=triple, context=context)) - def addN( # noqa: N802 - self, quads: Iterable[Tuple["Node", "Node", "Node", "Graph"]] - ): + def addN(self, quads: Iterable["_QuadType"]) -> None: # noqa: N802 """ Adds each item in the list of statements to a specific context. The quoted argument is interpreted by formula-aware stores to indicate this @@ -237,11 +267,68 @@ def addN( # noqa: N802 ) self.add((s, p, o), c) - def remove(self, triple, context=None): + def remove( + self, + triple: "_TriplePatternType", + context: Optional["Graph"] = None, + ) -> None: """Remove the set of triples matching the pattern from the store""" self.dispatcher.dispatch(TripleRemovedEvent(triple=triple, context=context)) - def triples_choices(self, triple, context=None): + @overload + def triples_choices( + self, + triple: Tuple[List["_SubjectType"], "_PredicateType", "_ObjectType"], + context: Optional["Graph"] = None, + ) -> Generator[Tuple["_TripleType", Iterator[Optional["Graph"]]], None, None,]: + ... + + @overload + def triples_choices( + self, + triple: Tuple["_SubjectType", List["_PredicateType"], "_ObjectType"], + context: Optional["Graph"] = None, + ) -> Generator[ + Tuple[ + Tuple["_SubjectType", "_PredicateType", "_ObjectType"], + Iterator[Optional["Graph"]], + ], + None, + None, + ]: + ... + + @overload + def triples_choices( + self, + triple: Tuple["_SubjectType", "_PredicateType", List["_ObjectType"]], + context: Optional["Graph"] = None, + ) -> Generator[ + Tuple[ + Tuple["_SubjectType", "_PredicateType", "_ObjectType"], + Iterator[Optional["Graph"]], + ], + None, + None, + ]: + ... + + def triples_choices( + self, + triple: Tuple[ + Union["_SubjectType", List["_SubjectType"]], + Union["_PredicateType", List["_PredicateType"]], + Union["_ObjectType", List["_ObjectType"]], + ], + context: Optional["Graph"] = None, + ) -> Generator[ + Tuple[ + Tuple["_SubjectType", "_PredicateType", "_ObjectType"], + Iterator[Optional["Graph"]], + ], + None, + None, + ]: """ A variant of triples that can take a list of terms instead of a single term in any slot. Stores can implement this to optimize the response @@ -290,13 +377,12 @@ def triples_choices(self, triple, context=None): for (s1, p1, o1), cg in self.triples((subject, None, object_), context): yield (s1, p1, o1), cg - def triples( + # type error: Missing return statement + def triples( # type: ignore[return] self, - triple_pattern: Tuple[ - Optional["IdentifiedNode"], Optional["IdentifiedNode"], Optional["Node"] - ], - context=None, - ): + triple_pattern: "_TriplePatternType", + context: Optional["Graph"] = None, + ) -> Iterator[Tuple["_TripleType", Iterator[Optional["Graph"]]]]: """ A generator over all the triples matching the pattern. Pattern can include any objects for used for comparing against nodes in the store, @@ -311,7 +397,7 @@ def triples( # variants of triples will be done if / when optimization is needed - def __len__(self, context=None): + def __len__(self, context: Optional["Graph"] = None) -> int: """ Number of statements in the store. This should only account for non- quoted (asserted) statements if the context is not specified, @@ -321,7 +407,9 @@ def __len__(self, context=None): :param context: a graph instance to query or None """ - def contexts(self, triple=None): + def contexts( + self, triple: Optional["_TripleType"] = None + ) -> Generator["Graph", None, None]: """ Generator over all contexts in the graph. If triple is specified, a generator over all contexts the triple is in. @@ -331,7 +419,15 @@ def contexts(self, triple=None): :returns: a generator over Nodes """ - def query(self, query, initNs, initBindings, queryGraph, **kwargs): # noqa: N803 + # TODO FIXME: the result of query is inconsistent. + def query( + self, + query: Union["Query", str], + initNs: Dict[str, str], # noqa: N803 + initBindings: Dict["Variable", "Identifier"], # noqa: N803 + queryGraph: "Identifier", # noqa: N803 + **kwargs: Any, + ): """ If stores provide their own SPARQL implementation, override this. @@ -347,7 +443,14 @@ def query(self, query, initNs, initBindings, queryGraph, **kwargs): # noqa: N80 raise NotImplementedError - def update(self, update, initNs, initBindings, queryGraph, **kwargs): # noqa: N803 + def update( + self, + update: Union["Update", str], + initNs: Dict[str, str], # noqa: N803 + initBindings: Dict["Variable", "Identifier"], # noqa: N803 + queryGraph: "Identifier", # noqa: N803 + **kwargs: Any, + ) -> None: """ If stores provide their own (SPARQL) Update implementation, override this. @@ -377,25 +480,25 @@ def prefix(self, namespace: "URIRef") -> Optional["str"]: def namespace(self, prefix: str) -> Optional["URIRef"]: """ """ - def namespaces(self): + def namespaces(self) -> Iterator[Tuple[str, "URIRef"]]: """ """ # This is here so that the function becomes an empty generator. # See https://stackoverflow.com/q/13243766 and # https://www.python.org/dev/peps/pep-0255/#why-a-new-keyword-for-yield-why-not-a-builtin-function-instead if False: - yield None + yield None # type: ignore[unreachable] # Optional Transactional methods - def commit(self): + def commit(self) -> None: """ """ - def rollback(self): + def rollback(self) -> None: """ """ # Optional graph methods - def add_graph(self, graph): + def add_graph(self, graph: "Graph") -> None: """ Add a graph to the store, no effect if the graph already exists. @@ -403,7 +506,7 @@ def add_graph(self, graph): """ raise Exception("Graph method called on non-graph_aware store") - def remove_graph(self, graph): + def remove_graph(self, graph: Optional["Graph"]) -> None: """ Remove a graph from the store, this should also remove all triples in the graph diff --git a/rdflib/term.py b/rdflib/term.py index e68f1a7dce..8fbaca0dd0 100644 --- a/rdflib/term.py +++ b/rdflib/term.py @@ -35,6 +35,7 @@ "Variable", ] +import abc import logging import math import warnings @@ -230,7 +231,7 @@ def startswith(self, prefix: str, start=..., end=...) -> bool: # type: ignore[o __hash__ = str.__hash__ -class IdentifiedNode(Identifier): +class IdentifiedNode(Identifier, abc.ABC): """ An abstract class, primarily defined to identify Nodes that are not Literals. @@ -243,6 +244,10 @@ def __getnewargs__(self) -> Tuple[str]: def toPython(self) -> str: # noqa: N802 return str(self) + @abc.abstractmethod + def n3(self, namespace_manager: Optional["NamespaceManager"] = None) -> str: + ... + class URIRef(IdentifiedNode): """ diff --git a/rdflib/tools/csv2rdf.py b/rdflib/tools/csv2rdf.py index 0179629f3b..c8fbc66da8 100644 --- a/rdflib/tools/csv2rdf.py +++ b/rdflib/tools/csv2rdf.py @@ -20,8 +20,7 @@ from urllib.parse import quote import rdflib -from rdflib import RDF, RDFS -from rdflib.namespace import split_uri +from rdflib.namespace import RDF, RDFS, split_uri __all__ = ["CSV2RDF"] diff --git a/rdflib/tools/defined_namespace_creator.py b/rdflib/tools/defined_namespace_creator.py index 2cfe99f295..9fbf18a106 100644 --- a/rdflib/tools/defined_namespace_creator.py +++ b/rdflib/tools/defined_namespace_creator.py @@ -17,7 +17,7 @@ sys.path.append(str(Path(__file__).parent.absolute().parent.parent)) -from rdflib import Graph +from rdflib.graph import Graph from rdflib.namespace import DCTERMS, OWL, RDFS, SKOS from rdflib.util import guess_format diff --git a/rdflib/tools/graphisomorphism.py b/rdflib/tools/graphisomorphism.py index 004b567b82..f1bba4b77e 100644 --- a/rdflib/tools/graphisomorphism.py +++ b/rdflib/tools/graphisomorphism.py @@ -5,7 +5,8 @@ from itertools import combinations -from rdflib import BNode, Graph +from rdflib.graph import Graph +from rdflib.term import BNode class IsomorphicTestableGraph(Graph): diff --git a/rdflib/tools/rdf2dot.py b/rdflib/tools/rdf2dot.py index c59f78b88b..8670d9c852 100644 --- a/rdflib/tools/rdf2dot.py +++ b/rdflib/tools/rdf2dot.py @@ -15,7 +15,7 @@ import rdflib import rdflib.extras.cmdlineutils -from rdflib import XSD +from rdflib.namespace import XSD LABEL_PROPERTIES = [ rdflib.RDFS.label, diff --git a/rdflib/tools/rdfs2dot.py b/rdflib/tools/rdfs2dot.py index 69ecfba581..78eccb7002 100644 --- a/rdflib/tools/rdfs2dot.py +++ b/rdflib/tools/rdfs2dot.py @@ -14,7 +14,7 @@ import sys import rdflib.extras.cmdlineutils -from rdflib import RDF, RDFS, XSD +from rdflib.namespace import RDF, RDFS, XSD XSDTERMS = [ XSD[x] diff --git a/rdflib/void.py b/rdflib/void.py index e0ac04f81e..54c3d7cb7c 100644 --- a/rdflib/void.py +++ b/rdflib/void.py @@ -1,7 +1,8 @@ import collections -from rdflib import Graph, Literal, URIRef +from rdflib.graph import Graph from rdflib.namespace import RDF, VOID +from rdflib.term import Literal, URIRef def generateVoID( # noqa: N802 diff --git a/test/test_misc/test_plugins.py b/test/test_misc/test_plugins.py index 46147f5a1c..7263bc738b 100644 --- a/test/test_misc/test_plugins.py +++ b/test/test_misc/test_plugins.py @@ -6,13 +6,14 @@ import warnings from contextlib import ExitStack, contextmanager from pathlib import Path -from typing import Any, Callable, Dict, Generator, List +from typing import Any, Callable, Dict, Generator, List, cast import rdflib.plugin import rdflib.plugins.sparql import rdflib.plugins.sparql.evaluate from rdflib import Graph from rdflib.parser import Parser +from rdflib.query import ResultRow TEST_DIR = Path(__file__).parent.parent TEST_PLUGINS_DIR = TEST_DIR / "plugins" @@ -92,7 +93,7 @@ def test_sparqleval(tmp_path: Path, no_cover: None) -> None: logging.debug("query_string = %s", query_string) result = graph.query(query_string) assert result.type == "SELECT" - rows = list(result) + rows = cast(List[ResultRow], list(result)) logging.debug("rows = %s", rows) assert len(rows) == 1 assert len(rows[0]) == 1 diff --git a/test/test_sparql/test_forward_slash_escapes.py b/test/test_sparql/test_forward_slash_escapes.py index 4400c003ae..a33e438324 100644 --- a/test/test_sparql/test_forward_slash_escapes.py +++ b/test/test_sparql/test_forward_slash_escapes.py @@ -28,6 +28,7 @@ from rdflib import Graph from rdflib.plugins.sparql.processor import prepareQuery from rdflib.plugins.sparql.sparql import Query +from rdflib.query import ResultRow query_string_expanded = r""" SELECT ?nIndividual @@ -113,6 +114,7 @@ def _test_escapes_and_query( assert expected_query_compiled == query_compiled for result in graph.query(query_object): + assert isinstance(result, ResultRow) computed.add(str(result[0])) assert expected == computed diff --git a/test/test_sparql/test_sparql.py b/test/test_sparql/test_sparql.py index d62c345e10..0c63a56588 100644 --- a/test/test_sparql/test_sparql.py +++ b/test/test_sparql/test_sparql.py @@ -18,7 +18,7 @@ from rdflib.plugins.sparql.parser import parseQuery from rdflib.plugins.sparql.parserutils import prettify_parsetree from rdflib.plugins.sparql.sparql import SPARQLError -from rdflib.query import Result +from rdflib.query import Result, ResultRow from rdflib.term import Identifier, Variable @@ -303,6 +303,7 @@ def test_call_function() -> None: assert result.type == "SELECT" rows = list(result) assert len(rows) == 1 + assert isinstance(rows[0], ResultRow) assert len(rows[0]) == 1 assert rows[0][0] == Literal("a + b") @@ -353,6 +354,7 @@ def custom_eval(ctx: Any, part: Any) -> Any: assert result.type == "SELECT" rows = list(result) assert len(rows) == 1 + assert isinstance(rows[0], ResultRow) assert len(rows[0]) == 2 assert rows[0][0] == Literal("a + b") assert rows[0][1] == custom_function_result diff --git a/test/test_typing.py b/test/test_typing.py index 2ee06d2a4c..8598fb0361 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -27,7 +27,8 @@ # TODO Bug - rdflib.plugins.sparql.prepareQuery() will run fine if this # test is run, but mypy can't tell the symbol is exposed. import rdflib.plugins.sparql.processor -from rdflib.term import Node +from rdflib.query import ResultRow +from rdflib.term import IdentifiedNode, Identifier, Node def test_rdflib_query_exercise() -> None: @@ -78,8 +79,12 @@ def test_rdflib_query_exercise() -> None: kb_https_uriref, kb_urn_uriref, } - computed_one_usage: Set[rdflib.IdentifiedNode] = set() + computed_one_usage: Set[Identifier] = set() for one_usage_result in graph.query(one_usage_query): + # NOTE on cast: A query result can be a graph (Iterable of Triples, + # bool or a row) so a cast is needed to disambiguiate + assert isinstance(one_usage_result, ResultRow) + # one_usage_result = cast(ResultRow, one_usage_result) computed_one_usage.add(one_usage_result[0]) assert expected_one_usage == computed_one_usage @@ -95,19 +100,16 @@ def test_rdflib_query_exercise() -> None: } """ - expected_two_usage: Set[ - Tuple[ - rdflib.IdentifiedNode, - rdflib.IdentifiedNode, - ] - ] = {(kb_https_uriref, predicate_p), (kb_https_uriref, predicate_q)} - computed_two_usage: Set[ - Tuple[ - rdflib.IdentifiedNode, - rdflib.IdentifiedNode, - ] - ] = set() + expected_two_usage: Set[Tuple[Identifier, ...]] = { + (kb_https_uriref, predicate_p), + (kb_https_uriref, predicate_q), + } + computed_two_usage: Set[Tuple[Identifier, ...]] = set() for two_usage_result in graph.query(two_usage_query): + # NOTE on cast: A query result can be a graph (Iterable of Triples, + # bool or a row) so a cast is needed to disambiguiate + # two_usage_result = cast(ResultRow, one_usage_result) + assert isinstance(two_usage_result, ResultRow) computed_two_usage.add(two_usage_result) assert expected_two_usage == computed_two_usage @@ -116,12 +118,17 @@ def test_rdflib_query_exercise() -> None: prepared_one_usage_query = rdflib.plugins.sparql.processor.prepareQuery( one_usage_query, initNs=nsdict ) - computed_one_usage_from_prepared_query: Set[rdflib.IdentifiedNode] = set() + computed_one_usage_from_prepared_query: Set[Identifier] = set() for prepared_one_usage_result in graph.query(prepared_one_usage_query): + # NOTE on cast: A query result can be a graph (Iterable of Triples, + # bool or a row) so a cast is needed to disambiguiate + # prepared_one_usage_result = cast(ResultRow, prepared_one_usage_result) + assert isinstance(prepared_one_usage_result, ResultRow) computed_one_usage_from_prepared_query.add(prepared_one_usage_result[0]) assert expected_one_usage == computed_one_usage_from_prepared_query for node_using_one in sorted(computed_one_usage): + assert isinstance(node_using_one, IdentifiedNode) graph.add((node_using_one, predicate_r, literal_true)) python_one: int = literal_one.toPython() diff --git a/test/utils/dawg_manifest.py b/test/utils/dawg_manifest.py index a256dcad40..4ef53cd02f 100644 --- a/test/utils/dawg_manifest.py +++ b/test/utils/dawg_manifest.py @@ -45,12 +45,16 @@ class ManifestEntry: result: Optional[IdentifiedNode] = field(init=False) def __post_init__(self) -> None: - type = self.value(RDF.type, IdentifiedNode) + # type error: Only concrete class can be given where "Type[IdentifiedNode]" is expected + # NOTE on type ignores: mypy is overly strict when it comes to abstract + # types, this is an open issue + # + type = self.value(RDF.type, IdentifiedNode) # type: ignore[misc] assert type is not None self.type = type - self.action = self.value(MF.action, IdentifiedNode) - self.result = self.value(MF.result, IdentifiedNode) + self.action = self.value(MF.action, IdentifiedNode) # type: ignore[misc] + self.result = self.value(MF.result, IdentifiedNode) # type: ignore[misc] @property def graph(self) -> Graph: diff --git a/test/utils/sparql_checker.py b/test/utils/sparql_checker.py index f6375a4ddf..78b26daabf 100644 --- a/test/utils/sparql_checker.py +++ b/test/utils/sparql_checker.py @@ -391,6 +391,8 @@ def check_query(monkeypatch: MonkeyPatch, entry: SPARQLEntry) -> None: elif result.type == ResultType.ASK: assert expected_result.askAnswer == result.askAnswer else: + assert expected_result.graph is not None + assert result.graph is not None logging.debug( "expected_result.graph = %s, result.graph = %s\n%s", expected_result.graph,