mirror of
https://github.com/w-okada/voice-changer.git
synced 2025-01-23 13:35:12 +03:00
Harden web server security
This commit is contained in:
parent
11672e9653
commit
cf2b693334
@ -65,6 +65,9 @@ def setupArgParser():
|
|||||||
parser.add_argument("--rmvpe", type=str, default="pretrain/rmvpe.pt", help="path to rmvpe")
|
parser.add_argument("--rmvpe", type=str, default="pretrain/rmvpe.pt", help="path to rmvpe")
|
||||||
parser.add_argument("--rmvpe_onnx", type=str, default="pretrain/rmvpe.onnx", help="path to rmvpe onnx")
|
parser.add_argument("--rmvpe_onnx", type=str, default="pretrain/rmvpe.onnx", help="path to rmvpe onnx")
|
||||||
|
|
||||||
|
parser.add_argument("--host", type=str, default='127.0.0.1', help="IP address of the network interface to listen for HTTP connections. Specify 0.0.0.0 to listen on all interfaces.")
|
||||||
|
parser.add_argument("--allowed-origins", action='append', default=[], help="List of URLs to allow connection from, i.e. https://example.com. Allows http(s)://127.0.0.1:{port} and http(s)://localhost:{port} by default.")
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -114,16 +117,19 @@ vcparams.setParams(voiceChangerParams)
|
|||||||
|
|
||||||
printMessage(f"Booting PHASE :{__name__}", level=2)
|
printMessage(f"Booting PHASE :{__name__}", level=2)
|
||||||
|
|
||||||
|
HOST = args.host
|
||||||
PORT = args.p
|
PORT = args.p
|
||||||
|
|
||||||
|
|
||||||
def localServer(logLevel: str = "critical"):
|
def localServer(logLevel: str = "critical", key_path: str | None = None, cert_path: str | None = None):
|
||||||
try:
|
try:
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
f"{os.path.basename(__file__)[:-3]}:app_socketio",
|
f"{os.path.basename(__file__)[:-3]}:app_socketio",
|
||||||
host="0.0.0.0",
|
host=HOST,
|
||||||
port=int(PORT),
|
port=int(PORT),
|
||||||
reload=False if hasattr(sys, "_MEIPASS") else True,
|
reload=False if hasattr(sys, "_MEIPASS") else True,
|
||||||
|
ssl_keyfile=key_path,
|
||||||
|
ssl_certfile=cert_path,
|
||||||
log_level=logLevel,
|
log_level=logLevel,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -134,7 +140,7 @@ if __name__ == "MMVCServerSIO":
|
|||||||
mp.freeze_support()
|
mp.freeze_support()
|
||||||
|
|
||||||
voiceChangerManager = VoiceChangerManager.get_instance(voiceChangerParams)
|
voiceChangerManager = VoiceChangerManager.get_instance(voiceChangerParams)
|
||||||
app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams)
|
app_fastapi = MMVC_Rest.get_instance(voiceChangerManager, voiceChangerParams, PORT, args.allowed_origins)
|
||||||
app_socketio = MMVC_SocketIOApp.get_instance(app_fastapi, voiceChangerManager)
|
app_socketio = MMVC_SocketIOApp.get_instance(app_fastapi, voiceChangerManager)
|
||||||
|
|
||||||
|
|
||||||
@ -220,34 +226,26 @@ if __name__ == "__main__":
|
|||||||
printMessage("In many cases, it will launch when you access any of the following URLs.", level=2)
|
printMessage("In many cases, it will launch when you access any of the following URLs.", level=2)
|
||||||
if "EX_PORT" in locals() and "EX_IP" in locals(): # シェルスクリプト経由起動(docker)
|
if "EX_PORT" in locals() and "EX_IP" in locals(): # シェルスクリプト経由起動(docker)
|
||||||
if args.https == 1:
|
if args.https == 1:
|
||||||
printMessage(f"https://127.0.0.1:{EX_PORT}/", level=1)
|
printMessage(f"https://localhost:{EX_PORT}/", level=1)
|
||||||
for ip in EX_IP.strip().split(" "):
|
for ip in EX_IP.strip().split(" "):
|
||||||
printMessage(f"https://{ip}:{EX_PORT}/", level=1)
|
printMessage(f"https://{ip}:{EX_PORT}/", level=1)
|
||||||
else:
|
else:
|
||||||
printMessage(f"http://127.0.0.1:{EX_PORT}/", level=1)
|
printMessage(f"http://localhost:{EX_PORT}/", level=1)
|
||||||
else: # 直接python起動
|
else: # 直接python起動
|
||||||
if args.https == 1:
|
if args.https == 1:
|
||||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
s.connect((args.test_connect, 80))
|
s.connect((args.test_connect, 80))
|
||||||
hostname = s.getsockname()[0]
|
hostname = s.getsockname()[0]
|
||||||
printMessage(f"https://127.0.0.1:{PORT}/", level=1)
|
printMessage(f"https://localhost:{PORT}/", level=1)
|
||||||
printMessage(f"https://{hostname}:{PORT}/", level=1)
|
printMessage(f"https://{hostname}:{PORT}/", level=1)
|
||||||
else:
|
else:
|
||||||
printMessage(f"http://127.0.0.1:{PORT}/", level=1)
|
printMessage(f"http://localhost:{PORT}/", level=1)
|
||||||
|
|
||||||
# サーバ起動
|
# サーバ起動
|
||||||
if args.https:
|
if args.https:
|
||||||
# HTTPS サーバ起動
|
# HTTPS サーバ起動
|
||||||
try:
|
try:
|
||||||
uvicorn.run(
|
localServer(args.logLevel, key_path, cert_path)
|
||||||
f"{os.path.basename(__file__)[:-3]}:app_socketio",
|
|
||||||
host="0.0.0.0",
|
|
||||||
port=int(PORT),
|
|
||||||
reload=False if hasattr(sys, "_MEIPASS") else True,
|
|
||||||
ssl_keyfile=key_path,
|
|
||||||
ssl_certfile=cert_path,
|
|
||||||
log_level=args.logLevel,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[Voice Changer] Web Server(https) Launch Exception, {e}")
|
logger.error(f"[Voice Changer] Web Server(https) Launch Exception, {e}")
|
||||||
|
|
||||||
@ -256,12 +254,12 @@ if __name__ == "__main__":
|
|||||||
p.start()
|
p.start()
|
||||||
try:
|
try:
|
||||||
if sys.platform.startswith("win"):
|
if sys.platform.startswith("win"):
|
||||||
process = subprocess.Popen([NATIVE_CLIENT_FILE_WIN, "--disable-gpu", "-u", f"http://127.0.0.1:{PORT}/"])
|
process = subprocess.Popen([NATIVE_CLIENT_FILE_WIN, "--disable-gpu", "-u", f"http://localhost:{PORT}/"])
|
||||||
return_code = process.wait()
|
return_code = process.wait()
|
||||||
logger.info("client closed.")
|
logger.info("client closed.")
|
||||||
p.terminate()
|
p.terminate()
|
||||||
elif sys.platform.startswith("darwin"):
|
elif sys.platform.startswith("darwin"):
|
||||||
process = subprocess.Popen([NATIVE_CLIENT_FILE_MAC, "--disable-gpu", "-u", f"http://127.0.0.1:{PORT}/"])
|
process = subprocess.Popen([NATIVE_CLIENT_FILE_MAC, "--disable-gpu", "-u", f"http://localhost:{PORT}/"])
|
||||||
return_code = process.wait()
|
return_code = process.wait()
|
||||||
logger.info("client closed.")
|
logger.info("client closed.")
|
||||||
p.terminate()
|
p.terminate()
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
from restapi.mods.trustedorigin import TrustedOriginMiddleware
|
||||||
from fastapi import FastAPI, Request, Response, HTTPException
|
from fastapi import FastAPI, Request, Response, HTTPException
|
||||||
from fastapi.routing import APIRoute
|
from fastapi.routing import APIRoute
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
@ -43,17 +43,17 @@ class MMVC_Rest:
|
|||||||
cls,
|
cls,
|
||||||
voiceChangerManager: VoiceChangerManager,
|
voiceChangerManager: VoiceChangerManager,
|
||||||
voiceChangerParams: VoiceChangerParams,
|
voiceChangerParams: VoiceChangerParams,
|
||||||
|
port: int,
|
||||||
|
allowedOrigins: list[str],
|
||||||
):
|
):
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
logger.info("[Voice Changer] MMVC_Rest initializing...")
|
logger.info("[Voice Changer] MMVC_Rest initializing...")
|
||||||
app_fastapi = FastAPI()
|
app_fastapi = FastAPI()
|
||||||
app_fastapi.router.route_class = ValidationErrorLoggingRoute
|
app_fastapi.router.route_class = ValidationErrorLoggingRoute
|
||||||
app_fastapi.add_middleware(
|
app_fastapi.add_middleware(
|
||||||
CORSMiddleware,
|
TrustedOriginMiddleware,
|
||||||
allow_origins=["*"],
|
allowed_origins=allowedOrigins,
|
||||||
allow_credentials=True,
|
port=port
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
app_fastapi.mount(
|
app_fastapi.mount(
|
||||||
|
49
server/restapi/mods/trustedorigin.py
Normal file
49
server/restapi/mods/trustedorigin.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import typing
|
||||||
|
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
from starlette.datastructures import Headers
|
||||||
|
from starlette.responses import PlainTextResponse
|
||||||
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||||
|
|
||||||
|
ENFORCE_URL_ORIGIN_FORMAT = "Input origins must be well-formed URLs, i.e. https://google.com or https://www.google.com."
|
||||||
|
|
||||||
|
|
||||||
|
class TrustedOriginMiddleware:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
app: ASGIApp,
|
||||||
|
allowed_origins: typing.Optional[typing.Sequence[str]] = None,
|
||||||
|
port: typing.Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
schemas = ['http', 'https']
|
||||||
|
local_origins = [f'{schema}://{origin}' for schema in schemas for origin in ['127.0.0.1', 'localhost']]
|
||||||
|
if port is not None:
|
||||||
|
local_origins = [f'{origin}:{port}' for origin in local_origins]
|
||||||
|
|
||||||
|
if not allowed_origins:
|
||||||
|
allowed_origins = local_origins
|
||||||
|
else:
|
||||||
|
for origin in allowed_origins:
|
||||||
|
assert urlparse(origin).scheme, ENFORCE_URL_ORIGIN_FORMAT
|
||||||
|
allowed_origins = local_origins + allowed_origins
|
||||||
|
|
||||||
|
self.app = app
|
||||||
|
self.allowed_origins = list(allowed_origins)
|
||||||
|
|
||||||
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
|
if scope["type"] not in (
|
||||||
|
"http",
|
||||||
|
"websocket",
|
||||||
|
): # pragma: no cover
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
headers = Headers(scope=scope)
|
||||||
|
origin = headers.get("origin", "")
|
||||||
|
# Origin header is not present for same origin
|
||||||
|
if not origin or origin in self.allowed_origins:
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
response = PlainTextResponse("Invalid origin header", status_code=400)
|
||||||
|
await response(scope, receive, send)
|
Loading…
Reference in New Issue
Block a user