From b7955f030778a29f79c092c1dbd1677c4edf1ad5 Mon Sep 17 00:00:00 2001 From: Zeyun Date: Mon, 8 Jul 2024 02:07:35 -0700 Subject: [PATCH] enhance codes --- sushie/infer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sushie/infer.py b/sushie/infer.py index 3d9ae19..349074b 100644 --- a/sushie/infer.py +++ b/sushie/infer.py @@ -469,11 +469,11 @@ def infer_sushie( ) elbo_last = elbo_tracker[o_iter] elbo_tracker = jnp.append(elbo_tracker, elbo_cur) - elbo_increase = not ( - elbo_cur < elbo_last and (not jnp.isclose(elbo_cur, elbo_last, atol=1e-8)) + elbo_increase = elbo_cur >= elbo_last or jnp.isclose( + elbo_cur, elbo_last, atol=1e-8 ) - if (not elbo_increase) or jnp.isnan(elbo_cur): + if not elbo_increase: log.logger.warning( f"Optimization concludes after {o_iter + 1} iterations." + f" ELBO decreases. Final ELBO score: {elbo_cur}. Return last iteration's results."