Files
2026-03-22 17:37:40 +08:00

1010 lines
22 KiB
Go

// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus
package dbs
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/TeaOSLab/EdgeDNS/internal/models"
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
"github.com/cockroachdb/pebble"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/types"
"log"
"os"
"path/filepath"
"strings"
"sync"
)
const (
domainPrefix = "domains:"
domainClusterIndex = "domains_cluster:"
recordPrefix = "records:"
routePrefix = "routes:"
keyPrefix = "keys:"
agentIPPrefix = "agent_ips:"
)
var defaultWriteOptions = &pebble.WriteOptions{Sync: false}
var flushRawDB = func(rawDB *pebble.DB) error {
return rawDB.Flush()
}
var sharedRawDBLocker sync.Mutex
var sharedRawDBMap = map[string]*sharedRawDB{}
type sharedRawDB struct {
rawDB *pebble.DB
refs int
}
type DB struct {
path string
storePath string
rawDB *pebble.DB
}
type domainValue struct {
Id int64 `json:"id"`
ClusterId int64 `json:"clusterId"`
UserId int64 `json:"userId"`
Name string `json:"name"`
TSIGJSON string `json:"tsigJSON,omitempty"`
Version int64 `json:"version"`
}
type recordValue struct {
Id int64 `json:"id"`
DomainId int64 `json:"domainId"`
Name string `json:"name"`
Type dnsconfigs.RecordType `json:"type"`
Value string `json:"value"`
MXPriority int32 `json:"mxPriority"`
SRVPriority int32 `json:"srvPriority"`
SRVWeight int32 `json:"srvWeight"`
SRVPort int32 `json:"srvPort"`
CAAFlag int32 `json:"caaFlag"`
CAATag string `json:"caaTag"`
TTL int32 `json:"ttl"`
Weight int32 `json:"weight"`
RouteIds []string `json:"routeIds,omitempty"`
Version int64 `json:"version"`
}
type routeValue struct {
Id int64 `json:"id"`
UserId int64 `json:"userId"`
RangesJSON string `json:"rangesJSON"`
Priority int32 `json:"priority"`
Order int32 `json:"order"`
Version int64 `json:"version"`
}
func NewDB(path string) *DB {
return &DB{
path: path,
storePath: storePath(path),
}
}
func (this *DB) Init() error {
var dir = filepath.Dir(this.storePath)
_, err := os.Stat(dir)
if err != nil {
err = os.MkdirAll(dir, 0777)
if err != nil {
return err
}
remotelogs.Println("DB", "create database dir '"+dir+"'")
}
rawDB, err := openSharedRawDB(this.storePath)
if err != nil {
return err
}
this.rawDB = rawDB
return this.migrateSQLiteIfNeeded()
}
func (this *DB) InsertDomain(domainId int64, clusterId int64, userId int64, name string, tsigJSON []byte, version int64) error {
if this.rawDB == nil {
return errors.New("db should not be nil")
}
this.log("InsertDomain", "domain:", domainId, "user:", userId, "name:", name)
return this.saveDomain(&domainValue{
Id: domainId,
ClusterId: clusterId,
UserId: userId,
Name: name,
TSIGJSON: string(tsigJSON),
Version: version,
})
}
func (this *DB) UpdateDomain(domainId int64, clusterId int64, userId int64, name string, tsigJSON []byte, version int64) error {
if this.rawDB == nil {
return errors.New("db should not be nil")
}
this.log("UpdateDomain", "domain:", domainId, "user:", userId, "name:", name)
oldValue, err := this.readDomainValue(domainId)
if err != nil {
if isNotFound(err) {
return nil
}
return err
}
return this.saveDomainWithOld(&domainValue{
Id: domainId,
ClusterId: clusterId,
UserId: userId,
Name: name,
TSIGJSON: string(tsigJSON),
Version: version,
}, oldValue)
}
func (this *DB) DeleteDomain(domainId int64) error {
if this.rawDB == nil {
return errors.New("db should not be nil")
}
this.log("DeleteDomain", "domain:", domainId)
oldValue, err := this.readDomainValue(domainId)
if err != nil {
if isNotFound(err) {
return nil
}
return err
}
batch := this.rawDB.NewBatch()
defer func() {
_ = batch.Close()
}()
err = batch.Delete(domainKey(domainId), defaultWriteOptions)
if err != nil {
return err
}
err = batch.Delete(domainClusterKey(oldValue.ClusterId, domainId), defaultWriteOptions)
if err != nil && !isNotFound(err) {
return err
}
return batch.Commit(defaultWriteOptions)
}
func (this *DB) ExistsDomain(domainId int64) (bool, error) {
if this.rawDB == nil {
return false, errors.New("db should not be nil")
}
_, err := this.readDomainValue(domainId)
if err != nil {
if isNotFound(err) {
return false, nil
}
return false, err
}
return true, nil
}
func (this *DB) ListDomains(clusterId int64, offset int, size int) (domains []*models.NSDomain, err error) {
if this.rawDB == nil {
return nil, errors.New("db should not be nil")
}
if size <= 0 {
return nil, nil
}
var prefix = append([]byte(domainClusterIndex), encodeInt64(clusterId)...)
it, err := this.newIterator(prefix)
if err != nil {
return nil, err
}
defer func() {
_ = it.Close()
}()
var skipped = 0
for it.First(); it.Valid(); it.Next() {
if skipped < offset {
skipped++
continue
}
var keyBytes = append([]byte(nil), it.Key()...)
var domainId = decodeInt64(keyBytes[len(prefix):])
domainValue, valueErr := this.readDomainValue(domainId)
if valueErr != nil {
if isNotFound(valueErr) {
continue
}
return nil, valueErr
}
domain, decodeErr := decodeDomain(domainValue)
if decodeErr != nil {
return nil, decodeErr
}
domains = append(domains, domain)
if len(domains) >= size {
break
}
}
return domains, it.Error()
}
func (this *DB) InsertRecord(recordId int64, domainId int64, name string, recordType dnsconfigs.RecordType, value string, mxPriority int32, srvPriority int32, srvWeight int32, srvPort int32, caaFlag int32, caaTag string, ttl int32, weight int32, routeIds []string, version int64) error {
if this.rawDB == nil {
return errors.New("db should not be nil")
}
this.log("InsertRecord", "domain:", domainId, "name:", name)
return this.saveJSON(recordKey(recordId), &recordValue{
Id: recordId,
DomainId: domainId,
Name: name,
Type: recordType,
Value: value,
MXPriority: mxPriority,
SRVPriority: srvPriority,
SRVWeight: srvWeight,
SRVPort: srvPort,
CAAFlag: caaFlag,
CAATag: caaTag,
TTL: ttl,
Weight: weight,
RouteIds: routeIds,
Version: version,
})
}
func (this *DB) UpdateRecord(recordId int64, domainId int64, name string, recordType dnsconfigs.RecordType, value string, mxPriority int32, srvPriority int32, srvWeight int32, srvPort int32, caaFlag int32, caaTag string, ttl int32, weight int32, routeIds []string, version int64) error {
if this.rawDB == nil {
return errors.New("db should not be nil")
}
this.log("UpdateRecord", "domain:", domainId, "name:", name)
_, err := this.readRecordValue(recordId)
if err != nil {
if isNotFound(err) {
return nil
}
return err
}
return this.saveJSON(recordKey(recordId), &recordValue{
Id: recordId,
DomainId: domainId,
Name: name,
Type: recordType,
Value: value,
MXPriority: mxPriority,
SRVPriority: srvPriority,
SRVWeight: srvWeight,
SRVPort: srvPort,
CAAFlag: caaFlag,
CAATag: caaTag,
TTL: ttl,
Weight: weight,
RouteIds: routeIds,
Version: version,
})
}
func (this *DB) DeleteRecord(recordId int64) error {
if this.rawDB == nil {
return errors.New("db should not be nil")
}
this.log("DeleteRecord", "record:", recordId)
err := this.rawDB.Delete(recordKey(recordId), defaultWriteOptions)
if err != nil && !isNotFound(err) {
return err
}
return nil
}
func (this *DB) ExistsRecord(recordId int64) (bool, error) {
if this.rawDB == nil {
return false, errors.New("db should not be nil")
}
_, err := this.readRecordValue(recordId)
if err != nil {
if isNotFound(err) {
return false, nil
}
return false, err
}
return true, nil
}
func (this *DB) ListRecords(offset int, size int) (records []*models.NSRecord, err error) {
if this.rawDB == nil {
return nil, errors.New("db should not be nil")
}
if size <= 0 {
return nil, nil
}
it, err := this.newIterator([]byte(recordPrefix))
if err != nil {
return nil, err
}
defer func() {
_ = it.Close()
}()
var skipped = 0
for it.First(); it.Valid(); it.Next() {
if skipped < offset {
skipped++
continue
}
valueBytes, valueErr := it.ValueAndErr()
if valueErr != nil {
return nil, valueErr
}
record, decodeErr := decodeRecord(valueBytes)
if decodeErr != nil {
return nil, decodeErr
}
records = append(records, record)
if len(records) >= size {
break
}
}
return records, it.Error()
}
func (this *DB) InsertRoute(routeId int64, userId int64, rangesJSON []byte, order int32, priority int32, version int64) error {
if this.rawDB == nil {
return errors.New("db should not be nil")
}
this.log("InsertRoute", "route:", routeId, "user:", userId)
return this.saveJSON(routeKey(routeId), &routeValue{
Id: routeId,
UserId: userId,
RangesJSON: string(rangesJSON),
Order: order,
Priority: priority,
Version: version,
})
}
func (this *DB) UpdateRoute(routeId int64, userId int64, rangesJSON []byte, order int32, priority int32, version int64) error {
if this.rawDB == nil {
return errors.New("db should not be nil")
}
this.log("UpdateRoute", "route:", routeId, "user:", userId)
_, err := this.readRouteValue(routeId)
if err != nil {
if isNotFound(err) {
return nil
}
return err
}
return this.saveJSON(routeKey(routeId), &routeValue{
Id: routeId,
UserId: userId,
RangesJSON: string(rangesJSON),
Order: order,
Priority: priority,
Version: version,
})
}
func (this *DB) DeleteRoute(routeId int64) error {
if this.rawDB == nil {
return errors.New("db should not be nil")
}
this.log("DeleteRoute", "route:", routeId)
err := this.rawDB.Delete(routeKey(routeId), defaultWriteOptions)
if err != nil && !isNotFound(err) {
return err
}
return nil
}
func (this *DB) ExistsRoute(routeId int64) (bool, error) {
if this.rawDB == nil {
return false, errors.New("db should not be nil")
}
_, err := this.readRouteValue(routeId)
if err != nil {
if isNotFound(err) {
return false, nil
}
return false, err
}
return true, nil
}
func (this *DB) ListRoutes(offset int64, size int64) (routes []*models.NSRoute, err error) {
if this.rawDB == nil {
return nil, errors.New("db should not be nil")
}
if size <= 0 {
return nil, nil
}
it, err := this.newIterator([]byte(routePrefix))
if err != nil {
return nil, err
}
defer func() {
_ = it.Close()
}()
var skipped int64
for it.First(); it.Valid(); it.Next() {
if skipped < offset {
skipped++
continue
}
valueBytes, valueErr := it.ValueAndErr()
if valueErr != nil {
return nil, valueErr
}
route, decodeErr := decodeRoute(valueBytes)
if decodeErr != nil {
return nil, decodeErr
}
routes = append(routes, route)
if int64(len(routes)) >= size {
break
}
}
return routes, it.Error()
}
func (this *DB) InsertKey(keyId int64, domainId int64, zoneId int64, algo string, secret string, secretType string, version int64) error {
if this.rawDB == nil {
return errors.New("db should not be nil")
}
this.log("InsertKey", "key:", keyId, "domain:", domainId, "zone:", zoneId)
return this.saveJSON(keyKey(keyId), &models.NSKey{
Id: keyId,
DomainId: domainId,
ZoneId: zoneId,
Algo: algo,
Secret: secret,
SecretType: secretType,
Version: version,
})
}
func (this *DB) UpdateKey(keyId int64, domainId int64, zoneId int64, algo string, secret string, secretType string, version int64) error {
if this.rawDB == nil {
return errors.New("db should not be nil")
}
this.log("UpdateKey", "key:", keyId, "domain:", domainId, "zone:", zoneId)
_, err := this.readKeyValue(keyId)
if err != nil {
if isNotFound(err) {
return nil
}
return err
}
return this.saveJSON(keyKey(keyId), &models.NSKey{
Id: keyId,
DomainId: domainId,
ZoneId: zoneId,
Algo: algo,
Secret: secret,
SecretType: secretType,
Version: version,
})
}
func (this *DB) ListKeys(offset int, size int) (keys []*models.NSKey, err error) {
if this.rawDB == nil {
return nil, errors.New("db should not be nil")
}
if size <= 0 {
return nil, nil
}
it, err := this.newIterator([]byte(keyPrefix))
if err != nil {
return nil, err
}
defer func() {
_ = it.Close()
}()
var skipped = 0
for it.First(); it.Valid(); it.Next() {
if skipped < offset {
skipped++
continue
}
valueBytes, valueErr := it.ValueAndErr()
if valueErr != nil {
return nil, valueErr
}
key, decodeErr := decodeKey(valueBytes)
if decodeErr != nil {
return nil, decodeErr
}
keys = append(keys, key)
if len(keys) >= size {
break
}
}
return keys, it.Error()
}
func (this *DB) DeleteKey(keyId int64) error {
if this.rawDB == nil {
return errors.New("db should not be nil")
}
this.log("DeleteKey", "key:", keyId)
err := this.rawDB.Delete(keyKey(keyId), defaultWriteOptions)
if err != nil && !isNotFound(err) {
return err
}
return nil
}
func (this *DB) ExistsKey(keyId int64) (bool, error) {
if this.rawDB == nil {
return false, errors.New("db should not be nil")
}
_, err := this.readKeyValue(keyId)
if err != nil {
if isNotFound(err) {
return false, nil
}
return false, err
}
return true, nil
}
func (this *DB) InsertAgentIP(ipId int64, ip string, agentCode string) error {
if this.rawDB == nil {
return errors.New("db should not be nil")
}
this.log("InsertAgentIP", "id:", ipId, "ip:", ip, "agent:", agentCode)
return this.saveJSON(agentIPKey(ipId), &models.AgentIP{
Id: ipId,
IP: ip,
AgentCode: agentCode,
})
}
func (this *DB) ListAgentIPs(offset int64, size int64) (agentIPs []*models.AgentIP, err error) {
if this.rawDB == nil {
return nil, errors.New("db should not be nil")
}
if size <= 0 {
return nil, nil
}
it, err := this.newIterator([]byte(agentIPPrefix))
if err != nil {
return nil, err
}
defer func() {
_ = it.Close()
}()
var skipped int64
for it.First(); it.Valid(); it.Next() {
if skipped < offset {
skipped++
continue
}
valueBytes, valueErr := it.ValueAndErr()
if valueErr != nil {
return nil, valueErr
}
agentIP, decodeErr := decodeAgentIP(valueBytes)
if decodeErr != nil {
return nil, decodeErr
}
agentIPs = append(agentIPs, agentIP)
if int64(len(agentIPs)) >= size {
break
}
}
return agentIPs, it.Error()
}
func (this *DB) Close() error {
if this.rawDB == nil {
return nil
}
var err = closeSharedRawDB(this.storePath)
if err != nil {
return err
}
this.rawDB = nil
return nil
}
func (this *DB) saveDomain(value *domainValue) error {
oldValue, err := this.readDomainValue(value.Id)
if err != nil && !isNotFound(err) {
return err
}
if isNotFound(err) {
oldValue = nil
}
return this.saveDomainWithOld(value, oldValue)
}
func (this *DB) saveDomainWithOld(value *domainValue, oldValue *domainValue) error {
valueBytes, err := json.Marshal(value)
if err != nil {
return err
}
batch := this.rawDB.NewBatch()
defer func() {
_ = batch.Close()
}()
if oldValue != nil && oldValue.ClusterId != value.ClusterId {
err = batch.Delete(domainClusterKey(oldValue.ClusterId, oldValue.Id), defaultWriteOptions)
if err != nil && !isNotFound(err) {
return err
}
}
err = batch.Set(domainKey(value.Id), valueBytes, defaultWriteOptions)
if err != nil {
return err
}
err = batch.Set(domainClusterKey(value.ClusterId, value.Id), nil, defaultWriteOptions)
if err != nil {
return err
}
return batch.Commit(defaultWriteOptions)
}
func (this *DB) saveJSON(key []byte, value any) error {
valueBytes, err := json.Marshal(value)
if err != nil {
return err
}
return this.rawDB.Set(key, valueBytes, defaultWriteOptions)
}
func (this *DB) newIterator(prefix []byte) (*pebble.Iterator, error) {
return this.rawDB.NewIter(&pebble.IterOptions{
LowerBound: prefix,
UpperBound: prefixUpperBound(prefix),
})
}
func (this *DB) readDomainValue(domainId int64) (*domainValue, error) {
valueBytes, err := this.get(domainKey(domainId))
if err != nil {
return nil, err
}
var value = &domainValue{}
err = json.Unmarshal(valueBytes, value)
if err != nil {
return nil, err
}
return value, nil
}
func (this *DB) readRecordValue(recordId int64) (*recordValue, error) {
valueBytes, err := this.get(recordKey(recordId))
if err != nil {
return nil, err
}
var value = &recordValue{}
err = json.Unmarshal(valueBytes, value)
if err != nil {
return nil, err
}
return value, nil
}
func (this *DB) readRouteValue(routeId int64) (*routeValue, error) {
valueBytes, err := this.get(routeKey(routeId))
if err != nil {
return nil, err
}
var value = &routeValue{}
err = json.Unmarshal(valueBytes, value)
if err != nil {
return nil, err
}
return value, nil
}
func (this *DB) readKeyValue(keyId int64) (*models.NSKey, error) {
valueBytes, err := this.get(keyKey(keyId))
if err != nil {
return nil, err
}
return decodeKey(valueBytes)
}
func (this *DB) get(key []byte) ([]byte, error) {
valueBytes, closer, err := this.rawDB.Get(key)
if err != nil {
return nil, err
}
defer func() {
_ = closer.Close()
}()
return append([]byte(nil), valueBytes...), nil
}
func decodeDomain(value *domainValue) (*models.NSDomain, error) {
var domain = &models.NSDomain{
Id: value.Id,
ClusterId: value.ClusterId,
UserId: value.UserId,
Name: value.Name,
Version: value.Version,
}
if len(value.TSIGJSON) > 0 {
var tsigConfig = &dnsconfigs.NSTSIGConfig{}
err := json.Unmarshal([]byte(value.TSIGJSON), tsigConfig)
if err != nil {
remotelogs.Error("decode tsig string failed: "+err.Error()+", domain:"+domain.Name, ", domainId: "+types.String(domain.Id))
} else {
domain.TSIG = tsigConfig
}
}
return domain, nil
}
func decodeRecord(valueBytes []byte) (*models.NSRecord, error) {
var value = &recordValue{}
err := json.Unmarshal(valueBytes, value)
if err != nil {
return nil, err
}
return &models.NSRecord{
Id: value.Id,
DomainId: value.DomainId,
Name: value.Name,
Type: value.Type,
Value: value.Value,
MXPriority: value.MXPriority,
SRVPriority: value.SRVPriority,
SRVWeight: value.SRVWeight,
SRVPort: value.SRVPort,
CAAFlag: value.CAAFlag,
CAATag: value.CAATag,
Ttl: value.TTL,
Weight: value.Weight,
RouteIds: value.RouteIds,
Version: value.Version,
}, nil
}
func decodeRoute(valueBytes []byte) (*models.NSRoute, error) {
var value = &routeValue{}
err := json.Unmarshal(valueBytes, value)
if err != nil {
return nil, err
}
ranges, err := models.InitRangesFromJSON([]byte(value.RangesJSON))
if err != nil {
return nil, err
}
return &models.NSRoute{
Id: value.Id,
UserId: value.UserId,
Ranges: ranges,
Priority: value.Priority,
Order: value.Order,
Version: value.Version,
}, nil
}
func decodeKey(valueBytes []byte) (*models.NSKey, error) {
var value = &models.NSKey{}
err := json.Unmarshal(valueBytes, value)
if err != nil {
return nil, err
}
return value, nil
}
func decodeAgentIP(valueBytes []byte) (*models.AgentIP, error) {
var value = &models.AgentIP{}
err := json.Unmarshal(valueBytes, value)
if err != nil {
return nil, err
}
return value, nil
}
func domainKey(domainId int64) []byte {
return append([]byte(domainPrefix), encodeInt64(domainId)...)
}
func domainClusterKey(clusterId int64, domainId int64) []byte {
var key = append([]byte(domainClusterIndex), encodeInt64(clusterId)...)
key = append(key, encodeInt64(domainId)...)
return key
}
func recordKey(recordId int64) []byte {
return append([]byte(recordPrefix), encodeInt64(recordId)...)
}
func routeKey(routeId int64) []byte {
return append([]byte(routePrefix), encodeInt64(routeId)...)
}
func keyKey(keyId int64) []byte {
return append([]byte(keyPrefix), encodeInt64(keyId)...)
}
func agentIPKey(agentIPId int64) []byte {
return append([]byte(agentIPPrefix), encodeInt64(agentIPId)...)
}
func prefixUpperBound(prefix []byte) []byte {
var bound = make([]byte, len(prefix)+1)
copy(bound, prefix)
bound[len(prefix)] = 0xFF
return bound
}
func encodeInt64(v int64) []byte {
if v < 0 {
v = 0
}
var b [8]byte
binary.BigEndian.PutUint64(b[:], uint64(v))
return b[:]
}
func decodeInt64(b []byte) int64 {
if len(b) < 8 {
return 0
}
return int64(binary.BigEndian.Uint64(b[:8]))
}
func storePath(path string) string {
if strings.HasSuffix(path, ".db") {
return strings.TrimSuffix(path, ".db") + ".store"
}
return path + ".store"
}
func openSharedRawDB(path string) (*pebble.DB, error) {
sharedRawDBLocker.Lock()
defer sharedRawDBLocker.Unlock()
shared, ok := sharedRawDBMap[path]
if ok {
shared.refs++
return shared.rawDB, nil
}
_ = os.MkdirAll(path, 0777)
rawDB, err := pebble.Open(path, &pebble.Options{})
if err != nil {
return nil, err
}
sharedRawDBMap[path] = &sharedRawDB{
rawDB: rawDB,
refs: 1,
}
return rawDB, nil
}
func closeSharedRawDB(path string) error {
sharedRawDBLocker.Lock()
shared, ok := sharedRawDBMap[path]
if !ok {
sharedRawDBLocker.Unlock()
return nil
}
shared.refs--
if shared.refs > 0 {
sharedRawDBLocker.Unlock()
return nil
}
delete(sharedRawDBMap, path)
sharedRawDBLocker.Unlock()
var lastErr error
if shared.rawDB != nil {
err := shared.rawDB.Close()
if err != nil {
lastErr = err
}
}
return lastErr
}
func isNotFound(err error) bool {
return err != nil && errors.Is(err, pebble.ErrNotFound)
}
// 打印日志
func (this *DB) log(args ...any) {
if !Tea.IsTesting() {
return
}
if len(args) == 0 {
return
}
args[0] = "[" + types.String(args[0]) + "]"
log.Println(args...)
}