package config import ( "fmt" "os" "strings" "time" "github.com/spf13/viper" ) type Config struct { Server ServerConfig CORS CORSConfig Postgres PostgresConfig Redis RedisConfig Log LogConfig Auth AuthConfig ModelRegistry ModelRegistryConfig Quota QuotaConfig Internal InternalConfig SyncOutbox SyncOutboxConfig TokenRefresh TokenRefreshConfig } type ServerConfig struct { Port string SwaggerHost string // Swagger UI 显示的 Host (可选,留空则使用相对路径) } type CORSConfig struct { AllowOrigins []string } type AuthConfig struct { JWTSecret string } type PostgresConfig struct { DSN string } type RedisConfig struct { Addr string Password string DB int } type LogConfig struct { BatchSize int FlushInterval time.Duration QueueCapacity int RetentionDays int MaxRecords int DSN string Partitioning string } type ModelRegistryConfig struct { Enabled bool RefreshSeconds int ModelsDevBaseURL string ModelsDevAPIBaseURL string ModelsDevRef string CacheDir string TimeoutSeconds int } type QuotaConfig struct { ResetIntervalSeconds int } type InternalConfig struct { StatsToken string AllowAnonymous bool } type SyncOutboxConfig struct { Enabled bool IntervalSeconds int BatchSize int MaxRetries int } type TokenRefreshConfig struct { IntervalSeconds int RefreshSkewSeconds int BatchSize int MaxRetries int } func Load() (*Config, error) { v := viper.New() v.SetDefault("server.port", "8080") v.SetDefault("server.swagger_host", "") v.SetDefault("cors.allow_origins", "*") v.SetDefault("postgres.dsn", "host=localhost user=postgres password=postgres dbname=ezapi port=5432 sslmode=disable") v.SetDefault("redis.addr", "localhost:6379") v.SetDefault("redis.password", "") v.SetDefault("redis.db", 0) v.SetDefault("log.batch_size", 10) v.SetDefault("log.flush_ms", 1000) v.SetDefault("log.queue_capacity", 10000) v.SetDefault("log.retention_days", 30) v.SetDefault("log.max_records", 1000000) v.SetDefault("log.dsn", "") v.SetDefault("log.partitioning", "off") v.SetDefault("auth.jwt_secret", "change_me_in_production") v.SetDefault("model_registry.enabled", false) v.SetDefault("model_registry.refresh_seconds", 1800) v.SetDefault("model_registry.models_dev_base_url", "https://codeload.github.com/sst/models.dev/tar.gz") v.SetDefault("model_registry.models_dev_api_base_url", "https://api.github.com") v.SetDefault("model_registry.models_dev_ref", "dev") v.SetDefault("model_registry.cache_dir", "./data/model-registry") v.SetDefault("model_registry.timeout_seconds", 30) v.SetDefault("quota.reset_interval_seconds", 300) v.SetDefault("internal.stats_token", "") v.SetDefault("internal.allow_anonymous", false) v.SetDefault("sync_outbox.enabled", true) v.SetDefault("sync_outbox.interval_seconds", 5) v.SetDefault("sync_outbox.batch_size", 200) v.SetDefault("sync_outbox.max_retries", 10) v.SetDefault("token_refresh.interval_seconds", 1800) v.SetDefault("token_refresh.refresh_skew_seconds", 3000) v.SetDefault("token_refresh.batch_size", 200) v.SetDefault("token_refresh.max_retries", 3) v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) v.AutomaticEnv() _ = v.BindEnv("server.port", "EZ_API_PORT") _ = v.BindEnv("server.swagger_host", "EZ_SWAGGER_HOST") _ = v.BindEnv("cors.allow_origins", "EZ_CORS_ALLOW_ORIGINS") _ = v.BindEnv("postgres.dsn", "EZ_PG_DSN") _ = v.BindEnv("redis.addr", "EZ_REDIS_ADDR") _ = v.BindEnv("redis.password", "EZ_REDIS_PASSWORD") _ = v.BindEnv("redis.db", "EZ_REDIS_DB") _ = v.BindEnv("log.batch_size", "EZ_LOG_BATCH_SIZE") _ = v.BindEnv("log.flush_ms", "EZ_LOG_FLUSH_MS") _ = v.BindEnv("log.queue_capacity", "EZ_LOG_QUEUE") _ = v.BindEnv("log.retention_days", "EZ_LOG_RETENTION_DAYS") _ = v.BindEnv("log.max_records", "EZ_LOG_MAX_RECORDS") _ = v.BindEnv("log.dsn", "EZ_LOG_PG_DSN") _ = v.BindEnv("log.partitioning", "EZ_LOG_PARTITIONING") _ = v.BindEnv("auth.jwt_secret", "EZ_JWT_SECRET") _ = v.BindEnv("model_registry.enabled", "EZ_MODEL_REGISTRY_ENABLED") _ = v.BindEnv("model_registry.refresh_seconds", "EZ_MODEL_REGISTRY_REFRESH_SECONDS") _ = v.BindEnv("model_registry.models_dev_base_url", "EZ_MODEL_REGISTRY_MODELS_DEV_BASE_URL") _ = v.BindEnv("model_registry.models_dev_api_base_url", "EZ_MODEL_REGISTRY_MODELS_DEV_API_BASE_URL") _ = v.BindEnv("model_registry.models_dev_ref", "EZ_MODEL_REGISTRY_MODELS_DEV_REF") _ = v.BindEnv("model_registry.cache_dir", "EZ_MODEL_REGISTRY_CACHE_DIR") _ = v.BindEnv("model_registry.timeout_seconds", "EZ_MODEL_REGISTRY_TIMEOUT_SECONDS") _ = v.BindEnv("quota.reset_interval_seconds", "EZ_QUOTA_RESET_INTERVAL_SECONDS") _ = v.BindEnv("internal.stats_token", "EZ_INTERNAL_STATS_TOKEN") _ = v.BindEnv("internal.allow_anonymous", "EZ_INTERNAL_ALLOW_ANON") _ = v.BindEnv("sync_outbox.enabled", "EZ_SYNC_OUTBOX_ENABLED") _ = v.BindEnv("sync_outbox.interval_seconds", "EZ_SYNC_OUTBOX_INTERVAL_SECONDS") _ = v.BindEnv("sync_outbox.batch_size", "EZ_SYNC_OUTBOX_BATCH_SIZE") _ = v.BindEnv("sync_outbox.max_retries", "EZ_SYNC_OUTBOX_MAX_RETRIES") _ = v.BindEnv("token_refresh.interval_seconds", "EZ_TOKEN_REFRESH_INTERVAL_SECONDS") _ = v.BindEnv("token_refresh.refresh_skew_seconds", "EZ_TOKEN_REFRESH_SKEW_SECONDS") _ = v.BindEnv("token_refresh.batch_size", "EZ_TOKEN_REFRESH_BATCH_SIZE") _ = v.BindEnv("token_refresh.max_retries", "EZ_TOKEN_REFRESH_MAX_RETRIES") if configFile := os.Getenv("EZ_CONFIG_FILE"); configFile != "" { v.SetConfigFile(configFile) } else { v.SetConfigName("config") v.SetConfigType("yaml") v.AddConfigPath(".") v.AddConfigPath("./config") } if err := v.ReadInConfig(); err != nil { if _, ok := err.(viper.ConfigFileNotFoundError); !ok { return nil, fmt.Errorf("read config: %w", err) } } cfg := &Config{ Server: ServerConfig{ Port: v.GetString("server.port"), SwaggerHost: strings.TrimSpace(v.GetString("server.swagger_host")), }, CORS: CORSConfig{ AllowOrigins: splitCommaList(v.GetString("cors.allow_origins")), }, Postgres: PostgresConfig{ DSN: v.GetString("postgres.dsn"), }, Redis: RedisConfig{ Addr: v.GetString("redis.addr"), Password: v.GetString("redis.password"), DB: v.GetInt("redis.db"), }, Log: LogConfig{ BatchSize: v.GetInt("log.batch_size"), FlushInterval: time.Duration(v.GetInt("log.flush_ms")) * time.Millisecond, QueueCapacity: v.GetInt("log.queue_capacity"), RetentionDays: v.GetInt("log.retention_days"), MaxRecords: v.GetInt("log.max_records"), DSN: strings.TrimSpace(v.GetString("log.dsn")), Partitioning: strings.TrimSpace(v.GetString("log.partitioning")), }, Auth: AuthConfig{ JWTSecret: v.GetString("auth.jwt_secret"), }, ModelRegistry: ModelRegistryConfig{ Enabled: v.GetBool("model_registry.enabled"), RefreshSeconds: v.GetInt("model_registry.refresh_seconds"), ModelsDevBaseURL: v.GetString("model_registry.models_dev_base_url"), ModelsDevAPIBaseURL: v.GetString("model_registry.models_dev_api_base_url"), ModelsDevRef: v.GetString("model_registry.models_dev_ref"), CacheDir: v.GetString("model_registry.cache_dir"), TimeoutSeconds: v.GetInt("model_registry.timeout_seconds"), }, Quota: QuotaConfig{ ResetIntervalSeconds: v.GetInt("quota.reset_interval_seconds"), }, Internal: InternalConfig{ StatsToken: v.GetString("internal.stats_token"), AllowAnonymous: v.GetBool("internal.allow_anonymous"), }, SyncOutbox: SyncOutboxConfig{ Enabled: v.GetBool("sync_outbox.enabled"), IntervalSeconds: v.GetInt("sync_outbox.interval_seconds"), BatchSize: v.GetInt("sync_outbox.batch_size"), MaxRetries: v.GetInt("sync_outbox.max_retries"), }, TokenRefresh: TokenRefreshConfig{ IntervalSeconds: v.GetInt("token_refresh.interval_seconds"), RefreshSkewSeconds: v.GetInt("token_refresh.refresh_skew_seconds"), BatchSize: v.GetInt("token_refresh.batch_size"), MaxRetries: v.GetInt("token_refresh.max_retries"), }, } return cfg, nil } func splitCommaList(raw string) []string { parts := strings.Split(raw, ",") out := make([]string, 0, len(parts)) for _, part := range parts { part = strings.TrimSpace(part) if part == "" { continue } out = append(out, part) } return out }