# Python 3 version of C++ microsimulation model, model_example.cc.
#
# This is free software published under the GNU General Public License
# version 3.

# You need to install numpy. Everything else is from the Python
# standard library.
#
# This code differs materially from the C++ version as follows: the number of
# agents is far smaller else it simply takes too long to run.

import numpy as np
import math
import random
from enum import IntEnum

YEAR = 365.25
DAY = 1.0 / YEAR

# Random number generator
rng = np.random.default_rng()

# This is much more of a song and dance using numpy than the equivalent C++
# standard library code. It is used to initialize the ages of the
# microsimulation model's agents. The proportion of males and females for each
# age up to 90 are drawn from the following distribution. The data is from the
# Thembisa model.
# An improved version of this program would read this in from a file.
MALE_DIST = [
    556199, 550200, 554289, 554555, 552631, 550195, 547156, 560931,
    563200, 563224, 559176, 553679, 546729, 543589, 549932, 546043,
    531869, 516847, 502187, 469603, 457234, 458782, 458195, 466585,
    466894, 478880, 499952, 494485, 499192, 503553, 509185, 514547,
    518779, 522765, 524642, 525344, 524028, 521013, 504604, 490965,
    470985, 446012, 418694, 392725, 370384, 353289, 340052, 328700,
    316930, 304013, 289265, 273724, 258566, 245052, 233457, 224203,
    216432, 208929, 200901, 192589, 183617, 174159, 164848, 155780,
    146472, 136949, 127381, 117876, 108654, 99880,  91671,  83926,
    76578,  69446,  62413,  55426,  48685,  42274,  36528,  31740,
    27981,  24914,  22216,  19608,  16967,  14268,  11689,  9404,
    7509,   5910,   17220]
FEMALE_DIST = [
    546583, 540070, 542761, 543436, 541679, 539135, 536397, 550108,
    552589, 552977, 549468, 544628, 538484, 536059, 542961, 539811,
    526372, 511457, 497112, 465517, 453715, 454499, 454057, 462111,
    462002, 473948, 495224, 490080, 496446, 501890, 507712, 512689,
    516280, 519075, 519740, 520070, 519042, 514800, 498903, 485707,
    466135, 441880, 416150, 392858, 374232, 361659, 353606, 347796,
    341605, 333841, 323561, 311601, 299187, 288039, 278985, 272570,
    267692, 262837, 256621, 248861, 239132, 228056, 216780, 205881,
    195007, 184279, 173677, 163111, 152615, 142337, 132336, 122614,
    113283, 104226, 95242,  86302,  77627,  69138,  61368,  54957,
    50070,  46156,  42632,  38899,  34836,  30361,  25828,  21628,
    17972,  14930,  51609]

SUM_MALE_AGE = sum(MALE_DIST)
SUM_FEMALE_AGE = sum(FEMALE_DIST)

AGES_PROB = [
    [n/SUM_MALE_AGE for n in MALE_DIST],
    [n/SUM_FEMALE_AGE for n in FEMALE_DIST]
]

AGES = [
    [i for i in range(len(MALE_DIST))],
    [i for i in range(len(FEMALE_DIST))]
]


