Skip to content

Commit

Permalink
refactor test_model to remove fixtures
Browse files Browse the repository at this point in the history
Signed-off-by: Peter St. John <[email protected]>
  • Loading branch information
pstjohn committed Jan 21, 2025
1 parent 3936d53 commit cc62c2b
Showing 1 changed file with 48 additions and 106 deletions.
154 changes: 48 additions & 106 deletions sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import io
import tarfile
from typing import List, Tuple
from unittest import mock

import pytest
Expand All @@ -36,111 +35,18 @@
from bionemo.testing import megatron_parallel_state_utils


hf_model_tag = "facebook/esm2_t6_8M_UR50D"
nv_model_tag = "esm2/8m:2.0"
# hf_model_tag = "facebook/esm2_t33_650M_UR50D"
# nv_model_tag = "esm2/650m:2.0"


def reduce_hiddens(hiddens: Tensor, attention_mask: Tensor) -> Tensor:
"""reduce last layer's hidden values to embeddings
Args:
hiddens: [b, s, h] tensor of hidden values
attention_mask: [b, s] attention mask tensor
Returns:
reduced embedding tensor [b, h]
"""
masks = torch.sum(attention_mask, dim=1)
embeddings = torch.zeros(
size=(hiddens.shape[0], hiddens.shape[2]),
dtype=torch.float32,
device=torch.cuda.current_device(),
)
for i, (hidden, mask) in enumerate(zip(hiddens, masks)):
embeddings[i, :] = torch.mean(hidden[1 : mask - 1], dim=0)
return embeddings


@pytest.fixture(scope="module")
def esm2_config() -> ESM2Config:
with megatron_parallel_state_utils.distributed_model_parallel_state():
yield ESM2Config()


@pytest.fixture(scope="module")
def esm2_config_w_ckpt() -> ESM2Config:
with megatron_parallel_state_utils.distributed_model_parallel_state():
yield ESM2Config(initial_ckpt_path=load(nv_model_tag))


@pytest.fixture(scope="module")
def esm2_model(esm2_config) -> ESM2Model:
def test_esm2_model_initialized():
with megatron_parallel_state_utils.distributed_model_parallel_state():
tokenizer = get_tokenizer()
model = esm2_config.configure_model(tokenizer)
yield model


@pytest.fixture(scope="module")
def sample_data() -> List[Tuple[str, str]]:
"""Generates sample protein sequences for sanity checks, including mask tokens."""
max_length = 1022 # The maximum length of the protein sequences to be considered.
sample_data = [
(
"protein1",
"MNGTEGPNFYVPFSNATGVVRSPFEYPQYYLAEPWQFSMLAAYMFLLIVLGFPINFLTLYVTVQHKKLRTPLNYILLNLAVADLFMVLGGFTSTLYTSLHGYFVFGPTGCNLEGFFATLGGEIALWSLVVLAIERYVVVCKPMSNFRFGENHAIMGVAFTWVMALACAAPPLAGWSRYIPEGLQCSCGIDYYTLKPEVNNESFVIYMFVVHFTIPMIIIFFCYGQLVFTVKEAAAQQQESATTQKAEKEVTRMVIIMVIAFLICWVPYASVAFYIFTHQGSNFGPIFMTIPAFFAKSAAIYNPVIYIMMNKQFRNCMLTTICCGKNPLGDDEASATVSKTETSQVAPA",
),
("protein2", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA"),
(
"protein3",
"MKTVRQERLKSI<mask>RILERSKEPVSGAQLAEELS<mask>SRQVIVQDIAYLRSLGYN<mask>VATPRGYVLAGG",
),
(
"protein4",
"MKTVRQERLKSI<mask>RILERSKEPVSGAQLAEELS<mask>SRQVIVQDIAYLRSLGYN<mask>VATPRGYVLA",
),
]
# add another sample protein that uses the maximum length to test this edge case
sample_data.append(("protein5", (sample_data[0][1] * 3)[:max_length]))
yield sample_data


def _compute_loss(model, dataloader, vocab_size=None):
loss = 0
n = 0
limit_batches = 10
for i, batch in enumerate(dataloader):
assert isinstance(batch, dict)
result = model(input_ids=batch["text"].cuda(), attention_mask=batch["attention_mask"].cuda())

# bionemo ESM2 vocab_size
if vocab_size is not None:
# token_logits is s,b and for simplicity here let's transpose to b,s. In general this reduces performance.
logits = result["token_logits"].transpose(0, 1).contiguous()[..., :vocab_size]
else:
logits = result.logits

loss_mask = batch["loss_mask"].cuda()
target = batch["labels"].cuda()

loss += torch.nn.functional.cross_entropy(logits[loss_mask].float(), target[loss_mask], reduction="sum")
n += loss_mask.sum()

if limit_batches is not None and i + 1 >= limit_batches:
break
mean_loss: Tensor = loss / n
return mean_loss
config = ESM2Config()
model = config.configure_model(tokenizer)

assert isinstance(model, MegatronBioBertModel)
assert isinstance(model, ESM2Model)
assert isinstance(model.embedding, ESM2Embedding)

def test_esm2_model_initialized(esm2_model):
assert isinstance(esm2_model, MegatronBioBertModel)
assert isinstance(esm2_model, ESM2Model)
assert isinstance(esm2_model.embedding, ESM2Embedding)


def test_esm2_nemo1_checkpoint(esm2_model):
def test_esm2_nemo1_checkpoint():
with tarfile.open(load("esm2/nv_650m:1.0"), "r") as ckpt, torch.no_grad():
ckpt_file = ckpt.extractfile("./model_weights.ckpt")

Expand All @@ -149,10 +55,14 @@ def test_esm2_nemo1_checkpoint(esm2_model):
# TODO: update Bionemo checkpoints
old_state_dict.pop("model.language_model.rotary_pos_emb.inv_freq")

new_state_dict = esm2_model.state_dict_for_save_checkpoint()
with megatron_parallel_state_utils.distributed_model_parallel_state():
tokenizer = get_tokenizer()
config = ESM2Config()
model = config.configure_model(tokenizer)
new_state_dict = model.state_dict_for_save_checkpoint()

# Set the new_model_prefix to "" since we are looking at the base megatron model and not the lightning module which stores a copy of
# this model into self.module
# Set the new_model_prefix to "" since we are looking at the base megatron model and not the lightning module
# which stores a copy of this model into self.module
old_keys = {
nemo1_to_nemo2_biobert_key_mapping(k, new_model_prefix="", te_mapping=True) for k in old_state_dict
}
Expand Down Expand Up @@ -180,7 +90,39 @@ def test_esm2_nemo1_checkpoint(esm2_model):
assert not missing_old_keys, "There are keys in the old checkpoint that are missing from the new model."


def test_esm2_loss(esm2_config_w_ckpt, dummy_protein_dataset, dummy_parquet_train_val_inputs):
def _compute_loss(model, dataloader, vocab_size=None):
loss = 0
n = 0
limit_batches = 10
for i, batch in enumerate(dataloader):
assert isinstance(batch, dict)
result = model(input_ids=batch["text"].cuda(), attention_mask=batch["attention_mask"].cuda())

# bionemo ESM2 vocab_size
if vocab_size is not None:
# token_logits is s,b and for simplicity here let's transpose to b,s. In general this reduces performance.
logits = result["token_logits"].transpose(0, 1).contiguous()[..., :vocab_size]
else:
logits = result.logits

loss_mask = batch["loss_mask"].cuda()
target = batch["labels"].cuda()

loss += torch.nn.functional.cross_entropy(logits[loss_mask].float(), target[loss_mask], reduction="sum")
n += loss_mask.sum()

if limit_batches is not None and i + 1 >= limit_batches:
break
mean_loss: Tensor = loss / n
return mean_loss


def test_esm2_loss(dummy_protein_dataset, dummy_parquet_train_val_inputs):
hf_model_tag = "facebook/esm2_t6_8M_UR50D"
nv_model_tag = "esm2/8m:2.0"
# hf_model_tag = "facebook/esm2_t33_650M_UR50D"
# nv_model_tag = "esm2/650m:2.0"

train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs

seed: int = 42
Expand All @@ -193,7 +135,7 @@ def test_esm2_loss(esm2_config_w_ckpt, dummy_protein_dataset, dummy_parquet_trai
tokenizer = get_tokenizer()

# ESM2 model initialized with params
model = esm2_config_w_ckpt.configure_model(tokenizer).cuda()
model = ESM2Config(initial_ckpt_path=str(load(nv_model_tag))).configure_model(tokenizer).cuda()

# Initialize the data module.
data_module = ESMDataModule(
Expand Down

0 comments on commit cc62c2b

Please sign in to comment.