Skip to content

Commit

Permalink
Support space multiplication (#74)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtth authored May 30, 2023
1 parent a8e2485 commit 1cadf1a
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 56 deletions.
4 changes: 2 additions & 2 deletions opvious/client/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ async def inspect_solve_instructions(
The LP formatted output will be fully annotated with matching keys and
labels:
.. code::
.. code-block::
minimize
+1 inventory$1 \\ [day=0]
Expand Down Expand Up @@ -348,7 +348,7 @@ async def run_solve(
The returned response exposes both metadata (status, objective value,
etc.) and solution data (if the solve was feasible):
.. code:: python
.. code-block:: python
response = await client.run_solve(
specification=opvious.RemoteSpecification.example(
Expand Down
2 changes: 2 additions & 0 deletions opvious/modeling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Domain,
Expression,
ExpressionLike,
IterableSpace,
Predicate,
Projection,
Quantifiable,
Expand Down Expand Up @@ -65,6 +66,7 @@
"total",
# Quantification
"Cross",
"IterableSpace",
"Domain",
"Projection",
"Quantifiable",
Expand Down
122 changes: 95 additions & 27 deletions opvious/modeling/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,18 @@
import dataclasses
import itertools
import math
from typing import Any, cast, Iterable, Optional, Sequence, TypeVar, Union
from typing import (
Any,
cast,
Iterable,
Iterator,
Mapping,
Optional,
Protocol,
Sequence,
TypeVar,
Union,
)

from ..common import untuple
from .identifiers import (
Expand Down Expand Up @@ -261,7 +272,7 @@ def render(self) -> str:
groups = []
outer = itertools.groupby(self.quantifiers, _quantifier_grouping_key)
for (_id, key), outer_qs in outer:
if isinstance(key, Space):
if isinstance(key, ScalarSpace):
names = ", ".join(q.format() for q in outer_qs)
groups.append(f"{names} \\in {key.render()}")
else:
Expand All @@ -285,10 +296,10 @@ def render(self) -> str:

def _quantifier_grouping_key(
q: QuantifierIdentifier,
) -> tuple[int, Union[Space, QuantifierGroup]]:
) -> tuple[int, Union[ScalarSpace, QuantifierGroup]]:
# We add the ID to prevent `__eq__` from being called on equations
sp = q.space
if not isinstance(sp, Space):
if not isinstance(sp, ScalarSpace):
raise TypeError(f"Unexpected space: {sp}")
if not q.groups:
return (id(sp), sp)
Expand Down Expand Up @@ -318,7 +329,7 @@ def render(self, _precedence=0) -> str:
with local_formatting_scope(qs):
if len(qs) == 1 and self.domain.mask is None:
_id, key = _quantifier_grouping_key(qs[0])
if isinstance(key, Space):
if isinstance(key, ScalarSpace):
sp = key.render()
else:
sp = render_identifier(key.alias, *key.subscripts)
Expand Down Expand Up @@ -349,6 +360,25 @@ def render(self, precedence=0) -> str:


class Space:
"""Base quantification
This class provides support for generating cross-products with the `*`
operator (see :func:`~opvious.modeling.cross`):
.. code-block:: python
space1 * space2 # Equivalent to cross(space1, space2)
"""

def __mul__(self, other: Quantifiable) -> Quantification:
return cross(self, other)

def __rmul__(self, left: Quantifiable) -> Quantification:
return cross(left, self)


class ScalarSpace(Space):
def __iter__(self) -> Quantified[Quantifier]:
return (untuple(t) for t in cross(self))

Expand All @@ -357,7 +387,7 @@ def render(self) -> str:


@dataclasses.dataclass(frozen=True)
class QuantifiableReference(Space):
class QuantifiableReference(ScalarSpace):
identifier: AliasIdentifier
subscripts: tuple[Expression, ...]
quantifiers: tuple[QuantifierIdentifier, ...]
Expand All @@ -382,7 +412,30 @@ def render(self, _precedence=0) -> str:
return self.identifier.format()


def expression_space(expr: Expression) -> Optional[Space]:
_Q = TypeVar(
"_Q", bound=Union[Quantifier, tuple[Quantifier, ...]], covariant=True
)


class IterableSpace(Protocol[_Q]):
"""Base protocol for spaces which can also be directly iterated on
It is exposed mostly as a typing convenience for typing model fragments.
:class:`~opvious.modeling.Space` is typically used for providing the
underlying implementation.
"""

def __mul__(self, other: Quantifiable) -> Quantification:
raise NotImplementedError()

def __rmul__(self, other: Quantifiable) -> Quantification:
raise NotImplementedError()

def __iter__(self) -> Iterator[_Q]:
raise NotImplementedError()


def expression_space(expr: Expression) -> Optional[ScalarSpace]:
"""Returns the underlying scalar quantifiable for an expression if any"""
if isinstance(expr, Quantifier):
return expr.identifier.space
Expand Down Expand Up @@ -512,7 +565,7 @@ def domain(
names: Optional[Iterable[Name]] = None,
) -> Domain:
"""Creates a domain from a quantifiable"""
return _domain_from_quantified(cross(quantifiable, names=names))
return _domain_from_quantified(iter(cross(quantifiable, names=names)))


def _domain_from_quantified(
Expand Down Expand Up @@ -552,11 +605,36 @@ def lifted(self) -> tuple[Quantifier, ...]:
raise Exception("Unlifted cross-product")
return self._lifted

def __getitem__(self, ix) -> Quantifier:
return self._quantifiers[ix]

def __iter__(self):
return iter(self._quantifiers)


Quantification = Quantified[Cross]
@dataclasses.dataclass(frozen=True)
class Quantification(Space):
quantifiables: tuple[Quantifiable, ...]
names: Mapping[int, Name]
projection: Projection
lift: bool

def __iter__(self) -> Quantified[Cross]:
projected: list[Quantifier] = []
lifted: list[Quantifier] = []
for i, d in enumerate(self.quantifiables):
project = (1 << i) & self.projection
if not project and not self.lift:
continue
j0 = len(projected)
quants = list(
Quantifier(declare(iden.named(self.names.get(j0 + j))))
for j, iden in enumerate(_quantifier_identifiers(d))
)
lifted.extend(quants)
if project:
projected.extend(quants)
yield Cross(tuple(projected), tuple(lifted))


def cross(
Expand All @@ -571,28 +649,18 @@ def cross(
quantifiables: One or more quantifiables
names: Optional names for the generated quantifiers
projection: Quantifiable selection mask
lift: Returns a lifted :class:`~opvious.modeling.Cross` instance.
lift: Returns lifted :class:`~opvious.modeling.Cross` instances.
Setting this option will include all masks present in the original
quantifiable, even if they are not projected.
This function is the core building block for quantifying values.
"""
names_by_index = dict(enumerate(names or []))
projected: list[Quantifier] = []
lifted: list[Quantifier] = []
for i, q in enumerate(quantifiables):
project = (1 << i) & projection
if not project and not lift:
continue
j0 = len(projected)
quants = list(
Quantifier(declare(iden.named(names_by_index.get(j0 + j))))
for j, iden in enumerate(_quantifier_identifiers(q))
)
lifted.extend(quants)
if project:
projected.extend(quants)
yield Cross(tuple(projected), tuple(lifted))
return Quantification(
quantifiables=quantifiables,
names=dict(enumerate(names or [])),
projection=projection,
lift=lift,
)


def _quantifier_identifiers(
Expand All @@ -610,7 +678,7 @@ def _quantifier_identifiers(
)
for q in qs:
yield q.grouped_within(group)
elif isinstance(quantifiable, Space):
elif isinstance(quantifiable, ScalarSpace):
yield QuantifierIdentifier.base(quantifiable)
else: # domain or quantified
if isinstance(quantifiable, Domain):
Expand Down
25 changes: 14 additions & 11 deletions opvious/modeling/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Any,
Callable,
Iterable,
Iterator,
Literal,
Optional,
Sequence,
Expand All @@ -25,17 +26,19 @@
Expression,
ExpressionLike,
ExpressionReference,
literal,
IterableSpace,
Predicate,
Space,
Quantifiable,
QuantifiableReference,
Quantifier,
QuantifierIdentifier,
ScalarSpace,
Space,
cross,
domain,
expression_space,
is_literal,
literal,
render_identifier,
to_expression,
within_domain,
Expand All @@ -55,7 +58,7 @@
_logger = logging.getLogger(__name__)


class Dimension(Definition, Space):
class Dimension(Definition, ScalarSpace):
"""An abstract collection of values
Args:
Expand Down Expand Up @@ -115,7 +118,7 @@ def render_statement(self, label: Label) -> Optional[str]:


@dataclasses.dataclass(frozen=True)
class _Interval(Space):
class _Interval(ScalarSpace):
lower_bound: Expression
upper_bound: Expression

Expand All @@ -139,7 +142,7 @@ def interval(
lower_bound: ExpressionLike,
upper_bound: ExpressionLike,
name: Optional[Name] = None,
) -> Iterable[Quantifier]:
) -> IterableSpace[Quantifier]:
"""A range of values
Args:
Expand All @@ -152,13 +155,13 @@ def interval(
upper_bound=to_expression(upper_bound),
)

class _Fragment(ModelFragment):
class _Fragment(ModelFragment, Space):
@property
@alias(name)
def interval(self):
return interval

def __iter__(self) -> Quantified[Quantifier]:
def __iter__(self) -> Iterator[Quantifier]:
return iter(self.interval)

return _Fragment()
Expand Down Expand Up @@ -204,7 +207,7 @@ class Tensor(Definition):
Calling a tensor returns an :class:`~.opvious.modeling.Expression` with any
arguments as subscripts. For example:
.. code:: python
.. code-block:: python
class ProductModel(Model):
products = Dimension()
Expand Down Expand Up @@ -364,7 +367,7 @@ class Parameter(Tensor):
Consider instantiating parameters via one of the various
:class:`~opvious.modeling.Tensor` convenience class methods, for example:
.. code:: python
.. code-block:: python
p1 = Parameter.continuous() # Real-valued parameter
p2 = Parameter.natural() # # Parameter with values in {0, 1...}
Expand Down Expand Up @@ -392,7 +395,7 @@ class Variable(Tensor):
various :class:`~opvious.modeling.Tensor` convenience class methods, for
example:
.. code:: python
.. code-block:: python
v1 = Variable.unit() # Variable with value within [0, 1]
v2 = Variable.non_negative() # # Variable with value at least 0
Expand All @@ -408,7 +411,7 @@ class Variable(Tensor):

@dataclasses.dataclass(frozen=True)
class _Aliased:
quantifiables: Sequence[Optional[Space]]
quantifiables: Sequence[Optional[ScalarSpace]]
quantifiers: Union[
None, QuantifierIdentifier, tuple[QuantifierIdentifier, ...]
]
Expand Down
2 changes: 1 addition & 1 deletion opvious/modeling/identifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class QuantifierGroup:


class QuantifierIdentifier(Identifier):
space: Any # Space
space: Any # ScalarSpace
groups: Sequence[QuantifierGroup]
name: Optional[Name]

Expand Down
4 changes: 2 additions & 2 deletions opvious/modeling/quantified.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import contextvars
import dataclasses
import itertools
from typing import Any, Generator, Tuple, TypeVar
from typing import Any, Iterator, Tuple, TypeVar


_V = TypeVar("_V")


Quantified = Generator[_V, None, None]
Quantified = Iterator[_V]


def _run_quantified(quantified: Quantified[_V]) -> _V:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "opvious"
version = "0.12.5rc3"
version = "0.12.6rc1"
description = "Opvious Python SDK"
authors = ["Opvious Engineering <[email protected]>"]
readme = "README.md"
Expand Down
Loading

0 comments on commit 1cadf1a

Please sign in to comment.