Skip to content

Commit

Permalink
ESM-2 to NeMo checkpoint conversion (#537)
Browse files Browse the repository at this point in the history
Adds a conversion script to convert from huggingface to ESM-2
checkpoints

---------

Signed-off-by: Peter St. John <[email protected]>
  • Loading branch information
pstjohn authored Jan 22, 2025
1 parent a2fd916 commit 257e918
Show file tree
Hide file tree
Showing 8 changed files with 497 additions and 200 deletions.
70 changes: 35 additions & 35 deletions sub-packages/bionemo-core/src/bionemo/core/data/resources/esm2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,32 @@
description: >
A pretrained 650M parameter ESM2 model. See https://ngc.nvidia.com/catalog/models/nvidia:clara:esm2nv650m.
- tag: nv_3b:2.1
ngc: "nvidia/clara/esm2nv3b:2.1"
- tag: 8m:2.0
ngc: nvidia/clara/esm2nv8m:2.0
ngc_registry: model
pbss: "s3://general-purpose/esm2/checkpoints/3b/esm2_3b_checkpoint.tar.gz"
sha256: a79327a4054bf8d1d7075e1b3c961dbc503da02d72ed15f707d9cbbd49d181b6 # pragma: allowlist secret
pbss: s3://general-purpose/esm2/checkpoints/converted/8m/esm2_hf_converted_8m_checkpoint.tar.gz
sha256: 2957b2c36d5978d0f595d6f1b72104b312621cf0329209086537b613c1c96d16 # pragma: allowlist secret
owner: Peter St John <[email protected]>
description: >
An ESM-2 3B model pre-trained on NVIDIA's train/test data split.
A NeMo2 compatible checkpoint converted from the huggingface facebook/esm2_t6_8M_UR50D model.
- tag: nv_650m:2.1
ngc: "nvidia/clara/esm2nv650m:2.1"
- tag: 650m:2.0
ngc: nvidia/clara/esm2nv650m:2.0
ngc_registry: model
pbss: "s3://general-purpose/esm2/checkpoints/650m/esm2_650m_checkpoint.tar.gz"
sha256: b83e9b5d62f1499b443817c5cd0facd3bdd4013a51a897e05e17228bf650befe # pragma: allowlist secret
owner: Peter St John <pstjohn@nvidia.com>
pbss: "s3://bionemo-ci/models/esm2_650M_nemo2.tar.gz"
sha256: 0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997 # pragma: allowlist secret
owner: Farhad Ramezanghorbani <farhadr@nvidia.com>
description: >
An ESM-2 650M model pre-trained on NVIDIA's train/test data split.
A NeMo2 compatible checkpoint converted from the huggingface facebook/esm2_t33_650M_UR50D model.
- tag: 3b:2.0
ngc: nvidia/clara/esm2nv3b:2.0
ngc_registry: model
pbss: "s3://bionemo-ci/models/esm2_3B_nemo2.tar.gz"
sha256: a2248cfed1ef39f83bd32a0e08b84c0a8f39325d383e2d92767022ff7f5260ed # pragma: allowlist secret
owner: Farhad Ramezanghorbani <[email protected]>
description: >
A NeMo2 compatible checkpoint converted from the huggingface facebook/esm2_t36_3B_UR50D model.
# - tag: nv_8m:2.1
# ngc: "nvidia/clara/esm2nv8m:2.1"
Expand All @@ -34,53 +43,44 @@
# description: >
# An ESM-2 8M model pre-trained on NVIDIA's train/test data split.

- tag: 8m:2.0
ngc: "nvidia/clara/esm2nv8m:2.0"
- tag: nv_650m:2.1
ngc: "nvidia/clara/esm2nv650m:2.1"
ngc_registry: model
pbss: "s3://general-purpose/esm2/checkpoints/converted/8m/esm2_hf_converted_8m_checkpoint.tar.gz"
sha256: 2957b2c36d5978d0f595d6f1b72104b312621cf0329209086537b613c1c96d16 # pragma: allowlist secret
pbss: "s3://general-purpose/esm2/checkpoints/650m/esm2_650m_checkpoint.tar.gz"
sha256: b83e9b5d62f1499b443817c5cd0facd3bdd4013a51a897e05e17228bf650befe # pragma: allowlist secret
owner: Peter St John <[email protected]>
description: >
The original 8M parameter ESM2 model weights converted to the NeMo2 checkpoint format.
- tag: 650m:2.0
ngc: nvidia/clara/esm2nv650m:2.0
ngc_registry: model
pbss: "s3://bionemo-ci/models/esm2_650M_nemo2.tar.gz"
sha256: 0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997 # pragma: allowlist secret
owner: Farhad Ramezanghorbani <[email protected]>
description: >
The original 650M parameter ESM2 model weights converted to the NeMo2 checkpoint format.
An ESM-2 650M model pre-trained on NVIDIA's train/test data split.
- tag: 3b:2.0
ngc: nvidia/clara/esm2nv3b:2.0
- tag: nv_3b:2.1
ngc: "nvidia/clara/esm2nv3b:2.1"
ngc_registry: model
pbss: "s3://bionemo-ci/models/esm2_3B_nemo2.tar.gz"
sha256: a2248cfed1ef39f83bd32a0e08b84c0a8f39325d383e2d92767022ff7f5260ed # pragma: allowlist secret
owner: Farhad Ramezanghorbani <farhadr@nvidia.com>
pbss: "s3://general-purpose/esm2/checkpoints/3b/esm2_3b_checkpoint.tar.gz"
sha256: a79327a4054bf8d1d7075e1b3c961dbc503da02d72ed15f707d9cbbd49d181b6 # pragma: allowlist secret
owner: Peter St John <pstjohn@nvidia.com>
description: >
The original 3B parameter ESM2 model c converted to the NeMo2 checkpoint format.
An ESM-2 3B model pre-trained on NVIDIA's train/test data split.
- tag: fulldata_esm2_pretrain:2.0
ngc: nvidia/clara/esm2_pretrain_nemo2_data:1.0
ngc_registry: resource
pbss: "s3://general-purpose/esm2/pretrain/2024_03.tar.gz"
sha256: 404d0ad8de58fa8aae96f8d9f54263a088bc7e4f7d668215afbe04c28416151b # pragma: allowlist secret
sha256: 404d0ad8de58fa8aae96f8d9f54263a088bc7e4f7d668215afbe04c28416151b # pragma: allowlist secret
owner: Peter St John <[email protected]>
description: Full data for ESM2 pretraining.

- tag: testdata_esm2_pretrain:2.0
ngc: nvidia/clara/esm2_pretrain_nemo2_testdata:1.0
ngc_registry: resource
pbss: "s3://general-purpose/esm2/pretrain/2024_03_sanity.tar.gz"
sha256: 006911f92bbc0ded7ea302bbdbfab4c694b409e699c32fd49de1c527a99dba3e # pragma: allowlist secret
sha256: 006911f92bbc0ded7ea302bbdbfab4c694b409e699c32fd49de1c527a99dba3e # pragma: allowlist secret
owner: Peter St John <[email protected]>
description: Test data for ESM2 pretraining.

- tag: esm2_inference_testdata:2.0
ngc: nvidia/clara/esm2_inference_testdata:2.0 # TODO: upload to NGC
ngc: nvidia/clara/esm2_inference_testdata:2.0 # TODO: upload to NGC
ngc_registry: resource
pbss: "s3://bionemo-ci/test_data/esm2/artificial_protein_sequences.csv"
sha256: 14ae3acfbf82218bc9e3e53d21a5b0594ba7c0369e169c9f1034e3fe4378d175 # pragma: allowlist secret
sha256: 14ae3acfbf82218bc9e3e53d21a5b0594ba7c0369e169c9f1034e3fe4378d175 # pragma: allowlist secret
owner: Farhad Ramezanghorbani <[email protected]>
description: Test data for ESM2 inference.
179 changes: 179 additions & 0 deletions sub-packages/bionemo-esm2/src/bionemo/esm2/model/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from pathlib import Path

import torch
from nemo.lightning import io, teardown
from nemo.lightning.pytorch.utils import dtype_from_hf
from transformers import AutoConfig as HFAutoConfig
from transformers import AutoModelForMaskedLM

from bionemo.esm2.data.tokenizer import BioNeMoESMTokenizer, get_tokenizer
from bionemo.esm2.model.model import ESM2Config
from bionemo.llm.lightning import BionemoLightningModule
from bionemo.llm.model.biobert.lightning import biobert_lightning_module


@io.model_importer(BionemoLightningModule, "hf")
class HFESM2Importer(io.ModelConnector[AutoModelForMaskedLM, BionemoLightningModule]):
"""Converts a Hugging Face ESM-2 model to a NeMo ESM-2 model."""

def init(self) -> BionemoLightningModule:
"""Initialize the converted model."""
return biobert_lightning_module(self.config, tokenizer=self.tokenizer)

def apply(self, output_path: Path) -> Path:
"""Applies the transformation.
Largely inspired by
https://docs.nvidia.com/nemo-framework/user-guide/latest/nemo-2.0/features/hf-integration.html
"""
source = AutoModelForMaskedLM.from_pretrained(str(self), trust_remote_code=True, torch_dtype="auto")
target = self.init()
trainer = self.nemo_setup(target)
self.convert_state(source, target)
self.nemo_save(output_path, trainer)

print(f"Converted ESM-2 model to Nemo, model saved to {output_path}")

teardown(trainer, target)
del trainer, target

return output_path

def convert_state(self, source, target):
"""Converting HF state dict to NeMo state dict."""
mapping = {
# "esm.encoder.layer.0.attention.self.rotary_embeddings.inv_freq": "rotary_pos_emb.inv_freq",
"esm.encoder.layer.*.attention.output.dense.weight": "encoder.layers.*.self_attention.linear_proj.weight",
"esm.encoder.layer.*.attention.output.dense.bias": "encoder.layers.*.self_attention.linear_proj.bias",
"esm.encoder.layer.*.attention.LayerNorm.weight": "encoder.layers.*.self_attention.linear_qkv.layer_norm_weight",
"esm.encoder.layer.*.attention.LayerNorm.bias": "encoder.layers.*.self_attention.linear_qkv.layer_norm_bias",
"esm.encoder.layer.*.intermediate.dense.weight": "encoder.layers.*.mlp.linear_fc1.weight",
"esm.encoder.layer.*.intermediate.dense.bias": "encoder.layers.*.mlp.linear_fc1.bias",
"esm.encoder.layer.*.output.dense.weight": "encoder.layers.*.mlp.linear_fc2.weight",
"esm.encoder.layer.*.output.dense.bias": "encoder.layers.*.mlp.linear_fc2.bias",
"esm.encoder.layer.*.LayerNorm.weight": "encoder.layers.*.mlp.linear_fc1.layer_norm_weight",
"esm.encoder.layer.*.LayerNorm.bias": "encoder.layers.*.mlp.linear_fc1.layer_norm_bias",
"esm.encoder.emb_layer_norm_after.weight": "encoder.final_layernorm.weight",
"esm.encoder.emb_layer_norm_after.bias": "encoder.final_layernorm.bias",
"lm_head.dense.weight": "lm_head.dense.weight",
"lm_head.dense.bias": "lm_head.dense.bias",
"lm_head.layer_norm.weight": "lm_head.layer_norm.weight",
"lm_head.layer_norm.bias": "lm_head.layer_norm.bias",
}

# lm_head.bias
return io.apply_transforms(
source,
target,
mapping=mapping,
transforms=[_pad_embeddings, _pad_bias, _import_qkv_weight, _import_qkv_bias],
)

@property
def tokenizer(self) -> BioNeMoESMTokenizer:
"""We just have the one tokenizer for ESM-2."""
return get_tokenizer()

@property
def config(self) -> ESM2Config:
"""Returns the transformed ESM-2 config given the model tag."""
source = HFAutoConfig.from_pretrained(str(self), trust_remote_code=True)
output = ESM2Config(
num_layers=source.num_hidden_layers,
hidden_size=source.hidden_size,
ffn_hidden_size=source.intermediate_size,
position_embedding_type="rope",
num_attention_heads=source.num_attention_heads,
seq_length=source.max_position_embeddings,
fp16=(dtype_from_hf(source) == torch.float16),
bf16=(dtype_from_hf(source) == torch.bfloat16),
params_dtype=dtype_from_hf(source),
)

return output


@io.state_transform(
source_key="esm.embeddings.word_embeddings.weight",
target_key="embedding.word_embeddings.weight",
)
def _pad_embeddings(ctx: io.TransformCTX, source_embed):
"""Pad the embedding layer to the new input dimension."""
nemo_embedding_dimension = ctx.target.config.make_vocab_size_divisible_by
hf_embedding_dimension = source_embed.size(0)
num_padding_rows = nemo_embedding_dimension - hf_embedding_dimension
padding_rows = torch.zeros(num_padding_rows, source_embed.size(1))
return torch.cat((source_embed, padding_rows), dim=0)


@io.state_transform(
source_key="lm_head.bias",
target_key="output_layer.bias",
)
def _pad_bias(ctx: io.TransformCTX, source_bias):
"""Pad the embedding layer to the new input dimension."""
nemo_embedding_dimension = ctx.target.config.make_vocab_size_divisible_by
hf_embedding_dimension = source_bias.size(0)
output_bias = torch.zeros(nemo_embedding_dimension, dtype=source_bias.dtype, device=source_bias.device)
output_bias[:hf_embedding_dimension] = source_bias
return output_bias


@io.state_transform(
source_key=(
"esm.encoder.layer.*.attention.self.query.weight",
"esm.encoder.layer.*.attention.self.key.weight",
"esm.encoder.layer.*.attention.self.value.weight",
),
target_key="encoder.layers.*.self_attention.linear_qkv.weight",
)
def _import_qkv_weight(ctx: io.TransformCTX, query, key, value):
"""Pad the embedding layer to the new input dimension."""
concat_weights = torch.cat((query, key, value), dim=0)
input_shape = concat_weights.size()
np = ctx.target.config.num_attention_heads
# transpose weights
# [sequence length, batch size, num_splits_model_parallel * attention head size * #attention heads]
# --> [sequence length, batch size, attention head size * num_splits_model_parallel * #attention heads]
concat_weights = concat_weights.view(3, np, -1, query.size()[-1])
concat_weights = concat_weights.transpose(0, 1).contiguous()
concat_weights = concat_weights.view(*input_shape)
return concat_weights


@io.state_transform(
source_key=(
"esm.encoder.layer.*.attention.self.query.bias",
"esm.encoder.layer.*.attention.self.key.bias",
"esm.encoder.layer.*.attention.self.value.bias",
),
target_key="encoder.layers.*.self_attention.linear_qkv.bias",
)
def _import_qkv_bias(ctx: io.TransformCTX, query, key, value):
"""Pad the embedding layer to the new input dimension."""
concat_biases = torch.cat((query, key, value), dim=0)
input_shape = concat_biases.size()
np = ctx.target.config.num_attention_heads
# transpose biases
# [num_splits_model_parallel * attention head size * #attention heads]
# --> [attention head size * num_splits_model_parallel * #attention heads]
concat_biases = concat_biases.view(3, np, -1)
concat_biases = concat_biases.transpose(0, 1).contiguous()
concat_biases = concat_biases.view(*input_shape)
return concat_biases
14 changes: 14 additions & 0 deletions sub-packages/bionemo-esm2/src/bionemo/esm2/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading

0 comments on commit 257e918

Please sign in to comment.