From f0c14b1fca4fc4a4646d60fca3f397bc8ea3432b Mon Sep 17 00:00:00 2001 From: XuHao Date: Thu, 13 Jul 2023 08:57:02 +0800 Subject: [PATCH] fix flat_model in Scube --- .gitignore | 1 + SPACEL/Scube/gpr.py | 17 +++++------------ 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index 462453e..f1de8a2 100644 --- a/.gitignore +++ b/.gitignore @@ -29,6 +29,7 @@ __pycache__/ # Distribution / packaging .Python build/ +generated/ develop-eggs/ dist/ downloads/ diff --git a/SPACEL/Scube/gpr.py b/SPACEL/Scube/gpr.py index 4f8b169..c480a16 100644 --- a/SPACEL/Scube/gpr.py +++ b/SPACEL/Scube/gpr.py @@ -109,8 +109,8 @@ def prepare_gpr_model(self, lengthscale_prior=None,outputscale_prior=None,noise_ noise_prior=noise_prior ) self.flat_model = ExactGPModel(self.train_x, self.train_y, likelihood, lengthscale_prior=lengthscale_prior,outputscale_prior=outputscale_prior) - # self.init_model(self.flat_model,lengthscale=torch.tensor(99999999)) - self.init_model(self.flat_model,lengthscale=torch.inf) + # self.init_model(self.flat_model,lengthscale=torch.tensor(99999)) + self.init_model(self.flat_model,lengthscale=torch.tensor(1000)) else: if self.model is None: likelihood = gpytorch.likelihoods.GaussianLikelihood( @@ -148,8 +148,8 @@ def train_single_model(self,model,lr=1,training_iter=500,save=False,save_path=No # "Loss" for GPs - the marginal log likelihood mll = gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model) - best_loss = None - best_model_state = None + best_loss = np.inf + best_model_state = model.state_dict() if optimize_method=='Adam': optimizer = torch.optim.Adam(model.parameters(),lr=lr) for i in range(training_iter): @@ -158,10 +158,7 @@ def train_single_model(self,model,lr=1,training_iter=500,save=False,save_path=No loss = -mll(output, self.train_y) loss.backward() optimizer.step() - if best_loss is None: - best_loss = loss.item() - if best_model_state is None: - best_model_state = model.state_dict() + print(loss.item()) if loss.item() < best_loss: best_model_state = model.state_dict() best_loss = loss.item() @@ -177,10 +174,6 @@ def closure(): return loss for i in range(training_iter): loss = optimizer.step(closure) - if best_loss is None: - best_loss = loss.item() - if best_model_state is None: - best_model_state = model.state_dict() if loss.item() < best_loss: best_model_state = model.state_dict() best_loss = loss.item()