Elastic FLUX.1-schnell¶
Overview¶
This notebook demonstrates how to use elastic_models
lib to
inference optimized FLUX.1-schnell model.
For each model, elastic_models provides a series of optimized models:
XL: Mathematically equivalent neural network, optimized with our DNN compiler.
L: Near lossless model, with less than 1% degradation obtained on corresponding benchmarks.
M: Faster model, with quality degradation less than 1.5%.
S: The fastest model, with quality degradation less than 2%.
To compare model quality for FLUX.1-schnell we used FID score between outputs of original and compressed models
Setup environment¶
Let’s start with installing elastic_models
and thestage
libs
Note
Access to elastic_models
requires an API token from the TheStage AI Platform
!pip install elastic_models thestage
Generate thestage api-token. See https://docs.thestage.ai/platform/src/thestage-ai-ssh-keys-and-api-tokens.html
Set up generated token
!thestage config set --api-token <YOUR-THESTAGE-TOKEN>
import torch
from elastic_models.diffusers import DiffusionPipeline
import matplotlib.pyplot as plt
Loading model¶
Below we will load flux schnell pipeline by it’s huggingface name.
We set mode="XL"
to use model variant equivalent to the original,
optimized with our DNNs compiler
mode_name = "black-forest-labs/FLUX.1-schnell"
device = torch.device("cuda")
pipeline = DiffusionPipeline.from_pretrained(
mode_name,
torch_dtype=torch.bfloat16,
cache_dir="/mount/huggingface_cache",
mode="XL"
)
pipeline.to(device)
xl_output = pipeline(
prompt=["An astronaut stands knee-deep in the sea"],
num_inference_steps=4,
generator=torch.Generator(42)
)
xl_output[0][0]
100%|██████████| 4/4 [00:00<00:00, 19.25it/s]

Pre-compiled engines support following image shapes: 512x512, 768x768, 1024x1024.
Let’s test image of shape 512x512
pipeline(
prompt=["Kitten eating a burger"],
height=512,
width=512,
num_inference_steps=4
)[0][0]
100%|██████████| 4/4 [00:02<00:00, 1.67it/s]

Engines also support batched versions. Batch sizes 1, 4, 8 are available.
Below we run generation for batch of 8 images:
prompts = [
"Firefighter rescuing cats from a giant yarn ball",
"Librarian wrestling an octopus in a bookstore",
"Chef being chased by giant angry vegetables",
"Dentist examining a dragon's teeth",
"Mailman delivering letters to confused aliens",
"Barber cutting a yeti's hair",
"Teacher explaining math to sleepy pandas",
"Plumber fixing pipes in a haunted house"
]
output = pipeline(
prompt=prompts,
height=512,
width=512,
num_inference_steps=4
)
# Display generated images
images = output[0]
plt.figure(figsize=(16, 8))
for i, image in enumerate(images[:8]):
plt.subplot(2, 4, i+1)
plt.imshow(image)
plt.axis('off')
plt.tight_layout()
plt.show()
100%|██████████| 4/4 [00:00<00:00, 9.33it/s]

Benchmark inference time¶
Let’s benchmark inference time and memory usage of the image generation pipeline.
To measure max memory usage we will use pynvml
lib
import pynvml
def monitor_gpu_memory(queue, running_flag):
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0) # Assuming GPU 0
max_memory_usage = 0
try:
while running_flag.value:
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
memory_usage = info.used / (1024 * 1024) # Convert to MiB
max_memory_usage = max(max_memory_usage, memory_usage)
finally:
pynvml.nvmlShutdown()
queue.put(max_memory_usage)
When benchmarking inference time, we perform warm-up calls and multiple repetitions to calculate average execution time, while ensuring proper CUDA synchronization.
import timeit
def benchmark_time(pipeline, prompt, number=3, repeat=3, **kwargs):
run_func = lambda: pipeline(prompt=prompt, **kwargs)
run_func()
with torch.no_grad():
runs = timeit.repeat(
run_func,
number=number,
repeat=repeat,
setup="import torch; torch.cuda.synchronize()",
)
return min(runs) / number
Batch Size |
Inference Time (sec) |
Max Memory Usage (GB) |
---|---|---|
1 |
0.70255 |
50.2 |
4 |
2.723 |
67.9 |
8 |
5.4082 |
81.5 |
Using compressed model¶
Set mode='S'
to load the fastest model variant.
mode_name = "black-forest-labs/FLUX.1-schnell"
device = torch.device("cuda")
pipeline = DiffusionPipeline.from_pretrained(
mode_name,
torch_dtype=torch.bfloat16,
cache_dir="/mount/huggingface_cache",
mode="S"
)
pipeline.to(device);
Let’s generate the same image with compressed model and compare it with original one.
s_output = pipeline(
prompt=["An astronaut stands knee-deep in the sea"],
num_inference_steps=4,
generator=torch.Generator(42)
)
plt.figure(figsize=(16, 8))
for i, image in enumerate([xl_output[0][0], s_output[0][0]]):
plt.subplot(1, 2, i+1)
plt.imshow(image)
plt.title(["XL", "S"][i])
plt.axis('off')
plt.tight_layout()
plt.show()
100%|██████████| 4/4 [00:00<00:00, 9.13it/s]

Now let’s benchmark inference time and memory usage of S
model for
different batch sizes:
Batch Size |
Inference Time (sec) |
Max Memory Usage (GB) |
---|---|---|
1 |
0.4983 |
32.8 |
4 |
1.9589 |
50.5 |
8 |
3.936 |
74.9 |