Skip to content

Commit

Permalink
refactor test_model to use the new model comparison function
Browse files Browse the repository at this point in the history
Signed-off-by: Peter St. John <[email protected]>
  • Loading branch information
pstjohn committed Jan 21, 2025
1 parent c3848e7 commit 4407b39
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 182 deletions.
67 changes: 20 additions & 47 deletions sub-packages/bionemo-core/src/bionemo/core/data/resources/esm2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,32 @@
description: >
A pretrained 650M parameter ESM2 model. See https://ngc.nvidia.com/catalog/models/nvidia:clara:esm2nv650m.
- tag: nv_3b:2.1
ngc: "nvidia/clara/esm2nv3b:2.1"
- tag: 8m:2.0
ngc: nvidia/clara/esm2nv8m:2.0
ngc_registry: model
pbss: "s3://general-purpose/esm2/checkpoints/3b/esm2_3b_checkpoint.tar.gz"
sha256: a79327a4054bf8d1d7075e1b3c961dbc503da02d72ed15f707d9cbbd49d181b6 # pragma: allowlist secret
pbss: s3://general-purpose/esm2/checkpoints/converted/8m/esm2_hf_converted_8m_checkpoint.tar.gz
sha256: 2957b2c36d5978d0f595d6f1b72104b312621cf0329209086537b613c1c96d16 # pragma: allowlist secret
owner: Peter St John <[email protected]>
description: >
An ESM-2 3B model pre-trained on NVIDIA's train/test data split.
A NeMo2 compatible checkpoint converted from the huggingface facebook/esm2_t6_8M_UR50D model.
- tag: nv_650m:2.1
ngc: "nvidia/clara/esm2nv650m:2.1"
- tag: 650m:2.0
ngc: nvidia/clara/esm2nv650m:2.0
ngc_registry: model
pbss: "s3://general-purpose/esm2/checkpoints/650m/esm2_650m_checkpoint.tar.gz"
sha256: b83e9b5d62f1499b443817c5cd0facd3bdd4013a51a897e05e17228bf650befe # pragma: allowlist secret
owner: Peter St John <pstjohn@nvidia.com>
pbss: "s3://bionemo-ci/models/esm2_650M_nemo2.tar.gz"
sha256: 0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997 # pragma: allowlist secret
owner: Farhad Ramezanghorbani <farhadr@nvidia.com>
description: >
An ESM-2 650M model pre-trained on NVIDIA's train/test data split.
A NeMo2 compatible checkpoint converted from the huggingface facebook/esm2_t33_650M_UR50D model.
- tag: 3b:2.0
ngc: nvidia/clara/esm2nv3b:2.0
ngc_registry: model
pbss: "s3://bionemo-ci/models/esm2_3B_nemo2.tar.gz"
sha256: a2248cfed1ef39f83bd32a0e08b84c0a8f39325d383e2d92767022ff7f5260ed # pragma: allowlist secret
owner: Farhad Ramezanghorbani <[email protected]>
description: >
A NeMo2 compatible checkpoint converted from the huggingface facebook/esm2_t36_3B_UR50D model.
# - tag: nv_8m:2.1
# ngc: "nvidia/clara/esm2nv8m:2.1"
Expand All @@ -34,15 +43,6 @@
# description: >
# An ESM-2 8M model pre-trained on NVIDIA's train/test data split.

- tag: 8m:2.0
ngc: "nvidia/clara/esm2nv8m:2.0"
ngc_registry: model
pbss: "s3://general-purpose/esm2/checkpoints/converted/8m/esm2_hf_converted_8m_checkpoint.tar.gz"
sha256: 2957b2c36d5978d0f595d6f1b72104b312621cf0329209086537b613c1c96d16 # pragma: allowlist secret
owner: Peter St John <[email protected]>
description: >
The original 8M parameter ESM2 model weights converted to the NeMo2 checkpoint format.
- tag: nv_650m:2.1
ngc: "nvidia/clara/esm2nv650m:2.1"
ngc_registry: model
Expand All @@ -61,33 +61,6 @@
description: >
An ESM-2 3B model pre-trained on NVIDIA's train/test data split.
- tag: 8m:2.0
ngc: null
ngc_registry: model
pbss: s3://general-purpose/esm2/checkpoints/converted/8m/esm2_hf_converted_8m_checkpoint.tar.gz
sha256: 2957b2c36d5978d0f595d6f1b72104b312621cf0329209086537b613c1c96d16 # pragma: allowlist secret
owner: Peter St John <[email protected]>
description: >
A NeMo2 compatible checkpoint converted from the huggingface facebook/esm2_t6_8M_UR50D model.
- tag: 650m:2.0
ngc: nvidia/clara/esm2nv650m:2.0
ngc_registry: model
pbss: "s3://bionemo-ci/models/esm2_650M_nemo2.tar.gz"
sha256: 0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997 # pragma: allowlist secret
owner: Farhad Ramezanghorbani <[email protected]>
description: >
A NeMo2 compatible checkpoint converted from the huggingface facebook/esm2_t33_650M_UR50D model.
- tag: 3b:2.0
ngc: nvidia/clara/esm2nv3b:2.0
ngc_registry: model
pbss: "s3://bionemo-ci/models/esm2_3B_nemo2.tar.gz"
sha256: a2248cfed1ef39f83bd32a0e08b84c0a8f39325d383e2d92767022ff7f5260ed # pragma: allowlist secret
owner: Farhad Ramezanghorbani <[email protected]>
description: >
A NeMo2 compatible checkpoint converted from the huggingface facebook/esm2_t36_3B_UR50D model.
- tag: fulldata_esm2_pretrain:2.0
ngc: nvidia/clara/esm2_pretrain_nemo2_data:1.0
ngc_registry: resource
Expand Down
14 changes: 14 additions & 0 deletions sub-packages/bionemo-esm2/src/bionemo/esm2/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
100 changes: 100 additions & 0 deletions sub-packages/bionemo-esm2/src/bionemo/esm2/testing/compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import gc
from pathlib import Path

