TGI, or Text Generation Inference, is a high performance model server for language models built by Huggingface. In this doc, we’ll cover when to use TGI, how to tune TGI for performance and how to deploy it to Baseten.

What is TGI?

TGI consists of 2 parts:

  1. A high performance, Rust-based server
  2. A set of optimized model implementations that outperform generic implementations

To optimize GPU utilization and improve throughput & latency, it’s common to implement optimizations at the server level and the model level.

It’s worth noting that TGI can only be used for a certain subset of models. As of September 2023, this includes:

  • Mistral 7B
  • Llama V2
  • Llama
  • MPT
  • Code Llama
  • Falcon 40B
  • Falcon 7B
  • FLAN-T5
  • Galactica
  • GPT-Neox
  • OPT
  • SantaCoder
  • Starcoder

You should use TGI if you care about model latency / throughput and want to be able stream results. It’s worth nothing that tuning TGI to achieve max performance requires a dedicated effort and is different for every model and what metric you’d like to optimize for.

How to use TGI

To define a TGI truss, we’ll use the truss CLI to generate the scaffold. Run the following in your terminal

truss init ./my-tgi-truss --backend TGI

Let’s walk through what this command is doing

- "truss init": This CLI command initializes a empty scaffold for a variety of model backends that truss supports (TGI, VLLM, Triton and TrussServer)
- "./my-tgi-truss": The `init` command requires a `target_directory` which is where this empty scaffold will be generated
- "--backend TGI": This tells truss that we are specifically interested in a TGI truss, which instantiates a different scaffold than the default (TrussServer)

After running that command, you’ll be prompted to pass a model_id.

$ truss init ./my-tgi-trus --backend TGI
? model_id

One thing that makes Truss support in TGI a bit different from the default is that there isn’t any model instantiation or model inference code that needs to be written! Because we leverage TGI under the hood, TGI handles loading the model in and running inference. All you need to do is pass in the model ID as it appears on Huggingface.

After typing in your model ID, you’ll be prompted with the endpoint you’d like to use.

$ truss init ./my-tgi-trus --backend TGI
? model_id facebook/opt-125M
? endpoint

TGI supports 2 endpoints:

  • generate which returns the entire generated response upon completion
  • generate_stream which streams the response as it’s being generated

You can press the tab key on any of these dialogues to see options for values.

Finally, you’ll be asked for the name of your model.

Deploying your TGI model

Now that we have a TGI model, let’s deploy this model to Baseten and see how it performs.

You’ll need an API key to deploy your model. You can get one by navigating to your Baseten settings page. To push the model to Baseten, run the following command:

$ truss push --publish

Let’s walk through what this command is doing

- "truss push": This CLI command zips and uploads your Truss to Baseten
- "--publish": This tells Baseten that you want a Production endpoint (vs. a development environment)

After running this command, you’ll be prompted to pass in your API key. Once you’ve done that, you can view your deployment status from the Baseten dashboard. Once the model has been deployed, you can invoke it by running the following command:

$ truss predict -d '{"inputs": "What is a large language model?", "parameters": {"max_new_tokens": 128, "sample": true}} --published'

Let’s walk through what this command is doing

- "truss predict": This CLI command invokes your model
- "-d": This tells the CLI that the next argument is the data you want to pass to your model
- '{"inputs": "What is a large language model?", "parameters": {"max_new_tokens": 128, "sample": true}}': This is the data you want to pass to your model. TGI expects a JSON object with 2 keys: `inputs` and `parameters` (optional). `inputs` is the prompt you want to pass to your model and `parameters` is a JSON object of parameters you want to pass to your model.
- "--published": This tells the CLI that you want to invoke the production endpoint (vs. a development endpoint)

Tuning your TGI server

After deploying your model, you may notice that you’re not getting the performance you’d like out of TGI. If you navigate to the target_directory from above, you’ll find a config.yaml file that contains a key build. The following is a set of arguments you can pass to the build key to tune TGI for max performance

  1. max_input_length (default: 1024) This parameter represents the maximum allowed input length expressed in the number of tokens.

If you expect longer sequences as input, you might need to adjust this parameter accordingly, keeping in mind that longer sequences may impact the overall memory required to handle the load. If you know that you’ll send less than the max length allowed by the model, it’s advantageous to set this parameter. When TGI starts up, it calculates how much memory to reserve per request and setting this value allows TGI to accurately allocate the correct amount of memory. In short, higher values take up more GPU memory but afford you the ability to send longer prompts.

  1. max_total_tokens (default: 2048)

This is a crucial parameter defining the “memory budget” of all requests. It is the total number of input tokens and max_new_tokens across all requests that TGI can handle at a given time. For example, with a value of 2048, users can send either a prompt of 1000 and ask for 1048 new tokens, or send a prompt of 1 and ask for 2047 max_new_tokens.

If you encounter memory limitations or need to optimize memory usage based on client requirements, adjusting this parameter could help in efficiently managing the memory usage. Higher values take up more GPU memory but allow you longer (prompts + generated text). The tradeoff here is that higher values will increase throughput but will increase individual request latency.

  1. max_batch_prefill_tokens (default: 4096)

TGI splits up the generation process into 2 phases: prefill and generation. The prefill phase reads in the input tokens and generates the first output token. This parameter controls, in a given batch of requests, how many tokens can be processed at once during the prefill phase. This value should be at least the value for max_input_length.

Similar to max input length, if your input tokens are constrained, this is worth setting as a function of (constrained input length) * (max batch size your hardware can handle). This setting is also worth defining when you want to impose stricter controls on the resource usage during prefill operations, especially when dealing with models having a large footprint or under constrained hardware environments.

  1. max_batch_total_tokens

Similar to max_batch_prefill_tokens, this represents the entire token count across a batch; the total input tokens + the total generated tokens. In short, this value should be the top end of number of tokens that can fit on the GPU after the model has been loaded.

This value is particularly important when maximizing GPU utilization. The tradeoff here is that higher values will increase throughput but will increase individual request latency.

  1. max_waiting_tokens

This setting defines how many tokens can be passed before forcing the waiting queries to be put on the batch. Adjusting this value will help in optimizing the overall latency for end-users by managing how quickly waiting queries are allowed a slot in the running batch.

Tune this parameter when optimizing for end-user latency and to manage the compute allocation between prefill and decode operations efficiently. The tradeoff here is that higher values will increase throughput but will increase individual request latency because of the additional time it takes to fill up the batch.

  1. sharded

This setting defines whether to shard the model across multiple GPUs. Use this when deploying on multi-GPU setups and want to improve the resource utilization and potentially the throughput of the server. You’d want to disable this if you’re loading other models on another GPU.