FLUX.1-lite-8B (1024x1024) - Neuron TP4

Compiled FLUX.1-lite-8B model for AWS Inferentia2 with Tensor Parallelism degree 4 (TP=4).

Model Details

  • Resolution: 1024x1024
  • Tensor Parallelism: 4 cores per inference
  • Recommended Instance: inf2.8xlarge (1 worker), inf2.24xlarge (12 cores = 3 workers), inf2.48xlarge (24 cores = 6 workers)
  • Guidance Scale: 3.5 (recommended)
  • Steps: 28 (recommended)

Installation

pip install optimum[neuronx] torch-neuronx neuronx-cc --extra-index-url=https://pip.repos.neuron.amazonaws.com

Usage

Single Inference (Non-Parallel)

Simple single-image generation:

#!/usr/bin/env python3
# single_inference_tp4.py
import os
import time
import torch
from optimum.neuron import NeuronFluxPipeline

MODEL_DIR = "/home/ubuntu/flux.1-lite-8B_1024x1024_neuronx_tp4"
GUIDANCE = 3.5
STEPS = 28
OUT_DIR = "outputs_single"

os.makedirs(OUT_DIR, exist_ok=True)

# Load the model
print("Loading model...")
t0 = time.perf_counter()
pipe = NeuronFluxPipeline.from_pretrained(MODEL_DIR)
t1 = time.perf_counter()
print(f"Model loaded in {t1 - t0:.2f}s")

# Generate image
prompt = "A futuristic city skyline at golden hour, ultra-detailed, cinematic lighting"
seed = 42
generator = torch.Generator(device="cpu").manual_seed(seed)

print(f"Generating image...")
t2 = time.perf_counter()
image = pipe(
    prompt=prompt,
    num_images_per_prompt=1,
    num_inference_steps=STEPS,
    guidance_scale=GUIDANCE,
    generator=generator,
).images[0]
t3 = time.perf_counter()

out_path = os.path.join(OUT_DIR, "output.png")
image.save(out_path)
print(f"Image generated in {t3 - t2:.2f}s -> {out_path}")

Parallel Inference (Multi-Worker)

For inf2.24xlarge (12 cores = 3 workers) or inf2.48xlarge (24 cores = 6 workers), run multiple workers in parallel:

#!/usr/bin/env python3
# parallel_flux_tp4_batch.py
import os, time, multiprocessing as mp
from typing import Sequence

MODEL_DIR = "/home/ubuntu/flux.1-lite-8B_1024x1024_neuronx_tp4"
GUIDANCE = 3.5
STEPS = 28

# 3 workers ร— 4 cores each = 12 cores (6 chips) on inf2.24xlarge (TP=4 per worker)
# For inf2.48xlarge (24 cores), use: ["0-3", "4-7", "8-11", "12-15", "16-19", "20-23"]
CORE_GROUPS: Sequence[str] = ["0-3", "4-7", "8-11"]

# 10 prompts (edit as you like)
PROMPTS: list[str] = [
    "A futuristic city skyline at golden hour, ultra-detailed, cinematic lighting",
    "A serene mossy forest with sunbeams and floating dust motes, photorealistic",
    "Cyberpunk alley in the rain, neon reflections, puddles, street vendors",
    "Watercolor painting of koi fish swirling in a pond, soft brush strokes",
    "Isometric cozy bedroom, warm light, plants, bookshelf, minimalist decor",
    "Macro photo of a dewdrop on a leaf, extreme detail, shallow depth of field",
    "Spaceship approaching a ringed planet, volumetric clouds, dramatic scale",
    "Minimal product photo of wireless earbuds on marble, studio lighting",
    "Vaporwave beach at night with palm trees and a giant moon on the horizon",
    "Ancient library with towering shelves, ladders, soft shafts of light",
]

IMAGES_PER_WORKER = 10  # generate 10 images sequentially in each worker
OUT_DIR = "outputs_flux_parallel_tp4"  # outputs/<worker>_<idx>.png

def worker(idx: int, core_range: str, barrier: mp.Barrier):
    # ---- isolate this worker BEFORE importing neuron/torch ----
    os.environ["NEURON_RT_VISIBLE_CORES"] = core_range
    os.environ["NEURON_RT_ROOT_COMM_ID"] = f"127.0.0.1:{62080 + idx}"  # unique per worker
    if os.environ.get("NEURON_RT_DISABLE_EXECUTION_BARRIER") == "1":
        del os.environ["NEURON_RT_DISABLE_EXECUTION_BARRIER"]
    os.environ.setdefault("FI_PROVIDER", "tcp")
    os.environ.setdefault("OMP_NUM_THREADS", "1")
    os.environ.setdefault("MALLOC_ARENA_MAX", "8")

    import torch
    from optimum.neuron import NeuronFluxPipeline

    os.makedirs(OUT_DIR, exist_ok=True)

    # ---- LOAD ONCE ----
    t0 = time.perf_counter()
    pipe = NeuronFluxPipeline.from_pretrained(MODEL_DIR)
    t1 = time.perf_counter()
    print(f"[{core_range}] load: {t1 - t0:.2f}s", flush=True)

    # ---- WAIT UNTIL ALL WORKERS LOADED ----
    barrier.wait()

    # ---- GENERATE SEQUENTIALLY ----
    # make seeds distinct per worker & image for reproducibility
    gen_start = time.perf_counter()
    for i in range(IMAGES_PER_WORKER):
        prompt = PROMPTS[i % len(PROMPTS)]
        seed = 1000 * idx + i
        g = torch.Generator(device="cpu").manual_seed(seed)

        t2 = time.perf_counter()
        img = pipe(
            prompt=prompt,
            num_images_per_prompt=1,
            num_inference_steps=STEPS,
            guidance_scale=GUIDANCE,
            generator=g,
        ).images[0]
        t3 = time.perf_counter()

        out_path = os.path.join(OUT_DIR, f"w{idx}_i{i:02d}.png")
        img.save(out_path)
        print(f"[{core_range}] img {i+1}/{IMAGES_PER_WORKER} | {t3 - t2:.2f}s -> {out_path}", flush=True)

    gen_end = time.perf_counter()
    print(f"[{core_range}] total gen time: {gen_end - gen_start:.2f}s "
          f"(avg {(gen_end - gen_start)/IMAGES_PER_WORKER:.2f}s/img)", flush=True)

if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)

    # barrier so all workers start generating only after every pipeline is loaded
    barrier = mp.Barrier(len(CORE_GROUPS))

    t_all = time.perf_counter()
    procs = []
    for idx, core_range in enumerate(CORE_GROUPS):
        p = mp.Process(target=worker, args=(idx, core_range, barrier), daemon=False)
        p.start()
        procs.append(p)

    for p in procs:
        p.join()

    print(f"[parallel] total wall time: {time.perf_counter() - t_all:.2f}s; "
          f"outputs in: {OUT_DIR}")

Performance Tips

  1. Core Allocation: Each worker needs 4 consecutive cores (TP=4)
  2. Memory: Ensure sufficient memory per worker (~12GB per worker)
  3. Parallelism:
    • inf2.8xlarge: 1 worker (cores 0-3)
    • inf2.24xlarge: 3 workers (cores 0-3, 4-7, 8-11)
    • inf2.48xlarge: 6 workers (cores 0-3, 4-7, 8-11, 12-15, 16-19, 20-23)
  4. Batch Size: Generate images sequentially within each worker for stability
  5. TP4 vs TP2: TP4 may offer better performance per image but supports fewer parallel workers

Troubleshooting

  • Out of Memory: Reduce number of parallel workers
  • Core Conflicts: Ensure CORE_GROUPS don't overlap and each group has exactly 4 cores
  • Slow Performance: Check NEURON_RT_VISIBLE_CORES is set correctly per worker
  • Communication Errors: Ensure each worker has unique NEURON_RT_ROOT_COMM_ID port
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support