import torch
from transformers import AutoModelForMaskedLM

from bionemo.core.utils.dtypes import PrecisionTypes, get_autocast_dtype
from bionemo.esm2.data.tokenizer import get_tokenizer
from bionemo.esm2.model.model import ESM2Config
from bionemo.testing import megatron_parallel_state_utils


def assert_model_equivalence(
ckpt_path: Path | str,
model_tag: str,
precision: PrecisionTypes = "fp32",
rtol: float | None = None,
atol: float | None = None,
) -> None:
"""Testing utility to compare the outputs of a NeMo2 checkpoint to the original HuggingFace model weights.
Compares the cosine similarity of the logit and hidden state outputs of a NeMo2 model checkpoint to the outputs of
the corresponding HuggingFace model.
Args:
ckpt_path: A path to a NeMo2 checkpoint for an ESM-2 model.
model_tag: The HuggingFace model tag for the model to compare against.
precision: The precision type to use for the comparison. Defaults to "fp32".
rtol: The relative tolerance to use for the comparison. Defaults to None, which chooses the tolerance based on
the precision.
atol: The absolute tolerance to use for the comparison. Defaults to None, which chooses the tolerance based on
the precision.
"""
tokenizer = get_tokenizer()

test_proteins = [
"MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA",
"MKTVRQERLKSI<mask>RILERSKEPVSGAQLAEELS<mask>SRQVIVQDIAYLRSLGYN<mask>VATPRGYVLAGG",
]
tokens = tokenizer(test_proteins, return_tensors="pt", padding=True, truncation=True).to("cuda")
input_ids = tokens["input_ids"]
attention_mask = tokens["attention_mask"]

with torch.inference_mode(), megatron_parallel_state_utils.distributed_model_parallel_state():
nemo_model = (
ESM2Config(
initial_ckpt_path=str(ckpt_path),
include_embeddings=True,
include_hiddens=True,
params_dtype=get_autocast_dtype(precision),
pipeline_dtype=get_autocast_dtype(precision),
autocast_dtype=get_autocast_dtype(precision),
bf16=True,
) # setting this speeds things up a lot)
.configure_model(tokenizer)
.to("cuda")
.eval()
)

nemo_output = nemo_model(input_ids, attention_mask)
nemo_logits = nemo_output["token_logits"].transpose(0, 1).contiguous()[..., : tokenizer.vocab_size]
nemo_hidden_state = nemo_output["hidden_states"]

del nemo_model
gc.collect()
torch.cuda.empty_cache()

hf_model = AutoModelForMaskedLM.from_pretrained(model_tag, torch_dtype=get_autocast_dtype(precision)).cuda()
hf_output_all = hf_model(input_ids, attention_mask, output_hidden_states=True)
hf_hidden_state = hf_output_all.hidden_states[-1]

# Rather than directly comparing the logit or hidden state tensors, we compare their cosine similarity. These
# should be essentially 1 if the outputs are equivalent, but is less sensitive to small numerical differences.
# We don't care about the padding tokens, so we only compare the non-padding tokens.
logit_similarity = torch.nn.functional.cosine_similarity(nemo_logits, hf_output_all.logits, dim=2)
logit_similarity = logit_similarity[attention_mask == 1]

hidden_state_similarity = torch.nn.functional.cosine_similarity(nemo_hidden_state, hf_hidden_state, dim=2)
hidden_state_similarity = hidden_state_similarity[attention_mask == 1]

torch.testing.assert_close(logit_similarity, torch.ones_like(logit_similarity), rtol=rtol, atol=atol)
torch.testing.assert_close(
hidden_state_similarity, torch.ones_like(hidden_state_similarity), rtol=rtol, atol=atol
)
86 changes: 16 additions & 70 deletions sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,87 +14,33 @@
# limitations under the License.


