Skip to content

Commit

Permalink
Reintroduce rime caching by passing the rime spec hash in as the 1st …
Browse files Browse the repository at this point in the history
…literal arg
  • Loading branch information
sjperkins committed Jan 29, 2024
1 parent 4f36818 commit 730dbaa
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions africanus/experimental/rime/fused/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,26 @@


def rime_impl_factory(terms, transformers, ncorr):
def rime_impl(*args):
raise NotImplementedError

@njit(nogil=True)
@njit(**JIT_OPTIONS)
def rime(*args):
return rime_impl(*args)

def rime_impl(*args):
raise NotImplementedError

@overload(rime_impl, jit_options=JIT_OPTIONS)
def nb_rime(*args):
if not len(args) % 2 == 0:
raise TypeError(f"len(args) {len(args)} is not divisible by 2")
if not len(args) > 0:
raise TypeError(f"rime must be called with at least the signature argument")

if not isinstance(args[0], types.Literal):
raise TypeError(f"Signature hash ({args[0]}) must be a literal")

if not len(args) % 2 == 1:
raise TypeError(f"Length of named arguments {len(args)} is not divisible by 2")

argstart = len(args) // 2
names = args[:argstart]
argstart = 1 + (len(args) - 1) // 2
names = args[1:argstart]

if not all(isinstance(n, types.Literal) for n in names):
raise TypeError(f"{names} must be a Tuple of Literal strings")
Expand Down Expand Up @@ -199,7 +205,7 @@ def __call__(self, time, antenna1, antenna2, feed1, feed2, **kwargs):

args = keys + (time, antenna1, antenna2, feed1,
feed2) + tuple(kwargs.values())
return self.impl(*args)
return self.impl(types.literal(self.rime_spec.spec_hash), *args)


def consolidate_args(args, kw):
Expand Down

0 comments on commit 730dbaa

Please sign in to comment.