Skip to content

Commit

Permalink
Merge pull request #337 from OpShin/feat/fast_index_skip
Browse files Browse the repository at this point in the history
Add fast list access implementation
  • Loading branch information
nielstron authored Nov 28, 2024
2 parents d657a22 + 96fc5c6 commit 400757f
Show file tree
Hide file tree
Showing 10 changed files with 795 additions and 546 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ repos:
# supported by your project here, or alternatively use
# pre-commit's default_language_version, see
# https://pre-commit.com/#top_level-default_language_version
language_version: python3.10
language_version: python3.11
52 changes: 32 additions & 20 deletions opshin/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,14 +401,14 @@ def perform_command(args):
print("Starting execution")
print("------------------")
assert isinstance(code, uplc.ast.Program)
try:
ret = uplc.eval(code)
except Exception as e:
raw_ret = uplc.eval(code)
if isinstance(raw_ret.result, Exception):
print("An exception was raised")
ret = e
ret = raw_ret.result
else:
print("Execution succeeded")
ret = uplc.dumps(ret.result)
ret = uplc.dumps(raw_ret.result)
print(f"CPU: {raw_ret.cost.cpu} | MEM: {raw_ret.cost.memory}")
print("------------------")
print(ret)

Expand Down Expand Up @@ -475,21 +475,33 @@ def parse_args():
)
for k, v in ARGPARSE_ARGS.items():
alts = v.pop("__alts__", [])
a.add_argument(
f"-f{k.replace('_', '-')}",
*alts,
**v,
action="store_true",
dest=k,
default=None,
)
a.add_argument(
f"-fno-{k.replace('_', '-')}",
action="store_false",
help=argparse.SUPPRESS,
dest=k,
default=None,
)
type = v.pop("type", None)
if type is None:
a.add_argument(
f"-f{k.replace('_', '-')}",
*alts,
**v,
action="store_true",
dest=k,
default=None,
)
a.add_argument(
f"-fno-{k.replace('_', '-')}",
action="store_false",
help=argparse.SUPPRESS,
dest=k,
default=None,
)
else:
a.add_argument(
f"-f{k.replace('_', '-')}",
*alts,
**v,
type=type,
dest=k,
default=None,
)

a.add_argument(
f"-O",
type=int,
Expand Down
26 changes: 24 additions & 2 deletions opshin/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,19 @@ class PlutoCompiler(CompilingNodeTransformer):

step = "Compiling python statements to UPLC"

def __init__(self, force_three_params=False, validator_function_name="validator"):
def __init__(
self,
force_three_params=False,
validator_function_name="validator",
config=DEFAULT_CONFIG,
):
# parameters
self.force_three_params = force_three_params
self.validator_function_name = validator_function_name
self.config = config
assert (
self.config.fast_access_skip is None or self.config.fast_access_skip > 1
), "Parameter fast-access-skip needs to be greater than 1 or omitted"
# marked knowledge during compilation
self.current_function_typ: typing.List[FunctionType] = []

Expand Down Expand Up @@ -654,6 +663,12 @@ def visit_Subscript(self, node: TypedSubscript) -> plt.AST:
assert (
node.slice.typ == IntegerInstanceType
), "Only single element list index access supported"
if isinstance(node.slice, Constant) and node.slice.value >= 0:
index = node.slice.value
return plt.ConstantIndexAccessListFast(
self.visit(node.value),
index,
)
return OLet(
[
(
Expand All @@ -675,7 +690,13 @@ def visit_Subscript(self, node: TypedSubscript) -> plt.AST:
),
),
],
plt.IndexAccessList(OVar("l"), OVar("i")),
(
plt.IndexAccessListFast(self.config.fast_access_skip)(
OVar("l"), OVar("i")
)
if self.config.fast_access_skip is not None
else plt.IndexAccessList(OVar("l"), OVar("i"))
),
)
else:
return OLet(
Expand Down Expand Up @@ -1089,6 +1110,7 @@ def compile(
s = PlutoCompiler(
force_three_params=config.force_three_params,
validator_function_name=validator_function_name,
config=config,
)
prog = s.visit(prog)

Expand Down
6 changes: 6 additions & 0 deletions opshin/compiler_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class CompilationConfig(pluthon.CompilationConfig):
allow_isinstance_anything: Optional[bool] = None
force_three_params: Optional[bool] = None
remove_dead_code: Optional[bool] = None
fast_access_skip: Optional[int] = None


# The default configuration for the compiler
Expand All @@ -35,6 +36,7 @@ class CompilationConfig(pluthon.CompilationConfig):
.update(pluthon.OPT_O2_CONFIG)
.update(
constant_folding=True,
fast_access_skip=5,
)
)
OPT_O3_CONFIG = (
Expand Down Expand Up @@ -64,6 +66,10 @@ class CompilationConfig(pluthon.CompilationConfig):
"remove_dead_code": {
"help": "Removes dead code and variables from the contract. Should be enabled for non-debugging purposes.",
},
"fast_access_skip": {
"help": "How many steps to skip for fast list index access, default None means no steps are skipped (useful if long lists are common).",
"type": int,
},
}
)
for k in ARGPARSE_ARGS:
Expand Down
27 changes: 16 additions & 11 deletions opshin/type_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def attribute(self, attr: str) -> plt.AST:
return OLambda(
["self"],
transform_ext_params_map(attr_typ)(
plt.ConstantNthField(
plt.ConstantNthFieldFast(
OVar("self"),
pos,
),
Expand Down Expand Up @@ -632,7 +632,7 @@ def stringify(self, recursive: bool = False) -> plt.AST:
plt.Apply(
field_type.stringify(recursive=True),
transform_ext_params_map(field_type)(
plt.ConstantNthField(OVar("self"), pos)
plt.ConstantNthFieldFast(OVar("self"), pos)
),
),
map_fields,
Expand All @@ -643,7 +643,7 @@ def stringify(self, recursive: bool = False) -> plt.AST:
plt.Apply(
self.record.fields[0][1].stringify(recursive=True),
transform_ext_params_map(self.record.fields[0][1])(
plt.ConstantNthField(OVar("self"), pos)
plt.ConstantNthFieldFast(OVar("self"), pos)
),
),
map_fields,
Expand Down Expand Up @@ -751,8 +751,16 @@ def attribute(self, attr: str) -> plt.AST:
if not pos_constrs:
pos_decisor = plt.TraceError("Invalid constructor")
else:
pos_decisor = plt.Integer(pos_constrs[-1][0])
pos_decisor = plt.ConstantNthFieldFast(OVar("self"), pos_constrs[-1][0])
pos_constrs = pos_constrs[:-1]
# constr is not needed when there is only one position for all constructors
if not pos_constrs:
return OLambda(
["self"],
transform_ext_params_map(attr_typ)(
pos_decisor,
),
)
for pos, constrs in pos_constrs:
assert constrs, "Found empty constructors for a position"
constr_check = plt.EqualsInteger(
Expand All @@ -765,18 +773,15 @@ def attribute(self, attr: str) -> plt.AST:
)
pos_decisor = plt.Ite(
constr_check,
plt.Integer(pos),
plt.ConstantNthFieldFast(OVar("self"), pos),
pos_decisor,
)
return OLambda(
["self"],
transform_ext_params_map(attr_typ)(
plt.NthField(
OVar("self"),
OLet(
[("constr", plt.Constructor(OVar("self")))],
pos_decisor,
),
OLet(
[("constr", plt.Constructor(OVar("self")))],
pos_decisor,
),
),
)
Expand Down
Loading

0 comments on commit 400757f

Please sign in to comment.