mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2026-03-06 21:53:50 +01:00
Remove gradio monkey-patches (moved to gradio fork)
This commit is contained in:
parent
e9f22813e4
commit
2260e530c9
|
|
@ -1,96 +0,0 @@
|
|||
import builtins
|
||||
import io
|
||||
import re
|
||||
|
||||
import requests
|
||||
|
||||
from modules import shared, ui
|
||||
from modules.logging_colors import logger
|
||||
|
||||
original_open = open
|
||||
original_get = requests.get
|
||||
original_print = print
|
||||
|
||||
|
||||
class RequestBlocker:
|
||||
|
||||
def __enter__(self):
|
||||
requests.get = my_get
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
requests.get = original_get
|
||||
|
||||
|
||||
class OpenMonkeyPatch:
|
||||
|
||||
def __enter__(self):
|
||||
builtins.open = my_open
|
||||
builtins.print = my_print
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
builtins.open = original_open
|
||||
builtins.print = original_print
|
||||
|
||||
|
||||
def my_get(url, **kwargs):
|
||||
logger.info('Unwanted HTTP request redirected to localhost :)')
|
||||
kwargs.setdefault('allow_redirects', True)
|
||||
return requests.api.request('get', 'http://127.0.0.1/', **kwargs)
|
||||
|
||||
|
||||
def my_open(*args, **kwargs):
|
||||
filename = str(args[0])
|
||||
if filename.endswith(('index.html', 'share.html')):
|
||||
with original_open(*args, **kwargs) as f:
|
||||
file_contents = f.read()
|
||||
|
||||
if len(args) > 1 and args[1] == 'rb':
|
||||
file_contents = file_contents.decode('utf-8')
|
||||
|
||||
file_contents = file_contents.replace('\t\t<script\n\t\t\tsrc="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"\n\t\t\tasync\n\t\t></script>', '')
|
||||
file_contents = file_contents.replace('cdnjs.cloudflare.com', '127.0.0.1')
|
||||
file_contents = file_contents.replace(
|
||||
'</head>',
|
||||
'\n <link rel="preload" href="file/css/Inter/Inter-VariableFont_opsz,wght.ttf" as="font" type="font/ttf" crossorigin>'
|
||||
'\n <link rel="preload" href="file/css/Inter/Inter-Italic-VariableFont_opsz,wght.ttf" as="font" type="font/ttf" crossorigin>'
|
||||
'\n <link rel="preload" href="file/css/NotoSans/NotoSans-Medium.woff2" as="font" type="font/woff2" crossorigin>'
|
||||
'\n <link rel="preload" href="file/css/NotoSans/NotoSans-MediumItalic.woff2" as="font" type="font/woff2" crossorigin>'
|
||||
'\n <link rel="preload" href="file/css/NotoSans/NotoSans-Bold.woff2" as="font" type="font/woff2" crossorigin>'
|
||||
'\n <script src="file/js/katex/katex.min.js"></script>'
|
||||
'\n <script src="file/js/katex/auto-render.min.js"></script>'
|
||||
'\n <script src="file/js/highlightjs/highlight.min.js"></script>'
|
||||
'\n <script src="file/js/highlightjs/highlightjs-copy.min.js"></script>'
|
||||
'\n <script src="file/js/morphdom/morphdom-umd.min.js"></script>'
|
||||
f'\n <link id="highlight-css" rel="stylesheet" href="file/css/highlightjs/{"github-dark" if shared.settings["dark_theme"] else "github"}.min.css">'
|
||||
'\n <script>hljs.addPlugin(new CopyButtonPlugin());</script>'
|
||||
f'\n <script>{ui.global_scope_js}</script>'
|
||||
'\n </head>'
|
||||
)
|
||||
|
||||
file_contents = re.sub(
|
||||
r'@media \(prefers-color-scheme: dark\) \{\s*body \{([^}]*)\}\s*\}',
|
||||
r'body.dark {\1}',
|
||||
file_contents,
|
||||
flags=re.DOTALL
|
||||
)
|
||||
|
||||
if len(args) > 1 and args[1] == 'rb':
|
||||
file_contents = file_contents.encode('utf-8')
|
||||
return io.BytesIO(file_contents)
|
||||
else:
|
||||
return io.StringIO(file_contents)
|
||||
|
||||
else:
|
||||
return original_open(*args, **kwargs)
|
||||
|
||||
|
||||
def my_print(*args, **kwargs):
|
||||
if len(args) > 0 and 'To create a public link, set `share=True`' in args[0]:
|
||||
return
|
||||
else:
|
||||
if len(args) > 0 and 'Running on local URL' in args[0]:
|
||||
args = list(args)
|
||||
args[0] = f"\n{args[0].strip()}\n"
|
||||
args = tuple(args)
|
||||
|
||||
original_print(*args, **kwargs)
|
||||
|
|
@ -1,97 +0,0 @@
|
|||
'''
|
||||
Most of the code here was adapted from:
|
||||
https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14184
|
||||
'''
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from functools import wraps
|
||||
|
||||
import gradio as gr
|
||||
import gradio.routes
|
||||
import gradio.utils
|
||||
from starlette.middleware.trustedhost import TrustedHostMiddleware
|
||||
|
||||
from modules import shared
|
||||
|
||||
orig_create_app = gradio.routes.App.create_app
|
||||
|
||||
|
||||
# Be strict about only approving access to localhost by default
|
||||
def create_app_with_trustedhost(*args, **kwargs):
|
||||
app = orig_create_app(*args, **kwargs)
|
||||
|
||||
if not (shared.args.listen or shared.args.share):
|
||||
app.add_middleware(
|
||||
TrustedHostMiddleware,
|
||||
allowed_hosts=["localhost", "127.0.0.1"]
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
gradio.routes.App.create_app = create_app_with_trustedhost
|
||||
gradio.utils.launch_counter = lambda: None
|
||||
|
||||
|
||||
class GradioDeprecationWarning(DeprecationWarning):
|
||||
pass
|
||||
|
||||
|
||||
def repair(grclass):
|
||||
if not getattr(grclass, 'EVENTS', None):
|
||||
return
|
||||
|
||||
@wraps(grclass.__init__)
|
||||
def __repaired_init__(self, *args, tooltip=None, source=None, original=grclass.__init__, **kwargs):
|
||||
if source:
|
||||
kwargs["sources"] = [source]
|
||||
|
||||
allowed_kwargs = inspect.signature(original).parameters
|
||||
fixed_kwargs = {}
|
||||
for k, v in kwargs.items():
|
||||
if k in allowed_kwargs:
|
||||
fixed_kwargs[k] = v
|
||||
else:
|
||||
warnings.warn(f"unexpected argument for {grclass.__name__}: {k}", GradioDeprecationWarning, stacklevel=2)
|
||||
|
||||
original(self, *args, **fixed_kwargs)
|
||||
|
||||
self.webui_tooltip = tooltip
|
||||
|
||||
for event in self.EVENTS:
|
||||
replaced_event = getattr(self, str(event))
|
||||
|
||||
def fun(*xargs, _js=None, replaced_event=replaced_event, **xkwargs):
|
||||
if _js:
|
||||
xkwargs['js'] = _js
|
||||
|
||||
return replaced_event(*xargs, **xkwargs)
|
||||
|
||||
setattr(self, str(event), fun)
|
||||
|
||||
grclass.__init__ = __repaired_init__
|
||||
grclass.update = gr.update
|
||||
|
||||
|
||||
for component in set(gr.components.__all__ + gr.layouts.__all__):
|
||||
repair(getattr(gr, component, None))
|
||||
|
||||
|
||||
class Dependency(gr.events.Dependency):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def then(*xargs, _js=None, **xkwargs):
|
||||
if _js:
|
||||
xkwargs['js'] = _js
|
||||
|
||||
return original_then(*xargs, **xkwargs)
|
||||
|
||||
original_then = self.then
|
||||
self.then = then
|
||||
|
||||
|
||||
gr.events.Dependency = Dependency
|
||||
|
||||
gr.Box = gr.Group
|
||||
58
server.py
58
server.py
|
|
@ -3,13 +3,7 @@ import shutil
|
|||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
# Monkey-patch HfFolder for gradio 4.x compatibility with huggingface-hub 1.x
|
||||
import huggingface_hub
|
||||
if not hasattr(huggingface_hub, 'HfFolder'):
|
||||
huggingface_hub.HfFolder = type('HfFolder', (), {'get_token': staticmethod(huggingface_hub.get_token)})
|
||||
|
||||
from modules import shared
|
||||
from modules.block_requests import OpenMonkeyPatch, RequestBlocker
|
||||
from modules.image_models import load_image_model
|
||||
from modules.logging_colors import logger
|
||||
from modules.prompts import load_prompt
|
||||
|
|
@ -32,9 +26,7 @@ warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_na
|
|||
warnings.filterwarnings('ignore', category=UserWarning, message='The value passed into gr.Dropdown()')
|
||||
warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_names" has conflict')
|
||||
|
||||
with RequestBlocker():
|
||||
from modules import gradio_hijack
|
||||
import gradio as gr
|
||||
import gradio as gr
|
||||
|
||||
import matplotlib
|
||||
|
||||
|
|
@ -148,7 +140,24 @@ def create_interface():
|
|||
# Interface state elements
|
||||
shared.input_elements = ui.list_interface_input_elements()
|
||||
|
||||
with gr.Blocks(css=css, analytics_enabled=False, title=title, theme=ui.theme) as shared.gradio['interface']:
|
||||
# Head HTML for font preloads, KaTeX, highlight.js, morphdom, and global JS
|
||||
head_html = '\n'.join([
|
||||
'<link rel="preload" href="file/css/Inter/Inter-VariableFont_opsz,wght.ttf" as="font" type="font/ttf" crossorigin>',
|
||||
'<link rel="preload" href="file/css/Inter/Inter-Italic-VariableFont_opsz,wght.ttf" as="font" type="font/ttf" crossorigin>',
|
||||
'<link rel="preload" href="file/css/NotoSans/NotoSans-Medium.woff2" as="font" type="font/woff2" crossorigin>',
|
||||
'<link rel="preload" href="file/css/NotoSans/NotoSans-MediumItalic.woff2" as="font" type="font/woff2" crossorigin>',
|
||||
'<link rel="preload" href="file/css/NotoSans/NotoSans-Bold.woff2" as="font" type="font/woff2" crossorigin>',
|
||||
'<script src="file/js/katex/katex.min.js"></script>',
|
||||
'<script src="file/js/katex/auto-render.min.js"></script>',
|
||||
'<script src="file/js/highlightjs/highlight.min.js"></script>',
|
||||
'<script src="file/js/highlightjs/highlightjs-copy.min.js"></script>',
|
||||
'<script src="file/js/morphdom/morphdom-umd.min.js"></script>',
|
||||
f'<link id="highlight-css" rel="stylesheet" href="file/css/highlightjs/{"github-dark" if shared.settings["dark_theme"] else "github"}.min.css">',
|
||||
'<script>hljs.addPlugin(new CopyButtonPlugin());</script>',
|
||||
f'<script>{ui.global_scope_js}</script>',
|
||||
])
|
||||
|
||||
with gr.Blocks(css=css, analytics_enabled=False, title=title, theme=ui.theme, head=head_html) as shared.gradio['interface']:
|
||||
|
||||
# Interface state
|
||||
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
|
||||
|
|
@ -234,21 +243,20 @@ def create_interface():
|
|||
|
||||
# Launch the interface
|
||||
shared.gradio['interface'].queue()
|
||||
with OpenMonkeyPatch():
|
||||
shared.gradio['interface'].launch(
|
||||
max_threads=64,
|
||||
prevent_thread_lock=True,
|
||||
share=shared.args.share,
|
||||
server_name=None if not shared.args.listen else (shared.args.listen_host or '0.0.0.0'),
|
||||
server_port=shared.args.listen_port,
|
||||
inbrowser=shared.args.auto_launch,
|
||||
auth=auth or None,
|
||||
ssl_verify=False if (shared.args.ssl_keyfile or shared.args.ssl_certfile) else True,
|
||||
ssl_keyfile=shared.args.ssl_keyfile,
|
||||
ssl_certfile=shared.args.ssl_certfile,
|
||||
root_path=shared.args.subpath,
|
||||
allowed_paths=allowed_paths,
|
||||
)
|
||||
shared.gradio['interface'].launch(
|
||||
max_threads=64,
|
||||
prevent_thread_lock=True,
|
||||
share=shared.args.share,
|
||||
server_name=None if not shared.args.listen else (shared.args.listen_host or '0.0.0.0'),
|
||||
server_port=shared.args.listen_port,
|
||||
inbrowser=shared.args.auto_launch,
|
||||
auth=auth or None,
|
||||
ssl_verify=False if (shared.args.ssl_keyfile or shared.args.ssl_certfile) else True,
|
||||
ssl_keyfile=shared.args.ssl_keyfile,
|
||||
ssl_certfile=shared.args.ssl_certfile,
|
||||
root_path=shared.args.subpath,
|
||||
allowed_paths=allowed_paths,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in a new issue