From d616e0f4f95d0f7d1abc97179077213d6a1d5f15 Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Tue, 21 Jan 2025 00:16:43 +0000 Subject: [PATCH 01/10] hotfix for some failing python tests due to NGC files being moved around Signed-off-by: Peter St. John --- .../bionemo-core/src/bionemo/core/data/resources/esm2.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sub-packages/bionemo-core/src/bionemo/core/data/resources/esm2.yaml b/sub-packages/bionemo-core/src/bionemo/core/data/resources/esm2.yaml index d7749aa78..58d0ef0a6 100644 --- a/sub-packages/bionemo-core/src/bionemo/core/data/resources/esm2.yaml +++ b/sub-packages/bionemo-core/src/bionemo/core/data/resources/esm2.yaml @@ -34,7 +34,7 @@ # description: > # An ESM-2 8M model pre-trained on NVIDIA's train/test data split. -- tag: 8m:2.0 +- tag: nv_8m:2.0 ngc: "nvidia/clara/esm2nv8m:2.0" ngc_registry: model pbss: "s3://general-purpose/esm2/checkpoints/converted/8m/esm2_hf_converted_8m_checkpoint.tar.gz" From 70419e5d6ce273ef3a018c8f81c07aef1ac6620c Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Tue, 21 Jan 2025 15:40:16 +0000 Subject: [PATCH 02/10] rename esm 8m checkpoint and move load calls Signed-off-by: Peter St. John --- .../bionemo-core/src/bionemo/core/data/resources/esm2.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sub-packages/bionemo-core/src/bionemo/core/data/resources/esm2.yaml b/sub-packages/bionemo-core/src/bionemo/core/data/resources/esm2.yaml index 58d0ef0a6..d7749aa78 100644 --- a/sub-packages/bionemo-core/src/bionemo/core/data/resources/esm2.yaml +++ b/sub-packages/bionemo-core/src/bionemo/core/data/resources/esm2.yaml @@ -34,7 +34,7 @@ # description: > # An ESM-2 8M model pre-trained on NVIDIA's train/test data split. -- tag: nv_8m:2.0 +- tag: 8m:2.0 ngc: "nvidia/clara/esm2nv8m:2.0" ngc_registry: model pbss: "s3://general-purpose/esm2/checkpoints/converted/8m/esm2_hf_converted_8m_checkpoint.tar.gz" From 20094195ac0540ee98087e2f6f2d1dc158674a7b Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Tue, 21 Jan 2025 16:27:26 +0000 Subject: [PATCH 03/10] temp commit Signed-off-by: Peter St. John --- .../tests/bionemo/esm2/model/test_model.py | 56 +++++++++++-------- 1 file changed, 34 insertions(+), 22 deletions(-) diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py index 7d0d20b46..10595c049 100644 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py @@ -24,7 +24,7 @@ import torch from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer from torch import Tensor -from transformers import EsmForMaskedLM +from transformers import AutoModelForMaskedLM from bionemo.core.data.load import load from bionemo.core.utils.dtypes import get_autocast_dtype @@ -38,6 +38,12 @@ from bionemo.testing import megatron_parallel_state_utils +hf_model_tag = "facebook/esm2_t6_8M_UR50D" +nv_model_tag = "esm2/8m:2.0" +# hf_model_tag = "facebook/esm2_t33_650M_UR50D" +# nv_model_tag = "esm2/650m:2.0" + + def reduce_hiddens(hiddens: Tensor, attention_mask: Tensor) -> Tensor: """reduce last layer's hidden values to embeddings @@ -66,9 +72,9 @@ def esm2_config() -> ESM2Config: @pytest.fixture(scope="module") -def esm2_650M_config_w_ckpt() -> ESM2Config: +def esm2_config_w_ckpt() -> ESM2Config: with megatron_parallel_state_utils.distributed_model_parallel_state(): - yield ESM2Config(initial_ckpt_path=load("esm2/650m:2.0")) + yield ESM2Config(initial_ckpt_path=load(nv_model_tag)) @pytest.fixture(scope="module") @@ -136,7 +142,7 @@ def test_esm2_model_initialized(esm2_model): assert isinstance(esm2_model.embedding, ESM2Embedding) -def test_esm2_650m_checkpoint(esm2_model): +def test_esm2_nemo1_checkpoint(esm2_model): with tarfile.open(load("esm2/nv_650m:1.0"), "r") as ckpt, torch.no_grad(): ckpt_file = ckpt.extractfile("./model_weights.ckpt") @@ -176,20 +182,19 @@ def test_esm2_650m_checkpoint(esm2_model): assert not missing_old_keys, "There are keys in the old checkpoint that are missing from the new model." -def test_esm2_golden_values(esm2_650M_config_w_ckpt, sample_data): +def test_esm2_golden_values(esm2_config_w_ckpt, sample_data): tokenizer = AutoTokenizer(pretrained_model_name="facebook/esm2_t33_650M_UR50D") tokens = tokenizer.tokenizer([row[1] for row in sample_data], return_tensors="pt", padding=True).to("cuda") input_ids = tokens["input_ids"] attention_mask = tokens["attention_mask"] - # HF 650M model - hf_model = EsmForMaskedLM.from_pretrained( - "facebook/esm2_t33_650M_UR50D", torch_dtype=get_autocast_dtype(32) - ).cuda() + # HF model + hf_model = AutoModelForMaskedLM.from_pretrained(hf_model_tag, torch_dtype=get_autocast_dtype(32)).cuda() with torch.no_grad(): hf_output_all = hf_model(input_ids, attention_mask, output_hidden_states=True) hf_logits = hf_output_all.logits * attention_mask.unsqueeze(-1) + hf_hiddens = hf_output_all.hidden_states[-1] hf_embeddings = reduce_hiddens(hf_output_all.hidden_states[-1], attention_mask) # free GPU RAM @@ -198,7 +203,7 @@ def test_esm2_golden_values(esm2_650M_config_w_ckpt, sample_data): torch.cuda.empty_cache() # configure the model to return logits - model = esm2_650M_config_w_ckpt.configure_model(get_tokenizer()).cuda() + model = esm2_config_w_ckpt.configure_model(get_tokenizer()).cuda() model.eval() result = model(input_ids, attention_mask) # token_logits is s,b and for simplicity here let's transpose to b,s. In general this reduces performance. @@ -211,18 +216,27 @@ def test_esm2_golden_values(esm2_650M_config_w_ckpt, sample_data): torch.cuda.empty_cache() # configure the model to return hiddens - esm2_650M_config_hiddens = deepcopy(esm2_650M_config_w_ckpt) - esm2_650M_config_hiddens.set_hparam("return_only_hidden_states", True) - model = esm2_650M_config_hiddens.configure_model(get_tokenizer()).cuda() + esm2_config_hiddens = deepcopy(esm2_config_w_ckpt) + esm2_config_hiddens.set_hparam("return_only_hidden_states", True) + model = esm2_config_hiddens.configure_model(get_tokenizer()).cuda() model.eval() hiddens = model(input_ids, attention_mask) embeddings = reduce_hiddens(torch.transpose(hiddens, 0, 1).float(), attention_mask) - torch.testing.assert_close(logits, hf_logits, atol=0.2, rtol=0.0) - torch.testing.assert_close(embeddings, hf_embeddings, atol=5e-3, rtol=0.0) + # Rather than directly comparing the logit or hidden state tensors, we compare their cosine similarity. These + # should be essentially 1 if the outputs are equivalent, but is less sensitive to small numerical differences. + # We don't care about the padding tokens, so we only compare the non-padding tokens. + logit_similarity = torch.nn.functional.cosine_similarity(logits, hf_logits, dim=2) + logit_similarity = logit_similarity[attention_mask == 1] + + hidden_state_similarity = torch.nn.functional.cosine_similarity(hiddens, hf_hiddens, dim=2) + hidden_state_similarity = hidden_state_similarity[attention_mask == 1] + + torch.testing.assert_close(logit_similarity, torch.ones_like(logit_similarity)) + torch.testing.assert_close(hidden_state_similarity, torch.ones_like(hidden_state_similarity)) -def test_esm2_loss(esm2_650M_config_w_ckpt, dummy_protein_dataset, dummy_parquet_train_val_inputs): +def test_esm2_loss(esm2_config_w_ckpt, dummy_protein_dataset, dummy_parquet_train_val_inputs): train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs compute_hf_reference: bool = True @@ -235,8 +249,8 @@ def test_esm2_loss(esm2_650M_config_w_ckpt, dummy_protein_dataset, dummy_parquet ): tokenizer = get_tokenizer() - # ESM2 model initialized with 650M params - model = esm2_650M_config_w_ckpt.configure_model(tokenizer).cuda() + # ESM2 model initialized with params + model = esm2_config_w_ckpt.configure_model(tokenizer).cuda() # Initialize the data module. data_module = ESMDataModule( @@ -269,10 +283,8 @@ def test_esm2_loss(esm2_650M_config_w_ckpt, dummy_protein_dataset, dummy_parquet mean_loss = _compute_loss(model, train_dataloader, vocab_size=tokenizer.vocab_size) if compute_hf_reference: - # HF model initialized with 650M params - hf_model = EsmForMaskedLM.from_pretrained( - "facebook/esm2_t33_650M_UR50D", torch_dtype=get_autocast_dtype(32) - ).cuda() + # HF model initialized with params + hf_model = AutoModelForMaskedLM.from_pretrained(hf_model_tag, torch_dtype=get_autocast_dtype(32)).cuda() hf_mean_loss = _compute_loss(hf_model, train_dataloader) print(f"hf_mean_loss: {hf_mean_loss}") else: From d7973961e3041716c8516e5d244a17ea6fa8ffc3 Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Tue, 14 Jan 2025 07:08:24 -0800 Subject: [PATCH 04/10] Add conversion script for ESM-2 model from HF to NeMo Signed-off-by: Peter St. John --- .../src/bionemo/esm2/model/convert.py | 173 ++++++++++++++++++ .../tests/bionemo/esm2/model/test_convert.py | 82 +++++++++ 2 files changed, 255 insertions(+) create mode 100644 sub-packages/bionemo-esm2/src/bionemo/esm2/model/convert.py create mode 100644 sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_convert.py diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/convert.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/convert.py new file mode 100644 index 000000000..b392936f6 --- /dev/null +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/convert.py @@ -0,0 +1,173 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from pathlib import Path + +import torch +from nemo.lightning import io, teardown +from nemo.lightning.pytorch.utils import dtype_from_hf +from transformers import AutoConfig as HFAutoConfig +from transformers import AutoModelForMaskedLM + +from bionemo.esm2.data.tokenizer import BioNeMoESMTokenizer, get_tokenizer +from bionemo.esm2.model.model import ESM2Config +from bionemo.llm.lightning import BionemoLightningModule +from bionemo.llm.model.biobert.lightning import biobert_lightning_module + + +@io.model_importer(BionemoLightningModule, "hf") +class HFESM2Importer(io.ModelConnector[AutoModelForMaskedLM, BionemoLightningModule]): + """Converts a Hugging Face ESM-2 model to a NeMo ESM-2 model.""" + + def init(self) -> BionemoLightningModule: + """Initialize the converted model.""" + return biobert_lightning_module(self.config, tokenizer=self.tokenizer) + + def apply(self, output_path: Path) -> Path: + """Applies the transformation. + + Largely inspired by + https://docs.nvidia.com/nemo-framework/user-guide/latest/nemo-2.0/features/hf-integration.html + """ + source = AutoModelForMaskedLM.from_pretrained(str(self), trust_remote_code=True, torch_dtype="auto") + target = self.init() + trainer = self.nemo_setup(target) + self.convert_state(source, target) + self.nemo_save(output_path, trainer) + + print(f"Converted ESM-2 model to Nemo, model saved to {output_path}") + + teardown(trainer, target) + del trainer, target + + return output_path + + def convert_state(self, source, target): + """Converting HF state dict to NeMo state dict.""" + mapping = { + # "esm.encoder.layer.0.attention.self.rotary_embeddings.inv_freq": "rotary_pos_emb.inv_freq", + "esm.encoder.layer.*.attention.output.dense.weight": "encoder.layers.*.self_attention.linear_proj.weight", + "esm.encoder.layer.*.attention.output.dense.bias": "encoder.layers.*.self_attention.linear_proj.bias", + "esm.encoder.layer.*.attention.LayerNorm.weight": "encoder.layers.*.self_attention.linear_qkv.layer_norm_weight", + "esm.encoder.layer.*.attention.LayerNorm.bias": "encoder.layers.*.self_attention.linear_qkv.layer_norm_bias", + "esm.encoder.layer.*.intermediate.dense.weight": "encoder.layers.*.mlp.linear_fc1.weight", + "esm.encoder.layer.*.intermediate.dense.bias": "encoder.layers.*.mlp.linear_fc1.bias", + "esm.encoder.layer.*.output.dense.weight": "encoder.layers.*.mlp.linear_fc2.weight", + "esm.encoder.layer.*.output.dense.bias": "encoder.layers.*.mlp.linear_fc2.bias", + "esm.encoder.layer.*.LayerNorm.weight": "encoder.layers.*.mlp.linear_fc1.layer_norm_weight", + "esm.encoder.layer.*.LayerNorm.bias": "encoder.layers.*.mlp.linear_fc1.layer_norm_bias", + "esm.encoder.emb_layer_norm_after.weight": "encoder.final_layernorm.weight", + "esm.encoder.emb_layer_norm_after.bias": "encoder.final_layernorm.bias", + "lm_head.dense.weight": "lm_head.dense.weight", + "lm_head.dense.bias": "lm_head.dense.bias", + "lm_head.layer_norm.weight": "lm_head.layer_norm.weight", + "lm_head.layer_norm.bias": "lm_head.layer_norm.bias", + } + + # lm_head.bias + return io.apply_transforms( + source, + target, + mapping=mapping, + transforms=[_pad_embeddings, _pad_bias, _import_qkv_weight, _import_qkv_bias], + ) + + @property + def tokenizer(self) -> BioNeMoESMTokenizer: + """We just have the one tokenizer for ESM-2.""" + return get_tokenizer() + + @property + def config(self) -> ESM2Config: + """Returns the transformed ESM-2 config given the model tag.""" + source = HFAutoConfig.from_pretrained(str(self), trust_remote_code=True) + output = ESM2Config( + num_layers=source.num_hidden_layers, + hidden_size=source.hidden_size, + ffn_hidden_size=source.intermediate_size, + position_embedding_type="rope", + num_attention_heads=source.num_attention_heads, + seq_length=source.max_position_embeddings, + fp16=(dtype_from_hf(source) == torch.float16), + bf16=(dtype_from_hf(source) == torch.bfloat16), + params_dtype=dtype_from_hf(source), + ) + + return output + + +@io.state_transform( + source_key="esm.embeddings.word_embeddings.weight", + target_key="embedding.word_embeddings.weight", +) +def _pad_embeddings(ctx: io.TransformCTX, source_embed): + """Pad the embedding layer to the new input dimension.""" + nemo_embedding_dimension = ctx.target.config.make_vocab_size_divisible_by + hf_embedding_dimension = source_embed.size(0) + num_padding_rows = nemo_embedding_dimension - hf_embedding_dimension + padding_rows = torch.zeros(num_padding_rows, source_embed.size(1)) + return torch.cat((source_embed, padding_rows), dim=0) + + +@io.state_transform( + source_key="lm_head.bias", + target_key="output_layer.bias", +) +def _pad_bias(ctx: io.TransformCTX, source_bias): + """Pad the embedding layer to the new input dimension.""" + nemo_embedding_dimension = ctx.target.config.make_vocab_size_divisible_by + hf_embedding_dimension = source_bias.size(0) + output_bias = torch.zeros(nemo_embedding_dimension, dtype=source_bias.dtype, device=source_bias.device) + output_bias[:hf_embedding_dimension] = source_bias + return output_bias + + +@io.state_transform( + source_key=( + "esm.encoder.layer.*.attention.self.query.weight", + "esm.encoder.layer.*.attention.self.key.weight", + "esm.encoder.layer.*.attention.self.value.weight", + ), + target_key="encoder.layers.*.self_attention.linear_qkv.weight", +) +def _import_qkv_weight(ctx: io.TransformCTX, query, key, value): + """Pad the embedding layer to the new input dimension.""" + concat_weights = torch.cat((query, key, value), dim=0) + input_shape = concat_weights.size() + np = ctx.target.config.num_attention_heads + concat_weights = concat_weights.view(3, np, -1, query.size()[-1]) + concat_weights = concat_weights.transpose(0, 1).contiguous() + concat_weights = concat_weights.view(*input_shape) + return concat_weights + + +@io.state_transform( + source_key=( + "esm.encoder.layer.*.attention.self.query.bias", + "esm.encoder.layer.*.attention.self.key.bias", + "esm.encoder.layer.*.attention.self.value.bias", + ), + target_key="encoder.layers.*.self_attention.linear_qkv.bias", +) +def _import_qkv_bias(ctx: io.TransformCTX, query, key, value): + """Pad the embedding layer to the new input dimension.""" + concat_biases = torch.cat((query, key, value), dim=0) + input_shape = concat_biases.size() + np = ctx.target.config.num_attention_heads + concat_biases = concat_biases.view(3, np, -1) + concat_biases = concat_biases.transpose(0, 1).contiguous() + concat_biases = concat_biases.view(*input_shape) + return concat_biases diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_convert.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_convert.py new file mode 100644 index 000000000..a183b1b2f --- /dev/null +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_convert.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from nemo.lightning import io +from transformers import AutoModelForMaskedLM + +from bionemo.core.utils.dtypes import get_autocast_dtype +from bionemo.esm2.data.tokenizer import get_tokenizer +from bionemo.esm2.model.model import ESM2Config +from bionemo.llm.model.biobert.lightning import biobert_lightning_module +from bionemo.testing import megatron_parallel_state_utils + + +def test_convert_esm2_hf_to_nemo(tmp_path): + from bionemo.esm2.model.convert import HFESM2Importer # noqa: F401 + + model_tag = "facebook/esm2_t6_8M_UR50D" + module = biobert_lightning_module(config=ESM2Config()) + io.import_ckpt(module, f"hf://{model_tag}", tmp_path / "output_ckpt") + tokenizer = get_tokenizer() + + test_proteins = [ + "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA", + "MKTVRQERLKSIRILERSKEPVSGAQLAEELSSRQVIVQDIAYLRSLGYNVATPRGYVLAGG", + ] + + tokens = tokenizer(test_proteins, return_tensors="pt", padding=True, truncation=True).to("cuda") + input_ids = tokens["input_ids"] + attention_mask = tokens["attention_mask"] + + # HF 650M model + hf_model = AutoModelForMaskedLM.from_pretrained(model_tag, torch_dtype=get_autocast_dtype(32)).cuda() + + with torch.inference_mode(), megatron_parallel_state_utils.distributed_model_parallel_state(): + nemo_model = ( + ESM2Config(initial_ckpt_path=tmp_path / "output_ckpt", include_embeddings=True, include_hiddens=True) + .configure_model(tokenizer) + .to("cuda") + .eval() + ) + + for i in range(len(hf_model.esm.encoder.layer)): + torch.testing.assert_close( + hf_model.esm.encoder.layer[i].attention.self.rotary_embeddings.inv_freq, + nemo_model.rotary_pos_emb.inv_freq, + atol=1e-4, + rtol=1e-6, + ) + + hf_output_all = hf_model(input_ids, attention_mask, output_hidden_states=True) + + nemo_output = nemo_model(input_ids, attention_mask) + nemo_logits = nemo_output["token_logits"].transpose(0, 1).contiguous()[..., : tokenizer.vocab_size] + + nemo_hidden_state = nemo_output["hidden_states"] + hf_hidden_state = hf_output_all.hidden_states[-1] + + # Rather than directly comparing the logit or hidden state tensors, we compare their cosine similarity. These + # should be essentially 1 if the outputs are equivalent, but is less sensitive to small numerical differences. + # We don't care about the padding tokens, so we only compare the non-padding tokens. + logit_similarity = torch.nn.functional.cosine_similarity(nemo_logits, hf_output_all.logits, dim=2) + logit_similarity = logit_similarity[attention_mask == 1] + + hidden_state_similarity = torch.nn.functional.cosine_similarity(nemo_hidden_state, hf_hidden_state, dim=2) + hidden_state_similarity = hidden_state_similarity[attention_mask == 1] + + torch.testing.assert_close(logit_similarity, torch.ones_like(logit_similarity)) + torch.testing.assert_close(hidden_state_similarity, torch.ones_like(hidden_state_similarity)) From ab7f45926fe38ae2ac0c8551c2848ae98f26523e Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Tue, 14 Jan 2025 07:11:06 -0800 Subject: [PATCH 05/10] add comments from original script Signed-off-by: Peter St. John --- sub-packages/bionemo-esm2/src/bionemo/esm2/model/convert.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/convert.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/convert.py index b392936f6..06be1fa0a 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/convert.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/convert.py @@ -148,6 +148,9 @@ def _import_qkv_weight(ctx: io.TransformCTX, query, key, value): concat_weights = torch.cat((query, key, value), dim=0) input_shape = concat_weights.size() np = ctx.target.config.num_attention_heads + # transpose weights + # [sequence length, batch size, num_splits_model_parallel * attention head size * #attention heads] + # --> [sequence length, batch size, attention head size * num_splits_model_parallel * #attention heads] concat_weights = concat_weights.view(3, np, -1, query.size()[-1]) concat_weights = concat_weights.transpose(0, 1).contiguous() concat_weights = concat_weights.view(*input_shape) @@ -167,6 +170,9 @@ def _import_qkv_bias(ctx: io.TransformCTX, query, key, value): concat_biases = torch.cat((query, key, value), dim=0) input_shape = concat_biases.size() np = ctx.target.config.num_attention_heads + # transpose biases + # [num_splits_model_parallel * attention head size * #attention heads] + # --> [attention head size * num_splits_model_parallel * #attention heads] concat_biases = concat_biases.view(3, np, -1) concat_biases = concat_biases.transpose(0, 1).contiguous() concat_biases = concat_biases.view(*input_shape) From bd16b2a7f4cc5bd86f7925901c32c07af6392302 Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Tue, 14 Jan 2025 13:23:14 -0800 Subject: [PATCH 06/10] add esm2 checkpoint to pbss, refactor tests for xfail due to bug Signed-off-by: Peter St. John --- .../src/bionemo/core/data/resources/esm2.yaml | 39 ++++++++++++--- .../tests/bionemo/esm2/model/test_convert.py | 48 +++++++++++++------ 2 files changed, 66 insertions(+), 21 deletions(-) diff --git a/sub-packages/bionemo-core/src/bionemo/core/data/resources/esm2.yaml b/sub-packages/bionemo-core/src/bionemo/core/data/resources/esm2.yaml index d7749aa78..bdef7a588 100644 --- a/sub-packages/bionemo-core/src/bionemo/core/data/resources/esm2.yaml +++ b/sub-packages/bionemo-core/src/bionemo/core/data/resources/esm2.yaml @@ -43,6 +43,33 @@ description: > The original 8M parameter ESM2 model weights converted to the NeMo2 checkpoint format. +- tag: nv_650m:2.1 + ngc: "nvidia/clara/esm2nv650m:2.1" + ngc_registry: model + pbss: "s3://general-purpose/esm2/checkpoints/650m/esm2_650m_checkpoint.tar.gz" + sha256: b83e9b5d62f1499b443817c5cd0facd3bdd4013a51a897e05e17228bf650befe # pragma: allowlist secret + owner: Peter St John + description: > + An ESM-2 650M model pre-trained on NVIDIA's train/test data split. + +- tag: nv_3b:2.1 + ngc: "nvidia/clara/esm2nv3b:2.1" + ngc_registry: model + pbss: "s3://general-purpose/esm2/checkpoints/3b/esm2_3b_checkpoint.tar.gz" + sha256: a79327a4054bf8d1d7075e1b3c961dbc503da02d72ed15f707d9cbbd49d181b6 # pragma: allowlist secret + owner: Peter St John + description: > + An ESM-2 3B model pre-trained on NVIDIA's train/test data split. + +- tag: 8m:2.0 + ngc: null + ngc_registry: model + pbss: s3://general-purpose/esm2/checkpoints/converted/8m/esm2_hf_converted_8m_checkpoint.tar.gz + sha256: 2957b2c36d5978d0f595d6f1b72104b312621cf0329209086537b613c1c96d16 # pragma: allowlist secret + owner: Peter St John + description: > + A NeMo2 compatible checkpoint converted from the huggingface facebook/esm2_t6_8M_UR50D model. + - tag: 650m:2.0 ngc: nvidia/clara/esm2nv650m:2.0 ngc_registry: model @@ -50,7 +77,7 @@ sha256: 0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997 # pragma: allowlist secret owner: Farhad Ramezanghorbani description: > - The original 650M parameter ESM2 model weights converted to the NeMo2 checkpoint format. + A NeMo2 compatible checkpoint converted from the huggingface facebook/esm2_t33_650M_UR50D model. - tag: 3b:2.0 ngc: nvidia/clara/esm2nv3b:2.0 @@ -59,13 +86,13 @@ sha256: a2248cfed1ef39f83bd32a0e08b84c0a8f39325d383e2d92767022ff7f5260ed # pragma: allowlist secret owner: Farhad Ramezanghorbani description: > - The original 3B parameter ESM2 model c converted to the NeMo2 checkpoint format. + A NeMo2 compatible checkpoint converted from the huggingface facebook/esm2_t36_3B_UR50D model. - tag: fulldata_esm2_pretrain:2.0 ngc: nvidia/clara/esm2_pretrain_nemo2_data:1.0 ngc_registry: resource pbss: "s3://general-purpose/esm2/pretrain/2024_03.tar.gz" - sha256: 404d0ad8de58fa8aae96f8d9f54263a088bc7e4f7d668215afbe04c28416151b # pragma: allowlist secret + sha256: 404d0ad8de58fa8aae96f8d9f54263a088bc7e4f7d668215afbe04c28416151b # pragma: allowlist secret owner: Peter St John description: Full data for ESM2 pretraining. @@ -73,14 +100,14 @@ ngc: nvidia/clara/esm2_pretrain_nemo2_testdata:1.0 ngc_registry: resource pbss: "s3://general-purpose/esm2/pretrain/2024_03_sanity.tar.gz" - sha256: 006911f92bbc0ded7ea302bbdbfab4c694b409e699c32fd49de1c527a99dba3e # pragma: allowlist secret + sha256: 006911f92bbc0ded7ea302bbdbfab4c694b409e699c32fd49de1c527a99dba3e # pragma: allowlist secret owner: Peter St John description: Test data for ESM2 pretraining. - tag: esm2_inference_testdata:2.0 - ngc: nvidia/clara/esm2_inference_testdata:2.0 # TODO: upload to NGC + ngc: nvidia/clara/esm2_inference_testdata:2.0 # TODO: upload to NGC ngc_registry: resource pbss: "s3://bionemo-ci/test_data/esm2/artificial_protein_sequences.csv" - sha256: 14ae3acfbf82218bc9e3e53d21a5b0594ba7c0369e169c9f1034e3fe4378d175 # pragma: allowlist secret + sha256: 14ae3acfbf82218bc9e3e53d21a5b0594ba7c0369e169c9f1034e3fe4378d175 # pragma: allowlist secret owner: Farhad Ramezanghorbani description: Test data for ESM2 inference. diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_convert.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_convert.py index a183b1b2f..374faaa6d 100644 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_convert.py +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_convert.py @@ -14,10 +14,14 @@ # limitations under the License. +from pathlib import Path + +import pytest import torch from nemo.lightning import io from transformers import AutoModelForMaskedLM +from bionemo.core.data.load import load from bionemo.core.utils.dtypes import get_autocast_dtype from bionemo.esm2.data.tokenizer import get_tokenizer from bionemo.esm2.model.model import ESM2Config @@ -25,42 +29,40 @@ from bionemo.testing import megatron_parallel_state_utils -def test_convert_esm2_hf_to_nemo(tmp_path): +tokenizer = get_tokenizer() + + +def run_esm2_ckpt_conversion_hf_to_nemo(ckpt_path: Path, model_tag: str): from bionemo.esm2.model.convert import HFESM2Importer # noqa: F401 - model_tag = "facebook/esm2_t6_8M_UR50D" module = biobert_lightning_module(config=ESM2Config()) - io.import_ckpt(module, f"hf://{model_tag}", tmp_path / "output_ckpt") - tokenizer = get_tokenizer() + io.import_ckpt(module, f"hf://{model_tag}", ckpt_path / "nemo_checkpoint") + return ckpt_path / "nemo_checkpoint" + +def get_input_tokens(): test_proteins = [ "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA", "MKTVRQERLKSIRILERSKEPVSGAQLAEELSSRQVIVQDIAYLRSLGYNVATPRGYVLAGG", ] - tokens = tokenizer(test_proteins, return_tensors="pt", padding=True, truncation=True).to("cuda") input_ids = tokens["input_ids"] attention_mask = tokens["attention_mask"] + return input_ids, attention_mask + - # HF 650M model +def assert_model_equivalence(ckpt_path, model_tag): hf_model = AutoModelForMaskedLM.from_pretrained(model_tag, torch_dtype=get_autocast_dtype(32)).cuda() with torch.inference_mode(), megatron_parallel_state_utils.distributed_model_parallel_state(): nemo_model = ( - ESM2Config(initial_ckpt_path=tmp_path / "output_ckpt", include_embeddings=True, include_hiddens=True) + ESM2Config(initial_ckpt_path=ckpt_path, include_embeddings=True, include_hiddens=True) .configure_model(tokenizer) .to("cuda") .eval() ) - for i in range(len(hf_model.esm.encoder.layer)): - torch.testing.assert_close( - hf_model.esm.encoder.layer[i].attention.self.rotary_embeddings.inv_freq, - nemo_model.rotary_pos_emb.inv_freq, - atol=1e-4, - rtol=1e-6, - ) - + input_ids, attention_mask = get_input_tokens() hf_output_all = hf_model(input_ids, attention_mask, output_hidden_states=True) nemo_output = nemo_model(input_ids, attention_mask) @@ -80,3 +82,19 @@ def test_convert_esm2_hf_to_nemo(tmp_path): torch.testing.assert_close(logit_similarity, torch.ones_like(logit_similarity)) torch.testing.assert_close(hidden_state_similarity, torch.ones_like(hidden_state_similarity)) + + +@pytest.mark.xfail( + reason="This test is failing due to a bug in nemo global state when run in the same process as previous checkpoint" + "save/load scripts." +) +def test_nemo2_conversion_golden_values(tmp_path): + model_tag = "facebook/esm2_t6_8M_UR50D" + ckpt_path = run_esm2_ckpt_conversion_hf_to_nemo(tmp_path, model_tag) + assert_model_equivalence(ckpt_path, model_tag) + + +def test_pre_converted_checkpoint_golden_values(): + model_tag = "facebook/esm2_t6_8M_UR50D" + ckpt_path = load("esm2/8m:2.0", source="pbss") + assert_model_equivalence(ckpt_path, model_tag) From 3936d534b3309e6e1765bf83aa8e7e40a45bdd79 Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Tue, 21 Jan 2025 20:09:15 +0000 Subject: [PATCH 07/10] refactor test_model to use the new model comparison function Signed-off-by: Peter St. John --- .../src/bionemo/core/data/resources/esm2.yaml | 67 ++++-------- .../src/bionemo/esm2/testing/__init__.py | 14 +++ .../src/bionemo/esm2/testing/compare.py | 100 ++++++++++++++++++ .../tests/bionemo/esm2/model/test_convert.py | 86 +++------------ .../tests/bionemo/esm2/model/test_model.py | 91 +++++----------- 5 files changed, 176 insertions(+), 182 deletions(-) create mode 100644 sub-packages/bionemo-esm2/src/bionemo/esm2/testing/__init__.py create mode 100644 sub-packages/bionemo-esm2/src/bionemo/esm2/testing/compare.py diff --git a/sub-packages/bionemo-core/src/bionemo/core/data/resources/esm2.yaml b/sub-packages/bionemo-core/src/bionemo/core/data/resources/esm2.yaml index bdef7a588..ddc5033b3 100644 --- a/sub-packages/bionemo-core/src/bionemo/core/data/resources/esm2.yaml +++ b/sub-packages/bionemo-core/src/bionemo/core/data/resources/esm2.yaml @@ -7,23 +7,32 @@ description: > A pretrained 650M parameter ESM2 model. See https://ngc.nvidia.com/catalog/models/nvidia:clara:esm2nv650m. -- tag: nv_3b:2.1 - ngc: "nvidia/clara/esm2nv3b:2.1" +- tag: 8m:2.0 + ngc: nvidia/clara/esm2nv8m:2.0 ngc_registry: model - pbss: "s3://general-purpose/esm2/checkpoints/3b/esm2_3b_checkpoint.tar.gz" - sha256: a79327a4054bf8d1d7075e1b3c961dbc503da02d72ed15f707d9cbbd49d181b6 # pragma: allowlist secret + pbss: s3://general-purpose/esm2/checkpoints/converted/8m/esm2_hf_converted_8m_checkpoint.tar.gz + sha256: 2957b2c36d5978d0f595d6f1b72104b312621cf0329209086537b613c1c96d16 # pragma: allowlist secret owner: Peter St John description: > - An ESM-2 3B model pre-trained on NVIDIA's train/test data split. + A NeMo2 compatible checkpoint converted from the huggingface facebook/esm2_t6_8M_UR50D model. -- tag: nv_650m:2.1 - ngc: "nvidia/clara/esm2nv650m:2.1" +- tag: 650m:2.0 + ngc: nvidia/clara/esm2nv650m:2.0 ngc_registry: model - pbss: "s3://general-purpose/esm2/checkpoints/650m/esm2_650m_checkpoint.tar.gz" - sha256: b83e9b5d62f1499b443817c5cd0facd3bdd4013a51a897e05e17228bf650befe # pragma: allowlist secret - owner: Peter St John + pbss: "s3://bionemo-ci/models/esm2_650M_nemo2.tar.gz" + sha256: 0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997 # pragma: allowlist secret + owner: Farhad Ramezanghorbani description: > - An ESM-2 650M model pre-trained on NVIDIA's train/test data split. + A NeMo2 compatible checkpoint converted from the huggingface facebook/esm2_t33_650M_UR50D model. + +- tag: 3b:2.0 + ngc: nvidia/clara/esm2nv3b:2.0 + ngc_registry: model + pbss: "s3://bionemo-ci/models/esm2_3B_nemo2.tar.gz" + sha256: a2248cfed1ef39f83bd32a0e08b84c0a8f39325d383e2d92767022ff7f5260ed # pragma: allowlist secret + owner: Farhad Ramezanghorbani + description: > + A NeMo2 compatible checkpoint converted from the huggingface facebook/esm2_t36_3B_UR50D model. # - tag: nv_8m:2.1 # ngc: "nvidia/clara/esm2nv8m:2.1" @@ -34,15 +43,6 @@ # description: > # An ESM-2 8M model pre-trained on NVIDIA's train/test data split. -- tag: 8m:2.0 - ngc: "nvidia/clara/esm2nv8m:2.0" - ngc_registry: model - pbss: "s3://general-purpose/esm2/checkpoints/converted/8m/esm2_hf_converted_8m_checkpoint.tar.gz" - sha256: 2957b2c36d5978d0f595d6f1b72104b312621cf0329209086537b613c1c96d16 # pragma: allowlist secret - owner: Peter St John - description: > - The original 8M parameter ESM2 model weights converted to the NeMo2 checkpoint format. - - tag: nv_650m:2.1 ngc: "nvidia/clara/esm2nv650m:2.1" ngc_registry: model @@ -61,33 +61,6 @@ description: > An ESM-2 3B model pre-trained on NVIDIA's train/test data split. -- tag: 8m:2.0 - ngc: null - ngc_registry: model - pbss: s3://general-purpose/esm2/checkpoints/converted/8m/esm2_hf_converted_8m_checkpoint.tar.gz - sha256: 2957b2c36d5978d0f595d6f1b72104b312621cf0329209086537b613c1c96d16 # pragma: allowlist secret - owner: Peter St John - description: > - A NeMo2 compatible checkpoint converted from the huggingface facebook/esm2_t6_8M_UR50D model. - -- tag: 650m:2.0 - ngc: nvidia/clara/esm2nv650m:2.0 - ngc_registry: model - pbss: "s3://bionemo-ci/models/esm2_650M_nemo2.tar.gz" - sha256: 0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997 # pragma: allowlist secret - owner: Farhad Ramezanghorbani - description: > - A NeMo2 compatible checkpoint converted from the huggingface facebook/esm2_t33_650M_UR50D model. - -- tag: 3b:2.0 - ngc: nvidia/clara/esm2nv3b:2.0 - ngc_registry: model - pbss: "s3://bionemo-ci/models/esm2_3B_nemo2.tar.gz" - sha256: a2248cfed1ef39f83bd32a0e08b84c0a8f39325d383e2d92767022ff7f5260ed # pragma: allowlist secret - owner: Farhad Ramezanghorbani - description: > - A NeMo2 compatible checkpoint converted from the huggingface facebook/esm2_t36_3B_UR50D model. - - tag: fulldata_esm2_pretrain:2.0 ngc: nvidia/clara/esm2_pretrain_nemo2_data:1.0 ngc_registry: resource diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/testing/__init__.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/testing/__init__.py new file mode 100644 index 000000000..25e6abfbc --- /dev/null +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/testing/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/testing/compare.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/testing/compare.py new file mode 100644 index 000000000..1ff3d06b0 --- /dev/null +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/testing/compare.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import gc +from pathlib import Path + +import torch +from transformers import AutoModelForMaskedLM + +from bionemo.core.utils.dtypes import PrecisionTypes, get_autocast_dtype +from bionemo.esm2.data.tokenizer import get_tokenizer +from bionemo.esm2.model.model import ESM2Config +from bionemo.testing import megatron_parallel_state_utils + + +def assert_model_equivalence( + ckpt_path: Path | str, + model_tag: str, + precision: PrecisionTypes = "fp32", + rtol: float | None = None, + atol: float | None = None, +) -> None: + """Testing utility to compare the outputs of a NeMo2 checkpoint to the original HuggingFace model weights. + + Compares the cosine similarity of the logit and hidden state outputs of a NeMo2 model checkpoint to the outputs of + the corresponding HuggingFace model. + + Args: + ckpt_path: A path to a NeMo2 checkpoint for an ESM-2 model. + model_tag: The HuggingFace model tag for the model to compare against. + precision: The precision type to use for the comparison. Defaults to "fp32". + rtol: The relative tolerance to use for the comparison. Defaults to None, which chooses the tolerance based on + the precision. + atol: The absolute tolerance to use for the comparison. Defaults to None, which chooses the tolerance based on + the precision. + """ + tokenizer = get_tokenizer() + + test_proteins = [ + "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA", + "MKTVRQERLKSIRILERSKEPVSGAQLAEELSSRQVIVQDIAYLRSLGYNVATPRGYVLAGG", + ] + tokens = tokenizer(test_proteins, return_tensors="pt", padding=True, truncation=True).to("cuda") + input_ids = tokens["input_ids"] + attention_mask = tokens["attention_mask"] + + with torch.inference_mode(), megatron_parallel_state_utils.distributed_model_parallel_state(): + nemo_model = ( + ESM2Config( + initial_ckpt_path=str(ckpt_path), + include_embeddings=True, + include_hiddens=True, + params_dtype=get_autocast_dtype(precision), + pipeline_dtype=get_autocast_dtype(precision), + autocast_dtype=get_autocast_dtype(precision), + bf16=True, + ) # setting this speeds things up a lot) + .configure_model(tokenizer) + .to("cuda") + .eval() + ) + + nemo_output = nemo_model(input_ids, attention_mask) + nemo_logits = nemo_output["token_logits"].transpose(0, 1).contiguous()[..., : tokenizer.vocab_size] + nemo_hidden_state = nemo_output["hidden_states"] + + del nemo_model + gc.collect() + torch.cuda.empty_cache() + + hf_model = AutoModelForMaskedLM.from_pretrained(model_tag, torch_dtype=get_autocast_dtype(precision)).cuda() + hf_output_all = hf_model(input_ids, attention_mask, output_hidden_states=True) + hf_hidden_state = hf_output_all.hidden_states[-1] + + # Rather than directly comparing the logit or hidden state tensors, we compare their cosine similarity. These + # should be essentially 1 if the outputs are equivalent, but is less sensitive to small numerical differences. + # We don't care about the padding tokens, so we only compare the non-padding tokens. + logit_similarity = torch.nn.functional.cosine_similarity(nemo_logits, hf_output_all.logits, dim=2) + logit_similarity = logit_similarity[attention_mask == 1] + + hidden_state_similarity = torch.nn.functional.cosine_similarity(nemo_hidden_state, hf_hidden_state, dim=2) + hidden_state_similarity = hidden_state_similarity[attention_mask == 1] + + torch.testing.assert_close(logit_similarity, torch.ones_like(logit_similarity), rtol=rtol, atol=atol) + torch.testing.assert_close( + hidden_state_similarity, torch.ones_like(hidden_state_similarity), rtol=rtol, atol=atol + ) diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_convert.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_convert.py index 374faaa6d..f2b257074 100644 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_convert.py +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_convert.py @@ -14,87 +14,33 @@ # limitations under the License. -from pathlib import Path - import pytest -import torch from nemo.lightning import io -from transformers import AutoModelForMaskedLM -from bionemo.core.data.load import load -from bionemo.core.utils.dtypes import get_autocast_dtype -from bionemo.esm2.data.tokenizer import get_tokenizer +from bionemo.esm2.model.convert import HFESM2Importer # noqa: F401 from bionemo.esm2.model.model import ESM2Config +from bionemo.esm2.testing.compare import assert_model_equivalence from bionemo.llm.model.biobert.lightning import biobert_lightning_module -from bionemo.testing import megatron_parallel_state_utils - - -tokenizer = get_tokenizer() - - -def run_esm2_ckpt_conversion_hf_to_nemo(ckpt_path: Path, model_tag: str): - from bionemo.esm2.model.convert import HFESM2Importer # noqa: F401 - - module = biobert_lightning_module(config=ESM2Config()) - io.import_ckpt(module, f"hf://{model_tag}", ckpt_path / "nemo_checkpoint") - return ckpt_path / "nemo_checkpoint" - - -def get_input_tokens(): - test_proteins = [ - "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA", - "MKTVRQERLKSIRILERSKEPVSGAQLAEELSSRQVIVQDIAYLRSLGYNVATPRGYVLAGG", - ] - tokens = tokenizer(test_proteins, return_tensors="pt", padding=True, truncation=True).to("cuda") - input_ids = tokens["input_ids"] - attention_mask = tokens["attention_mask"] - return input_ids, attention_mask - - -def assert_model_equivalence(ckpt_path, model_tag): - hf_model = AutoModelForMaskedLM.from_pretrained(model_tag, torch_dtype=get_autocast_dtype(32)).cuda() - - with torch.inference_mode(), megatron_parallel_state_utils.distributed_model_parallel_state(): - nemo_model = ( - ESM2Config(initial_ckpt_path=ckpt_path, include_embeddings=True, include_hiddens=True) - .configure_model(tokenizer) - .to("cuda") - .eval() - ) - - input_ids, attention_mask = get_input_tokens() - hf_output_all = hf_model(input_ids, attention_mask, output_hidden_states=True) - - nemo_output = nemo_model(input_ids, attention_mask) - nemo_logits = nemo_output["token_logits"].transpose(0, 1).contiguous()[..., : tokenizer.vocab_size] - - nemo_hidden_state = nemo_output["hidden_states"] - hf_hidden_state = hf_output_all.hidden_states[-1] - - # Rather than directly comparing the logit or hidden state tensors, we compare their cosine similarity. These - # should be essentially 1 if the outputs are equivalent, but is less sensitive to small numerical differences. - # We don't care about the padding tokens, so we only compare the non-padding tokens. - logit_similarity = torch.nn.functional.cosine_similarity(nemo_logits, hf_output_all.logits, dim=2) - logit_similarity = logit_similarity[attention_mask == 1] - - hidden_state_similarity = torch.nn.functional.cosine_similarity(nemo_hidden_state, hf_hidden_state, dim=2) - hidden_state_similarity = hidden_state_similarity[attention_mask == 1] - - torch.testing.assert_close(logit_similarity, torch.ones_like(logit_similarity)) - torch.testing.assert_close(hidden_state_similarity, torch.ones_like(hidden_state_similarity)) @pytest.mark.xfail( reason="This test is failing due to a bug in nemo global state when run in the same process as previous checkpoint" "save/load scripts." ) -def test_nemo2_conversion_golden_values(tmp_path): +def test_nemo2_conversion_equivalent_8m(tmp_path): model_tag = "facebook/esm2_t6_8M_UR50D" - ckpt_path = run_esm2_ckpt_conversion_hf_to_nemo(tmp_path, model_tag) - assert_model_equivalence(ckpt_path, model_tag) + module = biobert_lightning_module(config=ESM2Config()) + io.import_ckpt(module, f"hf://{model_tag}", tmp_path / "nemo_checkpoint") + assert_model_equivalence(tmp_path / "nemo_checkpoint", model_tag) -def test_pre_converted_checkpoint_golden_values(): - model_tag = "facebook/esm2_t6_8M_UR50D" - ckpt_path = load("esm2/8m:2.0", source="pbss") - assert_model_equivalence(ckpt_path, model_tag) +@pytest.mark.xfail( + reason="This test is failing due to a bug in nemo global state when run in the same process as previous checkpoint" + "save/load scripts." +) +@pytest.mark.slow +def test_nemo2_conversion_equivalent_650m(tmp_path): + model_tag = "facebook/esm2_t33_650M_UR50D" + module = biobert_lightning_module(config=ESM2Config()) + io.import_ckpt(module, f"hf://{model_tag}", tmp_path / "nemo_checkpoint") + assert_model_equivalence(tmp_path / "nemo_checkpoint", model_tag) diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py index 10595c049..dce52b973 100644 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py @@ -13,16 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import gc import io import tarfile -from copy import deepcopy from typing import List, Tuple from unittest import mock import pytest import torch -from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer from torch import Tensor from transformers import AutoModelForMaskedLM @@ -33,6 +30,7 @@ from bionemo.esm2.data.datamodule import ESMDataModule from bionemo.esm2.data.tokenizer import get_tokenizer from bionemo.esm2.model.embedding import ESM2Embedding +from bionemo.esm2.testing.compare import assert_model_equivalence from bionemo.llm.model.biobert.model import MegatronBioBertModel from bionemo.llm.utils.weight_utils import nemo1_to_nemo2_biobert_key_mapping from bionemo.testing import megatron_parallel_state_utils @@ -182,64 +180,9 @@ def test_esm2_nemo1_checkpoint(esm2_model): assert not missing_old_keys, "There are keys in the old checkpoint that are missing from the new model." -def test_esm2_golden_values(esm2_config_w_ckpt, sample_data): - tokenizer = AutoTokenizer(pretrained_model_name="facebook/esm2_t33_650M_UR50D") - tokens = tokenizer.tokenizer([row[1] for row in sample_data], return_tensors="pt", padding=True).to("cuda") - input_ids = tokens["input_ids"] - attention_mask = tokens["attention_mask"] - - # HF model - hf_model = AutoModelForMaskedLM.from_pretrained(hf_model_tag, torch_dtype=get_autocast_dtype(32)).cuda() - - with torch.no_grad(): - hf_output_all = hf_model(input_ids, attention_mask, output_hidden_states=True) - hf_logits = hf_output_all.logits * attention_mask.unsqueeze(-1) - hf_hiddens = hf_output_all.hidden_states[-1] - hf_embeddings = reduce_hiddens(hf_output_all.hidden_states[-1], attention_mask) - - # free GPU RAM - del hf_model - gc.collect() - torch.cuda.empty_cache() - - # configure the model to return logits - model = esm2_config_w_ckpt.configure_model(get_tokenizer()).cuda() - model.eval() - result = model(input_ids, attention_mask) - # token_logits is s,b and for simplicity here let's transpose to b,s. In general this reduces performance. - logits = result["token_logits"].transpose(0, 1).contiguous()[..., : tokenizer.vocab_size] - logits = logits * attention_mask.unsqueeze(-1) # incorporate masking logic - - # free GPU RAM - del model - gc.collect() - torch.cuda.empty_cache() - - # configure the model to return hiddens - esm2_config_hiddens = deepcopy(esm2_config_w_ckpt) - esm2_config_hiddens.set_hparam("return_only_hidden_states", True) - model = esm2_config_hiddens.configure_model(get_tokenizer()).cuda() - model.eval() - hiddens = model(input_ids, attention_mask) - embeddings = reduce_hiddens(torch.transpose(hiddens, 0, 1).float(), attention_mask) - - # Rather than directly comparing the logit or hidden state tensors, we compare their cosine similarity. These - # should be essentially 1 if the outputs are equivalent, but is less sensitive to small numerical differences. - # We don't care about the padding tokens, so we only compare the non-padding tokens. - logit_similarity = torch.nn.functional.cosine_similarity(logits, hf_logits, dim=2) - logit_similarity = logit_similarity[attention_mask == 1] - - hidden_state_similarity = torch.nn.functional.cosine_similarity(hiddens, hf_hiddens, dim=2) - hidden_state_similarity = hidden_state_similarity[attention_mask == 1] - - torch.testing.assert_close(logit_similarity, torch.ones_like(logit_similarity)) - torch.testing.assert_close(hidden_state_similarity, torch.ones_like(hidden_state_similarity)) - - def test_esm2_loss(esm2_config_w_ckpt, dummy_protein_dataset, dummy_parquet_train_val_inputs): train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs - compute_hf_reference: bool = True seed: int = 42 with ( @@ -282,12 +225,30 @@ def test_esm2_loss(esm2_config_w_ckpt, dummy_protein_dataset, dummy_parquet_trai mean_loss = _compute_loss(model, train_dataloader, vocab_size=tokenizer.vocab_size) - if compute_hf_reference: - # HF model initialized with params - hf_model = AutoModelForMaskedLM.from_pretrained(hf_model_tag, torch_dtype=get_autocast_dtype(32)).cuda() - hf_mean_loss = _compute_loss(hf_model, train_dataloader) - print(f"hf_mean_loss: {hf_mean_loss}") - else: - hf_mean_loss = torch.tensor(2.9279041290283203).cuda() + # HF model initialized with params + hf_model = AutoModelForMaskedLM.from_pretrained(hf_model_tag, torch_dtype=get_autocast_dtype(32)).cuda() + hf_mean_loss = _compute_loss(hf_model, train_dataloader) + print(f"hf_mean_loss: {hf_mean_loss}") torch.testing.assert_close(mean_loss, hf_mean_loss, atol=1e-3, rtol=0.0) + + +def test_model_equivalence_with_huggingface_8m(): + model_tag = "facebook/esm2_t6_8M_UR50D" + ckpt_path = load("esm2/8m:2.0") + assert_model_equivalence(ckpt_path, model_tag) + + +@pytest.mark.slow +def test_model_equivalence_with_huggingface_650m(): + model_tag = "facebook/esm2_t33_650M_UR50D" + ckpt_path = load("esm2/650m:2.0") + assert_model_equivalence(ckpt_path, model_tag, atol=1e-4, rtol=1e-4) + + +@pytest.mark.slow +@pytest.mark.skip(reason="This test triggers a large download from huggingface and requires considerable GPU memory.") +def test_model_equivalence_with_huggingface_3b(): + model_tag = "facebook/esm2_t36_3B_UR50D" + ckpt_path = load("esm2/3b:2.0") + assert_model_equivalence(ckpt_path, model_tag, atol=1e-4, rtol=1e-4) From cc62c2bc8d485682f2f6515a25c1206657913640 Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Tue, 21 Jan 2025 22:32:25 +0000 Subject: [PATCH 08/10] refactor test_model to remove fixtures Signed-off-by: Peter St. John --- .../tests/bionemo/esm2/model/test_model.py | 154 ++++++------------ 1 file changed, 48 insertions(+), 106 deletions(-) diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py index dce52b973..a3cc4a425 100644 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py @@ -15,7 +15,6 @@ import io import tarfile -from typing import List, Tuple from unittest import mock import pytest @@ -36,111 +35,18 @@ from bionemo.testing import megatron_parallel_state_utils -hf_model_tag = "facebook/esm2_t6_8M_UR50D" -nv_model_tag = "esm2/8m:2.0" -# hf_model_tag = "facebook/esm2_t33_650M_UR50D" -# nv_model_tag = "esm2/650m:2.0" - - -def reduce_hiddens(hiddens: Tensor, attention_mask: Tensor) -> Tensor: - """reduce last layer's hidden values to embeddings - - Args: - hiddens: [b, s, h] tensor of hidden values - attention_mask: [b, s] attention mask tensor - - Returns: - reduced embedding tensor [b, h] - """ - masks = torch.sum(attention_mask, dim=1) - embeddings = torch.zeros( - size=(hiddens.shape[0], hiddens.shape[2]), - dtype=torch.float32, - device=torch.cuda.current_device(), - ) - for i, (hidden, mask) in enumerate(zip(hiddens, masks)): - embeddings[i, :] = torch.mean(hidden[1 : mask - 1], dim=0) - return embeddings - - -@pytest.fixture(scope="module") -def esm2_config() -> ESM2Config: - with megatron_parallel_state_utils.distributed_model_parallel_state(): - yield ESM2Config() - - -@pytest.fixture(scope="module") -def esm2_config_w_ckpt() -> ESM2Config: - with megatron_parallel_state_utils.distributed_model_parallel_state(): - yield ESM2Config(initial_ckpt_path=load(nv_model_tag)) - - -@pytest.fixture(scope="module") -def esm2_model(esm2_config) -> ESM2Model: +def test_esm2_model_initialized(): with megatron_parallel_state_utils.distributed_model_parallel_state(): tokenizer = get_tokenizer() - model = esm2_config.configure_model(tokenizer) - yield model - - -@pytest.fixture(scope="module") -def sample_data() -> List[Tuple[str, str]]: - """Generates sample protein sequences for sanity checks, including mask tokens.""" - max_length = 1022 # The maximum length of the protein sequences to be considered. - sample_data = [ - ( - "protein1", - "MNGTEGPNFYVPFSNATGVVRSPFEYPQYYLAEPWQFSMLAAYMFLLIVLGFPINFLTLYVTVQHKKLRTPLNYILLNLAVADLFMVLGGFTSTLYTSLHGYFVFGPTGCNLEGFFATLGGEIALWSLVVLAIERYVVVCKPMSNFRFGENHAIMGVAFTWVMALACAAPPLAGWSRYIPEGLQCSCGIDYYTLKPEVNNESFVIYMFVVHFTIPMIIIFFCYGQLVFTVKEAAAQQQESATTQKAEKEVTRMVIIMVIAFLICWVPYASVAFYIFTHQGSNFGPIFMTIPAFFAKSAAIYNPVIYIMMNKQFRNCMLTTICCGKNPLGDDEASATVSKTETSQVAPA", - ), - ("protein2", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA"), - ( - "protein3", - "MKTVRQERLKSIRILERSKEPVSGAQLAEELSSRQVIVQDIAYLRSLGYNVATPRGYVLAGG", - ), - ( - "protein4", - "MKTVRQERLKSIRILERSKEPVSGAQLAEELSSRQVIVQDIAYLRSLGYNVATPRGYVLA", - ), - ] - # add another sample protein that uses the maximum length to test this edge case - sample_data.append(("protein5", (sample_data[0][1] * 3)[:max_length])) - yield sample_data - - -def _compute_loss(model, dataloader, vocab_size=None): - loss = 0 - n = 0 - limit_batches = 10 - for i, batch in enumerate(dataloader): - assert isinstance(batch, dict) - result = model(input_ids=batch["text"].cuda(), attention_mask=batch["attention_mask"].cuda()) - - # bionemo ESM2 vocab_size - if vocab_size is not None: - # token_logits is s,b and for simplicity here let's transpose to b,s. In general this reduces performance. - logits = result["token_logits"].transpose(0, 1).contiguous()[..., :vocab_size] - else: - logits = result.logits - - loss_mask = batch["loss_mask"].cuda() - target = batch["labels"].cuda() - - loss += torch.nn.functional.cross_entropy(logits[loss_mask].float(), target[loss_mask], reduction="sum") - n += loss_mask.sum() - - if limit_batches is not None and i + 1 >= limit_batches: - break - mean_loss: Tensor = loss / n - return mean_loss + config = ESM2Config() + model = config.configure_model(tokenizer) + assert isinstance(model, MegatronBioBertModel) + assert isinstance(model, ESM2Model) + assert isinstance(model.embedding, ESM2Embedding) -def test_esm2_model_initialized(esm2_model): - assert isinstance(esm2_model, MegatronBioBertModel) - assert isinstance(esm2_model, ESM2Model) - assert isinstance(esm2_model.embedding, ESM2Embedding) - -def test_esm2_nemo1_checkpoint(esm2_model): +def test_esm2_nemo1_checkpoint(): with tarfile.open(load("esm2/nv_650m:1.0"), "r") as ckpt, torch.no_grad(): ckpt_file = ckpt.extractfile("./model_weights.ckpt") @@ -149,10 +55,14 @@ def test_esm2_nemo1_checkpoint(esm2_model): # TODO: update Bionemo checkpoints old_state_dict.pop("model.language_model.rotary_pos_emb.inv_freq") - new_state_dict = esm2_model.state_dict_for_save_checkpoint() + with megatron_parallel_state_utils.distributed_model_parallel_state(): + tokenizer = get_tokenizer() + config = ESM2Config() + model = config.configure_model(tokenizer) + new_state_dict = model.state_dict_for_save_checkpoint() - # Set the new_model_prefix to "" since we are looking at the base megatron model and not the lightning module which stores a copy of - # this model into self.module + # Set the new_model_prefix to "" since we are looking at the base megatron model and not the lightning module + # which stores a copy of this model into self.module old_keys = { nemo1_to_nemo2_biobert_key_mapping(k, new_model_prefix="", te_mapping=True) for k in old_state_dict } @@ -180,7 +90,39 @@ def test_esm2_nemo1_checkpoint(esm2_model): assert not missing_old_keys, "There are keys in the old checkpoint that are missing from the new model." -def test_esm2_loss(esm2_config_w_ckpt, dummy_protein_dataset, dummy_parquet_train_val_inputs): +def _compute_loss(model, dataloader, vocab_size=None): + loss = 0 + n = 0 + limit_batches = 10 + for i, batch in enumerate(dataloader): + assert isinstance(batch, dict) + result = model(input_ids=batch["text"].cuda(), attention_mask=batch["attention_mask"].cuda()) + + # bionemo ESM2 vocab_size + if vocab_size is not None: + # token_logits is s,b and for simplicity here let's transpose to b,s. In general this reduces performance. + logits = result["token_logits"].transpose(0, 1).contiguous()[..., :vocab_size] + else: + logits = result.logits + + loss_mask = batch["loss_mask"].cuda() + target = batch["labels"].cuda() + + loss += torch.nn.functional.cross_entropy(logits[loss_mask].float(), target[loss_mask], reduction="sum") + n += loss_mask.sum() + + if limit_batches is not None and i + 1 >= limit_batches: + break + mean_loss: Tensor = loss / n + return mean_loss + + +def test_esm2_loss(dummy_protein_dataset, dummy_parquet_train_val_inputs): + hf_model_tag = "facebook/esm2_t6_8M_UR50D" + nv_model_tag = "esm2/8m:2.0" + # hf_model_tag = "facebook/esm2_t33_650M_UR50D" + # nv_model_tag = "esm2/650m:2.0" + train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs seed: int = 42 @@ -193,7 +135,7 @@ def test_esm2_loss(esm2_config_w_ckpt, dummy_protein_dataset, dummy_parquet_trai tokenizer = get_tokenizer() # ESM2 model initialized with params - model = esm2_config_w_ckpt.configure_model(tokenizer).cuda() + model = ESM2Config(initial_ckpt_path=str(load(nv_model_tag))).configure_model(tokenizer).cuda() # Initialize the data module. data_module = ESMDataModule( From 4de89e74b5ec7d73305d237b384e299879970c72 Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Tue, 21 Jan 2025 23:01:39 +0000 Subject: [PATCH 09/10] teardown trainer, pass precision to megatron parallel state Signed-off-by: Peter St. John --- .../src/bionemo/esm2/testing/compare.py | 2 +- .../tests/bionemo/esm2/model/test_model.py | 13 ++++++++ .../testing/megatron_parallel_state_utils.py | 30 ++++++++++++++++--- 3 files changed, 40 insertions(+), 5 deletions(-) diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/testing/compare.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/testing/compare.py index 1ff3d06b0..d93daa08d 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/testing/compare.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/testing/compare.py @@ -57,7 +57,7 @@ def assert_model_equivalence( input_ids = tokens["input_ids"] attention_mask = tokens["attention_mask"] - with torch.inference_mode(), megatron_parallel_state_utils.distributed_model_parallel_state(): + with torch.inference_mode(), megatron_parallel_state_utils.distributed_model_parallel_state(precision=precision): nemo_model = ( ESM2Config( initial_ckpt_path=str(ckpt_path), diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py index a3cc4a425..3f89598bb 100644 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py @@ -181,6 +181,12 @@ def test_model_equivalence_with_huggingface_8m(): assert_model_equivalence(ckpt_path, model_tag) +def test_model_equivalence_with_huggingface_8m_bf16(): + model_tag = "facebook/esm2_t6_8M_UR50D" + ckpt_path = load("esm2/8m:2.0") + assert_model_equivalence(ckpt_path, model_tag, precision="bf16-mixed") + + @pytest.mark.slow def test_model_equivalence_with_huggingface_650m(): model_tag = "facebook/esm2_t33_650M_UR50D" @@ -188,6 +194,13 @@ def test_model_equivalence_with_huggingface_650m(): assert_model_equivalence(ckpt_path, model_tag, atol=1e-4, rtol=1e-4) +@pytest.mark.slow +def test_model_equivalence_with_huggingface_650m_bf16(): + model_tag = "facebook/esm2_t33_650M_UR50D" + ckpt_path = load("esm2/650m:2.0") + assert_model_equivalence(ckpt_path, model_tag, precision="bf16") + + @pytest.mark.slow @pytest.mark.skip(reason="This test triggers a large download from huggingface and requires considerable GPU memory.") def test_model_equivalence_with_huggingface_3b(): diff --git a/sub-packages/bionemo-testing/src/bionemo/testing/megatron_parallel_state_utils.py b/sub-packages/bionemo-testing/src/bionemo/testing/megatron_parallel_state_utils.py index 1686de309..2b51cc8b1 100644 --- a/sub-packages/bionemo-testing/src/bionemo/testing/megatron_parallel_state_utils.py +++ b/sub-packages/bionemo-testing/src/bionemo/testing/megatron_parallel_state_utils.py @@ -42,9 +42,12 @@ def my_test(): import torch.distributed from megatron.core import parallel_state from megatron.core.tensor_parallel import random as tp_random +from nemo import lightning as nl from nemo.utils import logging from torch.testing._internal.distributed.fake_pg import FakeStore +from bionemo.core.utils.dtypes import PrecisionTypes, get_autocast_dtype + __all__: Sequence[str] = ( "clean_parallel_state_context", @@ -81,12 +84,24 @@ def _initialize_distributed_parallel_state( pipeline_model_parallel_split_rank: int = 0, context_parallel_size: int = 1, interactive: bool = False, -) -> None: + precision: PrecisionTypes = "fp32", +) -> pl.Trainer | None: + trainer = None # initialize pytorch DDP # if not interactive and not torch.distributed.is_initialized(): if not torch.distributed.is_initialized(): - logging.info("pytorch DDP is not initialized. Initializing with pytorch-lightening...") - trainer = pl.Trainer(devices=devices, strategy="ddp" if not interactive else "auto", num_nodes=1) + logging.info("pytorch DDP is not initialized. Initializing with pytorch-lightning...") + trainer = pl.Trainer( + devices=devices, + strategy="ddp" if not interactive else "auto", + num_nodes=1, + plugins=nl.MegatronMixedPrecision( + precision=precision, + params_dtype=get_autocast_dtype(precision), + pipeline_dtype=get_autocast_dtype(precision), + autocast_enabled=False, + ), + ) if trainer.strategy.launcher is not None: trainer.strategy.launcher.launch(_dummy, trainer=trainer) @@ -101,6 +116,8 @@ def _initialize_distributed_parallel_state( context_parallel_size=context_parallel_size, ) + return trainer + @contextmanager def clean_parallel_state_context() -> Iterator[None]: @@ -124,6 +141,7 @@ def distributed_model_parallel_state( pipeline_model_parallel_split_rank: int = 0, context_parallel_size: int = 1, interactive: bool = False, + precision: PrecisionTypes = "fp32", ) -> Iterator[None]: """Context manager for handling creating and cleaning up distributed model parallel state for tests. Use like: @@ -132,16 +150,18 @@ def distributed_model_parallel_state( # After the block your state is cleaned up. """ # noqa: D205 initial_states: Optional[Any] = None + trainer: pl.Trainer | None = None try: _teardown_apex_megatron_cuda() - _initialize_distributed_parallel_state( + trainer = _initialize_distributed_parallel_state( devices=devices, tensor_model_parallel_size=tensor_model_parallel_size, pipeline_model_parallel_size=pipeline_model_parallel_size, pipeline_model_parallel_split_rank=pipeline_model_parallel_split_rank, context_parallel_size=context_parallel_size, interactive=interactive, + precision=precision, ) # Our goal is to set required state on entry, and then restore current state on exit for the RNGs. # there are two possibilities that are handled below: @@ -174,6 +194,8 @@ def distributed_model_parallel_state( # Reset to the unset state tp_random.get_cuda_rng_tracker().reset() _teardown_apex_megatron_cuda() + if trainer is not None: + nl.teardown(trainer) @contextmanager From d833763563ab83e16a05cb7ddac349524efc3486 Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Wed, 22 Jan 2025 16:27:02 +0000 Subject: [PATCH 10/10] debugging fp16 precision Signed-off-by: Peter St. John --- .../src/bionemo/esm2/testing/compare.py | 83 +++++++++---------- .../tests/bionemo/esm2/model/test_convert.py | 26 ++++-- .../tests/bionemo/esm2/model/test_model.py | 21 +++-- .../src/bionemo/llm/model/config.py | 2 + .../testing/megatron_parallel_state_utils.py | 14 ++-- 5 files changed, 77 insertions(+), 69 deletions(-) diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/testing/compare.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/testing/compare.py index d93daa08d..e8690c1d0 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/testing/compare.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/testing/compare.py @@ -18,12 +18,12 @@ from pathlib import Path import torch +from megatron.core.transformer.module import Float16Module from transformers import AutoModelForMaskedLM from bionemo.core.utils.dtypes import PrecisionTypes, get_autocast_dtype from bionemo.esm2.data.tokenizer import get_tokenizer from bionemo.esm2.model.model import ESM2Config -from bionemo.testing import megatron_parallel_state_utils def assert_model_equivalence( @@ -57,44 +57,43 @@ def assert_model_equivalence( input_ids = tokens["input_ids"] attention_mask = tokens["attention_mask"] - with torch.inference_mode(), megatron_parallel_state_utils.distributed_model_parallel_state(precision=precision): - nemo_model = ( - ESM2Config( - initial_ckpt_path=str(ckpt_path), - include_embeddings=True, - include_hiddens=True, - params_dtype=get_autocast_dtype(precision), - pipeline_dtype=get_autocast_dtype(precision), - autocast_dtype=get_autocast_dtype(precision), - bf16=True, - ) # setting this speeds things up a lot) - .configure_model(tokenizer) - .to("cuda") - .eval() - ) - - nemo_output = nemo_model(input_ids, attention_mask) - nemo_logits = nemo_output["token_logits"].transpose(0, 1).contiguous()[..., : tokenizer.vocab_size] - nemo_hidden_state = nemo_output["hidden_states"] - - del nemo_model - gc.collect() - torch.cuda.empty_cache() - - hf_model = AutoModelForMaskedLM.from_pretrained(model_tag, torch_dtype=get_autocast_dtype(precision)).cuda() - hf_output_all = hf_model(input_ids, attention_mask, output_hidden_states=True) - hf_hidden_state = hf_output_all.hidden_states[-1] - - # Rather than directly comparing the logit or hidden state tensors, we compare their cosine similarity. These - # should be essentially 1 if the outputs are equivalent, but is less sensitive to small numerical differences. - # We don't care about the padding tokens, so we only compare the non-padding tokens. - logit_similarity = torch.nn.functional.cosine_similarity(nemo_logits, hf_output_all.logits, dim=2) - logit_similarity = logit_similarity[attention_mask == 1] - - hidden_state_similarity = torch.nn.functional.cosine_similarity(nemo_hidden_state, hf_hidden_state, dim=2) - hidden_state_similarity = hidden_state_similarity[attention_mask == 1] - - torch.testing.assert_close(logit_similarity, torch.ones_like(logit_similarity), rtol=rtol, atol=atol) - torch.testing.assert_close( - hidden_state_similarity, torch.ones_like(hidden_state_similarity), rtol=rtol, atol=atol - ) + dtype = get_autocast_dtype(precision) + nemo_config = ESM2Config( + initial_ckpt_path=str(ckpt_path), + include_embeddings=True, + include_hiddens=True, + params_dtype=dtype, + pipeline_dtype=dtype, + autocast_dtype=dtype, + bf16=dtype is torch.bfloat16, + fp16=dtype is torch.float16, + ) + + nemo_model = nemo_config.configure_model(tokenizer).to("cuda").eval() + + if dtype is torch.float16 or dtype is torch.bfloat16: + nemo_model = Float16Module(nemo_config, nemo_model) + + nemo_output = nemo_model(input_ids, attention_mask) + nemo_logits = nemo_output["token_logits"].transpose(0, 1).contiguous()[..., : tokenizer.vocab_size] + nemo_hidden_state = nemo_output["hidden_states"] + + del nemo_model + gc.collect() + torch.cuda.empty_cache() + + hf_model = AutoModelForMaskedLM.from_pretrained(model_tag, torch_dtype=get_autocast_dtype(precision)).cuda().eval() + hf_output_all = hf_model(input_ids, attention_mask, output_hidden_states=True) + hf_hidden_state = hf_output_all.hidden_states[-1] + + # Rather than directly comparing the logit or hidden state tensors, we compare their cosine similarity. These + # should be essentially 1 if the outputs are equivalent, but is less sensitive to small numerical differences. + # We don't care about the padding tokens, so we only compare the non-padding tokens. + logit_similarity = torch.nn.functional.cosine_similarity(nemo_logits, hf_output_all.logits, dim=2) + logit_similarity = logit_similarity[attention_mask == 1] + + hidden_state_similarity = torch.nn.functional.cosine_similarity(nemo_hidden_state, hf_hidden_state, dim=2) + hidden_state_similarity = hidden_state_similarity[attention_mask == 1] + + torch.testing.assert_close(logit_similarity, torch.ones_like(logit_similarity), rtol=rtol, atol=atol) + torch.testing.assert_close(hidden_state_similarity, torch.ones_like(hidden_state_similarity), rtol=rtol, atol=atol) diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_convert.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_convert.py index f2b257074..f3d6d2e69 100644 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_convert.py +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_convert.py @@ -21,23 +21,31 @@ from bionemo.esm2.model.model import ESM2Config from bionemo.esm2.testing.compare import assert_model_equivalence from bionemo.llm.model.biobert.lightning import biobert_lightning_module +from bionemo.testing import megatron_parallel_state_utils + + +# pytestmark = pytest.mark.xfail( +# reason="These tests are failing due to a bug in nemo global state when run in the same process as previous " +# "checkpoint save/load scripts." +# ) -@pytest.mark.xfail( - reason="This test is failing due to a bug in nemo global state when run in the same process as previous checkpoint" - "save/load scripts." -) def test_nemo2_conversion_equivalent_8m(tmp_path): model_tag = "facebook/esm2_t6_8M_UR50D" module = biobert_lightning_module(config=ESM2Config()) io.import_ckpt(module, f"hf://{model_tag}", tmp_path / "nemo_checkpoint") - assert_model_equivalence(tmp_path / "nemo_checkpoint", model_tag) + with megatron_parallel_state_utils.distributed_model_parallel_state(): + assert_model_equivalence(tmp_path / "nemo_checkpoint", model_tag) + + +def test_nemo2_conversion_equivalent_8m_bf16(tmp_path): + model_tag = "facebook/esm2_t6_8M_UR50D" + module = biobert_lightning_module(config=ESM2Config()) + io.import_ckpt(module, f"hf://{model_tag}", tmp_path / "nemo_checkpoint") + with megatron_parallel_state_utils.distributed_model_parallel_state(precision="bf16"): + assert_model_equivalence(tmp_path / "nemo_checkpoint", model_tag, precision="bf16") -@pytest.mark.xfail( - reason="This test is failing due to a bug in nemo global state when run in the same process as previous checkpoint" - "save/load scripts." -) @pytest.mark.slow def test_nemo2_conversion_equivalent_650m(tmp_path): model_tag = "facebook/esm2_t33_650M_UR50D" diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py index 3f89598bb..8895b3719 100644 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py @@ -175,30 +175,28 @@ def test_esm2_loss(dummy_protein_dataset, dummy_parquet_train_val_inputs): torch.testing.assert_close(mean_loss, hf_mean_loss, atol=1e-3, rtol=0.0) -def test_model_equivalence_with_huggingface_8m(): +@pytest.mark.parametrize("precision", ["fp32", "bf16", "fp16", "bf16-mixed"]) +def test_model_equivalence_with_huggingface_8m(precision): model_tag = "facebook/esm2_t6_8M_UR50D" ckpt_path = load("esm2/8m:2.0") - assert_model_equivalence(ckpt_path, model_tag) - - -def test_model_equivalence_with_huggingface_8m_bf16(): - model_tag = "facebook/esm2_t6_8M_UR50D" - ckpt_path = load("esm2/8m:2.0") - assert_model_equivalence(ckpt_path, model_tag, precision="bf16-mixed") + with megatron_parallel_state_utils.distributed_model_parallel_state(precision=precision): + assert_model_equivalence(ckpt_path, model_tag, precision=precision) @pytest.mark.slow def test_model_equivalence_with_huggingface_650m(): model_tag = "facebook/esm2_t33_650M_UR50D" ckpt_path = load("esm2/650m:2.0") - assert_model_equivalence(ckpt_path, model_tag, atol=1e-4, rtol=1e-4) + with megatron_parallel_state_utils.distributed_model_parallel_state(): + assert_model_equivalence(ckpt_path, model_tag, atol=1e-4, rtol=1e-4) @pytest.mark.slow def test_model_equivalence_with_huggingface_650m_bf16(): model_tag = "facebook/esm2_t33_650M_UR50D" ckpt_path = load("esm2/650m:2.0") - assert_model_equivalence(ckpt_path, model_tag, precision="bf16") + with megatron_parallel_state_utils.distributed_model_parallel_state(precision="bf16"): + assert_model_equivalence(ckpt_path, model_tag, precision="bf16") @pytest.mark.slow @@ -206,4 +204,5 @@ def test_model_equivalence_with_huggingface_650m_bf16(): def test_model_equivalence_with_huggingface_3b(): model_tag = "facebook/esm2_t36_3B_UR50D" ckpt_path = load("esm2/3b:2.0") - assert_model_equivalence(ckpt_path, model_tag, atol=1e-4, rtol=1e-4) + with megatron_parallel_state_utils.distributed_model_parallel_state(): + assert_model_equivalence(ckpt_path, model_tag, atol=1e-4, rtol=1e-4) diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/model/config.py b/sub-packages/bionemo-llm/src/bionemo/llm/model/config.py index 425726da4..bf48a520c 100644 --- a/sub-packages/bionemo-llm/src/bionemo/llm/model/config.py +++ b/sub-packages/bionemo-llm/src/bionemo/llm/model/config.py @@ -45,6 +45,8 @@ "initial_ckpt_path_ignore_weights", "initial_ckpt_path", "model_cls", + "bf16", + "fp16", ] OVERRIDE_BIONEMO_CONFIG_DEFAULTS = deepcopy(_OVERRIDE_BIONEMO_CONFIG_DEFAULTS) # copy for export diff --git a/sub-packages/bionemo-testing/src/bionemo/testing/megatron_parallel_state_utils.py b/sub-packages/bionemo-testing/src/bionemo/testing/megatron_parallel_state_utils.py index 2b51cc8b1..8ef076223 100644 --- a/sub-packages/bionemo-testing/src/bionemo/testing/megatron_parallel_state_utils.py +++ b/sub-packages/bionemo-testing/src/bionemo/testing/megatron_parallel_state_utils.py @@ -46,7 +46,7 @@ def my_test(): from nemo.utils import logging from torch.testing._internal.distributed.fake_pg import FakeStore -from bionemo.core.utils.dtypes import PrecisionTypes, get_autocast_dtype +from bionemo.core.utils.dtypes import PrecisionTypes __all__: Sequence[str] = ( @@ -95,12 +95,12 @@ def _initialize_distributed_parallel_state( devices=devices, strategy="ddp" if not interactive else "auto", num_nodes=1, - plugins=nl.MegatronMixedPrecision( - precision=precision, - params_dtype=get_autocast_dtype(precision), - pipeline_dtype=get_autocast_dtype(precision), - autocast_enabled=False, - ), + # plugins=nl.MegatronMixedPrecision( + # precision=precision, + # params_dtype=get_autocast_dtype(precision), + # pipeline_dtype=get_autocast_dtype(precision), + # autocast_enabled=False, + # ), ) if trainer.strategy.launcher is not None: