Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong predictions from PINN for SEIR Model #345

Open
verobianca opened this issue Sep 13, 2024 · 4 comments
Open

Wrong predictions from PINN for SEIR Model #345

verobianca opened this issue Sep 13, 2024 · 4 comments
Labels
help wanted Extra attention is needed

Comments

@verobianca
Copy link

The objective
Hello,
I've been trying to use the library for the last couple of days but I'm encountering difficulties with my problem.
I'm trying to model pandemic dynamics through a SEIR model (susceptible, exposed, infected, recovered). In my problem, I have 4 outputs, 4 main equations that are interdependent, and the initial conditions. While the PINN successfully minimizes the loss function, the resulting predictions are wrong. I'm not sure if the problem is related to the hyperparameters since the loss is minimized, but I've tried different hyperparameter combinations without improvements. I also noticed that even in the simple ODE example in tutorial 1, as soon as you increase the domain interval it gets really hard to get correct results. In my problem, I'm using a temporal domain interval of [0, 200]. First of all, I wanted to ask if I am using the API correctly. I would appreciate any suggestions that could help resolve these issues. Thank you.

Already tried tests
This is my code. You can see from the function plot_real_results() the pandemic dynamics I'm trying to replicate with the PINN.

import torch
from pina.problem import TimeDependentProblem
from pina.operators import grad
from pina import Condition
from pina.geometry import CartesianDomain
from pina.equation import SystemEquation
import argparse
from torch.nn import Softplus
import matplotlib.pyplot as plt
from pina import Plotter, Trainer
from pina.model import FeedForward
from pina.solvers import PINN
from pina.callbacks import MetricTracker

class SEIR(TimeDependentProblem):

    # assign output/ temporal variables
    output_variables = ['S', 'E', 'I', 'R']
    temporal_domain = CartesianDomain({'t': [0, 200]})

    # define the SEIR equations
    def susceptible(input_, output_):
        S_t = grad(output_.extract('S'), input_, components=['S'], d=['t'])
        return S_t - (- beta * output_.extract(['S']) * output_.extract(['I'])/N)

    def exposed(input_, output_):
        E_t = grad(output_.extract('E'), input_, components=['E'], d=['t'])
        return E_t - (beta * output_.extract(['S']) * output_.extract(['I'])/N - sigma * output_.extract(['E']))

    def infected(input_, output_):
        I_t = grad(output_.extract('I'), input_, components=['I'], d=['t'])
        return I_t - (sigma * output_.extract(['E']) - gamma * output_.extract(['I']))

    def recovered(input_, output_):
        R_t = grad(output_.extract('R'), input_, components=['R'], d=['t'])
        return R_t - (gamma * output_.extract(['I']))

    # define initial conditions
    def initial_S(input_, output_):
        return output_.extract(['S']) - S0

    def initial_E(input_, output_):
        return output_.extract(['E']) - E0

    def initial_I(input_, output_):
        return output_.extract(['I']) - I0

    def initial_R(input_, output_):
        return output_.extract(['R']) - R0

    # problem condition statement
    conditions = {
        'initial': Condition(location=CartesianDomain({'t': 0}), equation=SystemEquation([initial_S, initial_E, initial_I, initial_R])),
        'domain': Condition(location=CartesianDomain({'t': [0, 200]}), equation=SystemEquation([susceptible, exposed, infected, recovered]))
    }


