Elastic Models: Fast flexible-size models¶
Warning
Access to Elastic models requires an API token from the TheStage AI Platform .
Elastic models are the models produced by TheStage AI ANNA: Automated Neural Networks Accelerator. ANNA allows you to control model size, latency and quality with a simple slider movement.
Elastic models are represented by 4 tiers: S, M, L, XL. From fastest to slowest:
Tiers
XL: Mathematically equivalent neural network, optimized with our DNN compiler.
L: Near lossless model, with less than 0.5% degradation obtained on corresponding benchmarks.
M: Faster model, with accuracy degradation less than 1%.
S: The fastest model, with accuracy degradation less than ~2%.
Supported models
Model Type |
Models |
GPUs |
---|---|---|
Text-to-Text |
Mistral, Llama, Qwen, DeepSeek-Distill |
L40s, H100 |
Text-to-Video |
Mochi |
H100, B200 |
Text-to-Music |
MusicGen |
H100, L40s |
Text-to-Image |
Flux, SDXL |
L40s, H100, B200 |
ASR |
Whisper |
L40s, H100 |
Main features
Supports LLMs, VLMs, Diffusion models.
All models provided in Hugging Face transformers and diffusers libraries.
Fast cold starts as models are pre-compiled and not uses JIT.
Underlying inference engine supports fp16, bf16, int8, fp8, int4, 2:4 sparsity inference.
To control quality of models we are using ANNA: Automated NNs Analyzer.
For each point corresponding to number of bitops or model size ANNA finds the best quality solution using supported hardware acceleration techniques.
No dependecies with TensorRT-LLM, Sglang, vLLM. Simple setup through PyPi.
Goals
Provide flexibility in cost vs quality selection for inference
Provide clear quality and latency benchmarks
Provide interface of HF libraries: transformers and diffusers with a single line of code
Provide models supported on a wide range of hardware, which are pre-compiled and require no JIT.
Provide the best models and service for self-hosting.
Installation¶
System requirements
Python version 3.10 - 3.12
x86 64bit CPU architecture
NVIDIA GPU with CUDA 11.8 or higher
PyTorch 2.4 or higher
Be sure that you have installed thestage
package and set your API token:
pip install thestage
thestage config set --api-token <YOUR_API_TOKEN>
To install packages for Nvidia GPUs before Blackwell:
pip install thestage_elastic_models[nvidia]
# additional dependencies
pip install flash_attn==2.8.2 --no-build-isolation
For Blackwell GPUs and Python 3.10:
pip install thestage_elastic_models[blackwell]
# additional dependencies
pip install torch==2.7.0+cu128 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 \
# dependencies for transformers on RTX 50s, diffusers not uses flash attention
wget https://github.com/Zarrac/flashattention-blackwell-wheels-whl-ONLY-5090-5080-5070-5060-flash-attention-/releases/download/FlashAttention/flash_attn-2.7.4.post1-rtx5090-torch2.7.0cu128cxx11abiTRUE-cp310-linux_x86_64.whl \
pip install flash_attn-2.7.4.post1-0rtx5090torch270cu128cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
Test your setup:
import elastic_models
elastic_models.print_available_models()
Output:
----------------------------------------------------------------------------------------------------------------------------------
Model | B200 | RTX-4090 | RTX-5090 | H100 | L40S
----------------------------------------------------------------------------------------------------------------------------------
DavidAU/MN-GRAND-Gutenberg-Lyra4-Lyra-12B-DARKNESS | | S, M | S, M, L | |
Qwen/Qwen2.5-14B-Instruct | | | | S, M, L, XL | S, M, L, XL
Qwen/Qwen2.5-7B-Instruct | | | | S, M, L, XL | S, M, L, XL
black-forest-labs/FLUX.1-dev | S, M, L, XL | | S | S, M, L, XL | S, M, L, XL
black-forest-labs/FLUX.1-schnell | S, M, L, XL | | S | S, M, L, XL | S, M, L, XL
deepseek-ai/DeepSeek-R1-Distill-Qwen-14B | | | | S, M, L, XL | S, M, L, XL
deepseek-ai/DeepSeek-R1-Distill-Qwen-7B | | | | S, M, L, XL | S, M, L, XL
facebook/musicgen-large | | | | S, M, L, XL | S, M, L, XL
genmo/mochi-1-preview | S, XL | | | S, XL |
meta-llama/Llama-3.1-8B-Instruct | | | | S, M, L, XL | S, M, L, XL
meta-llama/Llama-3.2-1B-Instruct | | | | S, M, L, XL | S, M, L, XL
mistralai/Mistral-7B-Instruct-v0.3 | | | | S, M, L, XL | S, M, L, XL
mistralai/Mistral-Nemo-Instruct-2407 | | | | S, M, L, XL | S, M, L, XL
mistralai/Mistral-Small-3.1-24B-Instruct-2503 | | | | S, M, L, XL | S, M, L
openai/whisper-large-v3 | | | | S | S
stabilityai/stable-diffusion-xl-base-1.0 | | | | XL | XL
----------------------------------------------------------------------------------------------------------------------------------
Example: Flux-model usage¶
import torch
# replace diffuser with elastic_models.diffusers
from elastic_models.diffusers import FluxPipeline
mode_name = 'black-forest-labs/FLUX.1-schnell'
hf_token = ''
device = torch.device("cuda")
pipeline = FluxPipeline.from_pretrained(
mode_name,
torch_dtype=torch.bfloat16,
token=hf_token,
# model sizes: S, M, L, XL
mode='S'
)
pipeline.to(device)
# Using the pipeline in HF interface
prompts = ["Kitten eating a banana"]
output = pipeline(prompt=prompts)
for prompt, output_image in zip(prompts, output.images):
output_image.save((prompt.replace(' ', '_') + '.png'))
Example: Mistral-7B-Instruct usage¶
import torch
from transformers import AutoTokenizer
from elastic_models.transformers import AutoModelForCausalLM
# Currently we require to have your HF token
# as we use original weights for part of layers and
# model confugaration as well
model_name = "mistralai/Mistral-Nemo-Instruct-2407"
hf_token = ''
device = torch.device("cuda")
# Create mode
tokenizer = AutoTokenizer.from_pretrained(
model_name, token=hf_token
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
token=hf_token,
torch_dtype=torch.bfloat16,
attn_implementation="sdpa",
# S, M, L, XL
mode='S'
).to(device)
model.generation_config.pad_token_id = tokenizer.eos_token_id
# Inference simple as transformers library
prompt = "Describe basics of DNNs quantization."
messages = [
{
"role": "system",
"content": "You are a search bot, answer on user text queries."
},
{
"role": "user",
"content": prompt
}
]
chat_prompt = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)
inputs = tokenizer(chat_prompt, return_tensors="pt")
inputs.to(device)
with torch.inference_mode():
generate_ids = model.generate(**inputs, max_length=500)
input_len = inputs['input_ids'].shape[1]
generate_ids = generate_ids[:, input_len:]
output = tokenizer.batch_decode(
generate_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
# Validate answer
print(f"# Q:\n{prompt}\n")
print(f"# A:\n{output}\n")