Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
fix pylint
Browse files Browse the repository at this point in the history
Signed-off-by: changwangss <[email protected]>
  • Loading branch information
changwangss committed Jul 4, 2024
1 parent 33e15d8 commit 54c8157
Showing 1 changed file with 6 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -366,26 +366,19 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if use_vllm is not None:
logger.info("The backend is vLLM.")
from vllm import LLM # pylint: disable=E1101
from vllm.model_executor.model_loader import (
get_model_loader,
) # pylint: disable=E0611
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
) # pylint: disable=E0401 disable=E0611
from vllm.model_executor.model_loader import get_model_loader # pylint: disable=E0611
from vllm.model_executor.model_loader.weight_utils import default_weight_loader # pylint: disable=E0401 disable=E0611
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
ColumnParallelLinear,
RowParallelLinear,
) # pylint: disable=E1101
RowParallelLinear) # pylint: disable=E1101

os.environ["backend"] = "use_vllm"
llm = LLM(
model=pretrained_model_name_or_path, trust_remote_code=True
) # Create an vllm instance.
model = (
llm.llm_engine.model_executor.driver_worker.model_runner.model
) # pylint: disable=E1101
model = llm.llm_engine.model_executor.driver_worker.model_runner.model # pylint: disable=E1101
print("Original model =", model)

original_parameter_memo = dict()
Expand Down Expand Up @@ -447,19 +440,15 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]:
)

print("Optimized model =", model)
loader = get_model_loader(
llm.llm_engine.load_config
) # pylint: disable=E1101
loader = get_model_loader(llm.llm_engine.load_config) # pylint: disable=E1101

weights_iterator = loader._get_weights_iterator(
llm.llm_engine.model_config.model,
llm.llm_engine.model_config.revision,
fall_back_to_pt=True,
)

from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
) # pylint: disable=E0401 disable=E0611
from vllm.model_executor.model_loader.weight_utils import default_weight_loader # pylint: disable=E0401 disable=E0611

params_dict = dict(model.named_parameters(remove_duplicate=False))
for name in params_dict.keys():
Expand Down

0 comments on commit 54c8157

Please sign in to comment.