Remove gradio monkey-patches (moved to gradio fork)

This commit is contained in:
oobabooga 2026-03-03 17:17:36 -08:00
parent e9f22813e4
commit 2260e530c9
3 changed files with 33 additions and 218 deletions

View file

@ -1,96 +0,0 @@
import builtins
import io
import re
import requests
from modules import shared, ui
from modules.logging_colors import logger
original_open = open
original_get = requests.get
original_print = print
class RequestBlocker:
def __enter__(self):
requests.get = my_get
def __exit__(self, exc_type, exc_value, traceback):
requests.get = original_get
class OpenMonkeyPatch:
def __enter__(self):
builtins.open = my_open
builtins.print = my_print
def __exit__(self, exc_type, exc_value, traceback):
builtins.open = original_open
builtins.print = original_print
def my_get(url, **kwargs):
logger.info('Unwanted HTTP request redirected to localhost :)')
kwargs.setdefault('allow_redirects', True)
return requests.api.request('get', 'http://127.0.0.1/', **kwargs)
def my_open(*args, **kwargs):
filename = str(args[0])
if filename.endswith(('index.html', 'share.html')):
with original_open(*args, **kwargs) as f:
file_contents = f.read()
if len(args) > 1 and args[1] == 'rb':
file_contents = file_contents.decode('utf-8')
file_contents = file_contents.replace('\t\t<script\n\t\t\tsrc="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"\n\t\t\tasync\n\t\t></script>', '')
file_contents = file_contents.replace('cdnjs.cloudflare.com', '127.0.0.1')
file_contents = file_contents.replace(
'</head>',
'\n <link rel="preload" href="file/css/Inter/Inter-VariableFont_opsz,wght.ttf" as="font" type="font/ttf" crossorigin>'
'\n <link rel="preload" href="file/css/Inter/Inter-Italic-VariableFont_opsz,wght.ttf" as="font" type="font/ttf" crossorigin>'
'\n <link rel="preload" href="file/css/NotoSans/NotoSans-Medium.woff2" as="font" type="font/woff2" crossorigin>'
'\n <link rel="preload" href="file/css/NotoSans/NotoSans-MediumItalic.woff2" as="font" type="font/woff2" crossorigin>'
'\n <link rel="preload" href="file/css/NotoSans/NotoSans-Bold.woff2" as="font" type="font/woff2" crossorigin>'
'\n <script src="file/js/katex/katex.min.js"></script>'
'\n <script src="file/js/katex/auto-render.min.js"></script>'
'\n <script src="file/js/highlightjs/highlight.min.js"></script>'
'\n <script src="file/js/highlightjs/highlightjs-copy.min.js"></script>'
'\n <script src="file/js/morphdom/morphdom-umd.min.js"></script>'
f'\n <link id="highlight-css" rel="stylesheet" href="file/css/highlightjs/{"github-dark" if shared.settings["dark_theme"] else "github"}.min.css">'
'\n <script>hljs.addPlugin(new CopyButtonPlugin());</script>'
f'\n <script>{ui.global_scope_js}</script>'
'\n </head>'
)
file_contents = re.sub(
r'@media \(prefers-color-scheme: dark\) \{\s*body \{([^}]*)\}\s*\}',
r'body.dark {\1}',
file_contents,
flags=re.DOTALL
)
if len(args) > 1 and args[1] == 'rb':
file_contents = file_contents.encode('utf-8')
return io.BytesIO(file_contents)
else:
return io.StringIO(file_contents)
else:
return original_open(*args, **kwargs)
def my_print(*args, **kwargs):
if len(args) > 0 and 'To create a public link, set `share=True`' in args[0]:
return
else:
if len(args) > 0 and 'Running on local URL' in args[0]:
args = list(args)
args[0] = f"\n{args[0].strip()}\n"
args = tuple(args)
original_print(*args, **kwargs)

View file

@ -1,97 +0,0 @@
'''
Most of the code here was adapted from:
https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14184
'''
import inspect
import warnings
from functools import wraps
import gradio as gr
import gradio.routes
import gradio.utils
from starlette.middleware.trustedhost import TrustedHostMiddleware
from modules import shared
orig_create_app = gradio.routes.App.create_app
# Be strict about only approving access to localhost by default
def create_app_with_trustedhost(*args, **kwargs):
app = orig_create_app(*args, **kwargs)
if not (shared.args.listen or shared.args.share):
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=["localhost", "127.0.0.1"]
)
return app
gradio.routes.App.create_app = create_app_with_trustedhost
gradio.utils.launch_counter = lambda: None
class GradioDeprecationWarning(DeprecationWarning):
pass
def repair(grclass):
if not getattr(grclass, 'EVENTS', None):
return
@wraps(grclass.__init__)
def __repaired_init__(self, *args, tooltip=None, source=None, original=grclass.__init__, **kwargs):
if source:
kwargs["sources"] = [source]
allowed_kwargs = inspect.signature(original).parameters
fixed_kwargs = {}
for k, v in kwargs.items():
if k in allowed_kwargs:
fixed_kwargs[k] = v
else:
warnings.warn(f"unexpected argument for {grclass.__name__}: {k}", GradioDeprecationWarning, stacklevel=2)
original(self, *args, **fixed_kwargs)
self.webui_tooltip = tooltip
for event in self.EVENTS:
replaced_event = getattr(self, str(event))
def fun(*xargs, _js=None, replaced_event=replaced_event, **xkwargs):
if _js:
xkwargs['js'] = _js
return replaced_event(*xargs, **xkwargs)
setattr(self, str(event), fun)
grclass.__init__ = __repaired_init__
grclass.update = gr.update
for component in set(gr.components.__all__ + gr.layouts.__all__):
repair(getattr(gr, component, None))
class Dependency(gr.events.Dependency):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def then(*xargs, _js=None, **xkwargs):
if _js:
xkwargs['js'] = _js
return original_then(*xargs, **xkwargs)
original_then = self.then
self.then = then
gr.events.Dependency = Dependency
gr.Box = gr.Group

