Skip to content

Commit

Permalink
Run formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz committed Feb 23, 2024
1 parent ac9332d commit d20c99c
Show file tree
Hide file tree
Showing 23 changed files with 31 additions and 4 deletions.
1 change: 1 addition & 0 deletions src/jnotype/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Exploratory analysis of binary data."""

import jnotype.bmm as bmm
import jnotype.datasets as datasets
import jnotype.sampling as sampling
Expand Down
1 change: 1 addition & 0 deletions src/jnotype/_csp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Cumulative shrinkage prior."""

import jax
import jax.numpy as jnp
import jax.scipy as jsp
Expand Down
1 change: 1 addition & 0 deletions src/jnotype/_factor_analysis/_gibbs_backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Sampling steps for all variables, apart from variances
attributed to latent traits, which are sampled with CSP module."""

from typing import Callable

import jax
Expand Down
1 change: 1 addition & 0 deletions src/jnotype/_factor_analysis/_inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Sampling from the posterior distribution."""

from typing import Callable

from jaxtyping import Float, Array
Expand Down
1 change: 1 addition & 0 deletions src/jnotype/_factor_analysis/_simulate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Simulate data sets."""

from jaxtyping import Float, Array

import jax.numpy as jnp
Expand Down
1 change: 1 addition & 0 deletions src/jnotype/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
This file should be as small as possible.
Appearing themes should be refactored and placed
into separate modules."""

import jax


Expand Down
1 change: 1 addition & 0 deletions src/jnotype/_variance.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utilities for sampling variances."""

from jax import random
import jax
import jax.numpy as jnp
Expand Down
1 change: 1 addition & 0 deletions src/jnotype/bmm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Bernoulli Mixture Model."""

from jnotype.bmm._em import expectation_maximization
from jnotype.bmm._gibbs import (
sample_mixing,
Expand Down
7 changes: 6 additions & 1 deletion src/jnotype/bmm/_em.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""The Expectation-Maximization algorithm for Bernoulli Mixture Model."""

import dataclasses
import time
from typing import Optional
Expand Down Expand Up @@ -68,7 +69,11 @@ def em_step(
observed: Int[Array, "N K"],
mixing: Float[Array, "K B"],
proportions: Float[Array, " B"],
) -> tuple[Float[Array, "N B"], Float[Array, "K B"], Float[Array, " B"],]:
) -> tuple[
Float[Array, "N B"],
Float[Array, "K B"],
Float[Array, " B"],
]:
"""The E and M step combined, for better JIT compiler optimisation.
Args:
Expand Down
1 change: 1 addition & 0 deletions src/jnotype/bmm/_gibbs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Sampling cluster labels and proportions."""

from typing import Optional, Sequence

import jax
Expand Down
1 change: 1 addition & 0 deletions src/jnotype/checks/_histograms.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Plotting histograms of data."""

from typing import Sequence, Union, Literal

import matplotlib.pyplot as plt
Expand Down
1 change: 1 addition & 0 deletions src/jnotype/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Data sets."""

from jnotype.datasets._simulation import BlockImagesSampler

__all__ = [
Expand Down
1 change: 1 addition & 0 deletions src/jnotype/datasets/_simulation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Simulated data sets."""

from jnotype.datasets._simulation._block_images import BlockImagesSampler

__all__ = [
Expand Down
1 change: 1 addition & 0 deletions src/jnotype/datasets/_simulation/_block_images.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Simulation of binary images using Bernoulli mixture model."""

from typing import Optional

from jaxtyping import Array, Float, Int
Expand Down
1 change: 1 addition & 0 deletions src/jnotype/logistic/_binary_latent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Sample binary latent variables."""

from functools import partial

import jax
Expand Down
7 changes: 4 additions & 3 deletions src/jnotype/logistic/_polyagamma.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Logistic regression sampling utilities using Pólya-Gamma augmentation."""

from jax import random
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -53,9 +54,9 @@ def _sample_coefficients(
precision_matrices: Float[Array, "features covariates covariates"] = jax.vmap(
jnp.diag
)(jnp.reciprocal(prior_variance))
posterior_covariances: Float[
Array, "features covariates covariates"
] = jnp.linalg.inv(x_omega_x + precision_matrices)
posterior_covariances: Float[Array, "features covariates covariates"] = (
jnp.linalg.inv(x_omega_x + precision_matrices)
)

kappa: Float[Array, "points features"] = jnp.asarray(observed, dtype=float) - 0.5

Expand Down
1 change: 1 addition & 0 deletions src/jnotype/logistic/_structure.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Sample structure (spike/slab distinction) variables."""

from typing import Union

import jax
Expand Down
1 change: 1 addition & 0 deletions src/jnotype/logistic/logreg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Logistic regression utilities."""

import jax
import jax.numpy as jnp
from jaxtyping import Int, Float, Array
Expand Down
1 change: 1 addition & 0 deletions src/jnotype/pyramids/_sampler_csp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Sampler for two-layer Bayesian pyramids
with cumulative shrinkage process (CSP) prior
on latent binary codes."""

from typing import Optional, Sequence, Union, NewType

import jax
Expand Down
1 change: 1 addition & 0 deletions src/jnotype/pyramids/_sampler_fixed.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Sampler for two-layer Bayesian pyramids
with fixed number of latent binary codes."""

from typing import Optional, Sequence, Union, NewType

import jax
Expand Down
1 change: 1 addition & 0 deletions src/jnotype/sampling/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Generic utilities for sampling."""

from jnotype.sampling._chunker import (
DatasetInterface,
ListDataset,
Expand Down
1 change: 1 addition & 0 deletions src/jnotype/sampling/_chunker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utilities for saving samples in chunks, to limit RAM usage."""

import abc

from datetime import datetime
Expand Down
1 change: 1 addition & 0 deletions src/jnotype/sampling/_sampler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Generic Gibbs sampler."""

import abc
import logging
import time
Expand Down

0 comments on commit d20c99c

Please sign in to comment.