// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved. //go:build plus package dbs import ( "encoding/binary" "encoding/json" "errors" "github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs" "github.com/TeaOSLab/EdgeDNS/internal/models" "github.com/TeaOSLab/EdgeDNS/internal/remotelogs" "github.com/cockroachdb/pebble" "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/types" "log" "os" "path/filepath" "strings" "sync" ) const ( domainPrefix = "domains:" domainClusterIndex = "domains_cluster:" recordPrefix = "records:" routePrefix = "routes:" keyPrefix = "keys:" agentIPPrefix = "agent_ips:" ) var defaultWriteOptions = &pebble.WriteOptions{Sync: false} var flushRawDB = func(rawDB *pebble.DB) error { return rawDB.Flush() } var sharedRawDBLocker sync.Mutex var sharedRawDBMap = map[string]*sharedRawDB{} type sharedRawDB struct { rawDB *pebble.DB refs int } type DB struct { path string storePath string rawDB *pebble.DB } type domainValue struct { Id int64 `json:"id"` ClusterId int64 `json:"clusterId"` UserId int64 `json:"userId"` Name string `json:"name"` TSIGJSON string `json:"tsigJSON,omitempty"` Version int64 `json:"version"` } type recordValue struct { Id int64 `json:"id"` DomainId int64 `json:"domainId"` Name string `json:"name"` Type dnsconfigs.RecordType `json:"type"` Value string `json:"value"` MXPriority int32 `json:"mxPriority"` SRVPriority int32 `json:"srvPriority"` SRVWeight int32 `json:"srvWeight"` SRVPort int32 `json:"srvPort"` CAAFlag int32 `json:"caaFlag"` CAATag string `json:"caaTag"` TTL int32 `json:"ttl"` Weight int32 `json:"weight"` RouteIds []string `json:"routeIds,omitempty"` Version int64 `json:"version"` } type routeValue struct { Id int64 `json:"id"` UserId int64 `json:"userId"` RangesJSON string `json:"rangesJSON"` Priority int32 `json:"priority"` Order int32 `json:"order"` Version int64 `json:"version"` } func NewDB(path string) *DB { return &DB{ path: path, storePath: storePath(path), } } func (this *DB) Init() error { var dir = filepath.Dir(this.storePath) _, err := os.Stat(dir) if err != nil { err = os.MkdirAll(dir, 0777) if err != nil { return err } remotelogs.Println("DB", "create database dir '"+dir+"'") } rawDB, err := openSharedRawDB(this.storePath) if err != nil { return err } this.rawDB = rawDB return this.migrateSQLiteIfNeeded() } func (this *DB) InsertDomain(domainId int64, clusterId int64, userId int64, name string, tsigJSON []byte, version int64) error { if this.rawDB == nil { return errors.New("db should not be nil") } this.log("InsertDomain", "domain:", domainId, "user:", userId, "name:", name) return this.saveDomain(&domainValue{ Id: domainId, ClusterId: clusterId, UserId: userId, Name: name, TSIGJSON: string(tsigJSON), Version: version, }) } func (this *DB) UpdateDomain(domainId int64, clusterId int64, userId int64, name string, tsigJSON []byte, version int64) error { if this.rawDB == nil { return errors.New("db should not be nil") } this.log("UpdateDomain", "domain:", domainId, "user:", userId, "name:", name) oldValue, err := this.readDomainValue(domainId) if err != nil { if isNotFound(err) { return nil } return err } return this.saveDomainWithOld(&domainValue{ Id: domainId, ClusterId: clusterId, UserId: userId, Name: name, TSIGJSON: string(tsigJSON), Version: version, }, oldValue) } func (this *DB) DeleteDomain(domainId int64) error { if this.rawDB == nil { return errors.New("db should not be nil") } this.log("DeleteDomain", "domain:", domainId) oldValue, err := this.readDomainValue(domainId) if err != nil { if isNotFound(err) { return nil } return err } batch := this.rawDB.NewBatch() defer func() { _ = batch.Close() }() err = batch.Delete(domainKey(domainId), defaultWriteOptions) if err != nil { return err } err = batch.Delete(domainClusterKey(oldValue.ClusterId, domainId), defaultWriteOptions) if err != nil && !isNotFound(err) { return err } return batch.Commit(defaultWriteOptions) } func (this *DB) ExistsDomain(domainId int64) (bool, error) { if this.rawDB == nil { return false, errors.New("db should not be nil") } _, err := this.readDomainValue(domainId) if err != nil { if isNotFound(err) { return false, nil } return false, err } return true, nil } func (this *DB) ListDomains(clusterId int64, offset int, size int) (domains []*models.NSDomain, err error) { if this.rawDB == nil { return nil, errors.New("db should not be nil") } if size <= 0 { return nil, nil } var prefix = append([]byte(domainClusterIndex), encodeInt64(clusterId)...) it, err := this.newIterator(prefix) if err != nil { return nil, err } defer func() { _ = it.Close() }() var skipped = 0 for it.First(); it.Valid(); it.Next() { if skipped < offset { skipped++ continue } var keyBytes = append([]byte(nil), it.Key()...) var domainId = decodeInt64(keyBytes[len(prefix):]) domainValue, valueErr := this.readDomainValue(domainId) if valueErr != nil { if isNotFound(valueErr) { continue } return nil, valueErr } domain, decodeErr := decodeDomain(domainValue) if decodeErr != nil { return nil, decodeErr } domains = append(domains, domain) if len(domains) >= size { break } } return domains, it.Error() } func (this *DB) InsertRecord(recordId int64, domainId int64, name string, recordType dnsconfigs.RecordType, value string, mxPriority int32, srvPriority int32, srvWeight int32, srvPort int32, caaFlag int32, caaTag string, ttl int32, weight int32, routeIds []string, version int64) error { if this.rawDB == nil { return errors.New("db should not be nil") } this.log("InsertRecord", "domain:", domainId, "name:", name) return this.saveJSON(recordKey(recordId), &recordValue{ Id: recordId, DomainId: domainId, Name: name, Type: recordType, Value: value, MXPriority: mxPriority, SRVPriority: srvPriority, SRVWeight: srvWeight, SRVPort: srvPort, CAAFlag: caaFlag, CAATag: caaTag, TTL: ttl, Weight: weight, RouteIds: routeIds, Version: version, }) } func (this *DB) UpdateRecord(recordId int64, domainId int64, name string, recordType dnsconfigs.RecordType, value string, mxPriority int32, srvPriority int32, srvWeight int32, srvPort int32, caaFlag int32, caaTag string, ttl int32, weight int32, routeIds []string, version int64) error { if this.rawDB == nil { return errors.New("db should not be nil") } this.log("UpdateRecord", "domain:", domainId, "name:", name) _, err := this.readRecordValue(recordId) if err != nil { if isNotFound(err) { return nil } return err } return this.saveJSON(recordKey(recordId), &recordValue{ Id: recordId, DomainId: domainId, Name: name, Type: recordType, Value: value, MXPriority: mxPriority, SRVPriority: srvPriority, SRVWeight: srvWeight, SRVPort: srvPort, CAAFlag: caaFlag, CAATag: caaTag, TTL: ttl, Weight: weight, RouteIds: routeIds, Version: version, }) } func (this *DB) DeleteRecord(recordId int64) error { if this.rawDB == nil { return errors.New("db should not be nil") } this.log("DeleteRecord", "record:", recordId) err := this.rawDB.Delete(recordKey(recordId), defaultWriteOptions) if err != nil && !isNotFound(err) { return err } return nil } func (this *DB) ExistsRecord(recordId int64) (bool, error) { if this.rawDB == nil { return false, errors.New("db should not be nil") } _, err := this.readRecordValue(recordId) if err != nil { if isNotFound(err) { return false, nil } return false, err } return true, nil } func (this *DB) ListRecords(offset int, size int) (records []*models.NSRecord, err error) { if this.rawDB == nil { return nil, errors.New("db should not be nil") } if size <= 0 { return nil, nil } it, err := this.newIterator([]byte(recordPrefix)) if err != nil { return nil, err } defer func() { _ = it.Close() }() var skipped = 0 for it.First(); it.Valid(); it.Next() { if skipped < offset { skipped++ continue } valueBytes, valueErr := it.ValueAndErr() if valueErr != nil { return nil, valueErr } record, decodeErr := decodeRecord(valueBytes) if decodeErr != nil { return nil, decodeErr } records = append(records, record) if len(records) >= size { break } } return records, it.Error() } func (this *DB) InsertRoute(routeId int64, userId int64, rangesJSON []byte, order int32, priority int32, version int64) error { if this.rawDB == nil { return errors.New("db should not be nil") } this.log("InsertRoute", "route:", routeId, "user:", userId) return this.saveJSON(routeKey(routeId), &routeValue{ Id: routeId, UserId: userId, RangesJSON: string(rangesJSON), Order: order, Priority: priority, Version: version, }) } func (this *DB) UpdateRoute(routeId int64, userId int64, rangesJSON []byte, order int32, priority int32, version int64) error { if this.rawDB == nil { return errors.New("db should not be nil") } this.log("UpdateRoute", "route:", routeId, "user:", userId) _, err := this.readRouteValue(routeId) if err != nil { if isNotFound(err) { return nil } return err } return this.saveJSON(routeKey(routeId), &routeValue{ Id: routeId, UserId: userId, RangesJSON: string(rangesJSON), Order: order, Priority: priority, Version: version, }) } func (this *DB) DeleteRoute(routeId int64) error { if this.rawDB == nil { return errors.New("db should not be nil") } this.log("DeleteRoute", "route:", routeId) err := this.rawDB.Delete(routeKey(routeId), defaultWriteOptions) if err != nil && !isNotFound(err) { return err } return nil } func (this *DB) ExistsRoute(routeId int64) (bool, error) { if this.rawDB == nil { return false, errors.New("db should not be nil") } _, err := this.readRouteValue(routeId) if err != nil { if isNotFound(err) { return false, nil } return false, err } return true, nil } func (this *DB) ListRoutes(offset int64, size int64) (routes []*models.NSRoute, err error) { if this.rawDB == nil { return nil, errors.New("db should not be nil") } if size <= 0 { return nil, nil } it, err := this.newIterator([]byte(routePrefix)) if err != nil { return nil, err } defer func() { _ = it.Close() }() var skipped int64 for it.First(); it.Valid(); it.Next() { if skipped < offset { skipped++ continue } valueBytes, valueErr := it.ValueAndErr() if valueErr != nil { return nil, valueErr } route, decodeErr := decodeRoute(valueBytes) if decodeErr != nil { return nil, decodeErr } routes = append(routes, route) if int64(len(routes)) >= size { break } } return routes, it.Error() } func (this *DB) InsertKey(keyId int64, domainId int64, zoneId int64, algo string, secret string, secretType string, version int64) error { if this.rawDB == nil { return errors.New("db should not be nil") } this.log("InsertKey", "key:", keyId, "domain:", domainId, "zone:", zoneId) return this.saveJSON(keyKey(keyId), &models.NSKey{ Id: keyId, DomainId: domainId, ZoneId: zoneId, Algo: algo, Secret: secret, SecretType: secretType, Version: version, }) } func (this *DB) UpdateKey(keyId int64, domainId int64, zoneId int64, algo string, secret string, secretType string, version int64) error { if this.rawDB == nil { return errors.New("db should not be nil") } this.log("UpdateKey", "key:", keyId, "domain:", domainId, "zone:", zoneId) _, err := this.readKeyValue(keyId) if err != nil { if isNotFound(err) { return nil } return err } return this.saveJSON(keyKey(keyId), &models.NSKey{ Id: keyId, DomainId: domainId, ZoneId: zoneId, Algo: algo, Secret: secret, SecretType: secretType, Version: version, }) } func (this *DB) ListKeys(offset int, size int) (keys []*models.NSKey, err error) { if this.rawDB == nil { return nil, errors.New("db should not be nil") } if size <= 0 { return nil, nil } it, err := this.newIterator([]byte(keyPrefix)) if err != nil { return nil, err } defer func() { _ = it.Close() }() var skipped = 0 for it.First(); it.Valid(); it.Next() { if skipped < offset { skipped++ continue } valueBytes, valueErr := it.ValueAndErr() if valueErr != nil { return nil, valueErr } key, decodeErr := decodeKey(valueBytes) if decodeErr != nil { return nil, decodeErr } keys = append(keys, key) if len(keys) >= size { break } } return keys, it.Error() } func (this *DB) DeleteKey(keyId int64) error { if this.rawDB == nil { return errors.New("db should not be nil") } this.log("DeleteKey", "key:", keyId) err := this.rawDB.Delete(keyKey(keyId), defaultWriteOptions) if err != nil && !isNotFound(err) { return err } return nil } func (this *DB) ExistsKey(keyId int64) (bool, error) { if this.rawDB == nil { return false, errors.New("db should not be nil") } _, err := this.readKeyValue(keyId) if err != nil { if isNotFound(err) { return false, nil } return false, err } return true, nil } func (this *DB) InsertAgentIP(ipId int64, ip string, agentCode string) error { if this.rawDB == nil { return errors.New("db should not be nil") } this.log("InsertAgentIP", "id:", ipId, "ip:", ip, "agent:", agentCode) return this.saveJSON(agentIPKey(ipId), &models.AgentIP{ Id: ipId, IP: ip, AgentCode: agentCode, }) } func (this *DB) ListAgentIPs(offset int64, size int64) (agentIPs []*models.AgentIP, err error) { if this.rawDB == nil { return nil, errors.New("db should not be nil") } if size <= 0 { return nil, nil } it, err := this.newIterator([]byte(agentIPPrefix)) if err != nil { return nil, err } defer func() { _ = it.Close() }() var skipped int64 for it.First(); it.Valid(); it.Next() { if skipped < offset { skipped++ continue } valueBytes, valueErr := it.ValueAndErr() if valueErr != nil { return nil, valueErr } agentIP, decodeErr := decodeAgentIP(valueBytes) if decodeErr != nil { return nil, decodeErr } agentIPs = append(agentIPs, agentIP) if int64(len(agentIPs)) >= size { break } } return agentIPs, it.Error() } func (this *DB) Close() error { if this.rawDB == nil { return nil } var err = closeSharedRawDB(this.storePath) if err != nil { return err } this.rawDB = nil return nil } func (this *DB) saveDomain(value *domainValue) error { oldValue, err := this.readDomainValue(value.Id) if err != nil && !isNotFound(err) { return err } if isNotFound(err) { oldValue = nil } return this.saveDomainWithOld(value, oldValue) } func (this *DB) saveDomainWithOld(value *domainValue, oldValue *domainValue) error { valueBytes, err := json.Marshal(value) if err != nil { return err } batch := this.rawDB.NewBatch() defer func() { _ = batch.Close() }() if oldValue != nil && oldValue.ClusterId != value.ClusterId { err = batch.Delete(domainClusterKey(oldValue.ClusterId, oldValue.Id), defaultWriteOptions) if err != nil && !isNotFound(err) { return err } } err = batch.Set(domainKey(value.Id), valueBytes, defaultWriteOptions) if err != nil { return err } err = batch.Set(domainClusterKey(value.ClusterId, value.Id), nil, defaultWriteOptions) if err != nil { return err } return batch.Commit(defaultWriteOptions) } func (this *DB) saveJSON(key []byte, value any) error { valueBytes, err := json.Marshal(value) if err != nil { return err } return this.rawDB.Set(key, valueBytes, defaultWriteOptions) } func (this *DB) newIterator(prefix []byte) (*pebble.Iterator, error) { return this.rawDB.NewIter(&pebble.IterOptions{ LowerBound: prefix, UpperBound: prefixUpperBound(prefix), }) } func (this *DB) readDomainValue(domainId int64) (*domainValue, error) { valueBytes, err := this.get(domainKey(domainId)) if err != nil { return nil, err } var value = &domainValue{} err = json.Unmarshal(valueBytes, value) if err != nil { return nil, err } return value, nil } func (this *DB) readRecordValue(recordId int64) (*recordValue, error) { valueBytes, err := this.get(recordKey(recordId)) if err != nil { return nil, err } var value = &recordValue{} err = json.Unmarshal(valueBytes, value) if err != nil { return nil, err } return value, nil } func (this *DB) readRouteValue(routeId int64) (*routeValue, error) { valueBytes, err := this.get(routeKey(routeId)) if err != nil { return nil, err } var value = &routeValue{} err = json.Unmarshal(valueBytes, value) if err != nil { return nil, err } return value, nil } func (this *DB) readKeyValue(keyId int64) (*models.NSKey, error) { valueBytes, err := this.get(keyKey(keyId)) if err != nil { return nil, err } return decodeKey(valueBytes) } func (this *DB) get(key []byte) ([]byte, error) { valueBytes, closer, err := this.rawDB.Get(key) if err != nil { return nil, err } defer func() { _ = closer.Close() }() return append([]byte(nil), valueBytes...), nil } func decodeDomain(value *domainValue) (*models.NSDomain, error) { var domain = &models.NSDomain{ Id: value.Id, ClusterId: value.ClusterId, UserId: value.UserId, Name: value.Name, Version: value.Version, } if len(value.TSIGJSON) > 0 { var tsigConfig = &dnsconfigs.NSTSIGConfig{} err := json.Unmarshal([]byte(value.TSIGJSON), tsigConfig) if err != nil { remotelogs.Error("decode tsig string failed: "+err.Error()+", domain:"+domain.Name, ", domainId: "+types.String(domain.Id)) } else { domain.TSIG = tsigConfig } } return domain, nil } func decodeRecord(valueBytes []byte) (*models.NSRecord, error) { var value = &recordValue{} err := json.Unmarshal(valueBytes, value) if err != nil { return nil, err } return &models.NSRecord{ Id: value.Id, DomainId: value.DomainId, Name: value.Name, Type: value.Type, Value: value.Value, MXPriority: value.MXPriority, SRVPriority: value.SRVPriority, SRVWeight: value.SRVWeight, SRVPort: value.SRVPort, CAAFlag: value.CAAFlag, CAATag: value.CAATag, Ttl: value.TTL, Weight: value.Weight, RouteIds: value.RouteIds, Version: value.Version, }, nil } func decodeRoute(valueBytes []byte) (*models.NSRoute, error) { var value = &routeValue{} err := json.Unmarshal(valueBytes, value) if err != nil { return nil, err } ranges, err := models.InitRangesFromJSON([]byte(value.RangesJSON)) if err != nil { return nil, err } return &models.NSRoute{ Id: value.Id, UserId: value.UserId, Ranges: ranges, Priority: value.Priority, Order: value.Order, Version: value.Version, }, nil } func decodeKey(valueBytes []byte) (*models.NSKey, error) { var value = &models.NSKey{} err := json.Unmarshal(valueBytes, value) if err != nil { return nil, err } return value, nil } func decodeAgentIP(valueBytes []byte) (*models.AgentIP, error) { var value = &models.AgentIP{} err := json.Unmarshal(valueBytes, value) if err != nil { return nil, err } return value, nil } func domainKey(domainId int64) []byte { return append([]byte(domainPrefix), encodeInt64(domainId)...) } func domainClusterKey(clusterId int64, domainId int64) []byte { var key = append([]byte(domainClusterIndex), encodeInt64(clusterId)...) key = append(key, encodeInt64(domainId)...) return key } func recordKey(recordId int64) []byte { return append([]byte(recordPrefix), encodeInt64(recordId)...) } func routeKey(routeId int64) []byte { return append([]byte(routePrefix), encodeInt64(routeId)...) } func keyKey(keyId int64) []byte { return append([]byte(keyPrefix), encodeInt64(keyId)...) } func agentIPKey(agentIPId int64) []byte { return append([]byte(agentIPPrefix), encodeInt64(agentIPId)...) } func prefixUpperBound(prefix []byte) []byte { var bound = make([]byte, len(prefix)+1) copy(bound, prefix) bound[len(prefix)] = 0xFF return bound } func encodeInt64(v int64) []byte { if v < 0 { v = 0 } var b [8]byte binary.BigEndian.PutUint64(b[:], uint64(v)) return b[:] } func decodeInt64(b []byte) int64 { if len(b) < 8 { return 0 } return int64(binary.BigEndian.Uint64(b[:8])) } func storePath(path string) string { if strings.HasSuffix(path, ".db") { return strings.TrimSuffix(path, ".db") + ".store" } return path + ".store" } func openSharedRawDB(path string) (*pebble.DB, error) { sharedRawDBLocker.Lock() defer sharedRawDBLocker.Unlock() shared, ok := sharedRawDBMap[path] if ok { shared.refs++ return shared.rawDB, nil } _ = os.MkdirAll(path, 0777) rawDB, err := pebble.Open(path, &pebble.Options{}) if err != nil { return nil, err } sharedRawDBMap[path] = &sharedRawDB{ rawDB: rawDB, refs: 1, } return rawDB, nil } func closeSharedRawDB(path string) error { sharedRawDBLocker.Lock() shared, ok := sharedRawDBMap[path] if !ok { sharedRawDBLocker.Unlock() return nil } shared.refs-- if shared.refs > 0 { sharedRawDBLocker.Unlock() return nil } delete(sharedRawDBMap, path) sharedRawDBLocker.Unlock() var lastErr error if shared.rawDB != nil { err := shared.rawDB.Close() if err != nil { lastErr = err } } return lastErr } func isNotFound(err error) bool { return err != nil && errors.Is(err, pebble.ErrNotFound) } // 打印日志 func (this *DB) log(args ...any) { if !Tea.IsTesting() { return } if len(args) == 0 { return } args[0] = "[" + types.String(args[0]) + "]" log.Println(args...) }