Files
ez-api/internal/service/log_partition.go

273 lines
6.4 KiB
Go

package service
import (
"fmt"
"sort"
"strings"
"time"
"gorm.io/gorm"
)
type LogPartitioningMode string
const (
LogPartitioningOff LogPartitioningMode = "off"
LogPartitioningMonthly LogPartitioningMode = "monthly"
LogPartitioningDaily LogPartitioningMode = "daily"
)
type LogPartition struct {
Table string
Start time.Time
End time.Time
}
type LogPartitioner struct {
db *gorm.DB
mode LogPartitioningMode
baseTable string
viewTable string
}
func NewLogPartitioner(db *gorm.DB, mode string) *LogPartitioner {
return &LogPartitioner{
db: db,
mode: normalizePartitioningMode(mode),
baseTable: "log_records",
viewTable: "log_records_all",
}
}
func (p *LogPartitioner) Enabled() bool {
if p == nil || p.db == nil {
return false
}
if p.mode == LogPartitioningOff {
return false
}
return p.db.Dialector.Name() == "postgres"
}
func (p *LogPartitioner) ViewName() string {
if p == nil {
return "log_records"
}
if p.Enabled() {
return p.viewTable
}
return p.baseTable
}
func (p *LogPartitioner) TableForTime(t time.Time) string {
if p == nil || !p.Enabled() {
return "log_records"
}
t = t.UTC()
switch p.mode {
case LogPartitioningDaily:
return fmt.Sprintf("%s_%04d%02d%02d", p.baseTable, t.Year(), int(t.Month()), t.Day())
case LogPartitioningMonthly:
fallthrough
default:
return fmt.Sprintf("%s_%04d%02d", p.baseTable, t.Year(), int(t.Month()))
}
}
func (p *LogPartitioner) EnsurePartitionFor(t time.Time) (string, error) {
if p == nil || !p.Enabled() {
return "log_records", nil
}
table := p.TableForTime(t)
if err := p.ensureTable(table); err != nil {
return "", err
}
if err := p.ensureView(); err != nil {
return "", err
}
return table, nil
}
func (p *LogPartitioner) ListPartitions() ([]LogPartition, error) {
if p == nil || !p.Enabled() {
return nil, nil
}
tables, err := p.listPartitionTables()
if err != nil {
return nil, err
}
partitions := make([]LogPartition, 0, len(tables))
for _, table := range tables {
start, end, ok := p.parsePartitionRange(table)
if !ok {
continue
}
partitions = append(partitions, LogPartition{Table: table, Start: start, End: end})
}
sort.Slice(partitions, func(i, j int) bool {
return partitions[i].Start.Before(partitions[j].Start)
})
return partitions, nil
}
func (p *LogPartitioner) DropPartitionsBefore(cutoff time.Time) (int, error) {
if p == nil || !p.Enabled() {
return 0, nil
}
partitions, err := p.ListPartitions()
if err != nil {
return 0, err
}
cutoff = cutoff.UTC()
dropped := 0
for _, part := range partitions {
if part.End.After(cutoff) || part.End.Equal(cutoff) {
continue
}
if err := p.dropTable(part.Table); err != nil {
return dropped, err
}
dropped++
}
if dropped > 0 {
if err := p.ensureView(); err != nil {
return dropped, err
}
}
return dropped, nil
}
func (p *LogPartitioner) ensureTable(table string) error {
if p == nil || !p.Enabled() {
return nil
}
if table == "" || !p.validPartitionTable(table) {
return fmt.Errorf("invalid partition table %q", table)
}
if p.db.Migrator().HasTable(table) {
return nil
}
sql := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (LIKE %s INCLUDING ALL)", quoteIdent(table), quoteIdent(p.baseTable))
return p.db.Exec(sql).Error
}
func (p *LogPartitioner) ensureView() error {
if p == nil || !p.Enabled() {
return nil
}
tables, err := p.listPartitionTables()
if err != nil {
return err
}
selects := make([]string, 0, len(tables)+1)
selects = append(selects, fmt.Sprintf("SELECT * FROM %s", quoteIdent(p.baseTable)))
for _, table := range tables {
if table == p.baseTable {
continue
}
selects = append(selects, fmt.Sprintf("SELECT * FROM %s", quoteIdent(table)))
}
viewSQL := fmt.Sprintf("CREATE OR REPLACE VIEW %s AS %s", quoteIdent(p.viewTable), strings.Join(selects, " UNION ALL "))
return p.db.Exec(viewSQL).Error
}
func (p *LogPartitioner) listPartitionTables() ([]string, error) {
if p == nil || !p.Enabled() {
return nil, nil
}
var tables []string
err := p.db.Raw(
`SELECT table_name FROM information_schema.tables WHERE table_schema = current_schema() AND table_type = 'BASE TABLE' AND table_name LIKE ?`,
p.baseTable+"_%",
).Scan(&tables).Error
if err != nil {
return nil, err
}
out := make([]string, 0, len(tables))
for _, table := range tables {
if p.validPartitionTable(table) {
out = append(out, table)
}
}
return out, nil
}
func (p *LogPartitioner) parsePartitionRange(table string) (time.Time, time.Time, bool) {
if !p.validPartitionTable(table) {
return time.Time{}, time.Time{}, false
}
raw := strings.TrimPrefix(table, p.baseTable+"_")
if p.mode == LogPartitioningDaily {
if len(raw) != 8 {
return time.Time{}, time.Time{}, false
}
t, err := time.Parse("20060102", raw)
if err != nil {
return time.Time{}, time.Time{}, false
}
start := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
end := start.AddDate(0, 0, 1)
return start, end, true
}
if len(raw) != 6 {
return time.Time{}, time.Time{}, false
}
t, err := time.Parse("200601", raw)
if err != nil {
return time.Time{}, time.Time{}, false
}
start := time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, time.UTC)
end := start.AddDate(0, 1, 0)
return start, end, true
}
func (p *LogPartitioner) validPartitionTable(table string) bool {
if p == nil || table == "" {
return false
}
if !strings.HasPrefix(table, p.baseTable+"_") {
return false
}
raw := strings.TrimPrefix(table, p.baseTable+"_")
if p.mode == LogPartitioningDaily {
return len(raw) == 8 && isDigits(raw)
}
return len(raw) == 6 && isDigits(raw)
}
func (p *LogPartitioner) dropTable(table string) error {
if p == nil || !p.Enabled() {
return nil
}
if !p.validPartitionTable(table) {
return fmt.Errorf("invalid partition table %q", table)
}
sql := fmt.Sprintf("DROP TABLE IF EXISTS %s", quoteIdent(table))
return p.db.Exec(sql).Error
}
func normalizePartitioningMode(raw string) LogPartitioningMode {
raw = strings.ToLower(strings.TrimSpace(raw))
switch raw {
case string(LogPartitioningDaily):
return LogPartitioningDaily
case string(LogPartitioningMonthly):
return LogPartitioningMonthly
default:
return LogPartitioningOff
}
}
func quoteIdent(name string) string {
return `"` + strings.ReplaceAll(name, `"`, `""`) + `"`
}
func isDigits(raw string) bool {
for _, r := range raw {
if r < '0' || r > '9' {
return false
}
}
return raw != ""
}