Skip to content

Commit

Permalink
Integrate NNLS reg CV into code structure #44
Browse files Browse the repository at this point in the history
  • Loading branch information
JoJas102 committed Dec 21, 2023
1 parent 7ec6064 commit 4c7b26f
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions src/fit/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import math
from scipy import signal
from scipy.sparse import diags
from scipy.linalg import norm
from functools import partial
from typing import Callable
import json
Expand Down Expand Up @@ -595,13 +596,13 @@ def __init__(
self.tol = tol
self.reg_order: reg_order
self.mu = None
self.fit_function = Model.NNLSregCV.fit
# self.fit_function = Model.NNLSregCV.fit

# @staticmethod
# def _get_G(basis_CV, H, In, mu, signal_CV):
# """Determining lambda function G with cross-validation method."""
# fit = NNLS_reg_fit(basis_CV, H, mu, signal_CV)
# # fit = model.NNLS.fit(basis_CV, H, mu, signal_CV)
# fit = Model.NNLS.fit(1, signal_CV, np.concatenate((basis_CV, mu * H)))
# # fit = NNLS_reg_fit(basis, H, mu, signal)
#
# # Calculating G with CrossValidation method
# G = (
Expand Down Expand Up @@ -636,23 +637,23 @@ def __init__(
# )
#
# # Identity matrix
# In = np.identity(len(signal)) # = 16 = len(b_values)
# In = np.identity(len(self.b_values))
#
# Lambda_left = 0.00001
# Lambda_right = 8
# midpoint = (Lambda_right + Lambda_left) / 2
#
# # Function (+ delta) and derivative f at left point
# G_left = get_G(basis, H, In, Lambda_left, signal)
# G_leftDiff = get_G(basis, H, In, Lambda_left + tol, signal)
# G_left = self._get_G(basis, H, In, Lambda_left, signal)
# G_leftDiff = self._get_G(basis, H, In, Lambda_left + tol, signal)
# f_left = (G_leftDiff - G_left) / tol
#
# count = 0
# while abs(Lambda_right - Lambda_left) > tol:
# midpoint = (Lambda_right + Lambda_left) / 2
# # Function (+ delta) and derivative f at middle point
# G_middle = get_G(basis, H, In, midpoint, signal)
# G_middleDiff = get_G(basis, H, In, midpoint + tol, signal)
# G_middle = self._get_G(basis, H, In, midpoint, signal)
# G_middleDiff = self._get_G(basis, H, In, midpoint + tol, signal)
# f_middle = (G_middleDiff - G_middle) / tol
#
# if count > 100:
Expand Down

0 comments on commit 4c7b26f

Please sign in to comment.