Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ESM-2 to NeMo checkpoint conversion #537

Merged
merged 10 commits into from
Jan 22, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,32 @@
description: >
A pretrained 650M parameter ESM2 model. See https://ngc.nvidia.com/catalog/models/nvidia:clara:esm2nv650m.

- tag: nv_3b:2.1
ngc: "nvidia/clara/esm2nv3b:2.1"
- tag: 8m:2.0
ngc: nvidia/clara/esm2nv8m:2.0
ngc_registry: model
pbss: "s3://general-purpose/esm2/checkpoints/3b/esm2_3b_checkpoint.tar.gz"
sha256: a79327a4054bf8d1d7075e1b3c961dbc503da02d72ed15f707d9cbbd49d181b6 # pragma: allowlist secret
pbss: s3://general-purpose/esm2/checkpoints/converted/8m/esm2_hf_converted_8m_checkpoint.tar.gz
sha256: 2957b2c36d5978d0f595d6f1b72104b312621cf0329209086537b613c1c96d16 # pragma: allowlist secret
owner: Peter St John <[email protected]>
description: >
An ESM-2 3B model pre-trained on NVIDIA's train/test data split.
A NeMo2 compatible checkpoint converted from the huggingface facebook/esm2_t6_8M_UR50D model.

- tag: nv_650m:2.1
ngc: "nvidia/clara/esm2nv650m:2.1"
- tag: 650m:2.0
ngc: nvidia/clara/esm2nv650m:2.0
ngc_registry: model
pbss: "s3://general-purpose/esm2/checkpoints/650m/esm2_650m_checkpoint.tar.gz"
sha256: b83e9b5d62f1499b443817c5cd0facd3bdd4013a51a897e05e17228bf650befe # pragma: allowlist secret
owner: Peter St John <pstjohn@nvidia.com>
pbss: "s3://bionemo-ci/models/esm2_650M_nemo2.tar.gz"
sha256: 0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997 # pragma: allowlist secret
owner: Farhad Ramezanghorbani <farhadr@nvidia.com>
description: >
An ESM-2 650M model pre-trained on NVIDIA's train/test data split.
A NeMo2 compatible checkpoint converted from the huggingface facebook/esm2_t33_650M_UR50D model.

- tag: 3b:2.0
ngc: nvidia/clara/esm2nv3b:2.0
ngc_registry: model
pbss: "s3://bionemo-ci/models/esm2_3B_nemo2.tar.gz"
sha256: a2248cfed1ef39f83bd32a0e08b84c0a8f39325d383e2d92767022ff7f5260ed # pragma: allowlist secret
owner: Farhad Ramezanghorbani <[email protected]>
description: >
A NeMo2 compatible checkpoint converted from the huggingface facebook/esm2_t36_3B_UR50D model.

# - tag: nv_8m:2.1
# ngc: "nvidia/clara/esm2nv8m:2.1"
Expand All @@ -34,53 +43,44 @@
# description: >
# An ESM-2 8M model pre-trained on NVIDIA's train/test data split.

- tag: 8m:2.0
ngc: "nvidia/clara/esm2nv8m:2.0"
- tag: nv_650m:2.1
ngc: "nvidia/clara/esm2nv650m:2.1"
ngc_registry: model
pbss: "s3://general-purpose/esm2/checkpoints/converted/8m/esm2_hf_converted_8m_checkpoint.tar.gz"
sha256: 2957b2c36d5978d0f595d6f1b72104b312621cf0329209086537b613c1c96d16 # pragma: allowlist secret
pbss: "s3://general-purpose/esm2/checkpoints/650m/esm2_650m_checkpoint.tar.gz"
sha256: b83e9b5d62f1499b443817c5cd0facd3bdd4013a51a897e05e17228bf650befe # pragma: allowlist secret
owner: Peter St John <[email protected]>
description: >
The original 8M parameter ESM2 model weights converted to the NeMo2 checkpoint format.

- tag: 650m:2.0
ngc: nvidia/clara/esm2nv650m:2.0
ngc_registry: model
pbss: "s3://bionemo-ci/models/esm2_650M_nemo2.tar.gz"
sha256: 0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997 # pragma: allowlist secret
owner: Farhad Ramezanghorbani <[email protected]>
description: >
The original 650M parameter ESM2 model weights converted to the NeMo2 checkpoint format.
An ESM-2 650M model pre-trained on NVIDIA's train/test data split.

- tag: 3b:2.0
ngc: nvidia/clara/esm2nv3b:2.0
- tag: nv_3b:2.1
ngc: "nvidia/clara/esm2nv3b:2.1"
ngc_registry: model
pbss: "s3://bionemo-ci/models/esm2_3B_nemo2.tar.gz"
sha256: a2248cfed1ef39f83bd32a0e08b84c0a8f39325d383e2d92767022ff7f5260ed # pragma: allowlist secret
owner: Farhad Ramezanghorbani <farhadr@nvidia.com>
pbss: "s3://general-purpose/esm2/checkpoints/3b/esm2_3b_checkpoint.tar.gz"
sha256: a79327a4054bf8d1d7075e1b3c961dbc503da02d72ed15f707d9cbbd49d181b6 # pragma: allowlist secret
owner: Peter St John <pstjohn@nvidia.com>
description: >
The original 3B parameter ESM2 model c converted to the NeMo2 checkpoint format.
An ESM-2 3B model pre-trained on NVIDIA's train/test data split.

- tag: fulldata_esm2_pretrain:2.0
ngc: nvidia/clara/esm2_pretrain_nemo2_data:1.0
ngc_registry: resource
pbss: "s3://general-purpose/esm2/pretrain/2024_03.tar.gz"
sha256: 404d0ad8de58fa8aae96f8d9f54263a088bc7e4f7d668215afbe04c28416151b # pragma: allowlist secret
sha256: 404d0ad8de58fa8aae96f8d9f54263a088bc7e4f7d668215afbe04c28416151b # pragma: allowlist secret
owner: Peter St John <[email protected]>
description: Full data for ESM2 pretraining.

- tag: testdata_esm2_pretrain:2.0
ngc: nvidia/clara/esm2_pretrain_nemo2_testdata:1.0
ngc_registry: resource
pbss: "s3://general-purpose/esm2/pretrain/2024_03_sanity.tar.gz"
sha256: 006911f92bbc0ded7ea302bbdbfab4c694b409e699c32fd49de1c527a99dba3e # pragma: allowlist secret
sha256: 006911f92bbc0ded7ea302bbdbfab4c694b409e699c32fd49de1c527a99dba3e # pragma: allowlist secret
owner: Peter St John <[email protected]>
description: Test data for ESM2 pretraining.

- tag: esm2_inference_testdata:2.0
ngc: nvidia/clara/esm2_inference_testdata:2.0 # TODO: upload to NGC
ngc: nvidia/clara/esm2_inference_testdata:2.0 # TODO: upload to NGC
ngc_registry: resource
pbss: "s3://bionemo-ci/test_data/esm2/artificial_protein_sequences.csv"
sha256: 14ae3acfbf82218bc9e3e53d21a5b0594ba7c0369e169c9f1034e3fe4378d175 # pragma: allowlist secret
sha256: 14ae3acfbf82218bc9e3e53d21a5b0594ba7c0369e169c9f1034e3fe4378d175 # pragma: allowlist secret
owner: Farhad Ramezanghorbani <[email protected]>
description: Test data for ESM2 inference.
179 changes: 179 additions & 0 deletions sub-packages/bionemo-esm2/src/bionemo/esm2/model/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from pathlib import Path

import torch
from nemo.lightning import io, teardown
from nemo.lightning.pytorch.utils import dtype_from_hf
from transformers import AutoConfig as HFAutoConfig
from transformers import AutoModelForMaskedLM

from bionemo.esm2.data.tokenizer import BioNeMoESMTokenizer, get_tokenizer
from bionemo.esm2.model.model import ESM2Config
from bionemo.llm.lightning import BionemoLightningModule
from bionemo.llm.model.biobert.lightning import biobert_lightning_module


@io.model_importer(BionemoLightningModule, "hf")
class HFESM2Importer(io.ModelConnector[AutoModelForMaskedLM, BionemoLightningModule]):
"""Converts a Hugging Face ESM-2 model to a NeMo ESM-2 model."""

def init(self) -> BionemoLightningModule:
"""Initialize the converted model."""
return biobert_lightning_module(self.config, tokenizer=self.tokenizer)

def apply(self, output_path: Path) -> Path:
"""Applies the transformation.

Largely inspired by
https://docs.nvidia.com/nemo-framework/user-guide/latest/nemo-2.0/features/hf-integration.html
"""
source = AutoModelForMaskedLM.from_pretrained(str(self), trust_remote_code=True, torch_dtype="auto")
target = self.init()
trainer = self.nemo_setup(target)
self.convert_state(source, target)
self.nemo_save(output_path, trainer)

print(f"Converted ESM-2 model to Nemo, model saved to {output_path}")

teardown(trainer, target)
del trainer, target

return output_path

def convert_state(self, source, target):
"""Converting HF state dict to NeMo state dict."""
mapping = {
# "esm.encoder.layer.0.attention.self.rotary_embeddings.inv_freq": "rotary_pos_emb.inv_freq",
"esm.encoder.layer.*.attention.output.dense.weight": "encoder.layers.*.self_attention.linear_proj.weight",
"esm.encoder.layer.*.attention.output.dense.bias": "encoder.layers.*.self_attention.linear_proj.bias",
"esm.encoder.layer.*.attention.LayerNorm.weight": "encoder.layers.*.self_attention.linear_qkv.layer_norm_weight",
"esm.encoder.layer.*.attention.LayerNorm.bias": "encoder.layers.*.self_attention.linear_qkv.layer_norm_bias",
"esm.encoder.layer.*.intermediate.dense.weight": "encoder.layers.*.mlp.linear_fc1.weight",
"esm.encoder.layer.*.intermediate.dense.bias": "encoder.layers.*.mlp.linear_fc1.bias",
"esm.encoder.layer.*.output.dense.weight": "encoder.layers.*.mlp.linear_fc2.weight",
"esm.encoder.layer.*.output.dense.bias": "encoder.layers.*.mlp.linear_fc2.bias",
"esm.encoder.layer.*.LayerNorm.weight": "encoder.layers.*.mlp.linear_fc1.layer_norm_weight",
"esm.encoder.layer.*.LayerNorm.bias": "encoder.layers.*.mlp.linear_fc1.layer_norm_bias",
"esm.encoder.emb_layer_norm_after.weight": "encoder.final_layernorm.weight",
"esm.encoder.emb_layer_norm_after.bias": "encoder.final_layernorm.bias",
"lm_head.dense.weight": "lm_head.dense.weight",
"lm_head.dense.bias": "lm_head.dense.bias",
"lm_head.layer_norm.weight": "lm_head.layer_norm.weight",
"lm_head.layer_norm.bias": "lm_head.layer_norm.bias",
}

