diff --git a/extensions/openai/script.py b/extensions/openai/script.py index f907cdbb..0a887de2 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -86,6 +86,20 @@ app.add_middleware( ) +@app.middleware("http") +async def validate_host_header(request: Request, call_next): + # Be strict about only approving access to localhost by default + if not (shared.args.listen or shared.args.public_api): + host = request.headers.get("host", "").split(":")[0] + if host not in ["localhost", "127.0.0.1"]: + return JSONResponse( + status_code=400, + content={"detail": "Invalid host header"} + ) + + return await call_next(request) + + @app.options("/", dependencies=check_key) async def options_route(): return JSONResponse(content="OK") diff --git a/modules/gradio_hijack.py b/modules/gradio_hijack.py index 2ddd983a..8e3bb0d9 100644 --- a/modules/gradio_hijack.py +++ b/modules/gradio_hijack.py @@ -1,5 +1,6 @@ ''' -Copied from: https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14184 +Most of the code here was adapted from: +https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14184 ''' import inspect @@ -7,6 +8,28 @@ import warnings from functools import wraps import gradio as gr +import gradio.routes +from starlette.middleware.trustedhost import TrustedHostMiddleware + +from modules import shared + +orig_create_app = gradio.routes.App.create_app + + +# Be strict about only approving access to localhost by default +def create_app_with_trustedhost(*args, **kwargs): + app = orig_create_app(*args, **kwargs) + + if not (shared.args.listen or shared.args.share): + app.add_middleware( + TrustedHostMiddleware, + allowed_hosts=["localhost", "127.0.0.1"] + ) + + return app + + +gradio.routes.App.create_app = create_app_with_trustedhost class GradioDeprecationWarning(DeprecationWarning):