only enable corsmiddleware when in dev mode and allow all origins
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
from enum import Enum
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@@ -11,18 +13,39 @@ from .error_handlers import clan_error_handler
|
||||
from .routers import clan_modules, flake, health, machines, root, vms
|
||||
from .tags import tags_metadata
|
||||
|
||||
origins = [
|
||||
"http://localhost:3000",
|
||||
]
|
||||
# Logging setup
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnvType(Enum):
|
||||
production = "production"
|
||||
development = "development"
|
||||
|
||||
@staticmethod
|
||||
def from_environment() -> "EnvType":
|
||||
t = os.environ.get("CLAN_WEBUI_ENV", "production")
|
||||
try:
|
||||
return EnvType[t]
|
||||
except KeyError:
|
||||
log.warning(f"Invalid environment type: {t}, fallback to production")
|
||||
return EnvType.production
|
||||
|
||||
def is_production(self) -> bool:
|
||||
return self == EnvType.production
|
||||
|
||||
def is_development(self) -> bool:
|
||||
return self == EnvType.development
|
||||
|
||||
|
||||
def setup_app() -> FastAPI:
|
||||
env = EnvType.from_environment()
|
||||
app = FastAPI()
|
||||
|
||||
if env.is_development():
|
||||
# Allow CORS in development mode for nextjs dev server
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
@@ -76,6 +77,8 @@ def spawn_node_dev_server(host: IPvAnyAddress, port: int) -> Iterator[None]:
|
||||
|
||||
|
||||
def start_server(args: argparse.Namespace) -> None:
|
||||
os.environ["CLAN_WEBUI_ENV"] = "development" if args.dev else "production"
|
||||
|
||||
with ExitStack() as stack:
|
||||
headers: list[tuple[str, str]] = []
|
||||
if args.dev:
|
||||
@@ -85,20 +88,6 @@ def start_server(args: argparse.Namespace) -> None:
|
||||
host = args.dev_host
|
||||
if ":" in host:
|
||||
host = f"[{host}]"
|
||||
headers = [
|
||||
# (
|
||||
# "Access-Control-Allow-Origin",
|
||||
# f"http://{host}:{args.dev_port}",
|
||||
# ),
|
||||
# (
|
||||
# "Access-Control-Allow-Methods",
|
||||
# "DELETE, GET, HEAD, OPTIONS, PATCH, POST, PUT"
|
||||
# ),
|
||||
# (
|
||||
# "Allow",
|
||||
# "DELETE, GET, HEAD, OPTIONS, PATCH, POST, PUT"
|
||||
# )
|
||||
]
|
||||
else:
|
||||
base_url = f"http://{args.host}:{args.port}"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user