Improve host header validation in local mode

This commit is contained in:
oobabooga 2025-04-26 15:07:35 -07:00
parent a317450dfa
commit bc55feaf3e
2 changed files with 38 additions and 1 deletions

View file

@ -86,6 +86,20 @@ app.add_middleware(
) )
@app.middleware("http")
async def validate_host_header(request: Request, call_next):
# Be strict about only approving access to localhost by default
if not (shared.args.listen or shared.args.public_api):
host = request.headers.get("host", "").split(":")[0]
if host not in ["localhost", "127.0.0.1"]:
return JSONResponse(
status_code=400,
content={"detail": "Invalid host header"}
)
return await call_next(request)
@app.options("/", dependencies=check_key) @app.options("/", dependencies=check_key)
async def options_route(): async def options_route():
return JSONResponse(content="OK") return JSONResponse(content="OK")

View file

@ -1,5 +1,6 @@
''' '''
Copied from: https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14184 Most of the code here was adapted from:
https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14184
''' '''
import inspect import inspect
@ -7,6 +8,28 @@ import warnings
from functools import wraps from functools import wraps
import gradio as gr import gradio as gr
import gradio.routes
from starlette.middleware.trustedhost import TrustedHostMiddleware
from modules import shared
orig_create_app = gradio.routes.App.create_app
# Be strict about only approving access to localhost by default
def create_app_with_trustedhost(*args, **kwargs):
app = orig_create_app(*args, **kwargs)
if not (shared.args.listen or shared.args.share):
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=["localhost", "127.0.0.1"]
)
return app
gradio.routes.App.create_app = create_app_with_trustedhost
class GradioDeprecationWarning(DeprecationWarning): class GradioDeprecationWarning(DeprecationWarning):