Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fast list access implementation #337

Merged
merged 10 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ repos:
# supported by your project here, or alternatively use
# pre-commit's default_language_version, see
# https://pre-commit.com/#top_level-default_language_version
language_version: python3.10
language_version: python3.11
52 changes: 32 additions & 20 deletions opshin/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,14 +401,14 @@ def perform_command(args):
print("Starting execution")
print("------------------")
assert isinstance(code, uplc.ast.Program)
try:
ret = uplc.eval(code)
except Exception as e:
raw_ret = uplc.eval(code)
if isinstance(raw_ret.result, Exception):
print("An exception was raised")
ret = e
ret = raw_ret.result
else:
print("Execution succeeded")
ret = uplc.dumps(ret.result)
ret = uplc.dumps(raw_ret.result)
print(f"CPU: {raw_ret.cost.cpu} | MEM: {raw_ret.cost.memory}")
print("------------------")
print(ret)

Expand Down Expand Up @@ -475,21 +475,33 @@ def parse_args():
)
for k, v in ARGPARSE_ARGS.items():
alts = v.pop("__alts__", [])
a.add_argument(
f"-f{k.replace('_', '-')}",
*alts,
**v,
action="store_true",
dest=k,
default=None,
)
a.add_argument(
f"-fno-{k.replace('_', '-')}",
action="store_false",
help=argparse.SUPPRESS,
dest=k,
default=None,
)
type = v.pop("type", None)
if type is None:
a.add_argument(
f"-f{k.replace('_', '-')}",
*alts,
**v,
action="store_true",
dest=k,
default=None,
)
a.add_argument(
f"-fno-{k.replace('_', '-')}",
action="store_false",
help=argparse.SUPPRESS,
dest=k,
default=None,
)
else:
a.add_argument(
f"-f{k.replace('_', '-')}",
*alts,
**v,
type=type,
dest=k,
default=None,
)

a.add_argument(
f"-O",
type=int,
Expand Down
26 changes: 24 additions & 2 deletions opshin/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,19 @@ class PlutoCompiler(CompilingNodeTransformer):

step = "Compiling python statements to UPLC"

def __init__(self, force_three_params=False, validator_function_name="validator"):
def __init__(
self,
force_three_params=False,
validator_function_name="validator",
config=DEFAULT_CONFIG,
):
# parameters
self.force_three_params = force_three_params
self.validator_function_name = validator_function_name
self.config = config
assert (
self.config.fast_access_skip is None or self.config.fast_access_skip > 1
), "Parameter fast-access-skip needs to be greater than 1 or omitted"
# marked knowledge during compilation
self.current_function_typ: typing.List[FunctionType] = []

Expand Down Expand Up @@ -654,6 +663,12 @@ def visit_Subscript(self, node: TypedSubscript) -> plt.AST:
assert (
node.slice.typ == IntegerInstanceType
), "Only single element list index access supported"
if isinstance(node.slice, Constant) and node.slice.value >= 0:
index = node.slice.value
return plt.ConstantIndexAccessListFast(
self.visit(node.value),
index,
)
return OLet(
[
(
Expand All @@ -675,7 +690,13 @@ def visit_Subscript(self, node: TypedSubscript) -> plt.AST:
),
),
],
plt.IndexAccessList(OVar("l"), OVar("i")),
(
plt.IndexAccessListFast(self.config.fast_access_skip)(
OVar("l"), OVar("i")
)
if self.config.fast_access_skip is not None
else plt.IndexAccessList(OVar("l"), OVar("i"))
),
)
else:
return OLet(
Expand Down Expand Up @@ -1089,6 +1110,7 @@ def compile(
s = PlutoCompiler(
force_three_params=config.force_three_params,
validator_function_name=validator_function_name,
config=config,
)
prog = s.visit(prog)

Expand Down
6 changes: 6 additions & 0 deletions opshin/compiler_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class CompilationConfig(pluthon.CompilationConfig):
allow_isinstance_anything: Optional[bool] = None
force_three_params: Optional[bool] = None
remove_dead_code: Optional[bool] = None
fast_access_skip: Optional[int] = None


# The default configuration for the compiler
Expand All @@ -35,6 +36,7 @@ class CompilationConfig(pluthon.CompilationConfig):
.update(pluthon.OPT_O2_CONFIG)
.update(
constant_folding=True,
fast_access_skip=5,
)
)
OPT_O3_CONFIG = (
Expand Down Expand Up @@ -64,6 +66,10 @@ class CompilationConfig(pluthon.CompilationConfig):
"remove_dead_code": {
"help": "Removes dead code and variables from the contract. Should be enabled for non-debugging purposes.",
},
"fast_access_skip": {
"help": "How many steps to skip for fast list index access, default None means no steps are skipped (useful if long lists are common).",
"type": int,
},
}
)
for k in ARGPARSE_ARGS:
Expand Down
27 changes: 16 additions & 11 deletions opshin/type_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def attribute(self, attr: str) -> plt.AST:
return OLambda(
["self"],
transform_ext_params_map(attr_typ)(
plt.ConstantNthField(
plt.ConstantNthFieldFast(
OVar("self"),
pos,
),
Expand Down Expand Up @@ -632,7 +632,7 @@ def stringify(self, recursive: bool = False) -> plt.AST:
plt.Apply(
field_type.stringify(recursive=True),
transform_ext_params_map(field_type)(
plt.ConstantNthField(OVar("self"), pos)
plt.ConstantNthFieldFast(OVar("self"), pos)
),
),
map_fields,
Expand All @@ -643,7 +643,7 @@ def stringify(self, recursive: bool = False) -> plt.AST:
plt.Apply(
self.record.fields[0][1].stringify(recursive=True),
transform_ext_params_map(self.record.fields[0][1])(
plt.ConstantNthField(OVar("self"), pos)
plt.ConstantNthFieldFast(OVar("self"), pos)
),
),
map_fields,
Expand Down Expand Up @@ -751,8 +751,16 @@ def attribute(self, attr: str) -> plt.AST:
if not pos_constrs:
pos_decisor = plt.TraceError("Invalid constructor")
else:
pos_decisor = plt.Integer(pos_constrs[-1][0])
pos_decisor = plt.ConstantNthFieldFast(OVar("self"), pos_constrs[-1][0])
pos_constrs = pos_constrs[:-1]
# constr is not needed when there is only one position for all constructors
if not pos_constrs:
return OLambda(
["self"],
transform_ext_params_map(attr_typ)(
pos_decisor,
),
)
for pos, constrs in pos_constrs:
assert constrs, "Found empty constructors for a position"
constr_check = plt.EqualsInteger(
Expand All @@ -765,18 +773,15 @@ def attribute(self, attr: str) -> plt.AST:
)
pos_decisor = plt.Ite(
constr_check,
plt.Integer(pos),
plt.ConstantNthFieldFast(OVar("self"), pos),
pos_decisor,
)
return OLambda(
["self"],
transform_ext_params_map(attr_typ)(
plt.NthField(
OVar("self"),
OLet(
[("constr", plt.Constructor(OVar("self")))],
pos_decisor,
),
OLet(
[("constr", plt.Constructor(OVar("self")))],
pos_decisor,
),
),
)
Expand Down
Loading