from pathlib import Path

import pytest
import torch
from nemo.lightning import io
from transformers import AutoModelForMaskedLM

from bionemo.core.data.load import load
from bionemo.core.utils.dtypes import get_autocast_dtype
from bionemo.esm2.data.tokenizer import get_tokenizer
from bionemo.esm2.model.convert import HFESM2Importer # noqa: F401
from bionemo.esm2.model.model import ESM2Config
from bionemo.esm2.testing.compare import assert_model_equivalence
from bionemo.llm.model.biobert.lightning import biobert_lightning_module
from bionemo.testing import megatron_parallel_state_utils


tokenizer = get_tokenizer()


def run_esm2_ckpt_conversion_hf_to_nemo(ckpt_path: Path, model_tag: str):
from bionemo.esm2.model.convert import HFESM2Importer # noqa: F401

module = biobert_lightning_module(config=ESM2Config())
io.import_ckpt(module, f"hf://{model_tag}", ckpt_path / "nemo_checkpoint")
return ckpt_path / "nemo_checkpoint"


def get_input_tokens():
test_proteins = [
"MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA",
"MKTVRQERLKSI<mask>RILERSKEPVSGAQLAEELS<mask>SRQVIVQDIAYLRSLGYN<mask>VATPRGYVLAGG",
]
tokens = tokenizer(test_proteins, return_tensors="pt", padding=True, truncation=True).to("cuda")
input_ids = tokens["input_ids"]
attention_mask = tokens["attention_mask"]
return input_ids, attention_mask


def assert_model_equivalence(ckpt_path, model_tag):
hf_model = AutoModelForMaskedLM.from_pretrained(model_tag, torch_dtype=get_autocast_dtype(32)).cuda()

with torch.inference_mode(), megatron_parallel_state_utils.distributed_model_parallel_state():
nemo_model = (
ESM2Config(initial_ckpt_path=ckpt_path, include_embeddings=True, include_hiddens=True)
.configure_model(tokenizer)
.to("cuda")
.eval()
)

input_ids, attention_mask = get_input_tokens()
hf_output_all = hf_model(input_ids, attention_mask, output_hidden_states=True)

nemo_output = nemo_model(input_ids, attention_mask)
nemo_logits = nemo_output["token_logits"].transpose(0, 1).contiguous()[..., : tokenizer.vocab_size]

nemo_hidden_state = nemo_output["hidden_states"]
hf_hidden_state = hf_output_all.hidden_states[-1]

# Rather than directly comparing the logit or hidden state tensors, we compare their cosine similarity. These
# should be essentially 1 if the outputs are equivalent, but is less sensitive to small numerical differences.
# We don't care about the padding tokens, so we only compare the non-padding tokens.
logit_similarity = torch.nn.functional.cosine_similarity(nemo_logits, hf_output_all.logits, dim=2)
logit_similarity = logit_similarity[attention_mask == 1]

hidden_state_similarity = torch.nn.functional.cosine_similarity(nemo_hidden_state, hf_hidden_state, dim=2)
hidden_state_similarity = hidden_state_similarity[attention_mask == 1]

torch.testing.assert_close(logit_similarity, torch.ones_like(logit_similarity))
torch.testing.assert_close(hidden_state_similarity, torch.ones_like(hidden_state_similarity))


@pytest.mark.xfail(
reason="This test is failing due to a bug in nemo global state when run in the same process as previous checkpoint"
"save/load scripts."
)
def test_nemo2_conversion_golden_values(tmp_path):
def test_nemo2_conversion_equivalent_8m(tmp_path):
model_tag = "facebook/esm2_t6_8M_UR50D"
ckpt_path = run_esm2_ckpt_conversion_hf_to_nemo(tmp_path, model_tag)
assert_model_equivalence(ckpt_path, model_tag)
module = biobert_lightning_module(config=ESM2Config())
io.import_ckpt(module, f"hf://{model_tag}", tmp_path / "nemo_checkpoint")
assert_model_equivalence(tmp_path / "nemo_checkpoint", model_tag)


def test_pre_converted_checkpoint_golden_values():
model_tag = "facebook/esm2_t6_8M_UR50D"
ckpt_path = load("esm2/8m:2.0", source="pbss")
assert_model_equivalence(ckpt_path, model_tag)
@pytest.mark.xfail(
reason="This test is failing due to a bug in nemo global state when run in the same process as previous checkpoint"
"save/load scripts."
)
@pytest.mark.slow
def test_nemo2_conversion_equivalent_650m(tmp_path):
model_tag = "facebook/esm2_t33_650M_UR50D"
module = biobert_lightning_module(config=ESM2Config())
io.import_ckpt(module, f"hf://{model_tag}", tmp_path / "nemo_checkpoint")
assert_model_equivalence(tmp_path / "nemo_checkpoint", model_tag)
Loading

0 comments on commit 4407b39

Please sign in to comment.