Flux Caching Tutorial: Accelerating Elastic Models with cache-dit

View on github

This tutorial demonstrates how to apply cache-dit caching techniques to Elastic Flux models for faster inference.

Overview


Elastic Models allow you to trade off quality for speed by using different model sizes (S, XL, original). cache-dit is a caching framework that reduces redundant computations during diffusion model inference.

By combining these two approaches, we can achieve significant speedups while maintaining quality.

What we’ll cover:

  1. Loading Elastic Flux models in different modes (S, XL, original)

  2. Applying DualCache (aggressive and conservative) caching strategies

  3. Comparing inference times with and without caching

  4. Evaluating image quality metrics

  5. Visualizing results

DualCache Strategies:

DualCache Aggressive - High speedup, more caching: - Fn=1, Bn=0 (minimal recomputation) - rdt=0.2 (high threshold for cache invalidation) - max_continuous_cached_steps=10 (long cache reuse) - Best for: Static scenes, mid-diffusion steps

DualCache Conservative - Balanced speedup, safer caching: - Fn=4, Bn=0 (more recomputation layers) - rdt=0.05 (low threshold, more frequent cache refresh) - max_continuous_cached_steps=3 (shorter cache reuse) - Best for: Dynamic scenes, correcting accumulated errors

1. Setup and Imports


⚠️ IMPORTANT: Diffusers Version Compatibility

This tutorial uses cache-dit which is designed for newer versions of diffusers (with Chroma, HiDream models support). However, elastic_models may require an older version of diffusers.

Two solutions: 1. Use our compatibility patch (recommended) - see below 2. Manually edit cache-dit - comment out imports in: - /path/to/cache_dit/cache_factory/patch_functors/functor_chroma.py - /path/to/cache_dit/cache_factory/patch_functors/functor_hidream.py

Our utils automatically patch missing classes, so cache-dit works with older diffusers.

import os
import sys
import torch
import time
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt

# Add current directory to path for importing utils
sys.path.insert(0, os.path.dirname(os.path.abspath('__file__')))

from flux_caching_tutorial_utils import (
    setup_diffusers_compatibility, set_seed,
    visualize_comparison, create_performance_charts,
    print_performance_summary, save_results,
)

# Setup compatibility patches BEFORE importing cache-dit
setup_diffusers_compatibility()

# Now import elastic models and cache-dit
from elastic_models.diffusers import (
    DiffusionPipeline as ElasticDiffusionPipeline
)
import cache_dit
from cache_dit import (
    BasicCacheConfig, BlockAdapter, ForwardPattern
)

# Suppress warnings
import warnings
warnings.filterwarnings('ignore')

print("✓ All imports complete")

2. Configuration


# Device and dtype
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.bfloat16

# Model configuration
MODEL_NAME = 'black-forest-labs/FLUX.1-dev'
HF_TOKEN = None  # Set your HuggingFace token if needed
HF_CACHE_DIR = '/mount/huggingface_cache'
MODEL_PATH = None

# Generation parameters
WIDTH = 1024
HEIGHT = 1024
NUM_INFERENCE_STEPS = 28
GUIDANCE_SCALE = 3.5
SEED = 42

# Test prompts
TEST_PROMPTS = [
    "A majestic lion standing on a rocky cliff at sunset",
    "A futuristic city skyline with flying cars and neon lights",
    "A beautiful garden with blooming flowers and butterflies",
]

# Output directory
OUTPUT_DIR = Path('./flux_caching_tutorial_results')
OUTPUT_DIR.mkdir(exist_ok=True)

print(f"Device: {device}")
print(f"Model: {MODEL_NAME}")
print(f"Image size: {WIDTH}x{HEIGHT}")
print(f"Inference steps: {NUM_INFERENCE_STEPS}")
print(f"Test prompts: {len(TEST_PROMPTS)}")
Device: cuda
Model: black-forest-labs/FLUX.1-dev
Image size: 1024x1024
Inference steps: 28
Test prompts: 3

3. Core Functions for Elastic Models and Cache-Dit


These are the key functions demonstrating how to use cache-dit with elastic_models.