# lm_head.bias
return io.apply_transforms(
source,
target,
mapping=mapping,
transforms=[_pad_embeddings, _pad_bias, _import_qkv_weight, _import_qkv_bias],
)

@property
def tokenizer(self) -> BioNeMoESMTokenizer:
"""We just have the one tokenizer for ESM-2."""
return get_tokenizer()

@property
def config(self) -> ESM2Config:
"""Returns the transformed ESM-2 config given the model tag."""
source = HFAutoConfig.from_pretrained(str(self), trust_remote_code=True)
output = ESM2Config(
num_layers=source.num_hidden_layers,
hidden_size=source.hidden_size,
ffn_hidden_size=source.intermediate_size,
position_embedding_type="rope",
num_attention_heads=source.num_attention_heads,
seq_length=source.max_position_embeddings,
fp16=(dtype_from_hf(source) == torch.float16),
bf16=(dtype_from_hf(source) == torch.bfloat16),
params_dtype=dtype_from_hf(source),
)

return output


@io.state_transform(
source_key="esm.embeddings.word_embeddings.weight",
target_key="embedding.word_embeddings.weight",
)
def _pad_embeddings(ctx: io.TransformCTX, source_embed):
"""Pad the embedding layer to the new input dimension."""
nemo_embedding_dimension = ctx.target.config.make_vocab_size_divisible_by
hf_embedding_dimension = source_embed.size(0)
num_padding_rows = nemo_embedding_dimension - hf_embedding_dimension
padding_rows = torch.zeros(num_padding_rows, source_embed.size(1))
return torch.cat((source_embed, padding_rows), dim=0)


@io.state_transform(
source_key="lm_head.bias",
target_key="output_layer.bias",
)
def _pad_bias(ctx: io.TransformCTX, source_bias):
"""Pad the embedding layer to the new input dimension."""
nemo_embedding_dimension = ctx.target.config.make_vocab_size_divisible_by
hf_embedding_dimension = source_bias.size(0)
output_bias = torch.zeros(nemo_embedding_dimension, dtype=source_bias.dtype, device=source_bias.device)
output_bias[:hf_embedding_dimension] = source_bias
return output_bias


@io.state_transform(
source_key=(
"esm.encoder.layer.*.attention.self.query.weight",
"esm.encoder.layer.*.attention.self.key.weight",
"esm.encoder.layer.*.attention.self.value.weight",
),
target_key="encoder.layers.*.self_attention.linear_qkv.weight",
)
def _import_qkv_weight(ctx: io.TransformCTX, query, key, value):
"""Pad the embedding layer to the new input dimension."""
concat_weights = torch.cat((query, key, value), dim=0)
input_shape = concat_weights.size()
np = ctx.target.config.num_attention_heads
# transpose weights
# [sequence length, batch size, num_splits_model_parallel * attention head size * #attention heads]
# --> [sequence length, batch size, attention head size * num_splits_model_parallel * #attention heads]
concat_weights = concat_weights.view(3, np, -1, query.size()[-1])
concat_weights = concat_weights.transpose(0, 1).contiguous()
concat_weights = concat_weights.view(*input_shape)
return concat_weights


@io.state_transform(
source_key=(
"esm.encoder.layer.*.attention.self.query.bias",
"esm.encoder.layer.*.attention.self.key.bias",
"esm.encoder.layer.*.attention.self.value.bias",
),
target_key="encoder.layers.*.self_attention.linear_qkv.bias",
)
def _import_qkv_bias(ctx: io.TransformCTX, query, key, value):
"""Pad the embedding layer to the new input dimension."""
concat_biases = torch.cat((query, key, value), dim=0)
input_shape = concat_biases.size()
np = ctx.target.config.num_attention_heads
# transpose biases
# [num_splits_model_parallel * attention head size * #attention heads]
# --> [attention head size * num_splits_model_parallel * #attention heads]
concat_biases = concat_biases.view(3, np, -1)
concat_biases = concat_biases.transpose(0, 1).contiguous()
concat_biases = concat_biases.view(*input_shape)
return concat_biases
14 changes: 14 additions & 0 deletions sub-packages/bionemo-esm2/src/bionemo/esm2/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading
Loading