from typing import Dict, List

import torch
from transformers import LlamaForCausalLM, LlamaTokenizer

DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant."

B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
CHECKPOINT = "NousResearch/Llama-2-7b-chat-hf"


def format_prompt(prompt: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
    return f"{B_INST} {B_SYS} {system_prompt} {E_SYS} {prompt} {E_INST}"


class Model:
    def __init__(self, **kwargs) -> None:
        self.model = None
        self.tokenizer = None

    def load(self):
        self.model = LlamaForCausalLM.from_pretrained(
            CHECKPOINT, torch_dtype=torch.float16, device_map="auto"
        )
        self.tokenizer = LlamaTokenizer.from_pretrained(CHECKPOINT)

    def predict(self, request: Dict) -> Dict[str, List]:
        prompt = request.pop("prompt")
        input_ids = self.tokenizer(
            format_prompt(prompt), return_tensors="pt"
        ).input_ids.cuda()

        outputs = self.model.generate(
            inputs=input_ids, do_sample=True, num_beams=1, max_new_tokens=100
        )
        response = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

        return {"response": response}

View on Github

In this example, we go through a Truss that serves an LLM, and caches the weights at build time. Loading model weights for any model can often be the most time-consuming part of starting a model. Caching the weights at build time means that the weights will be baked into the Truss image, and will be available immediately when your model replica starts. This means that cold starts will be significantly faster with this approach.

Implementing the Model class

With weight caching, you don’t have to change anything about how the Model class is implemented to take advantage of the weight caching.

model/model.py
from typing import Dict, List

import torch
from transformers import LlamaForCausalLM, LlamaTokenizer

DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant."

B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
CHECKPOINT = "NousResearch/Llama-2-7b-chat-hf"


def format_prompt(prompt: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
    return f"{B_INST} {B_SYS} {system_prompt} {E_SYS} {prompt} {E_INST}"


class Model:
    def __init__(self, **kwargs) -> None:
        self.model = None
        self.tokenizer = None

    def load(self):
        self.model = LlamaForCausalLM.from_pretrained(
            CHECKPOINT, torch_dtype=torch.float16, device_map="auto"
        )
        self.tokenizer = LlamaTokenizer.from_pretrained(CHECKPOINT)

    def predict(self, request: Dict) -> Dict[str, List]:
        prompt = request.pop("prompt")
        input_ids = self.tokenizer(
            format_prompt(prompt), return_tensors="pt"
        ).input_ids.cuda()

        outputs = self.model.generate(
            inputs=input_ids, do_sample=True, num_beams=1, max_new_tokens=100
        )
        response = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

        return {"response": response}

Setting up the config.yaml

The config.yaml file is where you need to include the changes to actually cache the weights at build time.

config.yaml
environment_variables: {}
external_package_dirs: []
model_metadata:
  example_model_input: {"prompt": "What is the meaning of life?"}
model_name: Llama with Cached Weights
python_version: py39
requirements:
- accelerate==0.21.0
- safetensors==0.3.2
- torch==2.0.1
- transformers==4.34.0
- sentencepiece==0.1.99
- protobuf==4.24.4

Configuring the model_cache

To cache model weights, set the model_cache key. The repo_id field allows you to specify a Huggingface repo to pull down and cache at build-time, and the ignore_patterns field allows you to specify files to ignore. If this is specified, then this repo won’t have to be pulled during runtime.

Check out the guide for more info.

config.yaml
model_cache:
- repo_id: "NousResearch/Llama-2-7b-chat-hf"
  ignore_patterns:
  - "*.bin"

The remaining config options are again, similar to what you would configure for the model without the weight caching.

config.yaml
resources:
  cpu: "4"
  memory: 30Gi
  use_gpu: True
  accelerator: A10G
secrets: {}

Deploy the model

Deploy the model like you would other Trusses, with:

$ truss push

The build step will take longer than with the normal Llama Truss, since bundling the model weights is now happening during the build. The deploy step & scale-ups will happen much faster with this approach.

You can then invoke the model with:

$ truss predict -d '{"inputs": "What is a large language model?"}'