# These death rates per sex and age (up to 90) are taken from the
# Thembisa model. An improvement would be to read this in from a file.
DEATH_RISK = [
    # Males
    [
        0.000076520889, 0.000017192652, 0.000007240219, 0.000005732737,
        0.000004441361, 0.000003385837, 0.000002569825, 0.000001990536,
        0.000001632355, 0.000001482819, 0.000001580916, 0.000001897548,
        0.000002208913, 0.000002665925, 0.000003273968, 0.000003996761,
        0.000004820406, 0.000005777845, 0.000006868745, 0.000008052928,
        0.000009304271, 0.000010546779, 0.000011701298, 0.000012734227,
        0.000013650988, 0.000014490752, 0.000015306001, 0.000016174900,
        0.000017115953, 0.000018095632, 0.000019070014, 0.000019976372,
        0.000020767143, 0.000021415938, 0.000021956049, 0.000022480765,
        0.000023107817, 0.000023928708, 0.000024999344, 0.000026292152,
        0.000027738385, 0.000029284192, 0.000030868925, 0.000032474390,
        0.000034122993, 0.000035866000, 0.000037772632, 0.000039906771,
        0.000042305628, 0.000044983861, 0.000047950067, 0.000051204119,
        0.000054734790, 0.000058521982, 0.000062525288, 0.000066680159,
        0.000070914398, 0.000075190600, 0.000079522922, 0.000083954371,
        0.000088526342, 0.000093290522, 0.000098309032, 0.000103656921,
        0.000109436436, 0.000115770302, 0.000122776669, 0.000130516418,
        0.000138964572, 0.000148092754, 0.000157955894, 0.000168657991,
        0.000180337069],
    # Females
    [
        0.000084085544, 0.000019234231, 0.000008680390, 0.000006692589,
        0.000005033958, 0.000003712643, 0.000002730538, 0.000002065006,
        0.000001679530, 0.000001535762, 0.000001592729, 0.000001838024,
        0.000002061974, 0.000002351475, 0.000002705738, 0.000003106961,
        0.000003547194, 0.000004091870, 0.000004790889, 0.000005609952,
        0.000006501563, 0.000007371587, 0.000008127697, 0.000008734919,
        0.000009208877, 0.000009592531, 0.000009941267, 0.000010321245,
        0.000010761386, 0.000011251573, 0.000011774332, 0.000012300970,
        0.000012807805, 0.000013283732, 0.000013750722, 0.000014257017,
        0.000014854579, 0.000015589684, 0.000016487504, 0.000017539826,
        0.000018716045, 0.000019980015, 0.000021295672, 0.000022643855,
        0.000024027336, 0.000025468550, 0.000027002242, 0.000028667599,
        0.000030499362, 0.000032510646, 0.000034683482, 0.000036983990,
        0.000039380543, 0.000041858866, 0.000044438940, 0.000047186143,
        0.000050185995, 0.000053508553, 0.000057185355, 0.000061208012,
        0.000065538377, 0.000070127481, 0.000074960964, 0.000080083792,
        0.000085535735, 0.000091291510, 0.000097296074, 0.000103503214,
        0.000109895484, 0.000116577835, 0.000123831829, 0.000132015621,
        0.000141420706
    ]
]


# Every agent will have a sex because the death rates between males and
# females will be different. Vaccine uptake will also differ.


class Sex(IntEnum):
    MALE = 0
    FEMALE = 1


# These are the states an agent can be in. In macro models, we'd call these
# compartments.

class State(IntEnum):
    SUSCEPTIBLE = 0
    EXPOSED = 1
    INFECTIOUS = 2
    RECOVERED = 3
    VACCINATED = 4
    DEAD = 5
    COUNT = 6


# An array of instances of the following Agent class is kept in the Model
# class.
class Agent:
    __id__ = 0

    def __init__(self, state=State.SUSCEPTIBLE, age=None):
        self.id = Agent.__id__
        Agent.__id__ = Agent.__id__ + 1
        if random.random() < 0.5:
            self.sex = Sex.MALE
        else:
            self.sex = Sex.FEMALE
        if age is None:
            self.age = rng.choice(AGES[self.sex], p=AGES_PROB[self.sex])
        else:
            self.age = age
        self.state = state


# This is the struct at the heart of our microsimulation.
class Model:
    def __init__(self, parameters, before_events, during_events, after_events):
        self.parameters = parameters
        self.before_events = before_events
        self.during_events = during_events
        self.after_events = after_events
        self.agents = []

        # We'll need to keep track of time steps for when we print out stats
        self.current_time_step = 0

        # This is used to track how many agents are in each state. It is
        # updated by the function event_tally_stats.
        self.state_counter = {}

        # Track the number of births that have to be given on a particular
        # time_step.
        self.birth_tracker = 0.0

        # Track the number of deaths while infectious.
        self.deaths_while_infectious = 0

    # This is the method that runs our model. It steps through the before,
    # during and after events respectively, passing this model as a parameter
    # to each event function, so that the event function can update the
    # model.

    def run(self):
        for event in self.before_events:
            event(self)
        time_steps = self.parameters['time_steps']
        for i in range(time_steps):
            self.current_time_step += 1
            for event in self.during_events:
                event(self)
        for event in self.after_events:
            event(self)


# This event creates and initializes the agents.
def event_initialize_agents(model):
    # This determines the number of susceptible agents at the beginning of
    # the simulation.
    num_susceptible = model.parameters['num_susceptible']
    # This determines the number of exposed agents at the beginning of the
    # simulation.
    num_exposed = model.parameters['num_exposed']

    # Now we loop through the number of susceptible and exposed, creating an
    # agent, appropriately initialized, on each iteration.
    for i in range(num_susceptible + num_exposed):
        if i < num_susceptible:
            state = State.SUSCEPTIBLE
        else:
            state = State.EXPOSED
        model.agents.append(Agent(state))


# We'll need to randomize the order of the agents at the beginning of each
# time step to avoid bias.
def event_shuffle_agents(model):
    random.shuffle(model.agents)


# This event increments the living agents' ages by a day.
def event_increment_age(model):
    for agent in model.agents:
        if agent.state != State.DEAD:
            agent.age += DAY


# This is our event infection algorithm. It works like this: Each
# susceptible agent, A, is randomly placed in contact with n other
# contacts, where n is a randomly drawn integer from a normal distribution
# with mean num_contacts_avg and standard deviation num_contacs_stdev. Then
# each infectious agent, B, that A comes into contact with, will infect A
# with probability risk_exposure_per_contact.
def event_infect(model):
    num_contacts = model.parameters['num_contacts_avg']
    stdev = model.parameters['num_contacts_stdev']
    risk_exposure = model.parameters['risk_exposure_per_contact']
    # We need the indices of all the living agents. They are the potential
    # contacts.
    alive_indices = [i for i in range(len(model.agents))
                     if model.agents[i] != State.DEAD]

    # We only going to iterate over living agents
    for i in alive_indices:
        if model.agents[i].state == State.SUSCEPTIBLE:
            # This seemingly complicated line of code chooses a random number
            # of contacts and makes sure that it's a least 0 and at most the
            # number of elements in the alive_indices array.
            num_contacts = int(max(0, min(len(alive_indices),
                                          rng.normal(num_contacts, stdev))))
            # Susceptible agents are potentially exposed to num_contact agents.
            # To keep things simple, the algorithm can select the same agent
            # more than once as a contact and the agent itself might be its own
            # contact.  But for our purposes this shortcoming isn't important;
            # it only reduces the number of actual contacts.
            for j in range(num_contacts):
                contact_index = rng.integers(0, len(alive_indices))
                if model.agents[alive_indices[contact_index]].state == \
                   State.INFECTIOUS:
                    if rng.uniform() < risk_exposure:
                        model.agents[i].state = State.EXPOSED
                        break


# This is used by events that transition an agent from one state to another
# with given probability.
def change_agent_states(model, from_state, to_state, parameter):
    risk = model.parameters[parameter]
    for agent in model.agents:
        if agent.state == from_state:
            if rng.uniform() < risk:
                agent.state = to_state


# This moves agents in the exposed state to the infectious state with
# probabilty risk_exposed_infectious.
def event_exposed_to_infectious(model):
    change_agent_states(model, State.EXPOSED, State.INFECTIOUS,
                        "risk_exposed_infectious")


# This moves agents in the infectious state to the recovered state with
# probabilty risk_infectious_exposed.
def event_infectious_to_recovered(model):
    change_agent_states(model, State.INFECTIOUS, State.RECOVERED,
                        "risk_infectious_recovered")


# This moves agents in the recovered state to the susceptible state with
# probabilty risk_recovered_infectious.
def event_recovered_to_susceptible(model):
    change_agent_states(model, State.RECOVERED, State.SUSCEPTIBLE,
                        "risk_recovered_susceptible")


# This moves agents in the susceptible state to the vaccinated state with
# probabilty risk_susceptible_vaccinated.
def event_susceptible_to_vaccinated(model):
    change_agent_states(model, State.SUSCEPTIBLE, State.VACCINATED,
                        "risk_susceptible_vaccinated")


# This moves agents in the vaccinated state to the susceptible state with
# probabilty risk_vaccinated_susceptible.
def event_vaccinated_to_susceptible(model):
    change_agent_states(model, State.VACCINATED, State.SUSCEPTIBLE,
                        "risk_vaccinated_susceptible")


# This event adds new agents with age 0 to the population based on given
# birth_rate. If the default time_step is small, say a day, and the
# population is also small, the number of births per day may be less than
# zero but this is a discrete model and in that case there will never be
# births. So on each time step we accumulate the number of births until
# greater than one (and then add one or more agents to the population),
# then subtract the number of births given from the accumulated number (so
# that a fraction between 0 and 1 remains).
def event_births(model):
    birth_rate = model.parameters['birth_rate']
    model.birth_tracker += \
        birth_rate * (len(model.agents) - model.state_counter[State.DEAD])
    for i in range(int(model.birth_tracker)):
        model.agents.append(Agent(State.SUSCEPTIBLE, 0.0))

    if model.birth_tracker > 0:
        model.birth_tracker -= math.floor(model.birth_tracker)


