Skip to content

Commit

Permalink
Merge pull request #20 from kwcantrell/add-yml
Browse files Browse the repository at this point in the history
Add enviornment.yml
  • Loading branch information
jbk708 authored Sep 18, 2024
2 parents 8077bba + efa148c commit 5420cca
Show file tree
Hide file tree
Showing 12 changed files with 150 additions and 70 deletions.
33 changes: 21 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,50 @@

Attention-based network for microbial sequencing data.

# Installation
# Installation Instructions
IMPORTANT: If installing on a server cluster, spawn an instance with a GPU before proceeding with environment setup.
First create a new conda environment with unifrac

## Install Requirements
Requires tensorflow==2.14 and tf-models-official==2.14.2
`conda create --name aam -c conda-forge -c bioconda unifrac python=3.9 cython`

`pip install tensorflow==2.14 tf-models-official==2.14.2`
`conda activate aam`

or
## GPU Support

`pip install tensorflow[and-cuda]==2.14 tf-models-official==2.14.2`
Install CUDA 11.8

for GPU support.
`conda install nvidia/label/cuda-11.8.0::cuda-toolkit`

Tensorboard is an optional dependency to visualize training losses/metrics.
Verify the NVIDIA GPU drives are on your path

`pip install tensorboard`
`nvidia-smi`

