diff --git a/README.md b/README.md index 14281b1..1969550 100644 --- a/README.md +++ b/README.md @@ -2,41 +2,50 @@ Attention-based network for microbial sequencing data. -# Installation +# Installation Instructions +IMPORTANT: If installing on a server cluster, spawn an instance with a GPU before proceeding with environment setup. +First create a new conda environment with unifrac -## Install Requirements -Requires tensorflow==2.14 and tf-models-official==2.14.2 +`conda create --name aam -c conda-forge -c bioconda unifrac python=3.9 cython` -`pip install tensorflow==2.14 tf-models-official==2.14.2` +`conda activate aam` -or +## GPU Support - `pip install tensorflow[and-cuda]==2.14 tf-models-official==2.14.2` +Install CUDA 11.8 -for GPU support. +`conda install nvidia/label/cuda-11.8.0::cuda-toolkit` -Tensorboard is an optional dependency to visualize training losses/metrics. +Verify the NVIDIA GPU drives are on your path -`pip install tensorboard` +`nvidia-smi` + +Please see [Tensorflow](https://www.tensorflow.org/install) for more information ## Install AAM + For the latest version `pip install git+https://github.com/kwcantrell/attention-all-microbes.git` -or +or install a specific version `pip install git+https://github.com/kwcantrell/attention-all-microbes.git@v0.1.0` -for a specific tagged version. +## Developers + +`git clone git@github.com:kwcantrell/attention-all-microbes.git` + +`cd attention-all-microbes` +`pip install -e .` # Training Classifiers and Regressors are trained use cross-validation -`python attention_cli.py --help` +`attention --help` diff --git a/aam/__init__.py b/aam/__init__.py index e69de29..ecd5f11 100644 --- a/aam/__init__.py +++ b/aam/__init__.py @@ -0,0 +1,14 @@ +from .cv_utils import CVModel, EnsembleModel +from .transfer_data_utils import load_data +from .transfer_nuc_model import TransferLearnNucleotideModel +from .unifrac_data_utils import load_data as _load_unifrac_data +from .unifrac_model import UnifracModel + +__all__ = [ + "UnifracModel", + "_load_unifrac_data", + "load_data", + "TransferLearnNucleotideModel", + "CVModel", + "EnsembleModel", +] diff --git a/attention_cli.py b/aam/attention_cli.py similarity index 87% rename from attention_cli.py rename to aam/attention_cli.py index 3ff7252..f530691 100644 --- a/attention_cli.py +++ b/aam/attention_cli.py @@ -7,7 +7,12 @@ from biom import load_table from sklearn.model_selection import KFold, StratifiedKFold -from aam.callbacks import SaveModel, _confusion_matrix, _mean_absolute_error +from aam.callbacks import ( + ConfusionMatrx, + SaveModel, + _confusion_matrix, + _mean_absolute_error, +) from aam.cv_utils import CVModel, EnsembleModel from aam.losses import ImbalancedCategoricalCrossEntropy, ImbalancedMSE from aam.transfer_nuc_model import TransferLearnNucleotideModel @@ -32,43 +37,42 @@ class cli: @click.option("--i-table", required=True, type=click.Path(exists=True), help=TABLE_DESC) @click.option("--i-tree", required=True, type=click.Path(exists=True)) @click.option("--p-max-bp", required=True, type=int) +@click.option("--p-batch-size", default=8, type=int, show_default=True) @click.option("--p-epochs", default=1000, show_default=True, type=int) -@click.option("--p-dropout", default=0.01, show_default=True, type=float) +@click.option("--p-dropout", default=0.0, show_default=True, type=float) @click.option("--p-ff-d-model", default=128, show_default=True, type=int) @click.option("--p-pca-heads", default=8, show_default=True, type=int) @click.option("--p-enc-layers", default=2, show_default=True, type=int) @click.option("--p-enc-heads", default=8, show_default=True, type=int) -@click.option("--p-output-dir", required=True) +@click.option("--output-dir", required=True) def fit_unifrac_regressor( i_table: str, i_tree: str, p_max_bp: int, + p_batch_size: int, p_epochs: int, p_dropout: float, p_ff_d_model: int, p_pca_heads: int, p_enc_layers: int, p_enc_heads: int, - p_output_dir: str, + output_dir: str, ): from aam.unifrac_data_utils import load_data tf.keras.mixed_precision.set_global_policy("mixed_float16") - if not os.path.exists(p_output_dir): - os.makedirs(p_output_dir) + if not os.path.exists(output_dir): + os.makedirs(output_dir) - figure_path = os.path.join(p_output_dir, "figures") + figure_path = os.path.join(output_dir, "figures") if not os.path.exists(figure_path): os.makedirs(figure_path) - data_obj = load_data( - i_table, - tree_path=i_tree, - ) + data_obj = load_data(i_table, tree_path=i_tree, batch_size=p_batch_size) - load_model = True + load_model = False if load_model: - model = tf.keras.models.load_model(f"{p_output_dir}/model.keras") + model = tf.keras.models.load_model(f"{output_dir}/model.keras") else: model = UnifracModel( p_ff_d_model, @@ -91,16 +95,19 @@ def fit_unifrac_regressor( run_eagerly=False, ) model.summary() - log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") - log_dir = os.path.join(p_output_dir, log_dir) + log_dir = "logs/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + log_dir = os.path.join(output_dir, log_dir) if not os.path.exists(log_dir): os.makedirs(log_dir) + model_save_path = os.path.join(output_dir, "model.keras") + model_saver = SaveModel(model_save_path, 1) core_callbacks = [ tf.keras.callbacks.TensorBoard( log_dir=log_dir, histogram_freq=0, ), - SaveModel(p_output_dir, 1), + tf.keras.callbacks.EarlyStopping("val_loss", patience=5, start_from_epoch=5), + model_saver, ] model.fit( data_obj["training_dataset"], @@ -108,6 +115,8 @@ def fit_unifrac_regressor( callbacks=[*core_callbacks], epochs=p_epochs, ) + model.set_weights(model_saver.best_weights) + model.save(model_save_path, save_format="keras") @cli.command() @@ -289,6 +298,7 @@ def _get_fold(indices, shuffle): ) @click.option("--p-patience", default=10, show_default=True, type=int) @click.option("--p-early-stop-warmup", default=50, show_default=True, type=int) +@click.option("--p-report-back", default=1, show_default=True, type=int) @click.option("--output-dir", required=True, type=click.Path(exists=False)) def fit_sample_classifier( i_table: str, @@ -304,6 +314,7 @@ def fit_sample_classifier( p_stratify: bool, p_patience: int, p_early_stop_warmup: int, + p_report_back: int, output_dir: str, ): from aam.transfer_data_utils import ( @@ -326,10 +337,12 @@ def fit_sample_classifier( table = load_table(i_table) df = pd.read_csv(m_metadata_file, sep="\t", index_col=0)[[m_metadata_column]] + print(df) ids, table, df = validate_metadata(table, df, p_missing_samples) table, df = shuffle(table, df) num_ids = len(ids) - + categories = df[m_metadata_column].astype("category").cat.categories + print("int", categories) fold_indices = [i for i in range(num_ids)] if p_test_size > 0: test_size = int(num_ids * p_test_size) @@ -371,7 +384,7 @@ def _get_fold(indices, shuffle): penalty=p_penalty, num_classes=train_data["num_classes"], ) - loss = ImbalancedCategoricalCrossEntropy(list(train_data["cat_counts"])) + loss = ImbalancedCategoricalCrossEntropy(train_data["cat_counts"]) fold_label = i + 1 model_cv = CVModel( model, @@ -387,6 +400,16 @@ def _get_fold(indices, shuffle): metric="target_loss", patience=p_patience, early_stop_warmup=p_early_stop_warmup, + callbacks=[ + ConfusionMatrx( + dataset=val_data["dataset"], + output_dir=os.path.join( + figure_path, f"model-f{fold_label}-val.png" + ), + report_back=p_report_back, + labels=categories, + ) + ], ) models.append(model_cv) diff --git a/aam/callbacks.py b/aam/callbacks.py index 489c998..e66e89a 100644 --- a/aam/callbacks.py +++ b/aam/callbacks.py @@ -34,7 +34,7 @@ def _mean_absolute_error(pred_val, true_val, fname, labels=None): plt.close() -def _confusion_matrix(pred_val, true_val, fname, labels): +def _confusion_matrix(pred_val, true_val, fname, cat_labels=None): cf_matrix = tf.math.confusion_matrix(true_val, pred_val).numpy() group_counts = ["{0:0.0f}".format(value) for value in cf_matrix.flatten()] group_percentages = [ @@ -46,8 +46,8 @@ def _confusion_matrix(pred_val, true_val, fname, labels): ax = sns.heatmap( cf_matrix, annot=labels, - xticklabels=labels, - yticklabels=labels, + xticklabels=cat_labels, + yticklabels=cat_labels, fmt="", ) import textwrap @@ -67,8 +67,21 @@ def wrap_labels(ax, width, break_long_words=False): plt.close() +class ConfusionMatrx(tf.keras.callbacks.Callback): + def __init__(self, dataset, output_dir, report_back, labels, **kwargs): + super().__init__(**kwargs) + self.output_dir = output_dir + self.report_back = report_back + self.dataset = dataset + self.labels = labels + + def on_epoch_end(self, epoch, logs=None): + y_pred, y_true = self.model.predict(self.dataset) + _confusion_matrix(y_pred, y_true, self.output_dir, self.labels) + + class SaveModel(tf.keras.callbacks.Callback): - def __init__(self, output_dir, report_back, monitor, **kwargs): + def __init__(self, output_dir, report_back, monitor="val_loss", **kwargs): super().__init__(**kwargs) self.output_dir = output_dir self.report_back = report_back diff --git a/aam/cv_utils.py b/aam/cv_utils.py index 1b37e5c..6c917ea 100644 --- a/aam/cv_utils.py +++ b/aam/cv_utils.py @@ -31,6 +31,7 @@ def fit_fold( metric="mae", patience=10, early_stop_warmup=50, + callbacks=[], ): if not os.path.exists(self.log_dir): os.makedirs(self.log_dir) @@ -51,7 +52,7 @@ def fit_fold( self.model.fit( self.train_dataset, validation_data=self.val_dataset, - callbacks=[*core_callbacks], + callbacks=[*callbacks, *core_callbacks], epochs=epochs, # verbose=0, ) diff --git a/aam/layers.py b/aam/layers.py index 5d07e8c..b520730 100644 --- a/aam/layers.py +++ b/aam/layers.py @@ -55,7 +55,7 @@ def __init__( self.dropout_rate = dropout_rate self.base_tokens = 6 - self.num_tokens = self.base_tokens * 150 + 2 + self.num_tokens = self.base_tokens * self.max_bp + 2 self.emb_layer = tf.keras.layers.Embedding( self.num_tokens, self.token_dim, @@ -63,14 +63,19 @@ def __init__( embeddings_initializer=tf.keras.initializers.GlorotNormal(), ) self.avs_attention = NucleotideAttention( - 128, num_heads=2, num_layers=3, dropout=0.0 + 128, max_bp=self.max_bp, num_heads=2, num_layers=3, dropout=0.0 ) self.asv_token = self.num_tokens - 1 - self.nucleotide_position = tf.range(0, 4 * 150, 4, dtype=tf.int32) + self.nucleotide_position = tf.range( + 0, self.base_tokens * self.max_bp, self.base_tokens, dtype=tf.int32 + ) def call(self, inputs, nuc_mask=None, training=False): seq = inputs + seq_mask = float_mask(seq, dtype=tf.int32) + seq = seq + self.nucleotide_position + seq = seq * seq_mask if nuc_mask is not None: seq = seq * nuc_mask @@ -147,7 +152,7 @@ def __init__( intermediate_size=self.attention_ff, norm_first=True, activation="relu", - dropout_rate=0.1, + dropout_rate=self.dropout_rate, ) self.sample_token = self.add_weight( "sample_token", @@ -199,9 +204,10 @@ def get_config(self): @tf.keras.saving.register_keras_serializable(package="NucleotideAttention") class NucleotideAttention(tf.keras.layers.Layer): - def __init__(self, hidden_dim, num_heads, num_layers, dropout, **kwargs): + def __init__(self, hidden_dim, max_bp, num_heads, num_layers, dropout, **kwargs): super(NucleotideAttention, self).__init__(**kwargs) self.hidden_dim = hidden_dim + self.max_bp = max_bp self.num_heads = num_heads self.num_layers = num_layers self.dropout = dropout @@ -210,7 +216,7 @@ def __init__(self, hidden_dim, num_heads, num_layers, dropout, **kwargs): self.compress_df = tf.keras.layers.Dense(64) self.decompress_df = tf.keras.layers.Dense(128) self.pos_emb = tfm.nlp.layers.PositionEmbedding( - max_length=151, seq_axis=2, name="nuc_pos" + max_length=self.max_bp + 1, seq_axis=2, name="nuc_pos" ) self.attention_layers = [] for i in range(self.num_layers): @@ -243,6 +249,7 @@ def get_config(self): config.update( { "hidden_dim": self.hidden_dim, + "max_bp": self.max_bp, "num_heads": self.num_heads, "num_layers": self.num_layers, "dropout": self.dropout, diff --git a/aam/losses.py b/aam/losses.py index 0d28058..1f1e6a2 100644 --- a/aam/losses.py +++ b/aam/losses.py @@ -82,11 +82,11 @@ def get_config(self): class ImbalancedCategoricalCrossEntropy(tf.keras.losses.Loss): def __init__(self, adjustment_weights=[0.1, 0.2, 0.3], reduction="none", **kwargs): super().__init__(reduction=reduction, **kwargs) + self.num_classes = len(adjustment_weights) adjustment_weights = tf.constant(adjustment_weights) adjustment_weights = tf.reduce_sum(adjustment_weights) / adjustment_weights adjustment_weights = tf.expand_dims(adjustment_weights, axis=-1) self.adjustment_weights = adjustment_weights - self.num_classes = len(adjustment_weights) def call(self, y_true, y_pred): y_true = tf.cast(y_true, dtype=tf.int32) @@ -95,7 +95,7 @@ def call(self, y_true, y_pred): y_true = tf.one_hot(y_true, self.num_classes) loss = tf.keras.losses.categorical_crossentropy(y_true, y_pred) - return loss * weights + return (weights) * loss def get_config(self): config = super().get_config() diff --git a/aam/transfer_data_utils.py b/aam/transfer_data_utils.py index a817167..060e92a 100644 --- a/aam/transfer_data_utils.py +++ b/aam/transfer_data_utils.py @@ -70,7 +70,7 @@ def _get_table_data(table_data): cat_counts = target_data.value_counts() cat_counts = cat_counts.reindex(cat_labels).to_numpy().astype(np.float32) target_data = target_data.cat.codes - num_classes = np.max(target_data) + 1 + num_classes = len(cat_labels) target_data = tf.expand_dims(target_data, axis=-1) target_dataset = tf.data.Dataset.from_tensor_slices(target_data) else: diff --git a/aam/transfer_nuc_model.py b/aam/transfer_nuc_model.py index 1682e4c..8ba19b2 100644 --- a/aam/transfer_nuc_model.py +++ b/aam/transfer_nuc_model.py @@ -20,11 +20,11 @@ def __init__( super(TransferLearnNucleotideModel, self).__init__(**kwargs) self.token_dim = 128 - self.mask_percent = 25 + self.mask_percent = mask_percent self.num_classes = num_classes self.shift = shift self.scale = scale - self.penalty = 5000 + self.penalty = tf.constant(penalty, dtype=tf.float32) self.loss_tracker = tf.keras.metrics.Mean() self.target_tracker = tf.keras.metrics.Mean() self.reg_tracker = tf.keras.metrics.Mean() @@ -93,15 +93,14 @@ def _compute_loss(self, y_true, outputs): # count mask count_mask = float_mask(counts) num_counts = tf.reduce_sum(count_mask, axis=-1, keepdims=True) - # count mse - count_loss = tf.math.square(counts - count_pred) - count_loss = tf.reduce_sum(count_loss * count_mask, axis=-1, keepdims=True) - count_loss = self.penalty * tf.reduce_mean(count_loss / num_counts) + count_loss = tf.math.square(counts - count_pred) * count_mask + count_loss = tf.reduce_sum(count_loss, axis=-1, keepdims=True) / num_counts + count_loss = self.penalty * tf.reduce_mean(count_loss) target_loss = self.loss(y_true, y_pred) - reg_loss = tf.reduce_sum(self.losses) - return target_loss + count_loss + reg_loss, target_loss, count_loss, reg_loss + reg_loss = tf.reduce_mean(self.losses) + return target_loss + count_loss, target_loss, count_loss, reg_loss def _compute_metric(self, y_true, outputs): _, _, y_pred, _ = outputs @@ -186,13 +185,17 @@ def call(self, inputs, training=False): count_mask = float_mask(extended_counts, dtype=self.compute_dtype) if self.trainable and training: random_mask = tf.random.uniform( - tf.shape(extended_counts), minval=0, maxval=1, dtype=self.compute_dtype + tf.shape(extended_counts), + minval=0, + maxval=100, + dtype=self.compute_dtype, ) random_mask = tf.cast( - tf.less_equal(random_mask, 0.75), dtype=self.compute_dtype + tf.greater_equal(random_mask, self.mask_percent), + dtype=self.compute_dtype, ) extended_counts = extended_counts * random_mask - # asv_embeddings = asv_embeddings * random_mask + asv_embeddings = asv_embeddings * random_mask # up project counts count_embeddings = self.count_encoder( diff --git a/aam/unifrac_data_utils.py b/aam/unifrac_data_utils.py index 084a780..5a0e8b6 100644 --- a/aam/unifrac_data_utils.py +++ b/aam/unifrac_data_utils.py @@ -12,6 +12,7 @@ def load_data( shuffle_samples=True, tree_path=None, max_token_per_sample=300, + batch_size=8, temp_table_path="temp_table.biom", ): def _get_unifrac_data(table_path, tree_path): @@ -29,6 +30,7 @@ def _get_table_data(table_path): def _preprocess_table(table_path): table = load_table(table_path) + table = table.remove_empty() table_data = table.matrix_data.tocoo() counts, (row, col) = table_data.data, table_data.coords @@ -97,7 +99,7 @@ def filter(samples, table_data, unifrac_data): ds = ds.cache() if shuffle_samples and not val: ds = ds.shuffle(shuffle_buf) - ds = ds.padded_batch(8) + ds = ds.padded_batch(batch_size) ds = ds.map(filter, num_parallel_calls=tf.data.AUTOTUNE) return ds diff --git a/aam/unifrac_model.py b/aam/unifrac_model.py index 5b58380..fc91e0f 100644 --- a/aam/unifrac_model.py +++ b/aam/unifrac_model.py @@ -156,7 +156,9 @@ def call(self, inputs, training=False): return asv_embeddings, sample_embeddings, nucleotides, inputs def build(self, input_shape=None): - super(UnifracModel, self).build(tf.TensorShape([None, None, 150])) + input = tf.keras.Input([None, self.max_bp]) + self.call(input) + super(UnifracModel, self).build(tf.TensorShape([None, None, self.max_bp])) def train_step(self, data): inputs, y = data @@ -179,7 +181,7 @@ def train_step(self, data): self.accuracy.update_state(self._compute_accuracy(tokens, logits)) return { "loss": self.loss_tracker.result(), - "mae": self.metric_traker.result(), + "mse": self.metric_traker.result(), "entropy": self.entropy.result(), "accuracy": self.accuracy.result(), } @@ -199,7 +201,7 @@ def test_step(self, data): self.accuracy.update_state(self._compute_accuracy(tokens, logits)) return { "loss": self.loss_tracker.result(), - "mae": self.metric_traker.result(), + "mse": self.metric_traker.result(), "entropy": self.entropy.result(), "accuracy": self.accuracy.result(), } diff --git a/pyproject.toml b/pyproject.toml index 209c288..a5535b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,10 +1,11 @@ [build-system] -requires = ["setuptools>=42", "wheel"] +requires = ["setuptools", "wheel", "setuptools_scm"] build-backend = "setuptools.build_meta" + [project] name = "aam" -requires-python = ">= 3.9" +requires-python = ">= 3.9, < 3.10" dynamic = ["version"] description = "Deep Learning Method for Microbial Sequencing Data" readme = "README.md" @@ -22,37 +23,42 @@ classifiers = [ dependencies = [ "numpy", "pandas", - "cython", "seaborn", "biom-format", - "scikit-bio", - "scikit-learn", - "scipy", + "scikit-bio >= 0.6", + "scikit-learn >= 1.3", + "scipy >= 1.13.0", "unifrac", - "click" + "click", + "tensorflow[and-cuda] >= 2.14.0, < 2.15", + "tf-models-official >= 2.14.2", + "tensorboard" + ] +keywords = ["aam"] [project.scripts] -aam-cli = "aam:attention_cli" +attention = "aam.attention_cli:main" [project.urls] "Bug Tracker" = "https://github.com/kwcantrell/attention-all-microbes/issues" "Source Code" = "https://github.com/kwcantrell/attention-all-microbes" + [tool.setuptools.packages.find] -where = ["aam"] +where = ["."] -[tool.setuptools] -package-dir = {"" = "aam"} [project.optional-dependencies] dev = ["pytest", "ruff"] docs = ["sphinx", "myst-parser"] + [tool.ruff] line-length = 128 ignore = ["F841", "F401"] + [tool.setuptools_scm] version_scheme = "guess-next-dev" local_scheme = "node-and-timestamp"