def load_elastic_pipeline(mode='original', model_path=None):
    """
    Load Elastic Flux pipeline in specified mode.

    This shows how to load elastic models in different configurations:
    - 'original': Full model
    - 'XL': Larger elastic variant
    - 'S': Smaller elastic variant (faster)
    """
    print(f"\n{'='*60}")
    print(f"Loading Elastic Flux pipeline - Mode: {mode}")
    print(f"{'='*60}")

    if mode == 'original':
        pipeline = ElasticDiffusionPipeline.from_pretrained(
            MODEL_NAME,
            torch_dtype=dtype,
            cache_dir=HF_CACHE_DIR,
            token=HF_TOKEN,
            device_map=device,
        )
    else:
        # Load elastic model with specific mode
        pipeline = ElasticDiffusionPipeline.from_pretrained(
            MODEL_NAME,
            torch_dtype=dtype,
            cache_dir=HF_CACHE_DIR,
            token=HF_TOKEN,
            mode=mode,  # 'XL' or 'S'
            device_map=device,
            __model_path=model_path,
        )

    pipeline = pipeline.to(device)
    print(f"✓ Pipeline loaded successfully")
    return pipeline


def enable_dualcache(pipeline, mode='conservative'):
    """
    Enable DualCache on Flux pipeline.

    This demonstrates how to apply cache-dit to elastic models:

    - 'aggressive': More caching, higher speedup
      * Fn=1, rdt=0.2, max_continuous=10
    - 'conservative': Safer caching, balanced speedup
      * Fn=4, rdt=0.05, max_continuous=3
    """
    print(f"\nEnabling DualCache ({mode} mode)...")

    # Configure cache parameters based on mode
    if mode == 'aggressive':
        Fn = 1
        Bn = 0
        max_warmup_steps = 8
        max_continuous_cached_steps = 10
        residual_diff_threshold = 0.2
    else:  # conservative
        Fn = 4
        Bn = 0
        max_warmup_steps = 8
        max_continuous_cached_steps = 3
        residual_diff_threshold = 0.05

    # Create cache configuration
    cache_config = BasicCacheConfig(
        Fn_compute_blocks=Fn,
        Bn_compute_blocks=Bn,
        max_warmup_steps=max_warmup_steps,
        max_cached_steps=-1,
        max_continuous_cached_steps=max_continuous_cached_steps,
        residual_diff_threshold=residual_diff_threshold,
    )

    # Check if blocks are compiled
    first_block = pipeline.transformer.transformer_blocks[0]
    is_compiled = (
        hasattr(first_block, '__wrapped__') or
        'Compiled' in type(first_block).__name__
    )

    # Enable cache using BlockAdapter for Flux architecture
    cache_dit.enable_cache(
        BlockAdapter(
            pipe=pipeline,
            transformer=pipeline.transformer,
            blocks=[
                pipeline.transformer.transformer_blocks,
                pipeline.transformer.single_transformer_blocks,
            ],
            forward_pattern=[
                ForwardPattern.Pattern_1,
                ForwardPattern.Pattern_3,
            ],
            check_forward_pattern=not is_compiled,
        ),
        cache_config=cache_config
    )

    print(f"✓ DualCache enabled: Fn={Fn}, "
          f"rdt={residual_diff_threshold}, "
          f"max_continuous={max_continuous_cached_steps}")


def disable_cache(pipeline):
    """Disable cache on pipeline."""
    cache_dit.disable_cache(pipeline)
    print("✓ Cache disabled")


def generate_and_time(pipeline, prompt, seed=42):
    """Generate image and measure time."""
    set_seed(seed)
    generator = torch.Generator(device=device).manual_seed(seed)

    torch.cuda.synchronize()
    start_time = time.time()

    result = pipeline(
        prompt=prompt,
        width=WIDTH,
        height=HEIGHT,
        num_inference_steps=NUM_INFERENCE_STEPS,
        guidance_scale=GUIDANCE_SCALE,
        generator=generator,
    )

    torch.cuda.synchronize()
    elapsed = time.time() - start_time

    return result.images[0], elapsed

4. Experiment: Mode XL


We’ll compare: 1. XL mode without caching (baseline) 2. XL mode with DualCache Aggressive 3. XL mode with DualCache Conservative

results_xl = {}

print("\n" + "="*70)
print("EXPERIMENT: Mode XL")
print("="*70)

# Load XL pipeline
pipeline_xl = load_elastic_pipeline(
    mode='XL', model_path=MODEL_PATH
)
✓ Pipeline loaded successfully (Mode: XL)

