mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-12-06 07:12:10 +01:00
Improve host header validation in local mode
This commit is contained in:
parent
a317450dfa
commit
bc55feaf3e
|
|
@ -86,6 +86,20 @@ app.add_middleware(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.middleware("http")
|
||||||
|
async def validate_host_header(request: Request, call_next):
|
||||||
|
# Be strict about only approving access to localhost by default
|
||||||
|
if not (shared.args.listen or shared.args.public_api):
|
||||||
|
host = request.headers.get("host", "").split(":")[0]
|
||||||
|
if host not in ["localhost", "127.0.0.1"]:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content={"detail": "Invalid host header"}
|
||||||
|
)
|
||||||
|
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
|
||||||
@app.options("/", dependencies=check_key)
|
@app.options("/", dependencies=check_key)
|
||||||
async def options_route():
|
async def options_route():
|
||||||
return JSONResponse(content="OK")
|
return JSONResponse(content="OK")
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
'''
|
'''
|
||||||
Copied from: https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14184
|
Most of the code here was adapted from:
|
||||||
|
https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14184
|
||||||
'''
|
'''
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
|
|
@ -7,6 +8,28 @@ import warnings
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
import gradio.routes
|
||||||
|
from starlette.middleware.trustedhost import TrustedHostMiddleware
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
|
orig_create_app = gradio.routes.App.create_app
|
||||||
|
|
||||||
|
|
||||||
|
# Be strict about only approving access to localhost by default
|
||||||
|
def create_app_with_trustedhost(*args, **kwargs):
|
||||||
|
app = orig_create_app(*args, **kwargs)
|
||||||
|
|
||||||
|
if not (shared.args.listen or shared.args.share):
|
||||||
|
app.add_middleware(
|
||||||
|
TrustedHostMiddleware,
|
||||||
|
allowed_hosts=["localhost", "127.0.0.1"]
|
||||||
|
)
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
gradio.routes.App.create_app = create_app_with_trustedhost
|
||||||
|
|
||||||
|
|
||||||
class GradioDeprecationWarning(DeprecationWarning):
|
class GradioDeprecationWarning(DeprecationWarning):
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue