diff --git a/modules/models_settings.py b/modules/models_settings.py index 3a2400d4..6b9493ca 100644 --- a/modules/models_settings.py +++ b/modules/models_settings.py @@ -379,12 +379,15 @@ def estimate_vram(gguf_file, gpu_layers, ctx_size, cache_type): return vram -def get_nvidia_free_vram(): +def get_nvidia_vram(return_free=True): """ - Calculates the total free VRAM across all NVIDIA GPUs by parsing nvidia-smi output. + Calculates VRAM statistics across all NVIDIA GPUs by parsing nvidia-smi output. + + Args: + return_free (bool): If True, returns free VRAM. If False, returns total VRAM. Returns: - int: The total free VRAM in MiB summed across all detected NVIDIA GPUs. + int: Either the total free VRAM or total VRAM in MiB summed across all detected NVIDIA GPUs. Returns -1 if nvidia-smi command fails (not found, error, etc.). Returns 0 if nvidia-smi succeeds but no GPU memory info found. """ @@ -412,17 +415,21 @@ def get_nvidia_free_vram(): # No GPUs found in expected format return 0 + total_vram_mib = 0 total_free_vram_mib = 0 + for used_mem_str, total_mem_str in matches: try: used_mib = int(used_mem_str) total_mib = int(total_mem_str) + total_vram_mib += total_mib total_free_vram_mib += (total_mib - used_mib) except ValueError: # Skip malformed entries pass - return total_free_vram_mib + # Return either free or total VRAM based on the flag + return total_free_vram_mib if return_free else total_vram_mib except FileNotFoundError: # nvidia-smi not found (likely no NVIDIA drivers installed) @@ -473,8 +480,10 @@ def update_gpu_layers_and_vram(loader, model, gpu_layers, ctx_size, cache_type, # No user setting, auto-adjust from the maximum current_layers = max_layers # Start from max - # Auto-adjust based on available VRAM - available_vram = get_nvidia_free_vram() + # Auto-adjust based on available/total VRAM + # If a model is loaded and it's for the UI, use the total VRAM to avoid confusion + return_free = False if (for_ui and shared.model_name not in [None, 'None']) else True + available_vram = get_nvidia_vram(return_free=return_free) if available_vram > 0: tolerance = 906 while current_layers > 0 and estimate_vram(model, current_layers, ctx_size, cache_type) > available_vram - tolerance: