Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add all_models/bert as an example for tensorrt-llm classification models #269

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

erenup
Copy link

@erenup erenup commented Dec 31, 2023

Hi @kaiyux @Shixiaowei02
First of all, Thank you very much for the great tensorrtllm_backend!

Pull Request Topic

This PR is trying to make an example of tensorrt-llm classification models. I understand the tensorrtllm is mainly for generation tasks. However, the community can also benefit from the amazing tensorrt llm in classification tasks which is also a fundamental NLP task. Hope it could be helpful!

Features

  • This feature is very related to my PR in TensorRT-LLM Add Roberta and few new tests for Bert. The code of PR in TensorRT-LLM can produce classification tensorrt engines for this tensorrtllm_backend PR.
  • I implemented the example classification models under all_models/bert. It contains three sub-directories including preprocessing, tensorrt_llm, and ensemble.
  • preprocessing is similar to all_models/gpt/preprocessing. but I removed unrelated parameters.
  • tensorrt_llm is similar to all_models/gpt/tensorrt_llm. but I directly load the engine in the model.py since classification is simpler than generation tasks.
  • ensemble is similar to all_models/gpt/ensemble. but I removed unrelated parameters.
  • parameters need to be modified or mentioned in README: ${engine_dir} inall_models/bert/tensorrt_llm/config.pbtxt and ${tokenizer_dir} inall_models/bert/preprocessing/config.pbtxt . I did not modify the main readme.md in this repo since it may be better for you to organize this tensorrtllm_backend.

Tests

  • I tested the backend in my 4080 GPU. I find the speed is amazing!
  • My simple speed test results with the engine built with -use_gemm_plugin float16 --use_bert_attention_plugin float16 --enable_context_fmha are provided below. I also provide my simple test python script below.
image - When the use_gemm_plugin and use_bert_attention_plugin are False, the speed is about half slower.

Simple Speed Tests script

import requests
import multiprocessing
import concurrent.futures
import time

# Configuration
SERVER_URL = "http://localhost:8000/v2/models/ensemble/generate"
NUM_REQUESTS = 1000  # Number of requests to send
MAX_WORKERS = multiprocessing.cpu_count()    # Number of concurrent workers
print(f'MAX_WORKERS: {MAX_WORKERS}')
def send_request():
    data = '{"text_input": "This is tensorrt-llm for bert and roberta sequence classification models!", "bad_words": "", "stop_words": ""}'
    response = requests.post(SERVER_URL, data=data)
    return response

def main():
    with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        start_time = time.time()
        futures = [executor.submit(send_request) for _ in range(NUM_REQUESTS)]
        concurrent.futures.wait(futures)
        end_time = time.time()

    # Calculate and print results
    total_time = end_time - start_time
    print(f"Total time for {NUM_REQUESTS} requests: {total_time} seconds")
    print(f"Average time per request: {total_time / NUM_REQUESTS} seconds")
    print(f"Requests per second: {NUM_REQUESTS / total_time}")

if __name__ == "__main__":
    main()

Hope this PR could be useful for the community!

Happy New Year!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants