Skip to content

Commit

Permalink
add esm2 checkpoint to pbss, refactor tests for xfail due to bug
Browse files Browse the repository at this point in the history
Signed-off-by: Peter St. John <[email protected]>
  • Loading branch information
pstjohn committed Jan 14, 2025
1 parent 800066e commit 1e42a13
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 31 deletions.
41 changes: 25 additions & 16 deletions sub-packages/bionemo-core/src/bionemo/core/data/resources/esm2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
description: >
A pretrained 650M parameter ESM2 model. See https://ngc.nvidia.com/catalog/models/nvidia:clara:esm2nv650m.
- tag: nv_3b:2.1
ngc: "nvidia/clara/esm2nv3b:2.1"
- tag: nv_8m:2.0
ngc: "nvidia/clara/esm2nv8m:2.0"
ngc_registry: model
pbss: "s3://general-purpose/esm2/checkpoints/3b/esm2_3b_checkpoint.tar.gz"
sha256: a79327a4054bf8d1d7075e1b3c961dbc503da02d72ed15f707d9cbbd49d181b6 # pragma: allowlist secret
pbss: "s3://general-purpose/esm2/checkpoints/8m/esm2_8m_checkpoint.tar.gz"
sha256: b4ea4d52eea8a25d2c2838617ff678f0da22d384cee195b0c192686816078dcd # pragma: allowlist secret
owner: Peter St John <[email protected]>
description: >
An ESM-2 3B model pre-trained on NVIDIA's train/test data split.
An ESM-2 8M model pre-trained on NVIDIA's train/test data split.
- tag: nv_650m:2.1
ngc: "nvidia/clara/esm2nv650m:2.1"
Expand All @@ -25,14 +25,23 @@
description: >
An ESM-2 650M model pre-trained on NVIDIA's train/test data split.
- tag: nv_8m:2.0
ngc: "nvidia/clara/esm2nv8m:2.0"
- tag: nv_3b:2.1
ngc: "nvidia/clara/esm2nv3b:2.1"
ngc_registry: model
pbss: "s3://general-purpose/esm2/checkpoints/8m/esm2_8m_checkpoint.tar.gz"
sha256: b4ea4d52eea8a25d2c2838617ff678f0da22d384cee195b0c192686816078dcd # pragma: allowlist secret
pbss: "s3://general-purpose/esm2/checkpoints/3b/esm2_3b_checkpoint.tar.gz"
sha256: a79327a4054bf8d1d7075e1b3c961dbc503da02d72ed15f707d9cbbd49d181b6 # pragma: allowlist secret
owner: Peter St John <[email protected]>
description: >
An ESM-2 8M model pre-trained on NVIDIA's train/test data split.
An ESM-2 3B model pre-trained on NVIDIA's train/test data split.
- tag: 8m:2.0
ngc: null
ngc_registry: model
pbss: s3://general-purpose/esm2/checkpoints/converted/8m/esm2_hf_converted_8m_checkpoint.tar.gz
sha256: 2957b2c36d5978d0f595d6f1b72104b312621cf0329209086537b613c1c96d16 # pragma: allowlist secret
owner: Peter St John <[email protected]>
description: >
A NeMo2 compatible checkpoint converted from the huggingface facebook/esm2_t6_8M_UR50D model.
- tag: 650m:2.0
ngc: nvidia/clara/esm2nv650m:2.0
Expand All @@ -41,7 +50,7 @@
sha256: 0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997 # pragma: allowlist secret
owner: Farhad Ramezanghorbani <[email protected]>
description: >
A pretrained 650M parameter ESM2 model. See https://ngc.nvidia.com/catalog/models/nvidia:clara:esm2nv650m.
A NeMo2 compatible checkpoint converted from the huggingface facebook/esm2_t33_650M_UR50D model.
- tag: 3b:2.0
ngc: nvidia/clara/esm2nv3b:2.0
Expand All @@ -50,28 +59,28 @@
sha256: a2248cfed1ef39f83bd32a0e08b84c0a8f39325d383e2d92767022ff7f5260ed # pragma: allowlist secret
owner: Farhad Ramezanghorbani <[email protected]>
description: >
A pretrained 3B parameter ESM2 model. See https://ngc.nvidia.com/catalog/models/nvidia:clara:esm2nv3b.
A NeMo2 compatible checkpoint converted from the huggingface facebook/esm2_t36_3B_UR50D model.
- tag: fulldata_esm2_pretrain:2.0
ngc: nvidia/clara/esm2_pretrain_nemo2_data:1.0
ngc_registry: resource
pbss: "s3://general-purpose/esm2/pretrain/2024_03.tar.gz"
sha256: 404d0ad8de58fa8aae96f8d9f54263a088bc7e4f7d668215afbe04c28416151b # pragma: allowlist secret
sha256: 404d0ad8de58fa8aae96f8d9f54263a088bc7e4f7d668215afbe04c28416151b # pragma: allowlist secret
owner: Peter St John <[email protected]>
description: Full data for ESM2 pretraining.

