Files
waf-platform/EdgeDNS/internal/dbs/db.go
2026-02-04 20:27:13 +08:00

784 lines
20 KiB
Go

// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus
package dbs
import (
"encoding/json"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/TeaOSLab/EdgeDNS/internal/models"
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
dbutils "github.com/TeaOSLab/EdgeDNS/internal/utils/dbs"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/types"
_ "github.com/mattn/go-sqlite3"
"log"
"os"
"path/filepath"
"strings"
)
const (
tableDomains = "domains_v2"
tableRecords = "records_v2"
tableRoutes = "routes_v2"
tableKeys = "keys"
tableAgentIPs = "agentIPs"
)
type DB struct {
db *dbutils.DB
path string
insertDomainStmt *dbutils.Stmt
updateDomainStmt *dbutils.Stmt
deleteDomainStmt *dbutils.Stmt
existsDomainStmt *dbutils.Stmt
listDomainsStmt *dbutils.Stmt
insertRecordStmt *dbutils.Stmt
updateRecordStmt *dbutils.Stmt
existsRecordStmt *dbutils.Stmt
deleteRecordStmt *dbutils.Stmt
listRecordsStmt *dbutils.Stmt
insertRouteStmt *dbutils.Stmt
updateRouteStmt *dbutils.Stmt
deleteRouteStmt *dbutils.Stmt
listRoutesStmt *dbutils.Stmt
existsRouteStmt *dbutils.Stmt
insertKeyStmt *dbutils.Stmt
updateKeyStmt *dbutils.Stmt
deleteKeyStmt *dbutils.Stmt
listKeysStmt *dbutils.Stmt
existsKeyStmt *dbutils.Stmt
insertAgentIPStmt *dbutils.Stmt
listAgentIPsStmt *dbutils.Stmt
}
func NewDB(path string) *DB {
return &DB{path: path}
}
func (this *DB) Init() error {
// 检查目录是否存在
var dir = filepath.Dir(this.path)
_, err := os.Stat(dir)
if err != nil {
err = os.MkdirAll(dir, 0777)
if err != nil {
return err
}
remotelogs.Println("DB", "create database dir '"+dir+"'")
}
// TODO 思考 data.db 的数据安全性
db, err := dbutils.OpenWriter("file:" + this.path + "?cache=shared&mode=rwc&_journal_mode=WAL&_locking_mode=EXCLUSIVE")
if err != nil {
return err
}
db.SetMaxOpenConns(1)
/**_, err = db.Exec("VACUUM")
if err != nil {
return err
}**/
// 创建数据表
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + tableDomains + `" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"clusterId" integer DEFAULT 0,
"userId" integer DEFAULT 0,
"name" varchar(255),
"version" integer DEFAULT 0,
"tsig" text
);
CREATE INDEX IF NOT EXISTS "clusterId"
ON "` + tableDomains + `" (
"clusterId"
);
`)
if err != nil {
return err
}
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + tableRecords + `" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"domainId" integer DEFAULT 0,
"name" varchar(255),
"type" varchar(32),
"value" varchar(4096),
"mxPriority" integer DEFAULT 10,
"srvPriority" integer DEFAULT 10,
"srvWeight" integer DEFAULT 10,
"srvPort" integer DEFAULT 0,
"caaFlag" integer DEFAULT 0,
"caaTag" varchar(16),
"ttl" integer DEFAULT 0,
"weight" integer DEFAULT 0,
"routeIds" varchar(512),
"version" integer DEFAULT 0
);
`)
if err != nil {
// 忽略可以预期的错误
if strings.Contains(err.Error(), "duplicate column name") {
err = nil
}
if err != nil {
return err
}
}
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + tableRoutes + `" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"ranges" text,
"order" integer DEFAULT 0,
"priority" integer DEFAULT 0,
"userId" integer DEFAULT 0,
"version" integer DEFAULT 0
);
`)
if err != nil {
// 忽略可以预期的错误
if strings.Contains(err.Error(), "duplicate column name") {
err = nil
}
if err != nil {
return err
}
}
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + tableKeys + `" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"domainId" integer DEFAULT 0,
"zoneId" integer DEFAULT 0,
"algo" varchar(128),
"secret" varchar(4096),
"secretType" varchar(32),
"version" integer DEFAULT 0
);`)
if err != nil {
return err
}
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + tableAgentIPs + `" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"ip" varchar(64),
"agentCode" varchar(128)
);`)
if err != nil {
return err
}
// 预编译语句
// domain statements
this.insertDomainStmt, err = db.Prepare(`INSERT INTO "` + tableDomains + `" ("id", "clusterId", "userId", "name", "tsig", "version") VALUES (?, ?, ?, ?, ?, ?)`)
if err != nil {
return err
}
this.updateDomainStmt, err = db.Prepare(`UPDATE "` + tableDomains + `" SET "clusterId"=?, "userId"=?, "name"=?, "tsig"=?, "version"=? WHERE "id"=?`)
if err != nil {
return err
}
this.deleteDomainStmt, err = db.Prepare(`DELETE FROM "` + tableDomains + `" WHERE id=?`)
if err != nil {
return err
}
this.existsDomainStmt, err = db.Prepare(`SELECT "id" FROM "` + tableDomains + `" WHERE "id"=? LIMIT 1`)
if err != nil {
return err
}
this.listDomainsStmt, err = db.Prepare(`SELECT "id", "clusterId", "userId", "name", "tsig", "version" FROM "` + tableDomains + `" WHERE "clusterId"=? ORDER BY "id" ASC LIMIT ? OFFSET ?`)
if err != nil {
return err
}
// record statements
this.insertRecordStmt, err = db.Prepare(`INSERT INTO "` + tableRecords + `" ("id", "domainId", "name", "type", "value", "mxPriority", "srvPriority", "srvWeight", "srvPort", "caaFlag", "caaTag", "ttl", "weight", "routeIds", "version") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
if err != nil {
return err
}
this.updateRecordStmt, err = db.Prepare(`UPDATE "` + tableRecords + `" SET "domainId"=?, "name"=?, "type"=?, "value"=?, "mxPriority"=?, "srvPriority"=?, "srvWeight"=?, "srvPort"=?, "caaFlag"=?, "caaTag"=?, "ttl"=?, "weight"=?, "routeIds"=?, "version"=? WHERE "id"=?`)
if err != nil {
return err
}
this.existsRecordStmt, err = db.Prepare(`SELECT "id" FROM "` + tableRecords + `" WHERE "id"=? LIMIT 1`)
if err != nil {
return err
}
this.deleteRecordStmt, err = db.Prepare(`DELETE FROM "` + tableRecords + `" WHERE id=?`)
if err != nil {
return err
}
this.listRecordsStmt, err = db.Prepare(`SELECT "id", "domainId", "name", "type", "value", "mxPriority", "srvPriority", "srvWeight", "srvPort", "caaFlag", "caaTag", "ttl", "weight", "routeIds", "version" FROM "` + tableRecords + `" ORDER BY "id" ASC LIMIT ? OFFSET ?`)
if err != nil {
return err
}
// route statements
this.insertRouteStmt, err = db.Prepare(`INSERT INTO "` + tableRoutes + `" ("id", "userId", "ranges", "order", "priority", "version") VALUES (?, ?, ?, ?, ?, ?)`)
if err != nil {
return err
}
this.updateRouteStmt, err = db.Prepare(`UPDATE "` + tableRoutes + `" SET "userId"=?, "ranges"=?, "order"=?, "priority"=?, "version"=? WHERE "id"=?`)
if err != nil {
return err
}
this.deleteRouteStmt, err = db.Prepare(`DELETE FROM "` + tableRoutes + `" WHERE "id"=?`)
if err != nil {
return err
}
this.listRoutesStmt, err = db.Prepare(`SELECT "id", "userId", "ranges", "priority", "order", "version" FROM "` + tableRoutes + `" ORDER BY "id" ASC LIMIT ? OFFSET ?`)
if err != nil {
return err
}
this.existsRouteStmt, err = db.Prepare(`SELECT "id" FROM "` + tableRoutes + `" WHERE "id"=? LIMIT 1`)
if err != nil {
return err
}
// key statements
this.insertKeyStmt, err = db.Prepare(`INSERT INTO "` + tableKeys + `" ("id", "domainId", "zoneId", "algo", "secret", "secretType", "version") VALUES (?, ?, ?, ?, ?, ?, ?)`)
if err != nil {
return err
}
this.updateKeyStmt, err = db.Prepare(`UPDATE "` + tableKeys + `" SET "domainId"=?, "zoneId"=?, "algo"=?, "secret"=?, "secretType"=?, "version"=? WHERE "id"=?`)
if err != nil {
return err
}
this.deleteKeyStmt, err = db.Prepare(`DELETE FROM "` + tableKeys + `" WHERE "id"=?`)
if err != nil {
return err
}
this.listKeysStmt, err = db.Prepare(`SELECT "id", "domainId", "zoneId", "algo", "secret", "secretType", "version" FROM "` + tableKeys + `" ORDER BY "id" ASC LIMIT ? OFFSET ?`)
if err != nil {
return err
}
this.existsKeyStmt, err = db.Prepare(`SELECT "id" FROM "` + tableKeys + `" WHERE "id"=? LIMIT 1`)
if err != nil {
return err
}
// agent ip record statements
this.insertAgentIPStmt, err = db.Prepare(`INSERT INTO "` + tableAgentIPs + `" ("id", "ip", "agentCode") VALUES (?, ?, ?)`)
if err != nil {
return err
}
this.listAgentIPsStmt, err = db.Prepare(`SELECT "id", "ip", "agentCode" FROM "` + tableAgentIPs + `" ORDER BY "id" ASC LIMIT ? OFFSET ?`)
if err != nil {
return err
}
this.db = db
return nil
}
func (this *DB) InsertDomain(domainId int64, clusterId int64, userId int64, name string, tsigJSON []byte, version int64) error {
if this.db == nil {
return errors.New("db should not be nil")
}
this.log("InsertDomain", "domain:", domainId, "user:", userId, "name:", name)
_, err := this.insertDomainStmt.Exec(domainId, clusterId, userId, name, string(tsigJSON), version)
if err != nil {
return err
}
return nil
}
func (this *DB) UpdateDomain(domainId int64, clusterId int64, userId int64, name string, tsigJSON []byte, version int64) error {
if this.db == nil {
return errors.New("db should not be nil")
}
this.log("UpdateDomain", "domain:", domainId, "user:", userId, "name:", name)
_, err := this.updateDomainStmt.Exec(clusterId, userId, name, string(tsigJSON), version, domainId)
if err != nil {
return err
}
return nil
}
func (this *DB) DeleteDomain(domainId int64) error {
if this.db == nil {
return errors.New("db should not be nil")
}
this.log("DeleteDomain", "domain:", domainId)
_, err := this.deleteDomainStmt.Exec(domainId)
if err != nil {
return err
}
return nil
}
func (this *DB) ExistsDomain(domainId int64) (bool, error) {
if this.db == nil {
return false, errors.New("db should not be nil")
}
rows, err := this.existsDomainStmt.Query(domainId)
if err != nil {
return false, err
}
if rows.Err() != nil {
return false, rows.Err()
}
defer func() {
_ = rows.Close()
}()
if rows.Next() {
return true, nil
}
return false, nil
}
func (this *DB) ListDomains(clusterId int64, offset int, size int) (domains []*models.NSDomain, err error) {
if this.db == nil {
return nil, errors.New("db should not be nil")
}
rows, err := this.listDomainsStmt.Query(clusterId, size, offset)
if err != nil {
return nil, err
}
if rows.Err() != nil {
return nil, rows.Err()
}
defer func() {
_ = rows.Close()
}()
for rows.Next() {
var domain = &models.NSDomain{}
var tsigString string
err = rows.Scan(&domain.Id, &domain.ClusterId, &domain.UserId, &domain.Name, &tsigString, &domain.Version)
if err != nil {
return nil, err
}
if len(tsigString) > 0 {
var tsigConfig = &dnsconfigs.NSTSIGConfig{}
err = json.Unmarshal([]byte(tsigString), tsigConfig)
if err != nil {
remotelogs.Error("decode tsig string failed: "+err.Error()+", domain:"+domain.Name, ", domainId: "+types.String(domain.Id))
} else {
domain.TSIG = tsigConfig
}
}
domains = append(domains, domain)
}
return
}
func (this *DB) InsertRecord(recordId int64, domainId int64, name string, recordType dnsconfigs.RecordType, value string, mxPriority int32, srvPriority int32, srvWeight int32, srvPort int32, caaFlag int32, caaTag string, ttl int32, weight int32, routeIds []string, version int64) error {
if this.db == nil {
return errors.New("db should not be nil")
}
this.log("InsertRecord", "domain:", domainId, "name:", name)
_, err := this.insertRecordStmt.Exec(recordId, domainId, name, recordType, value, mxPriority, srvPriority, srvWeight, srvPort, caaFlag, caaTag, ttl, weight, strings.Join(routeIds, ","), version)
if err != nil {
return err
}
return nil
}
func (this *DB) UpdateRecord(recordId int64, domainId int64, name string, recordType dnsconfigs.RecordType, value string, mxPriority int32, srvPriority int32, srvWeight int32, srvPort int32, caaFlag int32, caaTag string, ttl int32, weight int32, routeIds []string, version int64) error {
if this.db == nil {
return errors.New("db should not be nil")
}
this.log("UpdateRecord", "domain:", domainId, "name:", name)
_, err := this.updateRecordStmt.Exec(domainId, name, recordType, value, mxPriority, srvPriority, srvWeight, srvPort, caaFlag, caaTag, ttl, weight, strings.Join(routeIds, ","), version, recordId)
if err != nil {
return err
}
return nil
}
func (this *DB) DeleteRecord(recordId int64) error {
if this.db == nil {
return errors.New("db should not be nil")
}
this.log("DeleteRecord", "record:", recordId)
_, err := this.deleteRecordStmt.Exec(recordId)
if err != nil {
return err
}
return nil
}
func (this *DB) ExistsRecord(recordId int64) (bool, error) {
if this.db == nil {
return false, errors.New("db should not be nil")
}
rows, err := this.existsRecordStmt.Query(recordId)
if err != nil {
return false, err
}
if rows.Err() != nil {
return false, rows.Err()
}
defer func() {
_ = rows.Close()
}()
if rows.Next() {
return true, nil
}
return false, nil
}
// ListRecords 列出一组记录
// TODO 将来只加载本集群上的记录
func (this *DB) ListRecords(offset int, size int) (records []*models.NSRecord, err error) {
if this.db == nil {
return nil, errors.New("db should not be nil")
}
rows, err := this.listRecordsStmt.Query(size, offset)
if err != nil {
return nil, err
}
if rows.Err() != nil {
return nil, rows.Err()
}
defer func() {
_ = rows.Close()
}()
for rows.Next() {
var record = &models.NSRecord{}
var routeIds = ""
err = rows.Scan(&record.Id, &record.DomainId, &record.Name, &record.Type, &record.Value, &record.MXPriority, &record.SRVPriority, &record.SRVWeight, &record.SRVPort, &record.CAAFlag, &record.CAATag, &record.Ttl, &record.Weight, &routeIds, &record.Version)
if err != nil {
return nil, err
}
if len(routeIds) > 0 {
record.RouteIds = strings.Split(routeIds, ",")
}
records = append(records, record)
}
return
}
// InsertRoute 创建线路
func (this *DB) InsertRoute(routeId int64, userId int64, rangesJSON []byte, order int32, priority int32, version int64) error {
if this.db == nil {
return errors.New("db should not be nil")
}
this.log("InsertRoute", "route:", routeId, "user:", userId)
_, err := this.insertRouteStmt.Exec(routeId, userId, string(rangesJSON), order, priority, version)
if err != nil {
return err
}
return nil
}
// UpdateRoute 修改线路
func (this *DB) UpdateRoute(routeId int64, userId int64, rangesJSON []byte, order int32, priority int32, version int64) error {
if this.db == nil {
return errors.New("db should not be nil")
}
this.log("UpdateRoute", "route:", routeId, "user:", userId)
_, err := this.updateRouteStmt.Exec(userId, string(rangesJSON), order, priority, version, routeId)
if err != nil {
return err
}
return nil
}
// DeleteRoute 删除线路
func (this *DB) DeleteRoute(routeId int64) error {
if this.db == nil {
return errors.New("db should not be nil")
}
this.log("DeleteRoute", "route:", routeId)
_, err := this.deleteRouteStmt.Exec(routeId)
if err != nil {
return err
}
return nil
}
// ExistsRoute 检查是否存在线路
func (this *DB) ExistsRoute(routeId int64) (bool, error) {
if this.db == nil {
return false, errors.New("db should not be nil")
}
rows, err := this.existsRouteStmt.Query(routeId)
if err != nil {
return false, err
}
if rows.Err() != nil {
return false, rows.Err()
}
defer func() {
_ = rows.Close()
}()
if rows.Next() {
return true, nil
}
return false, nil
}
// ListRoutes 查找所有线路
func (this *DB) ListRoutes(offset int64, size int64) (routes []*models.NSRoute, err error) {
if this.db == nil {
return nil, errors.New("db should not be nil")
}
rows, err := this.listRoutesStmt.Query(size, offset)
if err != nil {
return nil, err
}
if rows.Err() != nil {
return nil, rows.Err()
}
defer func() {
_ = rows.Close()
}()
for rows.Next() {
var route = &models.NSRoute{}
var rangesString = ""
err = rows.Scan(&route.Id, &route.UserId, &rangesString, &route.Priority, &route.Order, &route.Version)
if err != nil {
return nil, err
}
route.Ranges, err = models.InitRangesFromJSON([]byte(rangesString))
if err != nil {
return nil, err
}
routes = append(routes, route)
}
return
}
func (this *DB) InsertKey(keyId int64, domainId int64, zoneId int64, algo string, secret string, secretType string, version int64) error {
if this.db == nil {
return errors.New("db should not be nil")
}
this.log("InsertKey", "key:", keyId, "domain:", domainId, "zone:", zoneId)
_, err := this.insertKeyStmt.Exec(keyId, domainId, zoneId, algo, secret, secretType, version)
if err != nil {
return err
}
return nil
}
func (this *DB) UpdateKey(keyId int64, domainId int64, zoneId int64, algo string, secret string, secretType string, version int64) error {
if this.db == nil {
return errors.New("db should not be nil")
}
this.log("UpdateKey", "key:", keyId, "domain:", domainId, "zone:", zoneId)
_, err := this.updateKeyStmt.Exec(domainId, zoneId, algo, secret, secretType, version, keyId)
if err != nil {
return err
}
return nil
}
func (this *DB) ListKeys(offset int, size int) (keys []*models.NSKey, err error) {
if this.db == nil {
return nil, errors.New("db should not be nil")
}
rows, err := this.listKeysStmt.Query(size, offset)
if err != nil {
return nil, err
}
if rows.Err() != nil {
return nil, rows.Err()
}
defer func() {
_ = rows.Close()
}()
for rows.Next() {
var key = &models.NSKey{}
err = rows.Scan(&key.Id, &key.DomainId, &key.ZoneId, &key.Algo, &key.Secret, &key.SecretType, &key.Version)
if err != nil {
return nil, err
}
keys = append(keys, key)
}
return
}
func (this *DB) DeleteKey(keyId int64) error {
if this.db == nil {
return errors.New("db should not be nil")
}
this.log("DeleteKey", "key:", keyId)
_, err := this.deleteKeyStmt.Exec(keyId)
if err != nil {
return err
}
return nil
}
func (this *DB) ExistsKey(keyId int64) (bool, error) {
if this.db == nil {
return false, errors.New("db should not be nil")
}
rows, err := this.existsKeyStmt.Query(keyId)
if err != nil {
return false, err
}
if rows.Err() != nil {
return false, rows.Err()
}
defer func() {
_ = rows.Close()
}()
if rows.Next() {
return true, nil
}
return false, nil
}
func (this *DB) InsertAgentIP(ipId int64, ip string, agentCode string) error {
if this.db == nil {
return errors.New("db should not be nil")
}
this.log("InsertAgentIP", "id:", ipId, "ip:", ip, "agent:", agentCode)
_, err := this.insertAgentIPStmt.Exec(ipId, ip, agentCode)
if err != nil {
return err
}
return nil
}
func (this *DB) ListAgentIPs(offset int64, size int64) (agentIPs []*models.AgentIP, err error) {
if this.db == nil {
return nil, errors.New("db should not be nil")
}
rows, err := this.listAgentIPsStmt.Query(size, offset)
if err != nil {
return nil, err
}
if rows.Err() != nil {
return nil, rows.Err()
}
defer func() {
_ = rows.Close()
}()
for rows.Next() {
var agentIP = &models.AgentIP{}
err = rows.Scan(&agentIP.Id, &agentIP.IP, &agentIP.AgentCode)
if err != nil {
return nil, err
}
agentIPs = append(agentIPs, agentIP)
}
return
}
func (this *DB) Close() error {
if this.db == nil {
return nil
}
for _, stmt := range []*dbutils.Stmt{
this.insertDomainStmt,
this.updateDomainStmt,
this.deleteDomainStmt,
this.existsDomainStmt,
this.listDomainsStmt,
this.insertRecordStmt,
this.updateRecordStmt,
this.existsRecordStmt,
this.deleteRecordStmt,
this.listRecordsStmt,
this.insertRouteStmt,
this.updateRouteStmt,
this.deleteRouteStmt,
this.listRoutesStmt,
this.existsRouteStmt,
this.insertKeyStmt,
this.updateKeyStmt,
this.deleteKeyStmt,
this.listKeysStmt,
this.existsKeyStmt,
this.insertAgentIPStmt,
this.listAgentIPsStmt,
} {
if stmt != nil {
_ = stmt.Close()
}
}
err := this.db.Close()
if err != nil {
return err
}
return nil
}
// 打印日志
func (this *DB) log(args ...any) {
if !Tea.IsTesting() {
return
}
if len(args) == 0 {
return
}
args[0] = "[" + types.String(args[0]) + "]"
log.Println(args...)
}