Please see [Tensorflow](https://www.tensorflow.org/install) for more information

## Install AAM


For the latest version

`pip install git+https://github.com/kwcantrell/attention-all-microbes.git`

or
or install a specific version

`pip install git+https://github.com/kwcantrell/[email protected]`

for a specific tagged version.
## Developers

`git clone [email protected]:kwcantrell/attention-all-microbes.git`

`cd attention-all-microbes`

`pip install -e .`

# Training

Classifiers and Regressors are trained use cross-validation

`python attention_cli.py --help`
`attention --help`



Expand Down
14 changes: 14 additions & 0 deletions aam/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from .cv_utils import CVModel, EnsembleModel
from .transfer_data_utils import load_data
from .transfer_nuc_model import TransferLearnNucleotideModel
from .unifrac_data_utils import load_data as _load_unifrac_data
from .unifrac_model import UnifracModel

__all__ = [
"UnifracModel",
"_load_unifrac_data",
"load_data",
"TransferLearnNucleotideModel",
"CVModel",
"EnsembleModel",
]
59 changes: 41 additions & 18 deletions attention_cli.py → aam/attention_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from biom import load_table
from sklearn.model_selection import KFold, StratifiedKFold

from aam.callbacks import SaveModel, _confusion_matrix, _mean_absolute_error
from aam.callbacks import (
ConfusionMatrx,
SaveModel,
_confusion_matrix,
_mean_absolute_error,
)
from aam.cv_utils import CVModel, EnsembleModel
from aam.losses import ImbalancedCategoricalCrossEntropy, ImbalancedMSE
from aam.transfer_nuc_model import TransferLearnNucleotideModel
Expand All @@ -32,43 +37,42 @@ class cli:
@click.option("--i-table", required=True, type=click.Path(exists=True), help=TABLE_DESC)
@click.option("--i-tree", required=True, type=click.Path(exists=True))
@click.option("--p-max-bp", required=True, type=int)
@click.option("--p-batch-size", default=8, type=int, show_default=True)
@click.option("--p-epochs", default=1000, show_default=True, type=int)
@click.option("--p-dropout", default=0.01, show_default=True, type=float)
@click.option("--p-dropout", default=0.0, show_default=True, type=float)
@click.option("--p-ff-d-model", default=128, show_default=True, type=int)
@click.option("--p-pca-heads", default=8, show_default=True, type=int)
@click.option("--p-enc-layers", default=2, show_default=True, type=int)
@click.option("--p-enc-heads", default=8, show_default=True, type=int)
@click.option("--p-output-dir", required=True)
@click.option("--output-dir", required=True)
def fit_unifrac_regressor(
i_table: str,
i_tree: str,
p_max_bp: int,
p_batch_size: int,
p_epochs: int,
p_dropout: float,
p_ff_d_model: int,
p_pca_heads: int,
p_enc_layers: int,
p_enc_heads: int,
p_output_dir: str,
output_dir: str,
):
from aam.unifrac_data_utils import load_data

tf.keras.mixed_precision.set_global_policy("mixed_float16")
if not os.path.exists(p_output_dir):
os.makedirs(p_output_dir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)

figure_path = os.path.join(p_output_dir, "figures")
figure_path = os.path.join(output_dir, "figures")
if not os.path.exists(figure_path):
os.makedirs(figure_path)

data_obj = load_data(
i_table,
tree_path=i_tree,
)
data_obj = load_data(i_table, tree_path=i_tree, batch_size=p_batch_size)

load_model = True
load_model = False
if load_model:
model = tf.keras.models.load_model(f"{p_output_dir}/model.keras")
model = tf.keras.models.load_model(f"{output_dir}/model.keras")
else:
model = UnifracModel(
p_ff_d_model,
Expand All @@ -91,23 +95,28 @@ def fit_unifrac_regressor(
run_eagerly=False,
)
model.summary()
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = os.path.join(p_output_dir, log_dir)
log_dir = "logs/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = os.path.join(output_dir, log_dir)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
model_save_path = os.path.join(output_dir, "model.keras")
model_saver = SaveModel(model_save_path, 1)
core_callbacks = [
tf.keras.callbacks.TensorBoard(
log_dir=log_dir,
histogram_freq=0,
),
SaveModel(p_output_dir, 1),
tf.keras.callbacks.EarlyStopping("val_loss", patience=5, start_from_epoch=5),
model_saver,
]
model.fit(
data_obj["training_dataset"],
validation_data=data_obj["validation_dataset"],
callbacks=[*core_callbacks],
epochs=p_epochs,
)
model.set_weights(model_saver.best_weights)
model.save(model_save_path, save_format="keras")


@cli.command()
Expand Down Expand Up @@ -289,6 +298,7 @@ def _get_fold(indices, shuffle):
)
@click.option("--p-patience", default=10, show_default=True, type=int)
@click.option("--p-early-stop-warmup", default=50, show_default=True, type=int)
@click.option("--p-report-back", default=1, show_default=True, type=int)
@click.option("--output-dir", required=True, type=click.Path(exists=False))
def fit_sample_classifier(
i_table: str,
Expand All @@ -304,6 +314,7 @@ def fit_sample_classifier(
p_stratify: bool,
p_patience: int,
p_early_stop_warmup: int,
p_report_back: int,
output_dir: str,
):
from aam.transfer_data_utils import (
Expand All @@ -326,10 +337,12 @@ def fit_sample_classifier(

table = load_table(i_table)
df = pd.read_csv(m_metadata_file, sep="\t", index_col=0)[[m_metadata_column]]
print(df)
ids, table, df = validate_metadata(table, df, p_missing_samples)
table, df = shuffle(table, df)
num_ids = len(ids)

categories = df[m_metadata_column].astype("category").cat.categories
print("int", categories)
fold_indices = [i for i in range(num_ids)]
if p_test_size > 0:
test_size = int(num_ids * p_test_size)
Expand Down Expand Up @@ -371,7 +384,7 @@ def _get_fold(indices, shuffle):
penalty=p_penalty,
num_classes=train_data["num_classes"],
)
loss = ImbalancedCategoricalCrossEntropy(list(train_data["cat_counts"]))
loss = ImbalancedCategoricalCrossEntropy(train_data["cat_counts"])
fold_label = i + 1
model_cv = CVModel(
model,
Expand All @@ -387,6 +400,16 @@ def _get_fold(indices, shuffle):
metric="target_loss",
patience=p_patience,
early_stop_warmup=p_early_stop_warmup,
callbacks=[
ConfusionMatrx(
dataset=val_data["dataset"],
output_dir=os.path.join(
figure_path, f"model-f{fold_label}-val.png"
),
report_back=p_report_back,
labels=categories,
)
],
)
models.append(model_cv)

Expand Down
21 changes: 17 additions & 4 deletions aam/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _mean_absolute_error(pred_val, true_val, fname, labels=None):
plt.close()


def _confusion_matrix(pred_val, true_val, fname, labels):
def _confusion_matrix(pred_val, true_val, fname, cat_labels=None):
cf_matrix = tf.math.confusion_matrix(true_val, pred_val).numpy()
group_counts = ["{0:0.0f}".format(value) for value in cf_matrix.flatten()]
group_percentages = [
Expand All @@ -46,8 +46,8 @@ def _confusion_matrix(pred_val, true_val, fname, labels):
ax = sns.heatmap(
cf_matrix,
annot=labels,
xticklabels=labels,
yticklabels=labels,
xticklabels=cat_labels,
yticklabels=cat_labels,
fmt="",
)
import textwrap
Expand All @@ -67,8 +67,21 @@ def wrap_labels(ax, width, break_long_words=False):
plt.close()


class ConfusionMatrx(tf.keras.callbacks.Callback):
def __init__(self, dataset, output_dir, report_back, labels, **kwargs):
super().__init__(**kwargs)
self.output_dir = output_dir
self.report_back = report_back
self.dataset = dataset
self.labels = labels

def on_epoch_end(self, epoch, logs=None):
y_pred, y_true = self.model.predict(self.dataset)
_confusion_matrix(y_pred, y_true, self.output_dir, self.labels)


class SaveModel(tf.keras.callbacks.Callback):
def __init__(self, output_dir, report_back, monitor, **kwargs):
def __init__(self, output_dir, report_back, monitor="val_loss", **kwargs):
super().__init__(**kwargs)
self.output_dir = output_dir
self.report_back = report_back
Expand Down
3 changes: 2 additions & 1 deletion aam/cv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def fit_fold(
metric="mae",
patience=10,
early_stop_warmup=50,
callbacks=[],
):
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
Expand All @@ -51,7 +52,7 @@ def fit_fold(
self.model.fit(
self.train_dataset,
validation_data=self.val_dataset,
callbacks=[*core_callbacks],
callbacks=[*callbacks, *core_callbacks],
epochs=epochs,
# verbose=0,
)
Expand Down
19 changes: 13 additions & 6 deletions aam/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,22 +55,27 @@ def __init__(
self.dropout_rate = dropout_rate

self.base_tokens = 6
self.num_tokens = self.base_tokens * 150 + 2
self.num_tokens = self.base_tokens * self.max_bp + 2
self.emb_layer = tf.keras.layers.Embedding(
self.num_tokens,
self.token_dim,
input_length=self.max_bp,
embeddings_initializer=tf.keras.initializers.GlorotNormal(),
)
self.avs_attention = NucleotideAttention(
128, num_heads=2, num_layers=3, dropout=0.0
128, max_bp=self.max_bp, num_heads=2, num_layers=3, dropout=0.0
)
self.asv_token = self.num_tokens - 1

self.nucleotide_position = tf.range(0, 4 * 150, 4, dtype=tf.int32)
self.nucleotide_position = tf.range(
0, self.base_tokens * self.max_bp, self.base_tokens, dtype=tf.int32
)

def call(self, inputs, nuc_mask=None, training=False):
seq = inputs
seq_mask = float_mask(seq, dtype=tf.int32)
seq = seq + self.nucleotide_position
seq = seq * seq_mask

if nuc_mask is not None:
seq = seq * nuc_mask
Expand Down Expand Up @@ -147,7 +152,7 @@ def __init__(
intermediate_size=self.attention_ff,
norm_first=True,
activation="relu",
dropout_rate=0.1,
dropout_rate=self.dropout_rate,
)
self.sample_token = self.add_weight(
"sample_token",
Expand Down Expand Up @@ -199,9 +204,10 @@ def get_config(self):

@tf.keras.saving.register_keras_serializable(package="NucleotideAttention")
class NucleotideAttention(tf.keras.layers.Layer):
def __init__(self, hidden_dim, num_heads, num_layers, dropout, **kwargs):
def __init__(self, hidden_dim, max_bp, num_heads, num_layers, dropout, **kwargs):
super(NucleotideAttention, self).__init__(**kwargs)
self.hidden_dim = hidden_dim
self.max_bp = max_bp
self.num_heads = num_heads
self.num_layers = num_layers
self.dropout = dropout
Expand All @@ -210,7 +216,7 @@ def __init__(self, hidden_dim, num_heads, num_layers, dropout, **kwargs):
self.compress_df = tf.keras.layers.Dense(64)
self.decompress_df = tf.keras.layers.Dense(128)
self.pos_emb = tfm.nlp.layers.PositionEmbedding(
max_length=151, seq_axis=2, name="nuc_pos"
max_length=self.max_bp + 1, seq_axis=2, name="nuc_pos"
)
self.attention_layers = []
for i in range(self.num_layers):
Expand Down Expand Up @@ -243,6 +249,7 @@ def get_config(self):
config.update(
{
"hidden_dim": self.hidden_dim,
"max_bp": self.max_bp,
"num_heads": self.num_heads,
"num_layers": self.num_layers,
"dropout": self.dropout,
Expand Down
4 changes: 2 additions & 2 deletions aam/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@ def get_config(self):
class ImbalancedCategoricalCrossEntropy(tf.keras.losses.Loss):
def __init__(self, adjustment_weights=[0.1, 0.2, 0.3], reduction="none", **kwargs):
super().__init__(reduction=reduction, **kwargs)
self.num_classes = len(adjustment_weights)
adjustment_weights = tf.constant(adjustment_weights)
adjustment_weights = tf.reduce_sum(adjustment_weights) / adjustment_weights
adjustment_weights = tf.expand_dims(adjustment_weights, axis=-1)
self.adjustment_weights = adjustment_weights
self.num_classes = len(adjustment_weights)

def call(self, y_true, y_pred):
y_true = tf.cast(y_true, dtype=tf.int32)
Expand All @@ -95,7 +95,7 @@ def call(self, y_true, y_pred):

y_true = tf.one_hot(y_true, self.num_classes)
loss = tf.keras.losses.categorical_crossentropy(y_true, y_pred)
return loss * weights
return (weights) * loss

def get_config(self):
config = super().get_config()
Expand Down
2 changes: 1 addition & 1 deletion aam/transfer_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _get_table_data(table_data):
cat_counts = target_data.value_counts()
cat_counts = cat_counts.reindex(cat_labels).to_numpy().astype(np.float32)
target_data = target_data.cat.codes
num_classes = np.max(target_data) + 1
num_classes = len(cat_labels)
target_data = tf.expand_dims(target_data, axis=-1)
target_dataset = tf.data.Dataset.from_tensor_slices(target_data)
else:
Expand Down
Loading

0 comments on commit 5420cca

Please sign in to comment.