diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py index dce52b973..a3cc4a425 100644 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py @@ -15,7 +15,6 @@ import io import tarfile -from typing import List, Tuple from unittest import mock import pytest @@ -36,111 +35,18 @@ from bionemo.testing import megatron_parallel_state_utils -hf_model_tag = "facebook/esm2_t6_8M_UR50D" -nv_model_tag = "esm2/8m:2.0" -# hf_model_tag = "facebook/esm2_t33_650M_UR50D" -# nv_model_tag = "esm2/650m:2.0" - - -def reduce_hiddens(hiddens: Tensor, attention_mask: Tensor) -> Tensor: - """reduce last layer's hidden values to embeddings - - Args: - hiddens: [b, s, h] tensor of hidden values - attention_mask: [b, s] attention mask tensor - - Returns: - reduced embedding tensor [b, h] - """ - masks = torch.sum(attention_mask, dim=1) - embeddings = torch.zeros( - size=(hiddens.shape[0], hiddens.shape[2]), - dtype=torch.float32, - device=torch.cuda.current_device(), - ) - for i, (hidden, mask) in enumerate(zip(hiddens, masks)): - embeddings[i, :] = torch.mean(hidden[1 : mask - 1], dim=0) - return embeddings - - -@pytest.fixture(scope="module") -def esm2_config() -> ESM2Config: - with megatron_parallel_state_utils.distributed_model_parallel_state(): - yield ESM2Config() - - -@pytest.fixture(scope="module") -def esm2_config_w_ckpt() -> ESM2Config: - with megatron_parallel_state_utils.distributed_model_parallel_state(): - yield ESM2Config(initial_ckpt_path=load(nv_model_tag)) - - -@pytest.fixture(scope="module") -def esm2_model(esm2_config) -> ESM2Model: +def test_esm2_model_initialized(): with megatron_parallel_state_utils.distributed_model_parallel_state(): tokenizer = get_tokenizer() - model = esm2_config.configure_model(tokenizer) - yield model - - -@pytest.fixture(scope="module") -def sample_data() -> List[Tuple[str, str]]: - """Generates sample protein sequences for sanity checks, including mask tokens.""" - max_length = 1022 # The maximum length of the protein sequences to be considered. - sample_data = [ - ( - "protein1", - "MNGTEGPNFYVPFSNATGVVRSPFEYPQYYLAEPWQFSMLAAYMFLLIVLGFPINFLTLYVTVQHKKLRTPLNYILLNLAVADLFMVLGGFTSTLYTSLHGYFVFGPTGCNLEGFFATLGGEIALWSLVVLAIERYVVVCKPMSNFRFGENHAIMGVAFTWVMALACAAPPLAGWSRYIPEGLQCSCGIDYYTLKPEVNNESFVIYMFVVHFTIPMIIIFFCYGQLVFTVKEAAAQQQESATTQKAEKEVTRMVIIMVIAFLICWVPYASVAFYIFTHQGSNFGPIFMTIPAFFAKSAAIYNPVIYIMMNKQFRNCMLTTICCGKNPLGDDEASATVSKTETSQVAPA", - ), - ("protein2", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA"), - ( - "protein3", - "MKTVRQERLKSIRILERSKEPVSGAQLAEELSSRQVIVQDIAYLRSLGYNVATPRGYVLAGG", - ), - ( - "protein4", - "MKTVRQERLKSIRILERSKEPVSGAQLAEELSSRQVIVQDIAYLRSLGYNVATPRGYVLA", - ), - ] - # add another sample protein that uses the maximum length to test this edge case - sample_data.append(("protein5", (sample_data[0][1] * 3)[:max_length])) - yield sample_data - - -def _compute_loss(model, dataloader, vocab_size=None): - loss = 0 - n = 0 - limit_batches = 10 - for i, batch in enumerate(dataloader): - assert isinstance(batch, dict) - result = model(input_ids=batch["text"].cuda(), attention_mask=batch["attention_mask"].cuda()) - - # bionemo ESM2 vocab_size - if vocab_size is not None: - # token_logits is s,b and for simplicity here let's transpose to b,s. In general this reduces performance. - logits = result["token_logits"].transpose(0, 1).contiguous()[..., :vocab_size] - else: - logits = result.logits - - loss_mask = batch["loss_mask"].cuda() - target = batch["labels"].cuda() - - loss += torch.nn.functional.cross_entropy(logits[loss_mask].float(), target[loss_mask], reduction="sum") - n += loss_mask.sum() - - if limit_batches is not None and i + 1 >= limit_batches: - break - mean_loss: Tensor = loss / n - return mean_loss + config = ESM2Config() + model = config.configure_model(tokenizer) + assert isinstance(model, MegatronBioBertModel) + assert isinstance(model, ESM2Model) + assert isinstance(model.embedding, ESM2Embedding) -def test_esm2_model_initialized(esm2_model): - assert isinstance(esm2_model, MegatronBioBertModel) - assert isinstance(esm2_model, ESM2Model) - assert isinstance(esm2_model.embedding, ESM2Embedding) - -def test_esm2_nemo1_checkpoint(esm2_model): +def test_esm2_nemo1_checkpoint(): with tarfile.open(load("esm2/nv_650m:1.0"), "r") as ckpt, torch.no_grad(): ckpt_file = ckpt.extractfile("./model_weights.ckpt") @@ -149,10 +55,14 @@ def test_esm2_nemo1_checkpoint(esm2_model): # TODO: update Bionemo checkpoints old_state_dict.pop("model.language_model.rotary_pos_emb.inv_freq") - new_state_dict = esm2_model.state_dict_for_save_checkpoint() + with megatron_parallel_state_utils.distributed_model_parallel_state(): + tokenizer = get_tokenizer() + config = ESM2Config() + model = config.configure_model(tokenizer) + new_state_dict = model.state_dict_for_save_checkpoint() - # Set the new_model_prefix to "" since we are looking at the base megatron model and not the lightning module which stores a copy of - # this model into self.module + # Set the new_model_prefix to "" since we are looking at the base megatron model and not the lightning module + # which stores a copy of this model into self.module old_keys = { nemo1_to_nemo2_biobert_key_mapping(k, new_model_prefix="", te_mapping=True) for k in old_state_dict } @@ -180,7 +90,39 @@ def test_esm2_nemo1_checkpoint(esm2_model): assert not missing_old_keys, "There are keys in the old checkpoint that are missing from the new model." -def test_esm2_loss(esm2_config_w_ckpt, dummy_protein_dataset, dummy_parquet_train_val_inputs): +def _compute_loss(model, dataloader, vocab_size=None): + loss = 0 + n = 0 + limit_batches = 10 + for i, batch in enumerate(dataloader): + assert isinstance(batch, dict) + result = model(input_ids=batch["text"].cuda(), attention_mask=batch["attention_mask"].cuda()) + + # bionemo ESM2 vocab_size + if vocab_size is not None: + # token_logits is s,b and for simplicity here let's transpose to b,s. In general this reduces performance. + logits = result["token_logits"].transpose(0, 1).contiguous()[..., :vocab_size] + else: + logits = result.logits + + loss_mask = batch["loss_mask"].cuda() + target = batch["labels"].cuda() + + loss += torch.nn.functional.cross_entropy(logits[loss_mask].float(), target[loss_mask], reduction="sum") + n += loss_mask.sum() + + if limit_batches is not None and i + 1 >= limit_batches: + break + mean_loss: Tensor = loss / n + return mean_loss + + +def test_esm2_loss(dummy_protein_dataset, dummy_parquet_train_val_inputs): + hf_model_tag = "facebook/esm2_t6_8M_UR50D" + nv_model_tag = "esm2/8m:2.0" + # hf_model_tag = "facebook/esm2_t33_650M_UR50D" + # nv_model_tag = "esm2/650m:2.0" + train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs seed: int = 42 @@ -193,7 +135,7 @@ def test_esm2_loss(esm2_config_w_ckpt, dummy_protein_dataset, dummy_parquet_trai tokenizer = get_tokenizer() # ESM2 model initialized with params - model = esm2_config_w_ckpt.configure_model(tokenizer).cuda() + model = ESM2Config(initial_ckpt_path=str(load(nv_model_tag))).configure_model(tokenizer).cuda() # Initialize the data module. data_module = ESMDataModule(