package service import ( "fmt" "sort" "strings" "time" "gorm.io/gorm" ) type LogPartitioningMode string const ( LogPartitioningOff LogPartitioningMode = "off" LogPartitioningMonthly LogPartitioningMode = "monthly" LogPartitioningDaily LogPartitioningMode = "daily" ) type LogPartition struct { Table string Start time.Time End time.Time } type LogPartitioner struct { db *gorm.DB mode LogPartitioningMode baseTable string viewTable string } func NewLogPartitioner(db *gorm.DB, mode string) *LogPartitioner { return &LogPartitioner{ db: db, mode: normalizePartitioningMode(mode), baseTable: "log_records", viewTable: "log_records_all", } } func (p *LogPartitioner) Enabled() bool { if p == nil || p.db == nil { return false } if p.mode == LogPartitioningOff { return false } return p.db.Dialector.Name() == "postgres" } func (p *LogPartitioner) ViewName() string { if p == nil { return "log_records" } if p.Enabled() { return p.viewTable } return p.baseTable } func (p *LogPartitioner) TableForTime(t time.Time) string { if p == nil || !p.Enabled() { return "log_records" } t = t.UTC() switch p.mode { case LogPartitioningDaily: return fmt.Sprintf("%s_%04d%02d%02d", p.baseTable, t.Year(), int(t.Month()), t.Day()) case LogPartitioningMonthly: fallthrough default: return fmt.Sprintf("%s_%04d%02d", p.baseTable, t.Year(), int(t.Month())) } } func (p *LogPartitioner) EnsurePartitionFor(t time.Time) (string, error) { if p == nil || !p.Enabled() { return "log_records", nil } table := p.TableForTime(t) if err := p.ensureTable(table); err != nil { return "", err } if err := p.ensureView(); err != nil { return "", err } return table, nil } func (p *LogPartitioner) ListPartitions() ([]LogPartition, error) { if p == nil || !p.Enabled() { return nil, nil } tables, err := p.listPartitionTables() if err != nil { return nil, err } partitions := make([]LogPartition, 0, len(tables)) for _, table := range tables { start, end, ok := p.parsePartitionRange(table) if !ok { continue } partitions = append(partitions, LogPartition{Table: table, Start: start, End: end}) } sort.Slice(partitions, func(i, j int) bool { return partitions[i].Start.Before(partitions[j].Start) }) return partitions, nil } func (p *LogPartitioner) DropPartitionsBefore(cutoff time.Time) (int, error) { if p == nil || !p.Enabled() { return 0, nil } partitions, err := p.ListPartitions() if err != nil { return 0, err } cutoff = cutoff.UTC() dropped := 0 for _, part := range partitions { if part.End.After(cutoff) || part.End.Equal(cutoff) { continue } if err := p.dropTable(part.Table); err != nil { return dropped, err } dropped++ } if dropped > 0 { if err := p.ensureView(); err != nil { return dropped, err } } return dropped, nil } func (p *LogPartitioner) ensureTable(table string) error { if p == nil || !p.Enabled() { return nil } if table == "" || !p.validPartitionTable(table) { return fmt.Errorf("invalid partition table %q", table) } if p.db.Migrator().HasTable(table) { return nil } sql := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (LIKE %s INCLUDING ALL)", quoteIdent(table), quoteIdent(p.baseTable)) return p.db.Exec(sql).Error } func (p *LogPartitioner) ensureView() error { if p == nil || !p.Enabled() { return nil } tables, err := p.listPartitionTables() if err != nil { return err } selects := make([]string, 0, len(tables)+1) selects = append(selects, fmt.Sprintf("SELECT * FROM %s", quoteIdent(p.baseTable))) for _, table := range tables { if table == p.baseTable { continue } selects = append(selects, fmt.Sprintf("SELECT * FROM %s", quoteIdent(table))) } viewSQL := fmt.Sprintf("CREATE OR REPLACE VIEW %s AS %s", quoteIdent(p.viewTable), strings.Join(selects, " UNION ALL ")) return p.db.Exec(viewSQL).Error } func (p *LogPartitioner) listPartitionTables() ([]string, error) { if p == nil || !p.Enabled() { return nil, nil } var tables []string err := p.db.Raw( `SELECT table_name FROM information_schema.tables WHERE table_schema = current_schema() AND table_type = 'BASE TABLE' AND table_name LIKE ?`, p.baseTable+"_%", ).Scan(&tables).Error if err != nil { return nil, err } out := make([]string, 0, len(tables)) for _, table := range tables { if p.validPartitionTable(table) { out = append(out, table) } } return out, nil } func (p *LogPartitioner) parsePartitionRange(table string) (time.Time, time.Time, bool) { if !p.validPartitionTable(table) { return time.Time{}, time.Time{}, false } raw := strings.TrimPrefix(table, p.baseTable+"_") if p.mode == LogPartitioningDaily { if len(raw) != 8 { return time.Time{}, time.Time{}, false } t, err := time.Parse("20060102", raw) if err != nil { return time.Time{}, time.Time{}, false } start := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC) end := start.AddDate(0, 0, 1) return start, end, true } if len(raw) != 6 { return time.Time{}, time.Time{}, false } t, err := time.Parse("200601", raw) if err != nil { return time.Time{}, time.Time{}, false } start := time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, time.UTC) end := start.AddDate(0, 1, 0) return start, end, true } func (p *LogPartitioner) validPartitionTable(table string) bool { if p == nil || table == "" { return false } if !strings.HasPrefix(table, p.baseTable+"_") { return false } raw := strings.TrimPrefix(table, p.baseTable+"_") if p.mode == LogPartitioningDaily { return len(raw) == 8 && isDigits(raw) } return len(raw) == 6 && isDigits(raw) } func (p *LogPartitioner) dropTable(table string) error { if p == nil || !p.Enabled() { return nil } if !p.validPartitionTable(table) { return fmt.Errorf("invalid partition table %q", table) } sql := fmt.Sprintf("DROP TABLE IF EXISTS %s", quoteIdent(table)) return p.db.Exec(sql).Error } func normalizePartitioningMode(raw string) LogPartitioningMode { raw = strings.ToLower(strings.TrimSpace(raw)) switch raw { case string(LogPartitioningDaily): return LogPartitioningDaily case string(LogPartitioningMonthly): return LogPartitioningMonthly default: return LogPartitioningOff } } func quoteIdent(name string) string { return `"` + strings.ReplaceAll(name, `"`, `""`) + `"` } func isDigits(raw string) bool { for _, r := range raw { if r < '0' || r > '9' { return false } } return raw != "" }