# This event moves agents into the death stage, after which they should not
# be updated by any other events. An alternative way of doing this would be
# to have a second vector that stores dead agents. This would have the
# advantage the other events not continuously traversing over dead agents.
# Then you could efficiently move a dead agent out of the vector of living
# agents by swapping the dead agent with the living agent at the end of the
# vector, then copying it into the vector of dead agents, then reducing the
# size of the vector of living agents by one. But we've gone for a simpler
# solution here, which for our purposes is efficient enough.
def event_death(model):
    for agent in model.agents:
        if agent.state != State.DEAD:
            age_index = int(agent.age)

            # For agents who are older than the maximum age catered for in our
            # mortality risk array, we simply use the last entry in the array.
            if age_index >= len(DEATH_RISK[agent.sex]):
                age_index = len(DEATH_RISK[agent.sex]) - 1
            risk = DEATH_RISK[agent.sex][age_index]

            # If an agent is in infectious stage we multiply their mortality by
            # infectious_mortality_factor. A more sophisticated algorithm might
            # have a separate set of mortality risks for infectious agents.
            if agent.state == State.INFECTIOUS:
                risk *= model.parameters['infectious_mortality_factor']
            if rng.random() < risk:
                if agent.state == State.INFECTIOUS:
                    model.deaths_while_infectious += 1
                agent.state = State.DEAD


# Sorts the agents back into order by id. This is simply so that when we print
# out the agents at the end, they are all in order instead of shuffled. Not
# essential, because we could do this easily in our environment in which we
# analyse the data but since this is only executed once, it is quick.
def event_sort_agents(model):
    sorted(model.agents, key=lambda x: x.id)


# Event to count the number of agents in each state.
def event_tally_states(model):
    for key in range(State.COUNT):
        model.state_counter[key] = 0
    for agent in model.agents:
        model.state_counter[agent.state] += 1


# Event to print a CSV file header. In this simple implementation the
# agents and demographic outputs are all printed to standard output.  An
# improvement would have them print to their own file.
def event_print_stats_header(_):
    print("#,S,E,I,R,V,D,D_i")


# Event to print the number of agents in each state as well as some other
# useful demographic data, such as the number of infectious agents who
# died.
def event_print_stats(model):
    # The output_frequency parameter determines how frequently this event is
    # run. If we want it to run on every time step set to 1, but this is
    # likely unnecessary and will slow down execution.
    if model.current_time_step % model.parameters["output_frequency"] == 0:
        print(f"{model.current_time_step}", end=",")
        for i in range(State.COUNT):
            print(f"{model.state_counter[i]}", end=",")
        print(model.deaths_while_infectious)


# Event to print all the agents. We typically only execute this once before
# and after the model has run. But for debugging or other purposes it may
# be useful to do so in the middle of a simulation.
def event_print_agents(model):
    def sex(agent):
        if agent.sex == Sex.MALE:
            return "Male"
        else:
            return "Female"

    for agent in model.agents:
        print(f"Agent: {agent.id}. Sex: {sex(agent)}.", end=" ")
        print(f"Age {agent.age:.2f}. State: {agent.state}")


# If this source file is run from the command line with:
# python model_example.py
# then it will execute the following code. Otherwise it can be imported
# as a module.
if __name__ == '__main__':
    Model(
        # An improvement would be to allow the user to specify the parameters
        # at the command line or in a configuration file.
        parameters={
            # Run for 20 years.
            "time_steps": math.floor(20 * 365.25),
            # Population will be 100 with 10 initially exposed agents.
            # Note the C++ version of the program has 9,990 susceptible agents.
            "num_susceptible": 90,
            "num_exposed": 10,
            # Mean number of contacts per agent per day. You could create even
            # more heterogeneity by making this specific to each agent.
            "num_contacts_avg": 20.0,
            # Standard deviation of number of contacts per agent per day.
            "num_contacts_stdev": 10.0,
            # Risk of moving from susceptible to exposure state per contact.
            "risk_exposure_per_contact": 0.005,
            # Increased risk of an infected agent dying.
            "infectious_mortality_factor": 8.0,
            # Risks of moving from one state to another per time step (which is
            # one day).
            "risk_exposed_infectious": 0.1,
            "risk_infectious_recovered": 0.005,
            "risk_recovered_susceptible": 0.0001,
            "risk_susceptible_vaccinated": 0.0003,
            "risk_vaccinated_susceptible": 0.0001,
            # Number of new agents added to the model daily.
            "birth_rate": 0.000055,
            # How often, in time steps, to print the demographic outputs.
            "output_frequency": 20
        },

        # Before events
        before_events=[
            event_initialize_agents, event_print_agents, event_tally_states,
            event_print_stats_header, event_print_stats
        ],

        # During events
        during_events=[
            event_shuffle_agents, event_increment_age, event_infect,
            event_exposed_to_infectious, event_infectious_to_recovered,
            event_recovered_to_susceptible, event_susceptible_to_vaccinated,
            event_vaccinated_to_susceptible, event_births, event_death,
            event_tally_states, event_print_stats
        ],

        # After events
        after_events=[
            event_tally_states, event_print_stats, event_sort_agents,
            event_print_agents
        ]).run()