View file

@ -3,13 +3,7 @@ import shutil
import warnings
from pathlib import Path
# Monkey-patch HfFolder for gradio 4.x compatibility with huggingface-hub 1.x
import huggingface_hub
if not hasattr(huggingface_hub, 'HfFolder'):
huggingface_hub.HfFolder = type('HfFolder', (), {'get_token': staticmethod(huggingface_hub.get_token)})
from modules import shared
from modules.block_requests import OpenMonkeyPatch, RequestBlocker
from modules.image_models import load_image_model
from modules.logging_colors import logger
from modules.prompts import load_prompt
@ -32,9 +26,7 @@ warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_na
warnings.filterwarnings('ignore', category=UserWarning, message='The value passed into gr.Dropdown()')
warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_names" has conflict')
with RequestBlocker():
from modules import gradio_hijack
import gradio as gr
import gradio as gr
import matplotlib
@ -148,7 +140,24 @@ def create_interface():
# Interface state elements
shared.input_elements = ui.list_interface_input_elements()
with gr.Blocks(css=css, analytics_enabled=False, title=title, theme=ui.theme) as shared.gradio['interface']:
# Head HTML for font preloads, KaTeX, highlight.js, morphdom, and global JS
head_html = '\n'.join([
'<link rel="preload" href="file/css/Inter/Inter-VariableFont_opsz,wght.ttf" as="font" type="font/ttf" crossorigin>',
'<link rel="preload" href="file/css/Inter/Inter-Italic-VariableFont_opsz,wght.ttf" as="font" type="font/ttf" crossorigin>',
'<link rel="preload" href="file/css/NotoSans/NotoSans-Medium.woff2" as="font" type="font/woff2" crossorigin>',
'<link rel="preload" href="file/css/NotoSans/NotoSans-MediumItalic.woff2" as="font" type="font/woff2" crossorigin>',
'<link rel="preload" href="file/css/NotoSans/NotoSans-Bold.woff2" as="font" type="font/woff2" crossorigin>',
'<script src="file/js/katex/katex.min.js"></script>',
'<script src="file/js/katex/auto-render.min.js"></script>',
'<script src="file/js/highlightjs/highlight.min.js"></script>',
'<script src="file/js/highlightjs/highlightjs-copy.min.js"></script>',
'<script src="file/js/morphdom/morphdom-umd.min.js"></script>',
f'<link id="highlight-css" rel="stylesheet" href="file/css/highlightjs/{"github-dark" if shared.settings["dark_theme"] else "github"}.min.css">',
'<script>hljs.addPlugin(new CopyButtonPlugin());</script>',
f'<script>{ui.global_scope_js}</script>',
])
with gr.Blocks(css=css, analytics_enabled=False, title=title, theme=ui.theme, head=head_html) as shared.gradio['interface']:
# Interface state
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
@ -234,21 +243,20 @@ def create_interface():
# Launch the interface
shared.gradio['interface'].queue()
with OpenMonkeyPatch():
shared.gradio['interface'].launch(
max_threads=64,
prevent_thread_lock=True,
share=shared.args.share,
server_name=None if not shared.args.listen else (shared.args.listen_host or '0.0.0.0'),
server_port=shared.args.listen_port,
inbrowser=shared.args.auto_launch,
auth=auth or None,
ssl_verify=False if (shared.args.ssl_keyfile or shared.args.ssl_certfile) else True,
ssl_keyfile=shared.args.ssl_keyfile,
ssl_certfile=shared.args.ssl_certfile,
root_path=shared.args.subpath,
allowed_paths=allowed_paths,
)
shared.gradio['interface'].launch(
max_threads=64,
prevent_thread_lock=True,
share=shared.args.share,
server_name=None if not shared.args.listen else (shared.args.listen_host or '0.0.0.0'),
server_port=shared.args.listen_port,
inbrowser=shared.args.auto_launch,
auth=auth or None,
ssl_verify=False if (shared.args.ssl_keyfile or shared.args.ssl_certfile) else True,
ssl_keyfile=shared.args.ssl_keyfile,
ssl_certfile=shared.args.ssl_certfile,
root_path=shared.args.subpath,
allowed_paths=allowed_paths,
)
if __name__ == "__main__":