4.1 XL - No Caching (Baseline)

# Warmup
print("\nWarmup...")
_ = generate_and_time(
    pipeline_xl, TEST_PROMPTS[0], seed=SEED
)
print("Warmup complete")

# Generate test images without caching
print("\nGenerating images (no caching)...")
images_xl_nocache = []
times_xl_nocache = []

for i, prompt in enumerate(TEST_PROMPTS):
    print(f"\nPrompt {i+1}/{len(TEST_PROMPTS)}: "
          f"{prompt}")
    image, elapsed = generate_and_time(
        pipeline_xl, prompt, seed=SEED+i
    )
    images_xl_nocache.append(image)
    times_xl_nocache.append(elapsed)
    print(f"Time: {elapsed:.2f}s")

avg_time_xl_nocache = np.mean(times_xl_nocache)
print(f"\nAverage time (XL no caching): "
      f"{avg_time_xl_nocache:.2f}s")

results_xl['no_cache'] = {
    'images': images_xl_nocache,
    'times': times_xl_nocache,
    'avg_time': avg_time_xl_nocache
}
Average time (XL no caching): 4.17s

4.2 XL - DualCache Aggressive

enable_dualcache(pipeline_xl, mode='aggressive')

# Warmup with cache
print("\nWarmup with cache...")
_ = generate_and_time(
    pipeline_xl, TEST_PROMPTS[0], seed=SEED
)
print("Warmup complete")

# Generate test images
print("\nGenerating images (DualCache Aggressive)...")
images_xl_aggressive = []
times_xl_aggressive = []

for i, prompt in enumerate(TEST_PROMPTS):
    print(f"\nPrompt {i+1}/{len(TEST_PROMPTS)}: "
          f"{prompt}")
    image, elapsed = generate_and_time(
        pipeline_xl, prompt, seed=SEED+i
    )
    images_xl_aggressive.append(image)
    times_xl_aggressive.append(elapsed)
    print(f"Time: {elapsed:.2f}s")

avg_time_xl_aggressive = np.mean(times_xl_aggressive)
speedup_xl_aggressive = (
    avg_time_xl_nocache / avg_time_xl_aggressive
)
print(f"\nAverage time (XL aggressive): "
      f"{avg_time_xl_aggressive:.2f}s")
print(f"Speedup: {speedup_xl_aggressive:.2f}x")

cache_stats_xl_aggressive = cache_dit.summary(pipeline_xl)

results_xl['aggressive'] = {
    'images': images_xl_aggressive,
    'times': times_xl_aggressive,
    'avg_time': avg_time_xl_aggressive,
    'speedup': speedup_xl_aggressive,
    'cache_stats': cache_stats_xl_aggressive
}

disable_cache(pipeline_xl)
✓ DualCache enabled: Fn=1, rdt=0.2, max_continuous=10
Average time (XL aggressive): 2.13s
Speedup: 1.96x

4.3 XL - DualCache Conservative

enable_dualcache(pipeline_xl, mode='conservative')

# Warmup with cache
print("\nWarmup with cache...")
_ = generate_and_time(
    pipeline_xl, TEST_PROMPTS[0], seed=SEED
)
print("Warmup complete")

# Generate test images
print("\nGenerating images (DualCache Conservative)...")
images_xl_conservative = []
times_xl_conservative = []

for i, prompt in enumerate(TEST_PROMPTS):
    print(f"\nPrompt {i+1}/{len(TEST_PROMPTS)}: "
          f"{prompt}")
    image, elapsed = generate_and_time(
        pipeline_xl, prompt, seed=SEED+i
    )
    images_xl_conservative.append(image)
    times_xl_conservative.append(elapsed)
    print(f"Time: {elapsed:.2f}s")

avg_time_xl_conservative = np.mean(times_xl_conservative)
speedup_xl_conservative = (
    avg_time_xl_nocache / avg_time_xl_conservative
)
print(f"\nAverage time (XL conservative): "
      f"{avg_time_xl_conservative:.2f}s")
print(f"Speedup: {speedup_xl_conservative:.2f}x")

cache_stats_xl_conservative = cache_dit.summary(pipeline_xl)

results_xl['conservative'] = {
    'images': images_xl_conservative,
    'times': times_xl_conservative,
    'avg_time': avg_time_xl_conservative,
    'speedup': speedup_xl_conservative,
    'cache_stats': cache_stats_xl_conservative
}

