Merge pull request 'fix cors headers in development' (#249) from cors into main

This commit is contained in:
clan-bot
2023-09-06 12:12:12 +00:00
3 changed files with 27 additions and 61 deletions

View File

@@ -1,10 +1,8 @@
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.routing import APIRoute from fastapi.routing import APIRoute
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from .assets import asset_path from .assets import asset_path
from .config import settings
from .routers import health, machines, root from .routers import health, machines, root
@@ -14,16 +12,6 @@ def setup_app() -> FastAPI:
app.include_router(machines.router) app.include_router(machines.router)
app.include_router(root.router) app.include_router(root.router)
if settings.env.is_development():
# TODO make this configurable
app.add_middleware(
CORSMiddleware,
allow_origins="http://${settings.dev_host}:${settings.dev_port}",
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
else:
app.mount("/static", StaticFiles(directory=asset_path()), name="static") app.mount("/static", StaticFiles(directory=asset_path()), name="static")
for route in app.routes: for route in app.routes:

View File

@@ -1,38 +0,0 @@
# config.py
import logging
import os
from enum import Enum
from pydantic import BaseSettings
logger = logging.getLogger(__name__)
class EnvType(Enum):
production = "production"
development = "development"
@staticmethod
def from_environment() -> "EnvType":
t = os.environ.get("CLAN_WEBUI_ENV", "production")
try:
return EnvType[t]
except KeyError:
logger.warning(f"Invalid environment type: {t}, fallback to production")
return EnvType.production
def is_production(self) -> bool:
return self == EnvType.production
def is_development(self) -> bool:
return self == EnvType.development
class Settings(BaseSettings):
env: EnvType = EnvType.from_environment()
dev_port: int = int(os.environ.get("CLAN_WEBUI_DEV_PORT", 3000))
dev_host: str = os.environ.get("CLAN_WEBUI_DEV_HOST", "localhost")
# global instance
settings = Settings()

View File

@@ -1,6 +1,5 @@
import argparse import argparse
import logging import logging
import os
import subprocess import subprocess
import time import time
import urllib.request import urllib.request
@@ -27,11 +26,23 @@ def defer_open_browser(base_url: str) -> None:
@contextmanager @contextmanager
def spawn_node_dev_server() -> Iterator[None]: def spawn_node_dev_server(host: str, port: int) -> Iterator[None]:
logger.info("Starting node dev server...") logger.info("Starting node dev server...")
path = Path(__file__).parent.parent.parent.parent / "ui" path = Path(__file__).parent.parent.parent.parent / "ui"
with subprocess.Popen( with subprocess.Popen(
["direnv", "exec", path, "npm", "run", "dev"], [
"direnv",
"exec",
path,
"npm",
"run",
"dev",
"--",
"--hostname",
host,
"--port",
str(port),
],
cwd=path, cwd=path,
) as proc: ) as proc:
try: try:
@@ -42,16 +53,21 @@ def spawn_node_dev_server() -> Iterator[None]:
def start_server(args: argparse.Namespace) -> None: def start_server(args: argparse.Namespace) -> None:
with ExitStack() as stack: with ExitStack() as stack:
headers: list[tuple[str, str]] = []
if args.dev: if args.dev:
os.environ["CLAN_WEBUI_ENV"] = "development" stack.enter_context(spawn_node_dev_server(args.dev_host, args.dev_port))
os.environ["CLAN_WEBUI_DEV_PORT"] = str(args.dev_port)
os.environ["CLAN_WEBUI_DEV_HOST"] = args.dev_host
stack.enter_context(spawn_node_dev_server())
open_url = f"http://{args.dev_host}:{args.dev_port}" open_url = f"http://{args.dev_host}:{args.dev_port}"
host = args.dev_host
if ":" in host:
host = f"[{host}]"
headers = [
(
"Access-Control-Allow-Origin",
f"http://{host}:{args.dev_port}",
)
]
else: else:
os.environ["CLAN_WEBUI_ENV"] = "production"
open_url = f"http://[{args.host}]:{args.port}" open_url = f"http://[{args.host}]:{args.port}"
if not args.no_open: if not args.no_open:
@@ -63,5 +79,5 @@ def start_server(args: argparse.Namespace) -> None:
port=args.port, port=args.port,
log_level=args.log_level, log_level=args.log_level,
reload=args.reload, reload=args.reload,
headers=[("Access-Control-Allow-Origin", "*")], headers=headers,
) )