Fast Cold Starts with Cached Weights
from typing import Dict, List
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant."
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
CHECKPOINT = "NousResearch/Llama-2-7b-chat-hf"
def format_prompt(prompt: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
return f"{B_INST} {B_SYS} {system_prompt} {E_SYS} {prompt} {E_INST}"
class Model:
def __init__(self, **kwargs) -> None:
self.model = None
self.tokenizer = None
def load(self):
self.model = LlamaForCausalLM.from_pretrained(
CHECKPOINT,
torch_dtype=torch.float16,
device_map="auto"
)
self.tokenizer = LlamaTokenizer.from_pretrained(
CHECKPOINT
)
def predict(self, request: Dict) -> Dict[str, List]:
prompt = request.pop("prompt")
input_ids = self.tokenizer(format_prompt(prompt), return_tensors="pt").input_ids.cuda()
outputs = self.model.generate(
inputs=input_ids,
do_sample=True,
num_beams=1,
max_new_tokens=100
)
response = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
return {"response": response}
View on Github
In this example, we go through a Truss that serves an LLM, and caches the weights at build time. Loading model weights for any model can often be the most time-consuming part of starting a model. Caching the weights at build time means that the weights will be baked into the Truss image, and will be available immediately when your model replica starts. This means that cold starts will be significantly faster with this approach.
Implementing the Model
class
With weight caching, you don’t have to change anything about how the Model
class
is implemented to take advantage of the weight caching.
from typing import Dict, List
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant."
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
CHECKPOINT = "NousResearch/Llama-2-7b-chat-hf"
def format_prompt(prompt: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
return f"{B_INST} {B_SYS} {system_prompt} {E_SYS} {prompt} {E_INST}"
class Model:
def __init__(self, **kwargs) -> None:
self.model = None
self.tokenizer = None
def load(self):
self.model = LlamaForCausalLM.from_pretrained(
CHECKPOINT,
torch_dtype=torch.float16,
device_map="auto"
)
self.tokenizer = LlamaTokenizer.from_pretrained(
CHECKPOINT
)
def predict(self, request: Dict) -> Dict[str, List]:
prompt = request.pop("prompt")
input_ids = self.tokenizer(format_prompt(prompt), return_tensors="pt").input_ids.cuda()
outputs = self.model.generate(
inputs=input_ids,
do_sample=True,
num_beams=1,
max_new_tokens=100
)
response = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
return {"response": response}
Setting up the config.yaml
The config.yaml
file is where you need to include the changes to
actually cache the weights at build time.
environment_variables: {}
external_package_dirs: []
model_metadata:
example_model_input: {"prompt": "What is the meaning of life?"}
model_name: Llama with Cached Weights
python_version: py39
requirements:
- accelerate==0.21.0
- safetensors==0.3.2
- torch==2.0.1
- transformers==4.34.0
- sentencepiece==0.1.99
- protobuf==4.24.4
Configuring the model_cache
To cache model weights, set the model_cache
key.
The repo_id
field allows you to specify a Huggingface
repo to pull down and cache at build-time, and the ignore_patterns
field allows you to specify files to ignore. If this is specified, then
this repo won’t have to be pulled during runtime.
Check out the guide for more info.
model_cache:
- repo_id: "NousResearch/Llama-2-7b-chat-hf"
ignore_patterns:
- "*.bin"
The remaining config options are again, similar to what you would configure for the model without the weight caching.
resources:
cpu: "4"
memory: 30Gi
use_gpu: True
accelerator: A10G
secrets: {}
Deploy the model
Deploy the model like you would other Trusses, with:
$ truss push
The build step will take longer than with the normal Llama Truss, since bundling the model weights is now happening during the build. The deploy step & scale-ups will happen much faster with this approach.
You can then invoke the model with:
$ truss predict -d '{"inputs": "What is a large language model?"}'
from typing import Dict, List
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant."
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
CHECKPOINT = "NousResearch/Llama-2-7b-chat-hf"
def format_prompt(prompt: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
return f"{B_INST} {B_SYS} {system_prompt} {E_SYS} {prompt} {E_INST}"
class Model:
def __init__(self, **kwargs) -> None:
self.model = None
self.tokenizer = None
def load(self):
self.model = LlamaForCausalLM.from_pretrained(
CHECKPOINT,
torch_dtype=torch.float16,
device_map="auto"
)
self.tokenizer = LlamaTokenizer.from_pretrained(
CHECKPOINT
)
def predict(self, request: Dict) -> Dict[str, List]:
prompt = request.pop("prompt")
input_ids = self.tokenizer(format_prompt(prompt), return_tensors="pt").input_ids.cuda()
outputs = self.model.generate(
inputs=input_ids,
do_sample=True,
num_beams=1,
max_new_tokens=100
)
response = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
return {"response": response}