disable_cache(pipeline_xl)
del pipeline_xl
torch.cuda.empty_cache()
✓ DualCache enabled: Fn=4, rdt=0.05, max_continuous=3
Average time (XL conservative): 3.33s
Speedup: 1.25x

5. Experiment: Mode S


We’ll compare: 1. S mode without caching (baseline) 2. S mode with DualCache Aggressive 3. S mode with DualCache Conservative

results_s = {}

# Load S pipeline
pipeline_s = load_elastic_pipeline(
    mode='S', model_path=MODEL_PATH
)
✓ Pipeline loaded successfully (Mode: S)

5.1 S - No Caching (Baseline)

# Warmup
print("\nWarmup...")
_ = generate_and_time(
    pipeline_s, TEST_PROMPTS[0], seed=SEED
)
print("Warmup complete")

# Generate test images without caching
print("\nGenerating images (no caching)...")
images_s_nocache = []
times_s_nocache = []

for i, prompt in enumerate(TEST_PROMPTS):
    print(f"\nPrompt {i+1}/{len(TEST_PROMPTS)}: "
          f"{prompt}")
    image, elapsed = generate_and_time(
        pipeline_s, prompt, seed=SEED+i
    )
    images_s_nocache.append(image)
    times_s_nocache.append(elapsed)
    print(f"Time: {elapsed:.2f}s")

avg_time_s_nocache = np.mean(times_s_nocache)
print(f"\nAverage time (S no caching): "
      f"{avg_time_s_nocache:.2f}s")

results_s['no_cache'] = {
    'images': images_s_nocache,
    'times': times_s_nocache,
    'avg_time': avg_time_s_nocache
}
Average time (S no caching): 2.71s

5.2 S - DualCache Aggressive

enable_dualcache(pipeline_s, mode='aggressive')

# Warmup with cache
print("\nWarmup with cache...")
_ = generate_and_time(
    pipeline_s, TEST_PROMPTS[0], seed=SEED
)
print("Warmup complete")

# Generate test images
print("\nGenerating images (DualCache Aggressive)...")
images_s_aggressive = []
times_s_aggressive = []

for i, prompt in enumerate(TEST_PROMPTS):
    print(f"\nPrompt {i+1}/{len(TEST_PROMPTS)}: "
          f"{prompt}")
    image, elapsed = generate_and_time(
        pipeline_s, prompt, seed=SEED+i
    )
    images_s_aggressive.append(image)
    times_s_aggressive.append(elapsed)
    print(f"Time: {elapsed:.2f}s")

avg_time_s_aggressive = np.mean(times_s_aggressive)
speedup_s_aggressive = (
    avg_time_s_nocache / avg_time_s_aggressive
)
print(f"\nAverage time (S aggressive): "
      f"{avg_time_s_aggressive:.2f}s")
print(f"Speedup: {speedup_s_aggressive:.2f}x")

cache_stats_s_aggressive = cache_dit.summary(pipeline_s)

results_s['aggressive'] = {
    'images': images_s_aggressive,
    'times': times_s_aggressive,
    'avg_time': avg_time_s_aggressive,
    'speedup': speedup_s_aggressive,
    'cache_stats': cache_stats_s_aggressive
}

disable_cache(pipeline_s)
✓ DualCache enabled: Fn=1, rdt=0.2, max_continuous=10
Average time (S aggressive): 1.45s
Speedup: 1.87x

5.3 S - DualCache Conservative

enable_dualcache(pipeline_s, mode='conservative')

# Warmup with cache
print("\nWarmup with cache...")
_ = generate_and_time(
    pipeline_s, TEST_PROMPTS[0], seed=SEED
)
print("Warmup complete")

# Generate test images
print("\nGenerating images (DualCache Conservative)...")
images_s_conservative = []
times_s_conservative = []

for i, prompt in enumerate(TEST_PROMPTS):
    print(f"\nPrompt {i+1}/{len(TEST_PROMPTS)}: "
          f"{prompt}")
    image, elapsed = generate_and_time(
        pipeline_s, prompt, seed=SEED+i
    )
    images_s_conservative.append(image)
    times_s_conservative.append(elapsed)
    print(f"Time: {elapsed:.2f}s")

