mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-01-15 13:10:20 +01:00
Add exception handling while generating images
This commit is contained in:
parent
62cf5aa51c
commit
0c4bdb19c8
|
|
@ -633,88 +633,94 @@ def generate(state):
|
|||
"""
|
||||
import torch
|
||||
|
||||
model_name = state['image_model_menu']
|
||||
try:
|
||||
model_name = state['image_model_menu']
|
||||
|
||||
if not model_name or model_name == 'None':
|
||||
logger.error("No image model selected. Go to the Model tab and select a model.")
|
||||
return []
|
||||
|
||||
if shared.image_model is None:
|
||||
result = load_image_model(
|
||||
model_name,
|
||||
dtype=state['image_dtype'],
|
||||
attn_backend=state['image_attn_backend'],
|
||||
cpu_offload=state['image_cpu_offload'],
|
||||
compile_model=state['image_compile'],
|
||||
quant_method=state['image_quant']
|
||||
)
|
||||
if result is None:
|
||||
logger.error(f"Failed to load model `{model_name}`.")
|
||||
if not model_name or model_name == 'None':
|
||||
logger.error("No image model selected. Go to the Model tab and select a model.")
|
||||
return []
|
||||
|
||||
shared.image_model_name = model_name
|
||||
if shared.image_model is None:
|
||||
result = load_image_model(
|
||||
model_name,
|
||||
dtype=state['image_dtype'],
|
||||
attn_backend=state['image_attn_backend'],
|
||||
cpu_offload=state['image_cpu_offload'],
|
||||
compile_model=state['image_compile'],
|
||||
quant_method=state['image_quant']
|
||||
)
|
||||
if result is None:
|
||||
logger.error(f"Failed to load model `{model_name}`.")
|
||||
return []
|
||||
|
||||
seed = state['image_seed']
|
||||
if seed == -1:
|
||||
seed = np.random.randint(0, 2**32 - 1)
|
||||
shared.image_model_name = model_name
|
||||
|
||||
device = get_device()
|
||||
if device is None:
|
||||
device = "cpu"
|
||||
generator = torch.Generator(device).manual_seed(int(seed))
|
||||
seed = state['image_seed']
|
||||
if seed == -1:
|
||||
seed = np.random.randint(0, 2**32 - 1)
|
||||
|
||||
all_images = []
|
||||
device = get_device()
|
||||
if device is None:
|
||||
device = "cpu"
|
||||
generator = torch.Generator(device).manual_seed(int(seed))
|
||||
|
||||
# Get pipeline type for parameter adjustment
|
||||
pipeline_type = getattr(shared, 'image_pipeline_type', None)
|
||||
if pipeline_type is None:
|
||||
pipeline_type = get_pipeline_type(shared.image_model)
|
||||
all_images = []
|
||||
|
||||
# Process Prompt
|
||||
prompt = state['image_prompt']
|
||||
# Get pipeline type for parameter adjustment
|
||||
pipeline_type = getattr(shared, 'image_pipeline_type', None)
|
||||
if pipeline_type is None:
|
||||
pipeline_type = get_pipeline_type(shared.image_model)
|
||||
|
||||
# Apply "Positive Magic" for Qwen models only
|
||||
if pipeline_type == 'qwenimage':
|
||||
magic_suffix = ", Ultra HD, 4K, cinematic composition"
|
||||
# Avoid duplication if user already added it
|
||||
if magic_suffix.strip(", ") not in prompt:
|
||||
prompt += magic_suffix
|
||||
# Process Prompt
|
||||
prompt = state['image_prompt']
|
||||
|
||||
# Build generation kwargs
|
||||
gen_kwargs = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": state['image_neg_prompt'],
|
||||
"height": int(state['image_height']),
|
||||
"width": int(state['image_width']),
|
||||
"num_inference_steps": int(state['image_steps']),
|
||||
"num_images_per_prompt": int(state['image_batch_size']),
|
||||
"generator": generator,
|
||||
}
|
||||
# Apply "Positive Magic" for Qwen models only
|
||||
if pipeline_type == 'qwenimage':
|
||||
magic_suffix = ", Ultra HD, 4K, cinematic composition"
|
||||
# Avoid duplication if user already added it
|
||||
if magic_suffix.strip(", ") not in prompt:
|
||||
prompt += magic_suffix
|
||||
|
||||
# Add pipeline-specific parameters for CFG
|
||||
cfg_val = state.get('image_cfg_scale', 0.0)
|
||||
# Build generation kwargs
|
||||
gen_kwargs = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": state['image_neg_prompt'],
|
||||
"height": int(state['image_height']),
|
||||
"width": int(state['image_width']),
|
||||
"num_inference_steps": int(state['image_steps']),
|
||||
"num_images_per_prompt": int(state['image_batch_size']),
|
||||
"generator": generator,
|
||||
}
|
||||
|
||||
if pipeline_type == 'qwenimage':
|
||||
# Qwen-Image uses true_cfg_scale (typically 4.0)
|
||||
gen_kwargs["true_cfg_scale"] = cfg_val
|
||||
else:
|
||||
# Z-Image and others use guidance_scale (typically 0.0 for Turbo)
|
||||
gen_kwargs["guidance_scale"] = cfg_val
|
||||
# Add pipeline-specific parameters for CFG
|
||||
cfg_val = state.get('image_cfg_scale', 0.0)
|
||||
|
||||
t0 = time.time()
|
||||
for i in range(int(state['image_batch_count'])):
|
||||
generator.manual_seed(int(seed + i))
|
||||
batch_results = shared.image_model(**gen_kwargs).images
|
||||
all_images.extend(batch_results)
|
||||
if pipeline_type == 'qwenimage':
|
||||
# Qwen-Image uses true_cfg_scale (typically 4.0)
|
||||
gen_kwargs["true_cfg_scale"] = cfg_val
|
||||
else:
|
||||
# Z-Image and others use guidance_scale (typically 0.0 for Turbo)
|
||||
gen_kwargs["guidance_scale"] = cfg_val
|
||||
|
||||
t1 = time.time()
|
||||
save_generated_images(all_images, state, seed)
|
||||
t0 = time.time()
|
||||
for i in range(int(state['image_batch_count'])):
|
||||
generator.manual_seed(int(seed + i))
|
||||
batch_results = shared.image_model(**gen_kwargs).images
|
||||
all_images.extend(batch_results)
|
||||
|
||||
total_images = int(state['image_batch_count']) * int(state['image_batch_size'])
|
||||
total_steps = state["image_steps"] * int(state['image_batch_count'])
|
||||
logger.info(f'Generated {total_images} images in {(t1-t0):.2f} seconds ({total_steps/(t1-t0):.2f} steps/s, seed {seed})')
|
||||
t1 = time.time()
|
||||
save_generated_images(all_images, state, seed)
|
||||
|
||||
return all_images
|
||||
total_images = int(state['image_batch_count']) * int(state['image_batch_size'])
|
||||
total_steps = state["image_steps"] * int(state['image_batch_count'])
|
||||
logger.info(f'Generated {total_images} images in {(t1-t0):.2f} seconds ({total_steps/(t1-t0):.2f} steps/s, seed {seed})')
|
||||
|
||||
return all_images
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Image generation failed: {e}")
|
||||
traceback.print_exc()
|
||||
return []
|
||||
|
||||
|
||||
def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compile_model, quant_method):
|
||||
|
|
|
|||
Loading…
Reference in a new issue