Skip to content

Commit

Permalink
formmating
Browse files Browse the repository at this point in the history
  • Loading branch information
edwinnglabs committed Dec 25, 2023
1 parent a720484 commit cbd778e
Showing 1 changed file with 9 additions and 20 deletions.
29 changes: 9 additions & 20 deletions orbit/utils/stan.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pickle

# import pkg_resources
import importlib_resources

Expand All @@ -9,7 +10,7 @@

logger = get_logger("orbit")

# Old approach
# Old approach
# def set_compiled_stan_path(parent, child="stan_compiled"):
# """
# Set the path for compiled stan models.
Expand Down Expand Up @@ -37,8 +38,6 @@
# CompiledStanModelPath.PARENT,
# "{}/{}.pkl".format(CompiledStanModelPath.CHILD, stan_model_name),
# )

<<<<<<< HEAD
# # updated for py3
# os.makedirs(os.path.dirname(compiled_model), exist_ok=True)
# # compile if compiled file does not exist or stan source has changed (with later datestamp than compiled)
Expand All @@ -52,20 +51,6 @@
# )
# )
# sm = CmdStanModel(stan_file=source_model)
=======
# updated for py3
os.makedirs(os.path.dirname(compiled_model), exist_ok=True)
# compile if compiled file does not exist or stan source has changed (with later datestamp than compiled)
if not os.path.isfile(compiled_model) or os.path.getmtime(
compiled_model
) < os.path.getmtime(source_model):
logger.info(
"First time in running stan model:{}. Expect 3 - 5 minutes for compilation.".format(
stan_model_name
)
)
sm = CmdStanModel(stan_file=source_model)
>>>>>>> 7919a8c474dd4880539a9fddb54f9c1888b1704c

# with open(compiled_model, "wb") as f:
# pickle.dump(sm, f, protocol=pickle.HIGHEST_PROTOCOL)
Expand All @@ -77,13 +62,17 @@ def get_compiled_stan_model(stan_model_name):
"""
Load compiled Stan model
"""
# Old approach
# Old approach
# compiled_model = compile_stan_model(stan_model_name)
# with open(compiled_model, "rb") as f:
# return pickle.load(f)

# New approach
model_file = importlib_resources.files("orbit") / "stan_compiled" / "{}.bin".format(stan_model_name)
model_file = (
importlib_resources.files("orbit")
/ "stan_compiled"
/ "{}.bin".format(stan_model_name)
)
return CmdStanModel(exe_file=str(model_file))


Expand Down

0 comments on commit cbd778e

Please sign in to comment.