diff --git a/modules/block_requests.py b/modules/block_requests.py deleted file mode 100644 index 911e41d9..00000000 --- a/modules/block_requests.py +++ /dev/null @@ -1,96 +0,0 @@ -import builtins -import io -import re - -import requests - -from modules import shared, ui -from modules.logging_colors import logger - -original_open = open -original_get = requests.get -original_print = print - - -class RequestBlocker: - - def __enter__(self): - requests.get = my_get - - def __exit__(self, exc_type, exc_value, traceback): - requests.get = original_get - - -class OpenMonkeyPatch: - - def __enter__(self): - builtins.open = my_open - builtins.print = my_print - - def __exit__(self, exc_type, exc_value, traceback): - builtins.open = original_open - builtins.print = original_print - - -def my_get(url, **kwargs): - logger.info('Unwanted HTTP request redirected to localhost :)') - kwargs.setdefault('allow_redirects', True) - return requests.api.request('get', 'http://127.0.0.1/', **kwargs) - - -def my_open(*args, **kwargs): - filename = str(args[0]) - if filename.endswith(('index.html', 'share.html')): - with original_open(*args, **kwargs) as f: - file_contents = f.read() - - if len(args) > 1 and args[1] == 'rb': - file_contents = file_contents.decode('utf-8') - - file_contents = file_contents.replace('\t\t', '') - file_contents = file_contents.replace('cdnjs.cloudflare.com', '127.0.0.1') - file_contents = file_contents.replace( - '', - '\n ' - '\n ' - '\n ' - '\n ' - '\n ' - '\n ' - '\n ' - '\n ' - '\n ' - '\n ' - f'\n ' - '\n ' - f'\n ' - '\n ' - ) - - file_contents = re.sub( - r'@media \(prefers-color-scheme: dark\) \{\s*body \{([^}]*)\}\s*\}', - r'body.dark {\1}', - file_contents, - flags=re.DOTALL - ) - - if len(args) > 1 and args[1] == 'rb': - file_contents = file_contents.encode('utf-8') - return io.BytesIO(file_contents) - else: - return io.StringIO(file_contents) - - else: - return original_open(*args, **kwargs) - - -def my_print(*args, **kwargs): - if len(args) > 0 and 'To create a public link, set `share=True`' in args[0]: - return - else: - if len(args) > 0 and 'Running on local URL' in args[0]: - args = list(args) - args[0] = f"\n{args[0].strip()}\n" - args = tuple(args) - - original_print(*args, **kwargs) diff --git a/modules/gradio_hijack.py b/modules/gradio_hijack.py deleted file mode 100644 index 817da40c..00000000 --- a/modules/gradio_hijack.py +++ /dev/null @@ -1,97 +0,0 @@ -''' -Most of the code here was adapted from: -https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14184 -''' - -import inspect -import warnings -from functools import wraps - -import gradio as gr -import gradio.routes -import gradio.utils -from starlette.middleware.trustedhost import TrustedHostMiddleware - -from modules import shared - -orig_create_app = gradio.routes.App.create_app - - -# Be strict about only approving access to localhost by default -def create_app_with_trustedhost(*args, **kwargs): - app = orig_create_app(*args, **kwargs) - - if not (shared.args.listen or shared.args.share): - app.add_middleware( - TrustedHostMiddleware, - allowed_hosts=["localhost", "127.0.0.1"] - ) - - return app - - -gradio.routes.App.create_app = create_app_with_trustedhost -gradio.utils.launch_counter = lambda: None - - -class GradioDeprecationWarning(DeprecationWarning): - pass - - -def repair(grclass): - if not getattr(grclass, 'EVENTS', None): - return - - @wraps(grclass.__init__) - def __repaired_init__(self, *args, tooltip=None, source=None, original=grclass.__init__, **kwargs): - if source: - kwargs["sources"] = [source] - - allowed_kwargs = inspect.signature(original).parameters - fixed_kwargs = {} - for k, v in kwargs.items(): - if k in allowed_kwargs: - fixed_kwargs[k] = v - else: - warnings.warn(f"unexpected argument for {grclass.__name__}: {k}", GradioDeprecationWarning, stacklevel=2) - - original(self, *args, **fixed_kwargs) - - self.webui_tooltip = tooltip - - for event in self.EVENTS: - replaced_event = getattr(self, str(event)) - - def fun(*xargs, _js=None, replaced_event=replaced_event, **xkwargs): - if _js: - xkwargs['js'] = _js - - return replaced_event(*xargs, **xkwargs) - - setattr(self, str(event), fun) - - grclass.__init__ = __repaired_init__ - grclass.update = gr.update - - -for component in set(gr.components.__all__ + gr.layouts.__all__): - repair(getattr(gr, component, None)) - - -class Dependency(gr.events.Dependency): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def then(*xargs, _js=None, **xkwargs): - if _js: - xkwargs['js'] = _js - - return original_then(*xargs, **xkwargs) - - original_then = self.then - self.then = then - - -gr.events.Dependency = Dependency - -gr.Box = gr.Group diff --git a/server.py b/server.py index e98ff153..bd23f368 100644 --- a/server.py +++ b/server.py @@ -3,13 +3,7 @@ import shutil import warnings from pathlib import Path -# Monkey-patch HfFolder for gradio 4.x compatibility with huggingface-hub 1.x -import huggingface_hub -if not hasattr(huggingface_hub, 'HfFolder'): - huggingface_hub.HfFolder = type('HfFolder', (), {'get_token': staticmethod(huggingface_hub.get_token)}) - from modules import shared -from modules.block_requests import OpenMonkeyPatch, RequestBlocker from modules.image_models import load_image_model from modules.logging_colors import logger from modules.prompts import load_prompt @@ -32,9 +26,7 @@ warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_na warnings.filterwarnings('ignore', category=UserWarning, message='The value passed into gr.Dropdown()') warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_names" has conflict') -with RequestBlocker(): - from modules import gradio_hijack - import gradio as gr +import gradio as gr import matplotlib @@ -148,7 +140,24 @@ def create_interface(): # Interface state elements shared.input_elements = ui.list_interface_input_elements() - with gr.Blocks(css=css, analytics_enabled=False, title=title, theme=ui.theme) as shared.gradio['interface']: + # Head HTML for font preloads, KaTeX, highlight.js, morphdom, and global JS + head_html = '\n'.join([ + '', + '', + '', + '', + '', + '', + '', + '', + '', + '', + f'', + '', + f'', + ]) + + with gr.Blocks(css=css, analytics_enabled=False, title=title, theme=ui.theme, head=head_html) as shared.gradio['interface']: # Interface state shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements}) @@ -234,21 +243,20 @@ def create_interface(): # Launch the interface shared.gradio['interface'].queue() - with OpenMonkeyPatch(): - shared.gradio['interface'].launch( - max_threads=64, - prevent_thread_lock=True, - share=shared.args.share, - server_name=None if not shared.args.listen else (shared.args.listen_host or '0.0.0.0'), - server_port=shared.args.listen_port, - inbrowser=shared.args.auto_launch, - auth=auth or None, - ssl_verify=False if (shared.args.ssl_keyfile or shared.args.ssl_certfile) else True, - ssl_keyfile=shared.args.ssl_keyfile, - ssl_certfile=shared.args.ssl_certfile, - root_path=shared.args.subpath, - allowed_paths=allowed_paths, - ) + shared.gradio['interface'].launch( + max_threads=64, + prevent_thread_lock=True, + share=shared.args.share, + server_name=None if not shared.args.listen else (shared.args.listen_host or '0.0.0.0'), + server_port=shared.args.listen_port, + inbrowser=shared.args.auto_launch, + auth=auth or None, + ssl_verify=False if (shared.args.ssl_keyfile or shared.args.ssl_certfile) else True, + ssl_keyfile=shared.args.ssl_keyfile, + ssl_certfile=shared.args.ssl_certfile, + root_path=shared.args.subpath, + allowed_paths=allowed_paths, + ) if __name__ == "__main__":