Elastic FLUX.1-schnell

View on github

Overview

This notebook demonstrates how to use elastic_models lib to inference optimized FLUX.1-schnell model.

For each model, elastic_models provides a series of optimized models:

  • XL: Mathematically equivalent neural network, optimized with our DNN compiler.

  • L: Near lossless model, with less than 1% degradation obtained on corresponding benchmarks.

  • M: Faster model, with quality degradation less than 1.5%.

  • S: The fastest model, with quality degradation less than 2%.

To compare model quality for FLUX.1-schnell we used FID score between outputs of original and compressed models

Setup environment

Let’s start with installing elastic_models and thestage libs

Note

Access to elastic_models requires an API token from the TheStage AI Platform

!pip install elastic_models thestage

Generate thestage api-token. See https://docs.thestage.ai/platform/src/thestage-ai-ssh-keys-and-api-tokens.html

Set up generated token

!thestage config set --api-token <YOUR-THESTAGE-TOKEN>
import torch
from elastic_models.diffusers import DiffusionPipeline
import matplotlib.pyplot as plt

Loading model

Below we will load flux schnell pipeline by it’s huggingface name.

We set mode="XL" to use model variant equivalent to the original, optimized with our DNNs compiler

mode_name = "black-forest-labs/FLUX.1-schnell"
device = torch.device("cuda")
pipeline = DiffusionPipeline.from_pretrained(
    mode_name,
    torch_dtype=torch.bfloat16,
    cache_dir="/mount/huggingface_cache",
    mode="XL"
)
pipeline.to(device)
xl_output = pipeline(
    prompt=["An astronaut stands knee-deep in the sea"],
    num_inference_steps=4,
    generator=torch.Generator(42)
)
xl_output[0][0]
100%|██████████| 4/4 [00:00<00:00, 19.25it/s]
../../_images/output_12_1.png

Pre-compiled engines support following image shapes: 512x512, 768x768, 1024x1024.

Let’s test image of shape 512x512

pipeline(
    prompt=["Kitten eating a burger"],
    height=512,
    width=512,
    num_inference_steps=4
)[0][0]
100%|██████████| 4/4 [00:02<00:00,  1.67it/s]
../../_images/output_14_1.png

Engines also support batched versions. Batch sizes 1, 4, 8 are available.

Below we run generation for batch of 8 images:

prompts = [
    "Firefighter rescuing cats from a giant yarn ball",
    "Librarian wrestling an octopus in a bookstore",
    "Chef being chased by giant angry vegetables",
    "Dentist examining a dragon's teeth",
    "Mailman delivering letters to confused aliens",
    "Barber cutting a yeti's hair",
    "Teacher explaining math to sleepy pandas",
    "Plumber fixing pipes in a haunted house"
]

output = pipeline(
    prompt=prompts,
    height=512,
    width=512,
    num_inference_steps=4
)

# Display generated images
images = output[0]
plt.figure(figsize=(16, 8))

for i, image in enumerate(images[:8]):
    plt.subplot(2, 4, i+1)
    plt.imshow(image)
    plt.axis('off')

plt.tight_layout()
plt.show()
100%|██████████| 4/4 [00:00<00:00,  9.33it/s]
../../_images/output_16_1.png

Benchmark inference time

Let’s benchmark inference time and memory usage of the image generation pipeline.

To measure max memory usage we will use pynvml lib

import pynvml


def monitor_gpu_memory(queue, running_flag):
    pynvml.nvmlInit()
    handle = pynvml.nvmlDeviceGetHandleByIndex(0)  # Assuming GPU 0
    max_memory_usage = 0

    try:
        while running_flag.value:
            info = pynvml.nvmlDeviceGetMemoryInfo(handle)
            memory_usage = info.used / (1024 * 1024)  # Convert to MiB
            max_memory_usage = max(max_memory_usage, memory_usage)
    finally:
        pynvml.nvmlShutdown()
        queue.put(max_memory_usage)

When benchmarking inference time, we perform warm-up calls and multiple repetitions to calculate average execution time, while ensuring proper CUDA synchronization.

import timeit

def benchmark_time(pipeline, prompt, number=3, repeat=3, **kwargs):
    run_func = lambda: pipeline(prompt=prompt, **kwargs)
    run_func()
    with torch.no_grad():
        runs = timeit.repeat(
            run_func,
            number=number,
            repeat=repeat,
            setup="import torch; torch.cuda.synchronize()",
        )
    return min(runs) / number
XL Model Performance Benchmarks

Batch Size

Inference Time (sec)

Max Memory Usage (GB)

1

0.70255

50.2

4

2.723

67.9

8

5.4082

81.5

Using compressed model

Set mode='S' to load the fastest model variant.

mode_name = "black-forest-labs/FLUX.1-schnell"
device = torch.device("cuda")
pipeline = DiffusionPipeline.from_pretrained(
    mode_name,
    torch_dtype=torch.bfloat16,
    cache_dir="/mount/huggingface_cache",
    mode="S"
)
pipeline.to(device);

Let’s generate the same image with compressed model and compare it with original one.

s_output = pipeline(
    prompt=["An astronaut stands knee-deep in the sea"],
    num_inference_steps=4,
    generator=torch.Generator(42)
)

plt.figure(figsize=(16, 8))

for i, image in enumerate([xl_output[0][0], s_output[0][0]]):
    plt.subplot(1, 2, i+1)
    plt.imshow(image)
    plt.title(["XL", "S"][i])
    plt.axis('off')
plt.tight_layout()
plt.show()
100%|██████████| 4/4 [00:00<00:00,  9.13it/s]
../../_images/output_30_1.png

Now let’s benchmark inference time and memory usage of S model for different batch sizes:

S Model Performance Benchmarks

Batch Size

Inference Time (sec)

Max Memory Usage (GB)

1

0.4983

32.8

4

1.9589

50.5

8

3.936

74.9