avg_time_s_conservative = np.mean(times_s_conservative)
speedup_s_conservative = (
    avg_time_s_nocache / avg_time_s_conservative
)
print(f"\nAverage time (S conservative): "
      f"{avg_time_s_conservative:.2f}s")
print(f"Speedup: {speedup_s_conservative:.2f}x")

cache_stats_s_conservative = cache_dit.summary(pipeline_s)

results_s['conservative'] = {
    'images': images_s_conservative,
    'times': times_s_conservative,
    'avg_time': avg_time_s_conservative,
    'speedup': speedup_s_conservative,
    'cache_stats': cache_stats_s_conservative
}

disable_cache(pipeline_s)
del pipeline_s
torch.cuda.empty_cache()
✓ DualCache enabled: Fn=4, rdt=0.05, max_continuous=3
Average time (S conservative): 2.59s
Speedup: 1.05x

6. Performance Summary


# Print performance summary using utility function
print_performance_summary(results_xl, results_s)
================================================================================
PERFORMANCE SUMMARY
================================================================================

--------------------------------------------------------------------------------
MODE XL:
--------------------------------------------------------------------------------
Config               Avg Time (s)    Speedup    Cache Hit Rate
--------------------------------------------------------------------------------
No Cache             4.17            1.00x      N/A
Aggressive           2.13            1.96      x 78.9%
Conservative         3.33            1.25      x 35.0%

--------------------------------------------------------------------------------
MODE S:
--------------------------------------------------------------------------------
Config               Avg Time (s)    Speedup    Cache Hit Rate
--------------------------------------------------------------------------------
No Cache             2.71            1.00x      N/A
Aggressive           1.45            1.87      x 78.9%
Conservative         2.59            1.05      x 10.0%

--------------------------------------------------------------------------------
OVERALL ANALYSIS:
--------------------------------------------------------------------------------
Fastest configuration: S Aggressive - 1.45s

Best XL speedup from caching: 1.96x
Best S speedup from caching: 1.87x

Recommendations:
  • For maximum speed: S + Aggressive Cache
  • For quality-critical: XL + Conservative Cache
  • For balanced: XL + Aggressive Cache or S + Conservative Cache
================================================================================

7. Visualize Results


# Visualize XL results using utility function
print("\nXL Mode Comparison:")
for i in range(len(TEST_PROMPTS)):
    visualize_comparison(results_xl, 'XL', TEST_PROMPTS, OUTPUT_DIR, prompt_idx=i)
../../_images/flux_caching_tutorial_27_1.png ../../_images/flux_caching_tutorial_27_2.png ../../_images/flux_caching_tutorial_27_3.png
# Visualize S results using utility function
print("\nS Mode Comparison:")
for i in range(len(TEST_PROMPTS)):
    visualize_comparison(results_s, 'S', TEST_PROMPTS, OUTPUT_DIR, prompt_idx=i)
../../_images/flux_caching_tutorial_28_1.png ../../_images/flux_caching_tutorial_28_2.png ../../_images/flux_caching_tutorial_28_3.png

8. Performance Bar Charts


# Create performance comparison charts using utility function
create_performance_charts(results_xl, results_s, OUTPUT_DIR)
../../_images/flux_caching_tutorial_30_0.png

9. Conclusions


Key Findings:

  1. Caching provides significant speedups: Both aggressive and conservative caching strategies provide substantial performance improvements over no caching.

  2. Aggressive vs Conservative trade-off:

    • Aggressive: Higher speedup, but may have slightly lower quality in very dynamic scenes

    • Conservative: Balanced speedup with better quality preservation

  3. Elastic modes + Caching = Maximum speedup: Combining a smaller elastic model (S) with aggressive caching provides the best performance with acceptable quality trade-offs.

  4. Visual quality: As seen in the comparisons above, the quality difference between caching strategies is often minimal for most prompts.

Recommendations:

  • For production/quality-critical: Use XL + Conservative Cache

  • For fast iteration/prototyping: Use S + Aggressive Cache

  • For balanced performance: Use XL + Aggressive Cache or S + Conservative Cache

Next Steps:

  1. Try with your own prompts and use cases

  2. Experiment with custom cache configurations (Fn, Bn, rdt parameters)

  3. Measure quality metrics on larger datasets

  4. Combine with other optimizations (quantization, compilation)