diff --git a/server.py b/server.py index 3611880da..4d78ee571 100644 --- a/server.py +++ b/server.py @@ -80,6 +80,27 @@ def create_cors_middleware(allowed_origin: str): return cors_middleware +def create_origin_only_middleware(): + @web.middleware + async def origin_only_middleware(request: web.Request, handler): + if request.method == "OPTIONS": + response = web.Response() + else: + response = await handler(request) + + if 'Host' in request.headers and 'Origin' in request.headers: + host = request.headers['Host'] + origin = request.headers['Origin'] + host_domain = host.lower() + origin_domain = urllib.parse.urlparse(origin).netloc.lower() + if host_domain != origin_domain: + logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain)) + return web.Response(status=403) + + return response + + return origin_only_middleware + class PromptServer(): def __init__(self, loop): PromptServer.instance = self @@ -99,6 +120,8 @@ class PromptServer(): middlewares = [cache_control] if args.enable_cors_header: middlewares.append(create_cors_middleware(args.enable_cors_header)) + else: + middlewares.append(create_origin_only_middleware()) max_upload_size = round(args.max_upload_size * 1024 * 1024) self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)