Image generation: Yield partial results for batch count > 1

This commit is contained in:
oobabooga 2025-12-03 16:13:07 -08:00
parent 49c60882bf
commit fbca54957e

View file

@ -677,7 +677,8 @@ def generate(state):
if not model_name or model_name == 'None':
logger.error("No image model selected. Go to the Model tab and select a model.")
return []
yield []
return
if shared.image_model is None:
result = load_image_model(
@ -690,7 +691,8 @@ def generate(state):
)
if result is None:
logger.error(f"Failed to load model `{model_name}`.")
return []
yield []
return
shared.image_model_name = model_name
@ -760,6 +762,7 @@ def generate(state):
generator.manual_seed(int(seed + i))
batch_results = shared.image_model(**gen_kwargs).images
all_images.extend(batch_results)
yield all_images
t1 = time.time()
save_generated_images(all_images, state, seed)
@ -768,12 +771,12 @@ def generate(state):
total_steps = state["image_steps"] * int(state['image_batch_count'])
logger.info(f'Generated {total_images} {"image" if total_images == 1 else "images"} in {(t1 - t0):.2f} seconds ({total_steps / (t1 - t0):.2f} steps/s, seed {seed})')
return all_images
yield all_images
except Exception as e:
logger.error(f"Image generation failed: {e}")
traceback.print_exc()
return []
yield []
def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compile_model, quant_method):