diff --git a/dmoj/cptbox/_cptbox.pyi b/dmoj/cptbox/_cptbox.pyi index 7a52be6bf..d7124f890 100644 --- a/dmoj/cptbox/_cptbox.pyi +++ b/dmoj/cptbox/_cptbox.pyi @@ -102,3 +102,6 @@ bsd_get_proc_fdno: Callable[[int, int], str] memory_fd_create: Callable[[], int] memory_fd_seal: Callable[[int], None] + +class BufferProxy: + def _get_real_buffer(self): ... diff --git a/dmoj/cptbox/_cptbox.pyx b/dmoj/cptbox/_cptbox.pyx index 3f296ea2c..03c9937d6 100644 --- a/dmoj/cptbox/_cptbox.pyx +++ b/dmoj/cptbox/_cptbox.pyx @@ -1,5 +1,6 @@ # cython: language_level=3 from cpython.exc cimport PyErr_NoMemory, PyErr_SetFromErrno +from cpython.buffer cimport PyObject_GetBuffer from cpython.bytes cimport PyBytes_AsString, PyBytes_FromStringAndSize from libc.stdio cimport FILE, fopen, fclose, fgets, sprintf from libc.stdlib cimport malloc, free, strtoul @@ -600,3 +601,11 @@ cdef class Process: if not self._exited: return None return self._exitcode + + +cdef class BufferProxy: + def _get_real_buffer(self): + raise NotImplementedError + + def __getbuffer__(self, Py_buffer *buffer, int flags): + PyObject_GetBuffer(self._get_real_buffer(), buffer, flags) diff --git a/dmoj/cptbox/lazy_bytes.py b/dmoj/cptbox/lazy_bytes.py new file mode 100644 index 000000000..b6b3cd8f7 --- /dev/null +++ b/dmoj/cptbox/lazy_bytes.py @@ -0,0 +1,88 @@ +# Based off https://github.com/django/django/blob/main/django/utils/functional.py, licensed under 3-clause BSD. +from functools import total_ordering + +from dmoj.cptbox._cptbox import BufferProxy + +_SENTINEL = object() + + +@total_ordering +class LazyBytes(BufferProxy): + """ + Encapsulate a function call and act as a proxy for methods that are + called on the result of that function. The function is not evaluated + until one of the methods on the result is called. + """ + + def __init__(self, func): + self.__func = func + self.__value = _SENTINEL + + def __get_value(self): + if self.__value is _SENTINEL: + self.__value = self.__func() + return self.__value + + @classmethod + def _create_promise(cls, method_name): + # Builds a wrapper around some magic method + def wrapper(self, *args, **kw): + # Automatically triggers the evaluation of a lazy value and + # applies the given magic method of the result type. + res = self.__get_value() + return getattr(res, method_name)(*args, **kw) + + return wrapper + + def __cast(self): + return bytes(self.__get_value()) + + def _get_real_buffer(self): + return self.__cast() + + def __bytes__(self): + return self.__cast() + + def __repr__(self): + return repr(self.__cast()) + + def __str__(self): + return str(self.__cast()) + + def __eq__(self, other): + if isinstance(other, LazyBytes): + other = other.__cast() + return self.__cast() == other + + def __lt__(self, other): + if isinstance(other, LazyBytes): + other = other.__cast() + return self.__cast() < other + + def __hash__(self): + return hash(self.__cast()) + + def __mod__(self, rhs): + return self.__cast() % rhs + + def __add__(self, other): + return self.__cast() + other + + def __radd__(self, other): + return other + self.__cast() + + def __deepcopy__(self, memo): + # Instances of this class are effectively immutable. It's just a + # collection of functions. So we don't need to do anything + # complicated for copying. + memo[id(self)] = self + return self + + +for type_ in bytes.mro(): + for method_name in type_.__dict__: + # All __promise__ return the same wrapper method, they + # look up the correct implementation when called. + if hasattr(LazyBytes, method_name): + continue + setattr(LazyBytes, method_name, LazyBytes._create_promise(method_name)) diff --git a/dmoj/graders/standard.py b/dmoj/graders/standard.py index 3729f3400..013418b15 100644 --- a/dmoj/graders/standard.py +++ b/dmoj/graders/standard.py @@ -3,6 +3,7 @@ from dmoj.checkers import CheckerOutput from dmoj.cptbox import TracedPopen +from dmoj.cptbox.lazy_bytes import LazyBytes from dmoj.error import OutputLimitExceeded from dmoj.executors import executors from dmoj.executors.base_executor import BaseExecutor @@ -60,7 +61,7 @@ def check_result(self, case: TestCase, result: Result) -> CheckerOutput: result.proc_output, case.output_data(), submission_source=self.source, - judge_input=case.input_data(), + judge_input=LazyBytes(case.input_data), point_value=case.points, case_position=case.position, batch=case.batch,