Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sample from fitted LatentCalendar #51

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion latent_calendar/const.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Constants used to create the full vocabulary of the dataset."""
"""Constants used to create the full vocabulary of the dataset.
"""
import calendar
from itertools import product
from typing import Dict, List, Union
Expand Down
92 changes: 73 additions & 19 deletions latent_calendar/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,22 @@
from typing import Optional, Tuple, Union

import numpy as np
from numpy import typing as npt
import pandas as pd

from latent_calendar.const import FULL_VOCAB
from latent_calendar.model.latent_calendar import LatentCalendar

try:
import pymc as pm
from pytensor.tensor import TensorVariable
from pytensor import tensor as pt
except ImportError:

class TensorVariable:
pass

class PyMC:
class MockModule:

Check warning on line 20 in latent_calendar/generate.py

View check run for this annotation

Codecov / codecov/patch

latent_calendar/generate.py#L20

Added line #L20 was not covered by tests
def __getattr__(self, name):
msg = (
"PyMC is not installed."
Expand All @@ -23,7 +26,8 @@
)
raise ImportError(msg)

pm = PyMC()
pm = MockModule()
pt = MockModule()

Check warning on line 30 in latent_calendar/generate.py

View check run for this annotation

Codecov / codecov/patch

latent_calendar/generate.py#L29-L30

Added lines #L29 - L30 were not covered by tests


def wide_format_dataframe(
Expand Down Expand Up @@ -57,31 +61,24 @@
return travel_style_user, time_slots


def sample_from_lda(
components_prior: Union[np.ndarray, TensorVariable],
components_time_slots_prior: Union[np.ndarray, TensorVariable],
n_samples: np.ndarray,
random_state: Optional[int] = None,
) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""Sample from LDA model.
N_SAMPLES = Union[npt.NDArray[np.int_], int]

Args:
components_prior: prior probability of each component (n_components, )
components_time_slots_prior: prior for time slots (n_components, n_time_slots)
n_samples: number of samples for each user (n_user, )
random_state: random state for sampling
SAMPLE_RESULT = Tuple[pd.DataFrame, pd.DataFrame]

Returns:
probability DataFrame (n_user, n_components) and event count DataFrame with (n_user, n_time_slots) with each row summing up to `n`

"""
def _sample_lda(
travel_style: npt.NDArray[np.float_],
time_slot_styles: npt.NDArray[np.float_],
n_samples: N_SAMPLES,
random_state: Optional[int] = None,
) -> SAMPLE_RESULT:
rng = np.random.default_rng(random_state)

user_travel_style_data = []
user_time_slot_data = []

travel_style = pm.Dirichlet.dist(components_prior)
time_slot_styles = pm.Dirichlet.dist(components_time_slots_prior)
if isinstance(n_samples, int):
n_samples = [n_samples]

Check warning on line 81 in latent_calendar/generate.py

View check run for this annotation

Codecov / codecov/patch

latent_calendar/generate.py#L80-L81

Added lines #L80 - L81 were not covered by tests

for n in n_samples:
_, user_time_slots = define_single_user_samples(
Expand All @@ -99,3 +96,60 @@
df_user_time_slots = pd.DataFrame(user_time_slot_data)

return df_user_travel_style, df_user_time_slots


def sample_from_lda(
components_prior: Union[np.ndarray, TensorVariable],
components_time_slots_prior: Union[np.ndarray, TensorVariable],
n_samples: N_SAMPLES,
random_state: Optional[int] = None,
) -> SAMPLE_RESULT:
"""Sample from LDA model.

Args:
components_prior: prior probability of each component (n_components, )
components_time_slots_prior: prior for time slots (n_components, n_time_slots)
n_samples: number of samples for all users or for each user (n_user, )
random_state: random state for sampling

Returns:
probability DataFrame (n_user, n_components) and event count DataFrame with (n_user, n_time_slots) with each row summing up to `n`

"""

travel_style = pm.Dirichlet.dist(components_prior)
time_slot_styles = pm.Dirichlet.dist(components_time_slots_prior)

Check warning on line 121 in latent_calendar/generate.py

View check run for this annotation

Codecov / codecov/patch

latent_calendar/generate.py#L120-L121

Added lines #L120 - L121 were not covered by tests

return _sample_lda(

Check warning on line 123 in latent_calendar/generate.py

View check run for this annotation

Codecov / codecov/patch

latent_calendar/generate.py#L123

Added line #L123 was not covered by tests
travel_style=travel_style,
time_slot_styles=time_slot_styles,
n_samples=n_samples,
random_state=random_state,
)


def sample_from_latent_calendar(
latent_calendar: LatentCalendar,
n_samples: Union,
random_state: Optional[int] = None,
) -> SAMPLE_RESULT:
"""Sample from a latent calendar model.

Args:
latent_calendar: fitted latent calendar model
n_samples: number of rows to sample
random_state: random state for reproducibility

Returns:
probability DataFrame (n_user, n_components) and event count DataFrame with (n_user, n_time_slots) with each row summing up to `n`

"""
# TODO: Figure out how to best recreate based on the population
travel_style = pm.Dirichlet.dist(latent_calendar.component_distribution_)
time_slot_styles = pm.Dirichlet.dist(latent_calendar.components_)
return _sample_lda(

Check warning on line 150 in latent_calendar/generate.py

View check run for this annotation

Codecov / codecov/patch

latent_calendar/generate.py#L148-L150

Added lines #L148 - L150 were not covered by tests
travel_style=travel_style,
time_slot_styles=time_slot_styles,
n_samples=n_samples,
random_state=random_state,
)
Loading