FLUX.1-lite-8B (1024x1024) - Neuron TP4
Compiled FLUX.1-lite-8B model for AWS Inferentia2 with Tensor Parallelism degree 4 (TP=4).
Model Details
- Resolution: 1024x1024
- Tensor Parallelism: 4 cores per inference
- Recommended Instance: inf2.8xlarge (1 worker), inf2.24xlarge (12 cores = 3 workers), inf2.48xlarge (24 cores = 6 workers)
- Guidance Scale: 3.5 (recommended)
- Steps: 28 (recommended)
Installation
pip install optimum[neuronx] torch-neuronx neuronx-cc --extra-index-url=https://pip.repos.neuron.amazonaws.com
Usage
Single Inference (Non-Parallel)
Simple single-image generation:
#!/usr/bin/env python3
# single_inference_tp4.py
import os
import time
import torch
from optimum.neuron import NeuronFluxPipeline
MODEL_DIR = "/home/ubuntu/flux.1-lite-8B_1024x1024_neuronx_tp4"
GUIDANCE = 3.5
STEPS = 28
OUT_DIR = "outputs_single"
os.makedirs(OUT_DIR, exist_ok=True)
# Load the model
print("Loading model...")
t0 = time.perf_counter()
pipe = NeuronFluxPipeline.from_pretrained(MODEL_DIR)
t1 = time.perf_counter()
print(f"Model loaded in {t1 - t0:.2f}s")
# Generate image
prompt = "A futuristic city skyline at golden hour, ultra-detailed, cinematic lighting"
seed = 42
generator = torch.Generator(device="cpu").manual_seed(seed)
print(f"Generating image...")
t2 = time.perf_counter()
image = pipe(
prompt=prompt,
num_images_per_prompt=1,
num_inference_steps=STEPS,
guidance_scale=GUIDANCE,
generator=generator,
).images[0]
t3 = time.perf_counter()
out_path = os.path.join(OUT_DIR, "output.png")
image.save(out_path)
print(f"Image generated in {t3 - t2:.2f}s -> {out_path}")
Parallel Inference (Multi-Worker)
For inf2.24xlarge (12 cores = 3 workers) or inf2.48xlarge (24 cores = 6 workers), run multiple workers in parallel:
#!/usr/bin/env python3
# parallel_flux_tp4_batch.py
import os, time, multiprocessing as mp
from typing import Sequence
MODEL_DIR = "/home/ubuntu/flux.1-lite-8B_1024x1024_neuronx_tp4"
GUIDANCE = 3.5
STEPS = 28
# 3 workers ร 4 cores each = 12 cores (6 chips) on inf2.24xlarge (TP=4 per worker)
# For inf2.48xlarge (24 cores), use: ["0-3", "4-7", "8-11", "12-15", "16-19", "20-23"]
CORE_GROUPS: Sequence[str] = ["0-3", "4-7", "8-11"]
# 10 prompts (edit as you like)
PROMPTS: list[str] = [
"A futuristic city skyline at golden hour, ultra-detailed, cinematic lighting",
"A serene mossy forest with sunbeams and floating dust motes, photorealistic",
"Cyberpunk alley in the rain, neon reflections, puddles, street vendors",
"Watercolor painting of koi fish swirling in a pond, soft brush strokes",
"Isometric cozy bedroom, warm light, plants, bookshelf, minimalist decor",
"Macro photo of a dewdrop on a leaf, extreme detail, shallow depth of field",
"Spaceship approaching a ringed planet, volumetric clouds, dramatic scale",
"Minimal product photo of wireless earbuds on marble, studio lighting",
"Vaporwave beach at night with palm trees and a giant moon on the horizon",
"Ancient library with towering shelves, ladders, soft shafts of light",
]
IMAGES_PER_WORKER = 10 # generate 10 images sequentially in each worker
OUT_DIR = "outputs_flux_parallel_tp4" # outputs/<worker>_<idx>.png
def worker(idx: int, core_range: str, barrier: mp.Barrier):
# ---- isolate this worker BEFORE importing neuron/torch ----
os.environ["NEURON_RT_VISIBLE_CORES"] = core_range
os.environ["NEURON_RT_ROOT_COMM_ID"] = f"127.0.0.1:{62080 + idx}" # unique per worker
if os.environ.get("NEURON_RT_DISABLE_EXECUTION_BARRIER") == "1":
del os.environ["NEURON_RT_DISABLE_EXECUTION_BARRIER"]
os.environ.setdefault("FI_PROVIDER", "tcp")
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MALLOC_ARENA_MAX", "8")
import torch
from optimum.neuron import NeuronFluxPipeline
os.makedirs(OUT_DIR, exist_ok=True)
# ---- LOAD ONCE ----
t0 = time.perf_counter()
pipe = NeuronFluxPipeline.from_pretrained(MODEL_DIR)
t1 = time.perf_counter()
print(f"[{core_range}] load: {t1 - t0:.2f}s", flush=True)
# ---- WAIT UNTIL ALL WORKERS LOADED ----
barrier.wait()
# ---- GENERATE SEQUENTIALLY ----
# make seeds distinct per worker & image for reproducibility
gen_start = time.perf_counter()
for i in range(IMAGES_PER_WORKER):
prompt = PROMPTS[i % len(PROMPTS)]
seed = 1000 * idx + i
g = torch.Generator(device="cpu").manual_seed(seed)
t2 = time.perf_counter()
img = pipe(
prompt=prompt,
num_images_per_prompt=1,
num_inference_steps=STEPS,
guidance_scale=GUIDANCE,
generator=g,
).images[0]
t3 = time.perf_counter()
out_path = os.path.join(OUT_DIR, f"w{idx}_i{i:02d}.png")
img.save(out_path)
print(f"[{core_range}] img {i+1}/{IMAGES_PER_WORKER} | {t3 - t2:.2f}s -> {out_path}", flush=True)
gen_end = time.perf_counter()
print(f"[{core_range}] total gen time: {gen_end - gen_start:.2f}s "
f"(avg {(gen_end - gen_start)/IMAGES_PER_WORKER:.2f}s/img)", flush=True)
if __name__ == "__main__":
mp.set_start_method("spawn", force=True)
# barrier so all workers start generating only after every pipeline is loaded
barrier = mp.Barrier(len(CORE_GROUPS))
t_all = time.perf_counter()
procs = []
for idx, core_range in enumerate(CORE_GROUPS):
p = mp.Process(target=worker, args=(idx, core_range, barrier), daemon=False)
p.start()
procs.append(p)
for p in procs:
p.join()
print(f"[parallel] total wall time: {time.perf_counter() - t_all:.2f}s; "
f"outputs in: {OUT_DIR}")
Performance Tips
- Core Allocation: Each worker needs 4 consecutive cores (TP=4)
- Memory: Ensure sufficient memory per worker (~12GB per worker)
- Parallelism:
- inf2.8xlarge: 1 worker (cores 0-3)
- inf2.24xlarge: 3 workers (cores 0-3, 4-7, 8-11)
- inf2.48xlarge: 6 workers (cores 0-3, 4-7, 8-11, 12-15, 16-19, 20-23)
- Batch Size: Generate images sequentially within each worker for stability
- TP4 vs TP2: TP4 may offer better performance per image but supports fewer parallel workers
Troubleshooting
- Out of Memory: Reduce number of parallel workers
- Core Conflicts: Ensure
CORE_GROUPSdon't overlap and each group has exactly 4 cores - Slow Performance: Check
NEURON_RT_VISIBLE_CORESis set correctly per worker - Communication Errors: Ensure each worker has unique
NEURON_RT_ROOT_COMM_IDport
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support