diff --git a/cmd/server/main.go b/cmd/server/main.go index 86e43cf..de440d1 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -7,6 +7,7 @@ import ( "net/http" "os" "os/signal" + "strings" "syscall" "time" @@ -54,6 +55,21 @@ func fatal(logger *slog.Logger, msg string, args ...any) { os.Exit(1) } +func isOriginAllowed(allowed []string, origin string) bool { + if len(allowed) == 0 { + return false + } + for _, item := range allowed { + if item == "*" { + return true + } + if strings.EqualFold(strings.TrimSpace(item), strings.TrimSpace(origin)) { + return true + } + } + return false +} + func main() { logger, _ := logging.New(logging.Options{Service: "ez-api"}) @@ -184,9 +200,18 @@ func main() { r := gin.Default() r.Use(middleware.RequestID()) + allowedOrigins := cfg.CORS.AllowOrigins + allowAllOrigins := isOriginAllowed(allowedOrigins, "*") + // CORS Middleware r.Use(func(c *gin.Context) { - c.Writer.Header().Set("Access-Control-Allow-Origin", "*") // TODO: Restrict this in production + origin := c.Request.Header.Get("Origin") + if allowAllOrigins { + c.Writer.Header().Set("Access-Control-Allow-Origin", "*") + } else if origin != "" && isOriginAllowed(allowedOrigins, origin) { + c.Writer.Header().Set("Access-Control-Allow-Origin", origin) + c.Writer.Header().Add("Vary", "Origin") + } c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With") c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE") @@ -257,6 +282,7 @@ func main() { adminGroup.POST("/models", handler.CreateModel) adminGroup.GET("/models", handler.ListModels) adminGroup.PUT("/models/:id", handler.UpdateModel) + adminGroup.DELETE("/models/:id", handler.DeleteModel) adminGroup.GET("/logs", handler.ListLogs) adminGroup.DELETE("/logs", handler.DeleteLogs) adminGroup.GET("/logs/stats", handler.LogStats) diff --git a/internal/api/admin_handler.go b/internal/api/admin_handler.go index cc30a4b..158a61a 100644 --- a/internal/api/admin_handler.go +++ b/internal/api/admin_handler.go @@ -129,12 +129,19 @@ func toMasterView(m model.Master) MasterView { // @Tags admin // @Produce json // @Security AdminAuth +// @Param page query int false "page (1-based)" +// @Param limit query int false "limit (default 50, max 200)" +// @Param search query string false "search by name/group" // @Success 200 {array} MasterView // @Failure 500 {object} gin.H // @Router /admin/masters [get] func (h *AdminHandler) ListMasters(c *gin.Context) { var masters []model.Master - if err := h.db.Order("id desc").Find(&masters).Error; err != nil { + q := h.db.Model(&model.Master{}).Order("id desc") + query := parseListQuery(c) + q = applyListSearch(q, query.Search, "name", `"group"`) + q = applyListPagination(q, query) + if err := q.Find(&masters).Error; err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list masters", "details": err.Error()}) return } diff --git a/internal/api/binding_handler.go b/internal/api/binding_handler.go index 5ba0708..21aeee1 100644 --- a/internal/api/binding_handler.go +++ b/internal/api/binding_handler.go @@ -80,12 +80,19 @@ func (h *Handler) CreateBinding(c *gin.Context) { // @Tags admin // @Produce json // @Security AdminAuth +// @Param page query int false "page (1-based)" +// @Param limit query int false "limit (default 50, max 200)" +// @Param search query string false "search by namespace/public_model/route_group" // @Success 200 {array} model.Binding // @Failure 500 {object} gin.H // @Router /admin/bindings [get] func (h *Handler) ListBindings(c *gin.Context) { var out []model.Binding - if err := h.db.Find(&out).Error; err != nil { + q := h.db.Model(&model.Binding{}).Order("id desc") + query := parseListQuery(c) + q = applyListSearch(q, query.Search, "namespace", "public_model", "route_group") + q = applyListPagination(q, query) + if err := q.Find(&out).Error; err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list bindings", "details": err.Error()}) return } diff --git a/internal/api/handler.go b/internal/api/handler.go index b5609aa..f390e1a 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -369,12 +369,19 @@ func (h *Handler) CreateModel(c *gin.Context) { // @Tags admin // @Produce json // @Security AdminAuth +// @Param page query int false "page (1-based)" +// @Param limit query int false "limit (default 50, max 200)" +// @Param search query string false "search by name/kind" // @Success 200 {array} model.Model // @Failure 500 {object} gin.H // @Router /admin/models [get] func (h *Handler) ListModels(c *gin.Context) { var models []model.Model - if err := h.db.Find(&models).Error; err != nil { + q := h.db.Model(&model.Model{}).Order("id desc") + query := parseListQuery(c) + q = applyListSearch(q, query.Search, "name", "kind") + q = applyListPagination(q, query) + if err := q.Find(&models).Error; err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list models", "details": err.Error()}) return } @@ -457,6 +464,45 @@ func (h *Handler) UpdateModel(c *gin.Context) { c.JSON(http.StatusOK, existing) } +// DeleteModel godoc +// @Summary Delete a model +// @Description Delete a model by id +// @Tags admin +// @Produce json +// @Security AdminAuth +// @Param id path int true "Model ID" +// @Success 200 {object} gin.H +// @Failure 400 {object} gin.H +// @Failure 404 {object} gin.H +// @Failure 500 {object} gin.H +// @Router /admin/models/{id} [delete] +func (h *Handler) DeleteModel(c *gin.Context) { + idParam := c.Param("id") + id, err := strconv.Atoi(idParam) + if err != nil || id <= 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"}) + return + } + + var existing model.Model + if err := h.db.First(&existing, id).Error; err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "model not found"}) + return + } + + if err := h.db.Delete(&existing).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete model", "details": err.Error()}) + return + } + + if err := h.sync.SyncModelDelete(&existing); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to sync model delete", "details": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"status": "deleted"}) +} + // SyncSnapshot godoc // @Summary Force sync snapshot // @Description Force full synchronization of DB state to Redis diff --git a/internal/api/list_query.go b/internal/api/list_query.go new file mode 100644 index 0000000..b9fda5c --- /dev/null +++ b/internal/api/list_query.go @@ -0,0 +1,88 @@ +package api + +import ( + "strconv" + "strings" + + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +type listQuery struct { + Page int + Limit int + Offset int + Search string + Enabled bool +} + +func parseListQuery(c *gin.Context) listQuery { + var q listQuery + if c == nil { + return q + } + if raw := strings.TrimSpace(c.Query("search")); raw != "" { + q.Search = raw + q.Enabled = true + } + if raw := strings.TrimSpace(c.Query("page")); raw != "" { + if v, err := strconv.Atoi(raw); err == nil && v > 0 { + q.Page = v + } + q.Enabled = true + } + if raw := strings.TrimSpace(c.Query("limit")); raw != "" { + if v, err := strconv.Atoi(raw); err == nil && v > 0 { + q.Limit = v + } + q.Enabled = true + } + if q.Enabled { + if q.Limit <= 0 { + q.Limit = 50 + } + if q.Limit > 200 { + q.Limit = 200 + } + if q.Page <= 0 { + q.Page = 1 + } + q.Offset = (q.Page - 1) * q.Limit + } + return q +} + +func applyListSearch(q *gorm.DB, search string, fields ...string) *gorm.DB { + if q == nil { + return q + } + search = strings.TrimSpace(search) + if search == "" || len(fields) == 0 { + return q + } + pattern := "%" + strings.ToLower(search) + "%" + clauses := make([]string, 0, len(fields)) + args := make([]any, 0, len(fields)) + for _, field := range fields { + field = strings.TrimSpace(field) + if field == "" { + continue + } + clauses = append(clauses, "LOWER("+field+") LIKE ?") + args = append(args, pattern) + } + if len(clauses) == 0 { + return q + } + return q.Where(strings.Join(clauses, " OR "), args...) +} + +func applyListPagination(q *gorm.DB, query listQuery) *gorm.DB { + if q == nil { + return q + } + if !query.Enabled { + return q + } + return q.Limit(query.Limit).Offset(query.Offset) +} diff --git a/internal/api/master_handler.go b/internal/api/master_handler.go index 287c75f..cede56c 100644 --- a/internal/api/master_handler.go +++ b/internal/api/master_handler.go @@ -215,6 +215,9 @@ func toTokenView(k model.Key) TokenView { // @Tags master // @Produce json // @Security MasterAuth +// @Param page query int false "page (1-based)" +// @Param limit query int false "limit (default 50, max 200)" +// @Param search query string false "search by group/scopes/namespaces/status" // @Success 200 {array} TokenView // @Failure 401 {object} gin.H // @Failure 500 {object} gin.H @@ -228,7 +231,11 @@ func (h *MasterHandler) ListTokens(c *gin.Context) { m := master.(*model.Master) var keys []model.Key - if err := h.db.Where("master_id = ?", m.ID).Order("id desc").Find(&keys).Error; err != nil { + q := h.db.Model(&model.Key{}).Where("master_id = ?", m.ID).Order("id desc") + query := parseListQuery(c) + q = applyListSearch(q, query.Search, `"group"`, "scopes", "default_namespace", "namespaces", "status") + q = applyListPagination(q, query) + if err := q.Find(&keys).Error; err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to list tokens", "details": err.Error()}) return } diff --git a/internal/api/model_handler_test.go b/internal/api/model_handler_test.go index 9fa4003..cc8c308 100644 --- a/internal/api/model_handler_test.go +++ b/internal/api/model_handler_test.go @@ -103,3 +103,47 @@ func TestCreateModel_InvalidKind_Returns400(t *testing.T) { t.Fatalf("expected 400, got %d body=%s", rr.Code, rr.Body.String()) } } + +func TestDeleteModel_RemovesMeta(t *testing.T) { + h, db, mr := newTestHandlerWithRedis(t) + + r := gin.New() + r.POST("/admin/models", h.CreateModel) + r.DELETE("/admin/models/:id", h.DeleteModel) + + reqBody := map[string]any{ + "name": "ns.del", + } + b, _ := json.Marshal(reqBody) + + req := httptest.NewRequest(http.MethodPost, "/admin/models", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + if rr.Code != http.StatusCreated { + t.Fatalf("expected 201, got %d body=%s", rr.Code, rr.Body.String()) + } + var created model.Model + if err := json.Unmarshal(rr.Body.Bytes(), &created); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + delReq := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/admin/models/%d", created.ID), nil) + delRec := httptest.NewRecorder() + r.ServeHTTP(delRec, delReq) + + if delRec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", delRec.Code, delRec.Body.String()) + } + if raw := mr.HGet("meta:models", "ns.del"); raw != "" { + t.Fatalf("expected meta:models[ns.del] removed, got %q", raw) + } + var remaining int64 + if err := db.Model(&model.Model{}).Where("name = ?", "ns.del").Count(&remaining).Error; err != nil { + t.Fatalf("count: %v", err) + } + if remaining != 0 { + t.Fatalf("expected model deleted, got count=%d", remaining) + } +} diff --git a/internal/api/provider_admin_handler.go b/internal/api/provider_admin_handler.go index 408aa99..e03f585 100644 --- a/internal/api/provider_admin_handler.go +++ b/internal/api/provider_admin_handler.go @@ -20,12 +20,19 @@ import ( // @Tags admin // @Produce json // @Security AdminAuth +// @Param page query int false "page (1-based)" +// @Param limit query int false "limit (default 50, max 200)" +// @Param search query string false "search by name/type/base_url/group" // @Success 200 {array} model.Provider // @Failure 500 {object} gin.H // @Router /admin/providers [get] func (h *Handler) ListProviders(c *gin.Context) { var providers []model.Provider - if err := h.db.Order("id desc").Find(&providers).Error; err != nil { + q := h.db.Model(&model.Provider{}).Order("id desc") + query := parseListQuery(c) + q = applyListSearch(q, query.Search, "name", `"type"`, "base_url", `"group"`) + q = applyListPagination(q, query) + if err := q.Find(&providers).Error; err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list providers", "details": err.Error()}) return } diff --git a/internal/config/config.go b/internal/config/config.go index 7782401..ee3b200 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -11,6 +11,7 @@ import ( type Config struct { Server ServerConfig + CORS CORSConfig Postgres PostgresConfig Redis RedisConfig Log LogConfig @@ -24,6 +25,10 @@ type ServerConfig struct { Port string } +type CORSConfig struct { + AllowOrigins []string +} + type AuthConfig struct { JWTSecret string } @@ -70,6 +75,7 @@ func Load() (*Config, error) { v := viper.New() v.SetDefault("server.port", "8080") + v.SetDefault("cors.allow_origins", "*") v.SetDefault("postgres.dsn", "host=localhost user=postgres password=postgres dbname=ezapi port=5432 sslmode=disable") v.SetDefault("redis.addr", "localhost:6379") v.SetDefault("redis.password", "") @@ -96,6 +102,7 @@ func Load() (*Config, error) { v.AutomaticEnv() _ = v.BindEnv("server.port", "EZ_API_PORT") + _ = v.BindEnv("cors.allow_origins", "EZ_CORS_ALLOW_ORIGINS") _ = v.BindEnv("postgres.dsn", "EZ_PG_DSN") _ = v.BindEnv("redis.addr", "EZ_REDIS_ADDR") _ = v.BindEnv("redis.password", "EZ_REDIS_PASSWORD") @@ -137,6 +144,9 @@ func Load() (*Config, error) { Server: ServerConfig{ Port: v.GetString("server.port"), }, + CORS: CORSConfig{ + AllowOrigins: splitCommaList(v.GetString("cors.allow_origins")), + }, Postgres: PostgresConfig{ DSN: v.GetString("postgres.dsn"), }, @@ -176,3 +186,16 @@ func Load() (*Config, error) { return cfg, nil } + +func splitCommaList(raw string) []string { + parts := strings.Split(raw, ",") + out := make([]string, 0, len(parts)) + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + out = append(out, part) + } + return out +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index be1c2b8..da10b06 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -5,6 +5,7 @@ import "testing" func TestLoad_LogDSNOverride(t *testing.T) { t.Setenv("EZ_LOG_PG_DSN", "host=log-db user=postgres dbname=logs") t.Setenv("EZ_LOG_PARTITIONING", "monthly") + t.Setenv("EZ_CORS_ALLOW_ORIGINS", "https://a.example.com,https://b.example.com") cfg, err := Load() if err != nil { t.Fatalf("load config: %v", err) @@ -15,4 +16,7 @@ func TestLoad_LogDSNOverride(t *testing.T) { if cfg.Log.Partitioning != "monthly" { t.Fatalf("expected log partitioning to be set, got %q", cfg.Log.Partitioning) } + if len(cfg.CORS.AllowOrigins) != 2 { + t.Fatalf("expected cors allow origins, got %v", cfg.CORS.AllowOrigins) + } } diff --git a/internal/service/sync.go b/internal/service/sync.go index 7b4fefe..c4374c6 100644 --- a/internal/service/sync.go +++ b/internal/service/sync.go @@ -176,6 +176,25 @@ func (s *SyncService) SyncModel(m *model.Model) error { return nil } +// SyncModelDelete removes model metadata from Redis and refreshes meta:models_meta. +func (s *SyncService) SyncModelDelete(m *model.Model) error { + if m == nil { + return fmt.Errorf("model required") + } + name := strings.TrimSpace(m.Name) + if name == "" { + return fmt.Errorf("model name required") + } + ctx := context.Background() + if err := s.rdb.HDel(ctx, "meta:models", name).Err(); err != nil { + return fmt.Errorf("delete meta:models: %w", err) + } + if err := s.refreshModelsMetaFromRedis(ctx, "db"); err != nil { + return err + } + return nil +} + type providerSnapshot struct { ID uint `json:"id"` Name string `json:"name"` diff --git a/internal/service/sync_test.go b/internal/service/sync_test.go index 3b1460e..dea6f51 100644 --- a/internal/service/sync_test.go +++ b/internal/service/sync_test.go @@ -93,6 +93,26 @@ func TestSyncKey_WritesTokenID(t *testing.T) { } } +func TestSyncModelDelete_RemovesMeta(t *testing.T) { + mr := miniredis.RunT(t) + rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()}) + svc := NewSyncService(rdb) + + mr.HSet("meta:models", "ns.m", `{"name":"ns.m"}`) + + m := &model.Model{Name: "ns.m"} + if err := svc.SyncModelDelete(m); err != nil { + t.Fatalf("SyncModelDelete: %v", err) + } + + if got := mr.HGet("meta:models", "ns.m"); got != "" { + t.Fatalf("expected meta:models entry removed, got %q", got) + } + if v := mr.HGet("meta:models_meta", "version"); v == "" { + t.Fatalf("expected meta:models_meta.version to be set") + } +} + func TestSyncProviderDelete_RemovesSnapshotAndRouting(t *testing.T) { mr := miniredis.RunT(t) rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()})