diff --git a/mushi/ksfs.py b/mushi/ksfs.py index e4557f7..b1181e2 100644 --- a/mushi/ksfs.py +++ b/mushi/ksfs.py @@ -3,10 +3,9 @@ import mushi.optimization as opt import mushi.composition as cmp -from jax.config import config import numpy as onp import jax.numpy as np -from jax import jit, grad +from jax import jit, grad, config from jax.scipy.special import expit, logit from scipy.stats import poisson from typing import Union, List, Dict, Tuple