mirror of
https://github.com/w-okada/voice-changer.git
synced 2025-01-23 13:35:12 +03:00
Harden web server security
This commit is contained in:
parent
11672e9653
commit
cf2b693334
@ -65,6 +65,9 @@ def setupArgParser():
|
||||
parser.add_argument("--rmvpe", type=str, default="pretrain/rmvpe.pt", help="path to rmvpe")
|
||||
parser.add_argument("--rmvpe_onnx", type=str, default="pretrain/rmvpe.onnx", help="path to rmvpe onnx")
|
||||
|
||||
parser.add_argument("--host", type=str, default='127.0.0.1', help="IP address of the network interface to listen for HTTP connections. Specify 0.0.0.0 to listen on all interfaces.")
|
||||
parser.add_argument("--allowed-origins", action='append', default=[], help="List of URLs to allow connection from, i.e. https://example.com. Allows http(s)://127.0.0.1:{port} and http(s)://localhost:{port} by default.")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -114,16 +117,19 @@ vcparams.setParams(voiceChangerParams)
|
||||
|
||||
printMessage(f"Booting PHASE :{__name__}", level=2)
|
||||
|
||||
HOST = args.host
|
||||
PORT = args.p
|
||||
|
||||
|
||||
def localServer(logLevel: str = "critical"):
|
||||
def localServer(logLevel: str = "critical", key_path: str | None = None, cert_path: str | None = None):
|
||||
try:
|
||||
uvicorn.run(
|
||||
f"{os.path.basename(__file__)[:-3]}:app_socketio",
|
||||
host="0.0.0.0",
|
||||
host=HOST,
|
||||
port=int(PORT),
|
||||
reload=False if hasattr(sys, "_MEIPASS") else True,
|
||||
ssl_keyfile=key_path,
|
||||
ssl_certfile=cert_path,
|
||||
log_level=logLevel,
|
||||
)
|
||||
except Exception as e:
|
||||
@ -134,7 +140,7 @@ if __name__ == "MMVCServerSIO":
|
||||
mp.freeze_support()
|
||||
|
||||
voiceChangerManager = VoiceChangerManager.get_instance(voiceChangerParams)
|
||||
app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams)
|
||||
app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams, PORT, args.allowed_origins)
|
||||
app_socketio = MMVC_SocketIOApp.get_instance(app_fastapi, voiceChangerManager)
|
||||
|
||||
|
||||
@ -220,34 +226,26 @@ if __name__ == "__main__":
|
||||
printMessage("In many cases, it will launch when you access any of the following URLs.", level=2)
|
||||
if "EX_PORT" in locals() and "EX_IP" in locals(): # シェルスクリプト経由起動(docker)
|
||||
if args.https == 1:
|
||||
printMessage(f"https://127.0.0.1:{EX_PORT}/", level=1)
|
||||
printMessage(f"https://localhost:{EX_PORT}/", level=1)
|
||||
for ip in EX_IP.strip().split(" "):
|
||||
printMessage(f"https://{ip}:{EX_PORT}/", level=1)
|
||||
else:
|
||||
printMessage(f"http://127.0.0.1:{EX_PORT}/", level=1)
|
||||
printMessage(f"http://localhost:{EX_PORT}/", level=1)
|
||||
else: # 直接python起動
|
||||
if args.https == 1:
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
s.connect((args.test_connect, 80))
|
||||
hostname = s.getsockname()[0]
|
||||
printMessage(f"https://127.0.0.1:{PORT}/", level=1)
|
||||
printMessage(f"https://localhost:{PORT}/", level=1)
|
||||
printMessage(f"https://{hostname}:{PORT}/", level=1)
|
||||
else:
|
||||
printMessage(f"http://127.0.0.1:{PORT}/", level=1)
|
||||
printMessage(f"http://localhost:{PORT}/", level=1)
|
||||
|
||||
# サーバ起動
|
||||
if args.https:
|
||||
# HTTPS サーバ起動
|
||||
try:
|
||||
uvicorn.run(
|
||||
f"{os.path.basename(__file__)[:-3]}:app_socketio",
|
||||
host="0.0.0.0",
|
||||
port=int(PORT),
|
||||
reload=False if hasattr(sys, "_MEIPASS") else True,
|
||||
ssl_keyfile=key_path,
|
||||
ssl_certfile=cert_path,
|
||||
log_level=args.logLevel,
|
||||
)
|
||||
localServer(args.logLevel, key_path, cert_path)
|
||||
except Exception as e:
|
||||
logger.error(f"[Voice Changer] Web Server(https) Launch Exception, {e}")
|
||||
|
||||
@ -256,12 +254,12 @@ if __name__ == "__main__":
|
||||
p.start()
|
||||
try:
|
||||
if sys.platform.startswith("win"):
|
||||
process = subprocess.Popen([NATIVE_CLIENT_FILE_WIN, "--disable-gpu", "-u", f"http://127.0.0.1:{PORT}/"])
|
||||
process = subprocess.Popen([NATIVE_CLIENT_FILE_WIN, "--disable-gpu", "-u", f"http://localhost:{PORT}/"])
|
||||
return_code = process.wait()
|
||||
logger.info("client closed.")
|
||||
p.terminate()
|
||||
elif sys.platform.startswith("darwin"):
|
||||
process = subprocess.Popen([NATIVE_CLIENT_FILE_MAC, "--disable-gpu", "-u", f"http://127.0.0.1:{PORT}/"])
|
||||
process = subprocess.Popen([NATIVE_CLIENT_FILE_MAC, "--disable-gpu", "-u", f"http://localhost:{PORT}/"])
|
||||
return_code = process.wait()
|
||||
logger.info("client closed.")
|
||||
p.terminate()
|
||||
|
@ -1,9 +1,9 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
from restapi.mods.trustedorigin import TrustedOriginMiddleware
|
||||
from fastapi import FastAPI, Request, Response, HTTPException
|
||||
from fastapi.routing import APIRoute
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from typing import Callable
|
||||
@ -43,17 +43,17 @@ class MMVC_Rest:
|
||||
cls,
|
||||
voiceChangerManager: VoiceChangerManager,
|
||||
voiceChangerParams: VoiceChangerParams,
|
||||
port: int,
|
||||
allowedOrigins: list[str],
|
||||
):
|
||||
if cls._instance is None:
|
||||
logger.info("[Voice Changer] MMVC_Rest initializing...")
|
||||
app_fastapi = FastAPI()
|
||||
app_fastapi.router.route_class = ValidationErrorLoggingRoute
|
||||
app_fastapi.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
TrustedOriginMiddleware,
|
||||
allowed_origins=allowedOrigins,
|
||||
port=port
|
||||
)
|
||||
|
||||
app_fastapi.mount(
|
||||
|
49
server/restapi/mods/trustedorigin.py
Normal file
49
server/restapi/mods/trustedorigin.py
Normal file
@ -0,0 +1,49 @@
|
||||
import typing
|
||||
|
||||
from urllib.parse import urlparse
|
||||
from starlette.datastructures import Headers
|
||||
from starlette.responses import PlainTextResponse
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
ENFORCE_URL_ORIGIN_FORMAT = "Input origins must be well-formed URLs, i.e. https://google.com or https://www.google.com."
|
||||
|
||||
|
||||
class TrustedOriginMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
allowed_origins: typing.Optional[typing.Sequence[str]] = None,
|
||||
port: typing.Optional[int] = None,
|
||||
) -> None:
|
||||
schemas = ['http', 'https']
|
||||
local_origins = [f'{schema}://{origin}' for schema in schemas for origin in ['127.0.0.1', 'localhost']]
|
||||
if port is not None:
|
||||
local_origins = [f'{origin}:{port}' for origin in local_origins]
|
||||
|
||||
if not allowed_origins:
|
||||
allowed_origins = local_origins
|
||||
else:
|
||||
for origin in allowed_origins:
|
||||
assert urlparse(origin).scheme, ENFORCE_URL_ORIGIN_FORMAT
|
||||
allowed_origins = local_origins + allowed_origins
|
||||
|
||||
self.app = app
|
||||
self.allowed_origins = list(allowed_origins)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] not in (
|
||||
"http",
|
||||
"websocket",
|
||||
): # pragma: no cover
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
headers = Headers(scope=scope)
|
||||
origin = headers.get("origin", "")
|
||||
# Origin header is not present for same origin
|
||||
if not origin or origin in self.allowed_origins:
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
response = PlainTextResponse("Invalid origin header", status_code=400)
|
||||
await response(scope, receive, send)
|
Loading…
Reference in New Issue
Block a user