- tag: testdata_esm2_pretrain:2.0
ngc: nvidia/clara/esm2_pretrain_nemo2_testdata:1.0
ngc_registry: resource
pbss: "s3://general-purpose/esm2/pretrain/2024_03_sanity.tar.gz"
sha256: 006911f92bbc0ded7ea302bbdbfab4c694b409e699c32fd49de1c527a99dba3e # pragma: allowlist secret
sha256: 006911f92bbc0ded7ea302bbdbfab4c694b409e699c32fd49de1c527a99dba3e # pragma: allowlist secret
owner: Peter St John <[email protected]>
description: Test data for ESM2 pretraining.

- tag: esm2_inference_testdata:2.0
ngc: nvidia/clara/esm2_inference_testdata:2.0 # TODO: upload to NGC
ngc: nvidia/clara/esm2_inference_testdata:2.0 # TODO: upload to NGC
ngc_registry: resource
pbss: "s3://bionemo-ci/test_data/esm2/artificial_protein_sequences.csv"
sha256: 14ae3acfbf82218bc9e3e53d21a5b0594ba7c0369e169c9f1034e3fe4378d175 # pragma: allowlist secret
sha256: 14ae3acfbf82218bc9e3e53d21a5b0594ba7c0369e169c9f1034e3fe4378d175 # pragma: allowlist secret
owner: Farhad Ramezanghorbani <[email protected]>
description: Test data for ESM2 inference.
48 changes: 33 additions & 15 deletions sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,53 +14,55 @@
# limitations under the License.


from pathlib import Path

import pytest
import torch
from nemo.lightning import io
from transformers import AutoModelForMaskedLM

from bionemo.core.data.load import load
from bionemo.core.utils.dtypes import get_autocast_dtype
from bionemo.esm2.data.tokenizer import get_tokenizer
from bionemo.esm2.model.model import ESM2Config
from bionemo.llm.model.biobert.lightning import biobert_lightning_module
from bionemo.testing import megatron_parallel_state_utils


def test_convert_esm2_hf_to_nemo(tmp_path):
tokenizer = get_tokenizer()


def run_esm2_ckpt_conversion_hf_to_nemo(ckpt_path: Path, model_tag: str):
from bionemo.esm2.model.convert import HFESM2Importer # noqa: F401

model_tag = "facebook/esm2_t6_8M_UR50D"
module = biobert_lightning_module(config=ESM2Config())
io.import_ckpt(module, f"hf://{model_tag}", tmp_path / "output_ckpt")
tokenizer = get_tokenizer()
io.import_ckpt(module, f"hf://{model_tag}", ckpt_path / "nemo_checkpoint")
return ckpt_path / "nemo_checkpoint"


def get_input_tokens():
test_proteins = [
"MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA",
"MKTVRQERLKSI<mask>RILERSKEPVSGAQLAEELS<mask>SRQVIVQDIAYLRSLGYN<mask>VATPRGYVLAGG",
]

tokens = tokenizer(test_proteins, return_tensors="pt", padding=True, truncation=True).to("cuda")
input_ids = tokens["input_ids"]
attention_mask = tokens["attention_mask"]
return input_ids, attention_mask


# HF 650M model
def assert_model_equivalence(ckpt_path, model_tag):
hf_model = AutoModelForMaskedLM.from_pretrained(model_tag, torch_dtype=get_autocast_dtype(32)).cuda()

with torch.inference_mode(), megatron_parallel_state_utils.distributed_model_parallel_state():
nemo_model = (
ESM2Config(initial_ckpt_path=tmp_path / "output_ckpt", include_embeddings=True, include_hiddens=True)
ESM2Config(initial_ckpt_path=ckpt_path, include_embeddings=True, include_hiddens=True)
.configure_model(tokenizer)
.to("cuda")
.eval()
)

for i in range(len(hf_model.esm.encoder.layer)):
torch.testing.assert_close(
hf_model.esm.encoder.layer[i].attention.self.rotary_embeddings.inv_freq,
nemo_model.rotary_pos_emb.inv_freq,
atol=1e-4,
rtol=1e-6,
)

input_ids, attention_mask = get_input_tokens()
hf_output_all = hf_model(input_ids, attention_mask, output_hidden_states=True)

nemo_output = nemo_model(input_ids, attention_mask)
Expand All @@ -80,3 +82,19 @@ def test_convert_esm2_hf_to_nemo(tmp_path):

torch.testing.assert_close(logit_similarity, torch.ones_like(logit_similarity))
torch.testing.assert_close(hidden_state_similarity, torch.ones_like(hidden_state_similarity))


@pytest.mark.xfail(
reason="This test is failing due to a bug in nemo global state when run in the same process as previous checkpoint"
"save/load scripts."
)
def test_nemo2_conversion_golden_values(tmp_path):
model_tag = "facebook/esm2_t6_8M_UR50D"
ckpt_path = run_esm2_ckpt_conversion_hf_to_nemo(tmp_path, model_tag)
assert_model_equivalence(ckpt_path, model_tag)


def test_pre_converted_checkpoint_golden_values():
model_tag = "facebook/esm2_t6_8M_UR50D"
ckpt_path = load("esm2/8m:2.0", source="pbss")
assert_model_equivalence(ckpt_path, model_tag)

0 comments on commit 1e42a13

Please sign in to comment.