Elastic Models: Fast flexible-size models


Warning

Access to Elastic models requires an API token from the TheStage AI Platform .

Elastic models are the models produced by TheStage AI ANNA: Automated Neural Networks Accelerator. ANNA allows you to control model size, latency and quality with a simple slider movement.

Elastic models are represented by 4 tiers: S, M, L, XL. From fastest to slowest:

Tiers

  • XL: Mathematically equivalent neural network, optimized with our DNN compiler.

  • L: Near lossless model, with less than 0.5% degradation obtained on corresponding benchmarks.

  • M: Faster model, with accuracy degradation less than 1%.

  • S: The fastest model, with accuracy degradation less than ~2%.

Supported models

Model Type

Models

GPUs

Text-to-Text

Mistral, Llama, Qwen, DeepSeek-Distill

L40s, H100

Text-to-Video

Mochi

H100, B200

Text-to-Music

MusicGen

H100, L40s

Text-to-Image

Flux, SDXL

L40s, H100, B200

ASR

Whisper

L40s, H100

Main features

  • Supports LLMs, VLMs, Diffusion models.

  • All models provided in Hugging Face transformers and diffusers libraries.

  • Fast cold starts as models are pre-compiled and not uses JIT.

  • Underlying inference engine supports fp16, bf16, int8, fp8, int4, 2:4 sparsity inference.

  • To control quality of models we are using ANNA: Automated NNs Analyzer.

  • For each point corresponding to number of bitops or model size ANNA finds the best quality solution using supported hardware acceleration techniques.

  • No dependecies with TensorRT-LLM, Sglang, vLLM. Simple setup through PyPi.

Goals

  • Provide flexibility in cost vs quality selection for inference

  • Provide clear quality and latency benchmarks

  • Provide interface of HF libraries: transformers and diffusers with a single line of code

  • Provide models supported on a wide range of hardware, which are pre-compiled and require no JIT.

  • Provide the best models and service for self-hosting.

Installation


System requirements

  • Python version 3.10 - 3.12

  • x86 64bit CPU architecture

  • NVIDIA GPU with CUDA 11.8 or higher

  • PyTorch 2.4 or higher

Be sure that you have installed thestage package and set your API token:

pip install thestage
thestage config set --api-token <YOUR_API_TOKEN>

To install packages for Nvidia GPUs before Blackwell:

pip install thestage_elastic_models[nvidia]
# additional dependencies
pip install flash_attn==2.8.2 --no-build-isolation

For Blackwell GPUs and Python 3.10:

pip install thestage_elastic_models[blackwell]
# additional dependencies
pip install torch==2.7.0+cu128 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 \
# dependencies for transformers on RTX 50s, diffusers not uses flash attention
wget https://github.com/Zarrac/flashattention-blackwell-wheels-whl-ONLY-5090-5080-5070-5060-flash-attention-/releases/download/FlashAttention/flash_attn-2.7.4.post1-rtx5090-torch2.7.0cu128cxx11abiTRUE-cp310-linux_x86_64.whl \
pip install flash_attn-2.7.4.post1-0rtx5090torch270cu128cxx11abiTRUE-cp310-cp310-linux_x86_64.whl

Test your setup:

import elastic_models

elastic_models.print_available_models()

Output:

----------------------------------------------------------------------------------------------------------------------------------
Model                                              | B200        | RTX-4090 | RTX-5090 | H100        | L40S
----------------------------------------------------------------------------------------------------------------------------------
DavidAU/MN-GRAND-Gutenberg-Lyra4-Lyra-12B-DARKNESS |             | S, M     | S, M, L  |             |
Qwen/Qwen2.5-14B-Instruct                          |             |          |          | S, M, L, XL | S, M, L, XL
Qwen/Qwen2.5-7B-Instruct                           |             |          |          | S, M, L, XL | S, M, L, XL
black-forest-labs/FLUX.1-dev                       | S, M, L, XL |          | S        | S, M, L, XL | S, M, L, XL
black-forest-labs/FLUX.1-schnell                   | S, M, L, XL |          | S        | S, M, L, XL | S, M, L, XL
deepseek-ai/DeepSeek-R1-Distill-Qwen-14B           |             |          |          | S, M, L, XL | S, M, L, XL
deepseek-ai/DeepSeek-R1-Distill-Qwen-7B            |             |          |          | S, M, L, XL | S, M, L, XL
facebook/musicgen-large                            |             |          |          | S, M, L, XL | S, M, L, XL
genmo/mochi-1-preview                              | S, XL       |          |          | S, XL       |
meta-llama/Llama-3.1-8B-Instruct                   |             |          |          | S, M, L, XL | S, M, L, XL
meta-llama/Llama-3.2-1B-Instruct                   |             |          |          | S, M, L, XL | S, M, L, XL
mistralai/Mistral-7B-Instruct-v0.3                 |             |          |          | S, M, L, XL | S, M, L, XL
mistralai/Mistral-Nemo-Instruct-2407               |             |          |          | S, M, L, XL | S, M, L, XL
mistralai/Mistral-Small-3.1-24B-Instruct-2503      |             |          |          | S, M, L, XL | S, M, L
openai/whisper-large-v3                            |             |          |          | S           | S
stabilityai/stable-diffusion-xl-base-1.0           |             |          |          | XL          | XL
----------------------------------------------------------------------------------------------------------------------------------

Example: Flux-model usage


import torch
# replace diffuser with elastic_models.diffusers
from elastic_models.diffusers import FluxPipeline

mode_name = 'black-forest-labs/FLUX.1-schnell'
hf_token = ''
device = torch.device("cuda")

pipeline = FluxPipeline.from_pretrained(
    mode_name,
    torch_dtype=torch.bfloat16,
    token=hf_token,
    # model sizes: S, M, L, XL
    mode='S'
)
pipeline.to(device)

# Using the pipeline in HF interface
prompts = ["Kitten eating a banana"]
output = pipeline(prompt=prompts)

for prompt, output_image in zip(prompts, output.images):
    output_image.save((prompt.replace(' ', '_') + '.png'))

Example: Mistral-7B-Instruct usage


import torch
from transformers import AutoTokenizer
from elastic_models.transformers import AutoModelForCausalLM

# Currently we require to have your HF token
# as we use original weights for part of layers and
# model confugaration as well
model_name = "mistralai/Mistral-Nemo-Instruct-2407"
hf_token = ''
device = torch.device("cuda")

# Create mode
tokenizer = AutoTokenizer.from_pretrained(
    model_name, token=hf_token
)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    token=hf_token,
    torch_dtype=torch.bfloat16,
    attn_implementation="sdpa",
    # S, M, L, XL
    mode='S'
).to(device)
model.generation_config.pad_token_id = tokenizer.eos_token_id

# Inference simple as transformers library
prompt = "Describe basics of DNNs quantization."
messages = [
{
    "role": "system",
    "content": "You are a search bot, answer on user text queries."
},
{
    "role": "user",
    "content": prompt
}
]

chat_prompt = tokenizer.apply_chat_template(
    messages, add_generation_prompt=True, tokenize=False
)

inputs = tokenizer(chat_prompt, return_tensors="pt")
inputs.to(device)

with torch.inference_mode():
    generate_ids = model.generate(**inputs, max_length=500)

input_len = inputs['input_ids'].shape[1]
generate_ids = generate_ids[:, input_len:]
output = tokenizer.batch_decode(
    generate_ids,
    skip_special_tokens=True,
    clean_up_tokenization_spaces=False
)[0]

# Validate answer
print(f"# Q:\n{prompt}\n")
print(f"# A:\n{output}\n")