Skip to content

Commit

Permalink
Use LazyBytes to avoid converting memfd to bytes unless needed
Browse files Browse the repository at this point in the history
  • Loading branch information
quantum5 committed Dec 28, 2024
1 parent 6144b60 commit c88ef56
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 1 deletion.
3 changes: 3 additions & 0 deletions dmoj/cptbox/_cptbox.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,6 @@ bsd_get_proc_fdno: Callable[[int, int], str]

memory_fd_create: Callable[[], int]
memory_fd_seal: Callable[[int], None]

class BufferProxy:
def _get_real_buffer(self): ...
9 changes: 9 additions & 0 deletions dmoj/cptbox/_cptbox.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# cython: language_level=3
from cpython.exc cimport PyErr_NoMemory, PyErr_SetFromErrno
from cpython.buffer cimport PyObject_GetBuffer
from cpython.bytes cimport PyBytes_AsString, PyBytes_FromStringAndSize
from libc.stdio cimport FILE, fopen, fclose, fgets, sprintf
from libc.stdlib cimport malloc, free, strtoul
Expand Down Expand Up @@ -600,3 +601,11 @@ cdef class Process:
if not self._exited:
return None
return self._exitcode


cdef class BufferProxy:
def _get_real_buffer(self):
raise NotImplementedError

def __getbuffer__(self, Py_buffer *buffer, int flags):
PyObject_GetBuffer(self._get_real_buffer(), buffer, flags)
88 changes: 88 additions & 0 deletions dmoj/cptbox/lazy_bytes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Based off https://github.com/django/django/blob/main/django/utils/functional.py, licensed under 3-clause BSD.
from functools import total_ordering

from dmoj.cptbox._cptbox import BufferProxy

_SENTINEL = object()


@total_ordering
class LazyBytes(BufferProxy):
"""
Encapsulate a function call and act as a proxy for methods that are
called on the result of that function. The function is not evaluated
until one of the methods on the result is called.
"""

def __init__(self, func):
self.__func = func
self.__value = _SENTINEL

def __get_value(self):
if self.__value is _SENTINEL:
self.__value = self.__func()
return self.__value

@classmethod
def _create_promise(cls, method_name):
# Builds a wrapper around some magic method
def wrapper(self, *args, **kw):
# Automatically triggers the evaluation of a lazy value and
# applies the given magic method of the result type.
res = self.__get_value()
return getattr(res, method_name)(*args, **kw)

return wrapper

def __cast(self):
return bytes(self.__get_value())

def _get_real_buffer(self):
return self.__cast()

def __bytes__(self):
return self.__cast()

def __repr__(self):
return repr(self.__cast())

def __str__(self):
return str(self.__cast())

def __eq__(self, other):
if isinstance(other, LazyBytes):
other = other.__cast()
return self.__cast() == other

def __lt__(self, other):
if isinstance(other, LazyBytes):
other = other.__cast()
return self.__cast() < other

def __hash__(self):
return hash(self.__cast())

def __mod__(self, rhs):
return self.__cast() % rhs

def __add__(self, other):
return self.__cast() + other

def __radd__(self, other):
return other + self.__cast()

def __deepcopy__(self, memo):
# Instances of this class are effectively immutable. It's just a
# collection of functions. So we don't need to do anything
# complicated for copying.
memo[id(self)] = self
return self


for type_ in bytes.mro():
for method_name in type_.__dict__:
# All __promise__ return the same wrapper method, they
# look up the correct implementation when called.
if hasattr(LazyBytes, method_name):
continue
setattr(LazyBytes, method_name, LazyBytes._create_promise(method_name))
3 changes: 2 additions & 1 deletion dmoj/graders/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from dmoj.checkers import CheckerOutput
from dmoj.cptbox import TracedPopen
from dmoj.cptbox.lazy_bytes import LazyBytes
from dmoj.error import OutputLimitExceeded
from dmoj.executors import executors
from dmoj.executors.base_executor import BaseExecutor
Expand Down Expand Up @@ -60,7 +61,7 @@ def check_result(self, case: TestCase, result: Result) -> CheckerOutput:
result.proc_output,
case.output_data(),
submission_source=self.source,
judge_input=case.input_data(),
judge_input=LazyBytes(case.input_data),
point_value=case.points,
case_position=case.position,
batch=case.batch,
Expand Down

0 comments on commit c88ef56

Please sign in to comment.