def plot_real_results(timesteps, S0, E0, I0, R0, beta, sigma, gamma):
    S = torch.tensor(S0)  # Convert initial S to a tensor
    E = torch.tensor(E0)
    I = torch.tensor(I0)
    R = torch.tensor(R0)
    N = S + E + I + R
    tt = torch.empty(timesteps)
    SEIR = torch.empty(timesteps, 4)
    compartment_values = torch.stack([S, E, I, R])
    SEIR[0] = compartment_values
    tt[0] = torch.tensor(0.)

    for t in range(1, timesteps):

        nS = S - beta * S * I/N
        nE = E + beta * S * I/N - sigma * E
        nI = I + sigma * E - gamma * I
        nR = R + gamma * I

        S = nS
        E = nE
        I = nI
        R = nR

        compartment_values = torch.stack([S, E, I, R])

        SEIR[t] = compartment_values
        tt[t] = torch.tensor(float(t))

    S = SEIR[:, 0]
    E = SEIR[:, 1]
    I = SEIR[:, 2]
    R = SEIR[:, 3]
    # Plot the results
    plt.figure(figsize=(10, 6))
    plt.plot(tt, S, label='Susceptible')
    plt.plot(tt, E, label='Exposed')
    plt.plot(tt, I, label='Infected')
    plt.plot(tt, R, label='Recovered')

    plt.xlabel('Time')
    plt.ylabel('Population')
    plt.title('SEIR Model')
    plt.legend()
    plt.grid(True)
    plt.show()



if __name__ == "__main__":
    N = 1001.
    S0 = 1000.  # Initial susceptible population
    E0 = 0.  # Initial exposed population
    I0 = 1.  # Initial infected population
    R0 = 0. #Initial recovered population
    beta = 0.2  # Infection rate
    sigma = 0.5  # Incubation rate
    gamma = 0.1  # Recovery rate

    plot_real_results(200, S0, E0, I0, R0, beta, sigma, gamma)

    parser = argparse.ArgumentParser(description="Run PINA")
    parser.add_argument("--load", help="directory to save or load file", type=str)
    parser.add_argument("--epochs", help="extra features", type=int, default=2000)
    args = parser.parse_args()

    # create problem and discretise domain

    seir_problem = SEIR()
    seir_problem.discretise_domain(10, 'grid', locations=['initial'])  # Increase initial condition points
    seir_problem.discretise_domain(2000, 'lh', locations=['domain'])  # Increase domain points

    # make the model
    model = FeedForward(
        layers=[10, 10, 10],
        output_dimensions=len(seir_problem.output_variables),
        input_dimensions=len(seir_problem.input_variables),
        func=Softplus,
    )


    # make the pinn
    pinn = PINN(
        seir_problem,
        model,
        optimizer_kwargs={'lr': 0.001}
    )

    # create trainer
    directory = 'pina.seir'
    trainer = Trainer(solver=pinn, callbacks=[MetricTracker()], accelerator='cpu', max_epochs=args.epochs, default_root_dir=directory)


    trainer.train()
    # inspecting final loss
    print(trainer.logged_metrics)
    # plotting the solution
    plotter = Plotter()
    plotter.plot_loss(trainer=trainer, label='mean_loss', logy=True)

    plotter.plot(solver=pinn, components='S')
    plotter.plot(solver=pinn, components='E')
    plotter.plot(solver=pinn, components='I')
    plotter.plot(solver=pinn, components='R')

@verobianca verobianca added the help wanted Extra attention is needed label Sep 13, 2024
@dario-coscia
Copy link
Collaborator

Hi 👋🏻 @verobianca, thanks for the PR! Your code looks good, and your use of the API is perfect👍🏻 I don't think there is any problem with the software.

The problem you are facing I think is due to the complexity of the equation you are trying to solve, and specific ad hoc considerations must be taken. For example, the neural net you are using is currently outputting 4 variables whose ranges are really different (and this is usually hard for a single network). Maybe Fourier Random Features or MultiFeedForward approaches could help... Here there are some nice tips on how to train PINNs for complex equations!

@verobianca
Copy link
Author

Hi @dario-coscia, thanks for the reply.
I will have a look at you suggestions :)

@dario-coscia
Copy link
Collaborator

Hi @verobianca ! Did you manage to solve your problem? I think it would greatly benefit the community (maybe if interested we can also make a tutorial on it😁)

@verobianca
Copy link
Author

Hi @dario-coscia , I was on holiday for almost a month and went back to work yesterday 😅 I will start working on it again and let you know if I solve the problem😊

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants