v1.5.1 增强程序稳定性
This commit is contained in:
@@ -22,13 +22,15 @@ type Manager struct {
|
||||
|
||||
db *dbs.DB
|
||||
|
||||
lastId int64
|
||||
lastId int64
|
||||
ReadyCh chan struct{} // 初始加载完成后关闭
|
||||
}
|
||||
|
||||
func NewManager(db *dbs.DB) *Manager {
|
||||
return &Manager{
|
||||
ipMap: map[string]string{},
|
||||
db: db,
|
||||
ipMap: map[string]string{},
|
||||
db: db,
|
||||
ReadyCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,6 +53,9 @@ func (this *Manager) Start() {
|
||||
}
|
||||
}
|
||||
|
||||
// 通知初始加载完成
|
||||
close(this.ReadyCh)
|
||||
|
||||
// 定时获取
|
||||
var duration = 30 * time.Minute
|
||||
if Tea.IsTesting() {
|
||||
|
||||
@@ -3,6 +3,14 @@
|
||||
|
||||
package configs
|
||||
|
||||
import "github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
var SharedNodeConfig *dnsconfigs.NSNodeConfig
|
||||
var SharedNodeConfig atomic.Pointer[dnsconfigs.NSNodeConfig]
|
||||
|
||||
// LoadSharedNodeConfig 返回当前节点配置(并发安全)
|
||||
func LoadSharedNodeConfig() *dnsconfigs.NSNodeConfig {
|
||||
return SharedNodeConfig.Load()
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package teaconst
|
||||
|
||||
const (
|
||||
Version = "1.5.0" //1.3.8.2
|
||||
Version = "1.5.1" //1.3.8.2
|
||||
|
||||
ProductName = "Edge DNS"
|
||||
ProcessName = "edge-dns"
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
279
EdgeDNS/internal/dbs/db_migrate_sqlite.go
Normal file
279
EdgeDNS/internal/dbs/db_migrate_sqlite.go
Normal file
@@ -0,0 +1,279 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build plus
|
||||
|
||||
package dbs
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/models"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const sqliteMigrationMarkerKey = "meta:migrated_from_sqlite"
|
||||
|
||||
func (this *DB) migrateSQLiteIfNeeded() error {
|
||||
if len(this.path) == 0 || this.path == this.storePath {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := os.Stat(this.path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
ok, err := this.hasMigrationMarker()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ok {
|
||||
removeErr := removeSQLiteFiles(this.path)
|
||||
if removeErr != nil {
|
||||
remotelogs.Warn("DB", "remove sqlite files failed after migration: "+removeErr.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
remotelogs.Println("DB", "migrating sqlite database from '"+this.path+"' to '"+this.storePath+"' ...")
|
||||
|
||||
err = this.truncateMigratedData()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = this.importSQLiteData()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = this.rawDB.Set([]byte(sqliteMigrationMarkerKey), []byte("1"), defaultWriteOptions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = flushRawDB(this.rawDB)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
removeErr := removeSQLiteFiles(this.path)
|
||||
if removeErr != nil {
|
||||
remotelogs.Warn("DB", "remove sqlite files failed after migration: "+removeErr.Error())
|
||||
}
|
||||
|
||||
remotelogs.Println("DB", "migrated sqlite database to pebble")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *DB) hasMigrationMarker() (bool, error) {
|
||||
_, closer, err := this.rawDB.Get([]byte(sqliteMigrationMarkerKey))
|
||||
if err != nil {
|
||||
if isNotFound(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
_ = closer.Close()
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (this *DB) truncateMigratedData() error {
|
||||
for _, prefix := range []string{
|
||||
domainPrefix,
|
||||
domainClusterIndex,
|
||||
recordPrefix,
|
||||
routePrefix,
|
||||
keyPrefix,
|
||||
agentIPPrefix,
|
||||
"meta:",
|
||||
} {
|
||||
err := this.rawDB.DeleteRange([]byte(prefix), prefixUpperBound([]byte(prefix)), defaultWriteOptions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *DB) importSQLiteData() error {
|
||||
sqliteDB, err := sql.Open("sqlite3", "file:"+this.path+"?mode=ro")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
_ = sqliteDB.Close()
|
||||
}()
|
||||
|
||||
err = this.importSQLiteDomains(sqliteDB)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = this.importSQLiteRecords(sqliteDB)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = this.importSQLiteRoutes(sqliteDB)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = this.importSQLiteKeys(sqliteDB)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return this.importSQLiteAgentIPs(sqliteDB)
|
||||
}
|
||||
|
||||
func (this *DB) importSQLiteDomains(sqliteDB *sql.DB) error {
|
||||
rows, err := sqliteDB.Query(`SELECT "id", "clusterId", "userId", "name", "tsig", "version" FROM "domains_v2" ORDER BY "id" ASC`)
|
||||
if err != nil {
|
||||
return ignoreMissingTable(err)
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
|
||||
for rows.Next() {
|
||||
value := &domainValue{}
|
||||
err = rows.Scan(&value.Id, &value.ClusterId, &value.UserId, &value.Name, &value.TSIGJSON, &value.Version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = this.saveDomain(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
func (this *DB) importSQLiteRecords(sqliteDB *sql.DB) error {
|
||||
rows, err := sqliteDB.Query(`SELECT "id", "domainId", "name", "type", "value", "mxPriority", "srvPriority", "srvWeight", "srvPort", "caaFlag", "caaTag", "ttl", "weight", "routeIds", "version" FROM "records_v2" ORDER BY "id" ASC`)
|
||||
if err != nil {
|
||||
return ignoreMissingTable(err)
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
|
||||
for rows.Next() {
|
||||
value := &recordValue{}
|
||||
var routeIDs string
|
||||
err = rows.Scan(&value.Id, &value.DomainId, &value.Name, &value.Type, &value.Value, &value.MXPriority, &value.SRVPriority, &value.SRVWeight, &value.SRVPort, &value.CAAFlag, &value.CAATag, &value.TTL, &value.Weight, &routeIDs, &value.Version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(routeIDs) > 0 {
|
||||
value.RouteIds = strings.Split(routeIDs, ",")
|
||||
}
|
||||
|
||||
err = this.saveJSON(recordKey(value.Id), value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
func (this *DB) importSQLiteRoutes(sqliteDB *sql.DB) error {
|
||||
rows, err := sqliteDB.Query(`SELECT "id", "userId", "ranges", "priority", "order", "version" FROM "routes_v2" ORDER BY "id" ASC`)
|
||||
if err != nil {
|
||||
return ignoreMissingTable(err)
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
|
||||
for rows.Next() {
|
||||
value := &routeValue{}
|
||||
err = rows.Scan(&value.Id, &value.UserId, &value.RangesJSON, &value.Priority, &value.Order, &value.Version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = this.saveJSON(routeKey(value.Id), value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
func (this *DB) importSQLiteKeys(sqliteDB *sql.DB) error {
|
||||
rows, err := sqliteDB.Query(`SELECT "id", "domainId", "zoneId", "algo", "secret", "secretType", "version" FROM "keys" ORDER BY "id" ASC`)
|
||||
if err != nil {
|
||||
return ignoreMissingTable(err)
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
|
||||
for rows.Next() {
|
||||
value := &models.NSKey{}
|
||||
err = rows.Scan(&value.Id, &value.DomainId, &value.ZoneId, &value.Algo, &value.Secret, &value.SecretType, &value.Version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = this.saveJSON(keyKey(value.Id), value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
func (this *DB) importSQLiteAgentIPs(sqliteDB *sql.DB) error {
|
||||
rows, err := sqliteDB.Query(`SELECT "id", "ip", "agentCode" FROM "agentIPs" ORDER BY "id" ASC`)
|
||||
if err != nil {
|
||||
return ignoreMissingTable(err)
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
|
||||
for rows.Next() {
|
||||
value := &models.AgentIP{}
|
||||
err = rows.Scan(&value.Id, &value.IP, &value.AgentCode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = this.saveJSON(agentIPKey(value.Id), value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
func ignoreMissingTable(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if strings.Contains(err.Error(), "no such table") {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func removeSQLiteFiles(path string) error {
|
||||
var lastErr error
|
||||
for _, filename := range []string{path, path + "-shm", path + "-wal"} {
|
||||
err := os.Remove(filename)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
63
EdgeDNS/internal/dbs/db_migrate_sqlite_internal_test.go
Normal file
63
EdgeDNS/internal/dbs/db_migrate_sqlite_internal_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
//go:build plus
|
||||
|
||||
package dbs
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/cockroachdb/pebble"
|
||||
)
|
||||
|
||||
func TestDB_MigrateSQLiteIfNeeded_FlushFailureKeepsSQLite(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
sqlitePath := filepath.Join(dir, "data.db")
|
||||
|
||||
sqliteDB, err := sql.Open("sqlite3", sqlitePath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = sqliteDB.Exec(`CREATE TABLE "domains_v2" ("id" INTEGER, "clusterId" INTEGER, "userId" INTEGER, "name" TEXT, "tsig" TEXT, "version" INTEGER)`)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "requires cgo") || strings.Contains(err.Error(), "CGO_ENABLED=0") {
|
||||
_ = sqliteDB.Close()
|
||||
t.Skip("sqlite3 driver is unavailable in current environment")
|
||||
}
|
||||
_ = sqliteDB.Close()
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = sqliteDB.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
oldFlushRawDB := flushRawDB
|
||||
flushRawDB = func(rawDB *pebble.DB) error {
|
||||
return errors.New("flush failed")
|
||||
}
|
||||
defer func() {
|
||||
flushRawDB = oldFlushRawDB
|
||||
}()
|
||||
|
||||
db := NewDB(sqlitePath)
|
||||
defer func() {
|
||||
_ = db.Close()
|
||||
}()
|
||||
|
||||
err = db.Init()
|
||||
if err == nil {
|
||||
t.Fatal("expected migration error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "flush failed") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
_, err = os.Stat(sqlitePath)
|
||||
if err != nil {
|
||||
t.Fatalf("sqlite source should remain after flush failure: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,7 @@ package firewalls
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -25,6 +26,7 @@ import (
|
||||
"net"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var SharedDDoSProtectionManager = NewDDoSProtectionManager()
|
||||
@@ -39,7 +41,7 @@ func init() {
|
||||
return
|
||||
}
|
||||
|
||||
var nodeConfig = configs.SharedNodeConfig
|
||||
var nodeConfig = configs.LoadSharedNodeConfig()
|
||||
if nodeConfig != nil {
|
||||
err := SharedDDoSProtectionManager.Apply(nodeConfig.DDoSProtection)
|
||||
if err != nil {
|
||||
@@ -49,7 +51,7 @@ func init() {
|
||||
})
|
||||
|
||||
events.On(events.EventNFTablesReady, func() {
|
||||
var nodeConfig = configs.SharedNodeConfig
|
||||
var nodeConfig = configs.LoadSharedNodeConfig()
|
||||
if nodeConfig != nil {
|
||||
err := SharedDDoSProtectionManager.Apply(nodeConfig.DDoSProtection)
|
||||
if err != nil {
|
||||
@@ -80,7 +82,7 @@ func NewDDoSProtectionManager() *DDoSProtectionManager {
|
||||
func (this *DDoSProtectionManager) Apply(config *ddosconfigs.ProtectionConfig) error {
|
||||
// 同集群节点IP白名单
|
||||
var allowIPListChanged = false
|
||||
var nodeConfig = configs.SharedNodeConfig
|
||||
var nodeConfig = configs.LoadSharedNodeConfig()
|
||||
if nodeConfig != nil {
|
||||
var allowIPList = nodeConfig.AllowedIPs
|
||||
if !utils.EqualStrings(allowIPList, this.lastAllowIPList) {
|
||||
@@ -175,6 +177,7 @@ func (this *DDoSProtectionManager) addTCPRules(tcpConfig *ddosconfigs.TCPConfig)
|
||||
for _, portConfig := range tcpConfig.Ports {
|
||||
// 校验端口范围
|
||||
if portConfig.Port <= 0 || portConfig.Port > 65535 {
|
||||
remotelogs.Warn("DDOS", "skipping invalid port in DDoS config: "+types.String(portConfig.Port))
|
||||
continue
|
||||
}
|
||||
if !lists.ContainsInt32(ports, portConfig.Port) {
|
||||
@@ -280,7 +283,7 @@ func (this *DDoSProtectionManager) addTCPRules(tcpConfig *ddosconfigs.TCPConfig)
|
||||
for _, port := range ports {
|
||||
// TODO 让用户选择是drop还是reject
|
||||
if maxConnections > 0 {
|
||||
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "count", "over", types.String(maxConnections), "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "maxConnections", types.String(maxConnections)}))
|
||||
var cmd = this.nftCommand( "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "count", "over", types.String(maxConnections), "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "maxConnections", types.String(maxConnections)}))
|
||||
var stderr = &bytes.Buffer{}
|
||||
cmd.Stderr = stderr
|
||||
err := cmd.Run()
|
||||
@@ -291,7 +294,7 @@ func (this *DDoSProtectionManager) addTCPRules(tcpConfig *ddosconfigs.TCPConfig)
|
||||
|
||||
// TODO 让用户选择是drop还是reject
|
||||
if maxConnectionsPerIP > 0 {
|
||||
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "meter", "meter-"+protocol+"-"+types.String(port)+"-max-connections", "{ "+protocol+" saddr ct count over "+types.String(maxConnectionsPerIP)+" }", "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "maxConnectionsPerIP", types.String(maxConnectionsPerIP)}))
|
||||
var cmd = this.nftCommand( "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "meter", "meter-"+protocol+"-"+types.String(port)+"-max-connections", "{ "+protocol+" saddr ct count over "+types.String(maxConnectionsPerIP)+" }", "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "maxConnectionsPerIP", types.String(maxConnectionsPerIP)}))
|
||||
var stderr = &bytes.Buffer{}
|
||||
cmd.Stderr = stderr
|
||||
err := cmd.Run()
|
||||
@@ -304,7 +307,7 @@ func (this *DDoSProtectionManager) addTCPRules(tcpConfig *ddosconfigs.TCPConfig)
|
||||
// TODO 让用户选择是drop还是reject
|
||||
if newConnectionsMinutelyRate > 0 {
|
||||
if newConnectionsMinutelyRateBlockTimeout > 0 {
|
||||
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsMinutelyRate)+"/minute burst "+types.String(newConnectionsMinutelyRate+3)+" packets }", "add", "@deny_set", "{"+protocol+" saddr timeout "+types.String(newConnectionsMinutelyRateBlockTimeout)+"s}", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsRate", types.String(newConnectionsMinutelyRate), types.String(newConnectionsMinutelyRateBlockTimeout)}))
|
||||
var cmd = this.nftCommand( "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsMinutelyRate)+"/minute burst "+types.String(newConnectionsMinutelyRate+3)+" packets }", "add", "@deny_set", "{"+protocol+" saddr timeout "+types.String(newConnectionsMinutelyRateBlockTimeout)+"s}", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsRate", types.String(newConnectionsMinutelyRate), types.String(newConnectionsMinutelyRateBlockTimeout)}))
|
||||
var stderr = &bytes.Buffer{}
|
||||
cmd.Stderr = stderr
|
||||
err := cmd.Run()
|
||||
@@ -312,7 +315,7 @@ func (this *DDoSProtectionManager) addTCPRules(tcpConfig *ddosconfigs.TCPConfig)
|
||||
return fmt.Errorf("add nftables rule '%s' failed: %w (%s)", cmd.String(), err, stderr.String())
|
||||
}
|
||||
} else {
|
||||
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsMinutelyRate)+"/minute burst "+types.String(newConnectionsMinutelyRate+3)+" packets }" /**"add", "@deny_set", "{"+protocol+" saddr}",**/, "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsRate", "0"}))
|
||||
var cmd = this.nftCommand( "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsMinutelyRate)+"/minute burst "+types.String(newConnectionsMinutelyRate+3)+" packets }" /**"add", "@deny_set", "{"+protocol+" saddr}",**/, "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsRate", "0"}))
|
||||
var stderr = &bytes.Buffer{}
|
||||
cmd.Stderr = stderr
|
||||
err := cmd.Run()
|
||||
@@ -326,7 +329,7 @@ func (this *DDoSProtectionManager) addTCPRules(tcpConfig *ddosconfigs.TCPConfig)
|
||||
// TODO 让用户选择是drop还是reject
|
||||
if newConnectionsSecondlyRate > 0 {
|
||||
if newConnectionsSecondlyRateBlockTimeout > 0 {
|
||||
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-secondly-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsSecondlyRate)+"/second burst "+types.String(newConnectionsSecondlyRate+3)+" packets }", "add", "@deny_set", "{"+protocol+" saddr timeout "+types.String(newConnectionsSecondlyRateBlockTimeout)+"s}", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsSecondlyRate", types.String(newConnectionsSecondlyRate), types.String(newConnectionsSecondlyRateBlockTimeout)}))
|
||||
var cmd = this.nftCommand( "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-secondly-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsSecondlyRate)+"/second burst "+types.String(newConnectionsSecondlyRate+3)+" packets }", "add", "@deny_set", "{"+protocol+" saddr timeout "+types.String(newConnectionsSecondlyRateBlockTimeout)+"s}", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsSecondlyRate", types.String(newConnectionsSecondlyRate), types.String(newConnectionsSecondlyRateBlockTimeout)}))
|
||||
var stderr = &bytes.Buffer{}
|
||||
cmd.Stderr = stderr
|
||||
err := cmd.Run()
|
||||
@@ -334,7 +337,7 @@ func (this *DDoSProtectionManager) addTCPRules(tcpConfig *ddosconfigs.TCPConfig)
|
||||
return fmt.Errorf("add nftables rule '%s' failed: %w (%s)", cmd.String(), err, stderr.String())
|
||||
}
|
||||
} else {
|
||||
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-secondly-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsSecondlyRate)+"/second burst "+types.String(newConnectionsSecondlyRate+3)+" packets }" /**"add", "@deny_set", "{"+protocol+" saddr}",**/, "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsSecondlyRate", "0"}))
|
||||
var cmd = this.nftCommand( "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-secondly-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsSecondlyRate)+"/second burst "+types.String(newConnectionsSecondlyRate+3)+" packets }" /**"add", "@deny_set", "{"+protocol+" saddr}",**/, "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsSecondlyRate", "0"}))
|
||||
var stderr = &bytes.Buffer{}
|
||||
cmd.Stderr = stderr
|
||||
err := cmd.Run()
|
||||
@@ -365,6 +368,12 @@ func (this *DDoSProtectionManager) removeTCPRules() error {
|
||||
}
|
||||
|
||||
// 组合user data
|
||||
// nftCommand 创建带 10s 超时的 nft 命令
|
||||
func (this *DDoSProtectionManager) nftCommand(args ...string) *exec.Cmd {
|
||||
ctx, _ := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
return exec.CommandContext(ctx, this.nftPath, args...)
|
||||
}
|
||||
|
||||
// 数据中不能包含字母、数字、下划线以外的数据
|
||||
func (this *DDoSProtectionManager) encodeUserData(attrs []string) string {
|
||||
if attrs == nil {
|
||||
|
||||
@@ -35,6 +35,12 @@ func init() {
|
||||
if runtime.GOOS == "linux" {
|
||||
var ticker = time.NewTicker(3 * time.Minute)
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
remotelogs.Error("FIREWALL", fmt.Sprintf("goroutine panic: %v", r))
|
||||
}
|
||||
}()
|
||||
|
||||
for range ticker.C {
|
||||
// if already ready, we break
|
||||
if nftablesIsReady {
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package goman
|
||||
|
||||
import (
|
||||
"log"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -17,6 +18,12 @@ func New(f func()) {
|
||||
_, file, line, _ := runtime.Caller(1)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("[GOMAN]goroutine panic at %s:%d: %v", file, line, r)
|
||||
}
|
||||
}()
|
||||
|
||||
locker.Lock()
|
||||
instanceId++
|
||||
|
||||
@@ -45,6 +52,12 @@ func NewWithArgs(f func(args ...interface{}), args ...interface{}) {
|
||||
_, file, line, _ := runtime.Caller(1)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("[GOMAN]goroutine panic at %s:%d: %v", file, line, r)
|
||||
}
|
||||
}()
|
||||
|
||||
locker.Lock()
|
||||
instanceId++
|
||||
|
||||
|
||||
@@ -225,7 +225,7 @@ func (this *APIStream) handleCheckLocalFirewall(message *pb.NSNodeStreamMessage)
|
||||
"version": version,
|
||||
}
|
||||
|
||||
var protectionConfig = sharedNodeConfig.DDoSProtection
|
||||
var protectionConfig = dnsNodeConfig().DDoSProtection
|
||||
err = firewalls.SharedDDoSProtectionManager.Apply(protectionConfig)
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, dataMessage.Name+"was installed, but apply DDoS protection config failed: "+err.Error())
|
||||
|
||||
@@ -34,6 +34,7 @@ import (
|
||||
"regexp"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
@@ -46,7 +47,18 @@ var sharedDomainManager *DomainManager
|
||||
var sharedRecordManager *RecordManager
|
||||
var sharedRouteManager *RouteManager
|
||||
var sharedKeyManager *KeyManager
|
||||
var sharedNodeConfig = &dnsconfigs.NSNodeConfig{}
|
||||
|
||||
var sharedNodeConfig atomic.Pointer[dnsconfigs.NSNodeConfig]
|
||||
|
||||
func init() {
|
||||
// 初始化默认空配置
|
||||
sharedNodeConfig.Store(&dnsconfigs.NSNodeConfig{})
|
||||
}
|
||||
|
||||
// dnsNodeConfig 返回当前 DNS 节点配置(并发安全)
|
||||
func dnsNodeConfig() *dnsconfigs.NSNodeConfig {
|
||||
return sharedNodeConfig.Load()
|
||||
}
|
||||
|
||||
func NewDNSNode() *DNSNode {
|
||||
return &DNSNode{
|
||||
@@ -105,13 +117,19 @@ func (this *DNSNode) Start() {
|
||||
events.Notify(events.EventStart)
|
||||
|
||||
// 监控状态
|
||||
go NewNodeStatusExecutor().Listen()
|
||||
goman.New(func() {
|
||||
NewNodeStatusExecutor().Listen()
|
||||
})
|
||||
|
||||
// 连接API
|
||||
go NewAPIStream().Start()
|
||||
goman.New(func() {
|
||||
NewAPIStream().Start()
|
||||
})
|
||||
|
||||
// 启动
|
||||
go this.start()
|
||||
goman.New(func() {
|
||||
this.start()
|
||||
})
|
||||
|
||||
// Hold住进程
|
||||
logs.Println("[DNS_NODE]started")
|
||||
@@ -224,7 +242,7 @@ func (this *DNSNode) listenSock() error {
|
||||
}
|
||||
|
||||
// 启动监听
|
||||
go func() {
|
||||
goman.New(func() {
|
||||
this.sock.OnCommand(func(cmd *gosock.Command) {
|
||||
switch cmd.Code {
|
||||
case "pid":
|
||||
@@ -262,7 +280,7 @@ func (this *DNSNode) listenSock() error {
|
||||
if err != nil {
|
||||
logs.Println("NODE", err.Error())
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
||||
events.On(events.EventQuit, func() {
|
||||
logs.Println("[DNS_NODE]", "quit unix sock")
|
||||
@@ -317,9 +335,9 @@ func (this *DNSNode) start() {
|
||||
return
|
||||
}
|
||||
|
||||
sharedNodeConfig = config
|
||||
sharedNodeConfig.Store(config)
|
||||
|
||||
configs.SharedNodeConfig = config
|
||||
configs.SharedNodeConfig.Store(config)
|
||||
events.Notify(events.EventReload)
|
||||
|
||||
sharedNodeConfigManager.reload(config)
|
||||
@@ -343,29 +361,47 @@ func (this *DNSNode) start() {
|
||||
return
|
||||
}
|
||||
|
||||
go sharedNodeConfigManager.Start()
|
||||
goman.New(func() {
|
||||
sharedNodeConfigManager.Start()
|
||||
})
|
||||
|
||||
sharedDomainManager = NewDomainManager(db, config.ClusterId)
|
||||
go sharedDomainManager.Start()
|
||||
goman.New(func() {
|
||||
sharedDomainManager.Start()
|
||||
})
|
||||
|
||||
sharedRecordManager = NewRecordManager(db)
|
||||
go sharedRecordManager.Start()
|
||||
goman.New(func() {
|
||||
sharedRecordManager.Start()
|
||||
})
|
||||
|
||||
sharedRouteManager = NewRouteManager(db)
|
||||
go sharedRouteManager.Start()
|
||||
goman.New(func() {
|
||||
sharedRouteManager.Start()
|
||||
})
|
||||
|
||||
sharedKeyManager = NewKeyManager(db)
|
||||
go sharedKeyManager.Start()
|
||||
goman.New(func() {
|
||||
sharedKeyManager.Start()
|
||||
})
|
||||
|
||||
agents.SharedManager = agents.NewManager(db)
|
||||
go agents.SharedManager.Start()
|
||||
goman.New(func() {
|
||||
agents.SharedManager.Start()
|
||||
})
|
||||
|
||||
// 发送通知,这里发送通知需要在DomainManager、RecordeManager等加载完成之后
|
||||
time.Sleep(1 * time.Second)
|
||||
// 等待所有 Manager 初始加载完成后再发送通知
|
||||
<-sharedDomainManager.readyCh
|
||||
<-sharedRecordManager.readyCh
|
||||
<-sharedRouteManager.readyCh
|
||||
<-sharedKeyManager.readyCh
|
||||
<-agents.SharedManager.ReadyCh
|
||||
events.Notify(events.EventLoaded)
|
||||
|
||||
// 启动循环
|
||||
go this.loop()
|
||||
goman.New(func() {
|
||||
this.loop()
|
||||
})
|
||||
}
|
||||
|
||||
// 更新配置Loop
|
||||
@@ -439,8 +475,9 @@ func (this *DNSNode) updateDDoS(rpcClient *rpc.RPCClient) error {
|
||||
return err
|
||||
}
|
||||
if len(resp.DdosProtectionJSON) == 0 {
|
||||
if sharedNodeConfig != nil {
|
||||
sharedNodeConfig.DDoSProtection = nil
|
||||
var cfg = dnsNodeConfig()
|
||||
if cfg != nil {
|
||||
cfg.DDoSProtection = nil
|
||||
}
|
||||
} else {
|
||||
var ddosProtectionConfig = &ddosconfigs.ProtectionConfig{}
|
||||
@@ -449,8 +486,9 @@ func (this *DNSNode) updateDDoS(rpcClient *rpc.RPCClient) error {
|
||||
return fmt.Errorf("decode DDoS protection config failed: %w", err)
|
||||
}
|
||||
|
||||
if sharedNodeConfig != nil {
|
||||
sharedNodeConfig.DDoSProtection = ddosProtectionConfig
|
||||
var cfg = dnsNodeConfig()
|
||||
if cfg != nil {
|
||||
cfg.DDoSProtection = ddosProtectionConfig
|
||||
}
|
||||
|
||||
err = firewalls.SharedDDoSProtectionManager.Apply(ddosProtectionConfig)
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
|
||||
@@ -27,7 +28,7 @@ func init() {
|
||||
sharedListenManager = NewListenManager()
|
||||
|
||||
events.On(events.EventReload, func() {
|
||||
sharedListenManager.Update(sharedNodeConfig)
|
||||
sharedListenManager.Update(dnsNodeConfig())
|
||||
})
|
||||
events.On(events.EventQuit, func() {
|
||||
_ = sharedListenManager.ShutdownAll()
|
||||
@@ -161,6 +162,12 @@ func (this *ListenManager) Update(config *dnsconfigs.NSNodeConfig) {
|
||||
this.serverMap[fullAddr] = server
|
||||
this.locker.Unlock()
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
remotelogs.Error("LISTEN_MANAGER", fmt.Sprintf("goroutine panic: %v", r))
|
||||
}
|
||||
}()
|
||||
|
||||
remotelogs.Println("LISTEN_MANAGER", "listen '"+fullAddr+"'")
|
||||
err = server.ListenAndServe()
|
||||
if err != nil {
|
||||
@@ -194,6 +201,11 @@ func (this *ListenManager) Update(config *dnsconfigs.NSNodeConfig) {
|
||||
|
||||
// 添加端口到firewalld
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
remotelogs.Error("LISTEN_MANAGER", fmt.Sprintf("firewalld goroutine panic: %v", r))
|
||||
}
|
||||
}()
|
||||
this.addToFirewalld(serverAddrs)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -22,11 +23,13 @@ type DomainManager struct {
|
||||
namesMap map[string]map[int64]*models.NSDomain // domain name => { domainId => domain }
|
||||
clusterId int64
|
||||
|
||||
db *dbs.DB
|
||||
version int64
|
||||
locker *sync.RWMutex
|
||||
db *dbs.DB
|
||||
version int64
|
||||
locker *sync.RWMutex
|
||||
dbWriteFailures atomic.Int64 // DB 写入累计失败次数
|
||||
|
||||
notifier chan bool
|
||||
readyCh chan struct{} // 初始加载完成后关闭
|
||||
}
|
||||
|
||||
// NewDomainManager 获取域名管理器对象
|
||||
@@ -38,6 +41,7 @@ func NewDomainManager(db *dbs.DB, clusterId int64) *DomainManager {
|
||||
clusterId: clusterId,
|
||||
notifier: make(chan bool, 8),
|
||||
locker: &sync.RWMutex{},
|
||||
readyCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,6 +69,9 @@ func (this *DomainManager) Start() {
|
||||
}
|
||||
}
|
||||
|
||||
// 通知初始加载完成
|
||||
close(this.readyCh)
|
||||
|
||||
// 更新
|
||||
var ticker = time.NewTicker(20 * time.Second)
|
||||
for {
|
||||
@@ -244,7 +251,8 @@ func (this *DomainManager) processDomain(domain *pb.NSDomain) {
|
||||
if this.db != nil {
|
||||
err := this.db.DeleteDomain(domain.Id)
|
||||
if err != nil {
|
||||
remotelogs.Error("DOMAIN_MANAGER", "delete domain from db failed: "+err.Error())
|
||||
count := this.dbWriteFailures.Add(1)
|
||||
remotelogs.Error("DOMAIN_MANAGER", "delete domain from db failed (total failures: "+types.String(count)+"): "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -255,17 +263,20 @@ func (this *DomainManager) processDomain(domain *pb.NSDomain) {
|
||||
if this.db != nil {
|
||||
exists, err := this.db.ExistsDomain(domain.Id)
|
||||
if err != nil {
|
||||
remotelogs.Error("DOMAIN_MANAGER", "query failed: "+err.Error())
|
||||
count := this.dbWriteFailures.Add(1)
|
||||
remotelogs.Error("DOMAIN_MANAGER", "query failed (total failures: "+types.String(count)+"): "+err.Error())
|
||||
} else {
|
||||
if exists {
|
||||
err = this.db.UpdateDomain(domain.Id, domain.NsCluster.Id, domain.UserId, domain.Name, domain.TsigJSON, domain.Version)
|
||||
if err != nil {
|
||||
remotelogs.Error("DOMAIN_MANAGER", "update failed: "+err.Error())
|
||||
count := this.dbWriteFailures.Add(1)
|
||||
remotelogs.Error("DOMAIN_MANAGER", "update failed (total failures: "+types.String(count)+"): "+err.Error())
|
||||
}
|
||||
} else {
|
||||
err = this.db.InsertDomain(domain.Id, domain.NsCluster.Id, domain.UserId, domain.Name, domain.TsigJSON, domain.Version)
|
||||
if err != nil {
|
||||
remotelogs.Error("DOMAIN_MANAGER", "insert failed: "+err.Error())
|
||||
count := this.dbWriteFailures.Add(1)
|
||||
remotelogs.Error("DOMAIN_MANAGER", "insert failed (total failures: "+types.String(count)+"): "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ type KeyManager struct {
|
||||
version int64
|
||||
|
||||
notifier chan bool
|
||||
readyCh chan struct{} // 初始加载完成后关闭
|
||||
}
|
||||
|
||||
// NewKeyManager 获取密钥管理器
|
||||
@@ -31,6 +32,7 @@ func NewKeyManager(db *dbs.DB) *KeyManager {
|
||||
zoneKeyMap: map[int64]*models.NSKeys{},
|
||||
db: db,
|
||||
notifier: make(chan bool, 8),
|
||||
readyCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,6 +60,9 @@ func (this *KeyManager) Start() {
|
||||
}
|
||||
}
|
||||
|
||||
// 通知初始加载完成
|
||||
close(this.readyCh)
|
||||
|
||||
// 更新
|
||||
var ticker = time.NewTicker(1 * time.Minute)
|
||||
for {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
@@ -87,8 +88,8 @@ func (this *NodeConfigManager) Loop() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sharedNodeConfig = config
|
||||
configs.SharedNodeConfig = config
|
||||
sharedNodeConfig.Store(config)
|
||||
configs.SharedNodeConfig.Store(config)
|
||||
|
||||
this.reload(config)
|
||||
|
||||
@@ -180,6 +181,12 @@ func (this *NodeConfigManager) changeAPINodeAddrs(apiNodeAddrs []*serverconfigs.
|
||||
|
||||
// 异步检测,防止阻塞
|
||||
go func(v int64) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
remotelogs.Error("NODE", fmt.Sprintf("goroutine panic: %v", r))
|
||||
}
|
||||
}()
|
||||
|
||||
// 测试新的API节点地址
|
||||
if rpcClient.TestEndpoints(addrs) {
|
||||
config.RPCEndpoints = addrs
|
||||
|
||||
@@ -9,7 +9,9 @@ import (
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/models"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -17,11 +19,13 @@ import (
|
||||
type RecordManager struct {
|
||||
recordsMap map[int64]*models.DomainRecords // domainId => RecordsMap
|
||||
|
||||
db *dbs.DB
|
||||
locker sync.RWMutex
|
||||
version int64
|
||||
db *dbs.DB
|
||||
locker sync.RWMutex
|
||||
version int64
|
||||
dbWriteFailures atomic.Int64 // DB 写入累计失败次数
|
||||
|
||||
notifier chan bool
|
||||
readyCh chan struct{} // 初始加载完成后关闭
|
||||
}
|
||||
|
||||
// NewRecordManager 获取新记录管理器对象
|
||||
@@ -30,6 +34,7 @@ func NewRecordManager(db *dbs.DB) *RecordManager {
|
||||
db: db,
|
||||
recordsMap: map[int64]*models.DomainRecords{},
|
||||
notifier: make(chan bool, 8),
|
||||
readyCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,6 +62,9 @@ func (this *RecordManager) Start() {
|
||||
}
|
||||
}
|
||||
|
||||
// 通知初始加载完成
|
||||
close(this.readyCh)
|
||||
|
||||
// 更新
|
||||
var ticker = time.NewTicker(30 * time.Second)
|
||||
for {
|
||||
@@ -158,7 +166,7 @@ func (this *RecordManager) FindRecords(domainId int64, routeCodes []string, reco
|
||||
this.locker.RLock()
|
||||
domainRecords, ok := this.recordsMap[domainId]
|
||||
if ok {
|
||||
records, routeCode = domainRecords.Find(routeCodes, recordName, recordType, sharedNodeConfig.Answer, strictMode)
|
||||
records, routeCode = domainRecords.Find(routeCodes, recordName, recordType, dnsNodeConfig().Answer, strictMode)
|
||||
}
|
||||
this.locker.RUnlock()
|
||||
return
|
||||
@@ -190,7 +198,8 @@ func (this *RecordManager) processRecord(record *pb.NSRecord) {
|
||||
if this.db != nil {
|
||||
err := this.db.DeleteRecord(record.Id)
|
||||
if err != nil {
|
||||
remotelogs.Error("RECORD_MANAGER", "delete record from db failed: "+err.Error())
|
||||
count := this.dbWriteFailures.Add(1)
|
||||
remotelogs.Error("RECORD_MANAGER", "delete record from db failed (total failures: "+types.String(count)+"): "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -201,7 +210,8 @@ func (this *RecordManager) processRecord(record *pb.NSRecord) {
|
||||
if this.db != nil {
|
||||
exists, err := this.db.ExistsRecord(record.Id)
|
||||
if err != nil {
|
||||
remotelogs.Error("RECORD_MANAGER", "query failed: "+err.Error())
|
||||
count := this.dbWriteFailures.Add(1)
|
||||
remotelogs.Error("RECORD_MANAGER", "query failed (total failures: "+types.String(count)+"): "+err.Error())
|
||||
} else {
|
||||
var routeIds = []string{}
|
||||
for _, route := range record.NsRoutes {
|
||||
@@ -211,12 +221,14 @@ func (this *RecordManager) processRecord(record *pb.NSRecord) {
|
||||
if exists {
|
||||
err = this.db.UpdateRecord(record.Id, record.NsDomain.Id, record.Name, record.Type, record.Value, record.MxPriority, record.SrvPriority, record.SrvWeight, record.SrvPort, record.CaaFlag, record.CaaTag, record.Ttl, record.Weight, routeIds, record.Version)
|
||||
if err != nil {
|
||||
remotelogs.Error("RECORD_MANAGER", "update failed: "+err.Error())
|
||||
count := this.dbWriteFailures.Add(1)
|
||||
remotelogs.Error("RECORD_MANAGER", "update failed (total failures: "+types.String(count)+"): "+err.Error())
|
||||
}
|
||||
} else {
|
||||
err = this.db.InsertRecord(record.Id, record.NsDomain.Id, record.Name, record.Type, record.Value, record.MxPriority, record.SrvPriority, record.SrvWeight, record.SrvPort, record.CaaFlag, record.CaaTag, record.Ttl, record.Weight, routeIds, record.Version)
|
||||
if err != nil {
|
||||
remotelogs.Error("RECORD_MANAGER", "insert failed: "+err.Error())
|
||||
count := this.dbWriteFailures.Add(1)
|
||||
remotelogs.Error("RECORD_MANAGER", "insert failed (total failures: "+types.String(count)+"): "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,6 +31,7 @@ type RouteManager struct {
|
||||
locker sync.RWMutex
|
||||
|
||||
notifier chan bool
|
||||
readyCh chan struct{} // 初始加载完成后关闭
|
||||
|
||||
ispRouteMap map[string]string // name => code
|
||||
chinaRouteMap map[string]string // name => code
|
||||
@@ -45,6 +46,7 @@ func NewRouteManager(db *dbs.DB) *RouteManager {
|
||||
userRouteMap: map[int64][]int64{},
|
||||
|
||||
notifier: make(chan bool, 8),
|
||||
readyCh: make(chan struct{}),
|
||||
|
||||
ispRouteMap: map[string]string{},
|
||||
chinaRouteMap: map[string]string{},
|
||||
@@ -79,6 +81,9 @@ func (this *RouteManager) Start() {
|
||||
}
|
||||
}
|
||||
|
||||
// 通知初始加载完成
|
||||
close(this.readyCh)
|
||||
|
||||
// 更新
|
||||
var ticker = time.NewTicker(1 * time.Minute)
|
||||
for {
|
||||
|
||||
@@ -3,6 +3,7 @@ package nodes
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/accesslogs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
|
||||
"strconv"
|
||||
@@ -24,7 +25,9 @@ func NewNSAccessLogQueue() *NSAccessLogQueue {
|
||||
queue := &NSAccessLogQueue{
|
||||
queue: make(chan *pb.NSAccessLog, maxSize),
|
||||
}
|
||||
go queue.Start()
|
||||
goman.New(func() {
|
||||
queue.Start()
|
||||
})
|
||||
|
||||
return queue
|
||||
}
|
||||
@@ -93,10 +96,11 @@ Loop:
|
||||
var clusterId int64
|
||||
var needWriteFile = true
|
||||
var needReportAPI = true
|
||||
if sharedNodeConfig != nil {
|
||||
clusterId = sharedNodeConfig.ClusterId
|
||||
if sharedNodeConfig.AccessLogWriteTargets != nil {
|
||||
targets := sharedNodeConfig.AccessLogWriteTargets
|
||||
var cfg = dnsNodeConfig()
|
||||
if cfg != nil {
|
||||
clusterId = cfg.ClusterId
|
||||
if cfg.AccessLogWriteTargets != nil {
|
||||
targets := cfg.AccessLogWriteTargets
|
||||
needWriteFile = targets.File || targets.ClickHouse
|
||||
needReportAPI = targets.MySQL
|
||||
}
|
||||
|
||||
@@ -31,7 +31,11 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
var sharedRecursionDNSClient = &dns.Client{}
|
||||
var sharedRecursionDNSClient = &dns.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
ReadTimeout: 3 * time.Second,
|
||||
WriteTimeout: 2 * time.Second,
|
||||
}
|
||||
|
||||
type httpContextKey struct {
|
||||
key string
|
||||
@@ -171,9 +175,70 @@ func (this *Server) init() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// addECSOption 向 DNS 请求中添加 EDNS Client Subnet (ECS) 信息
|
||||
// addECSOption 向 DNS 请求中设置 EDNS Client Subnet (ECS)。
|
||||
// 如果请求已携带 ECS 则覆盖(避免双 ECS 导致上游 malformed request)。
|
||||
func addECSOption(req *dns.Msg, clientIP string) {
|
||||
if len(clientIP) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
ip := net.ParseIP(clientIP)
|
||||
if ip == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var ecs = &dns.EDNS0_SUBNET{
|
||||
Code: dns.EDNS0SUBNET,
|
||||
}
|
||||
if ip.To4() != nil {
|
||||
ecs.Family = 1 // IPv4
|
||||
ecs.SourceNetmask = 24
|
||||
ecs.Address = ip.To4()
|
||||
} else {
|
||||
ecs.Family = 2 // IPv6
|
||||
ecs.SourceNetmask = 56
|
||||
ecs.Address = ip
|
||||
}
|
||||
|
||||
// 查找或创建 OPT 记录
|
||||
var opt = req.IsEdns0()
|
||||
if opt == nil {
|
||||
req.SetEdns0(4096, false)
|
||||
opt = req.IsEdns0()
|
||||
}
|
||||
if opt != nil {
|
||||
// 删除已有的 ECS option,避免出现双 EDNS0_SUBNET
|
||||
var filtered []dns.EDNS0
|
||||
for _, o := range opt.Option {
|
||||
if o.Option() != dns.EDNS0SUBNET {
|
||||
filtered = append(filtered, o)
|
||||
}
|
||||
}
|
||||
opt.Option = append(filtered, ecs)
|
||||
}
|
||||
}
|
||||
|
||||
// stripECSFromExtra 从 Extra section 中移除 OPT 记录里的 EDNS0_SUBNET,
|
||||
// 防止服务端注入的 ECS 信息回传给下游客户端(隐私泄露风险)。
|
||||
func stripECSFromExtra(extra []dns.RR) []dns.RR {
|
||||
for _, rr := range extra {
|
||||
if opt, ok := rr.(*dns.OPT); ok {
|
||||
var filtered []dns.EDNS0
|
||||
for _, o := range opt.Option {
|
||||
if o.Option() != dns.EDNS0SUBNET {
|
||||
filtered = append(filtered, o)
|
||||
}
|
||||
}
|
||||
opt.Option = filtered
|
||||
}
|
||||
}
|
||||
return extra
|
||||
}
|
||||
|
||||
// 查询递归DNS
|
||||
func (this *Server) lookupRecursionDNS(req *dns.Msg, resp *dns.Msg) error {
|
||||
var config = sharedNodeConfig.RecursionConfig
|
||||
func (this *Server) lookupRecursionDNS(req *dns.Msg, resp *dns.Msg, clientIP string) error {
|
||||
var config = dnsNodeConfig().RecursionConfig
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -182,6 +247,9 @@ func (this *Server) lookupRecursionDNS(req *dns.Msg, resp *dns.Msg) error {
|
||||
}
|
||||
|
||||
// 是否允许
|
||||
if len(req.Question) == 0 {
|
||||
return nil
|
||||
}
|
||||
var domain = strings.TrimSuffix(req.Question[0].Name, ".")
|
||||
if len(config.DenyDomains) > 0 && configutils.MatchDomains(config.DenyDomains, domain) {
|
||||
return nil
|
||||
@@ -190,6 +258,9 @@ func (this *Server) lookupRecursionDNS(req *dns.Msg, resp *dns.Msg) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 携带客户端真实 IP(ECS)向上游查询
|
||||
addECSOption(req, clientIP)
|
||||
|
||||
if config.UseLocalHosts {
|
||||
// TODO 需要缓存文件内容
|
||||
resolveConfig, err := dns.ClientConfigFromFile("/etc/resolv.conf")
|
||||
@@ -206,7 +277,12 @@ func (this *Server) lookupRecursionDNS(req *dns.Msg, resp *dns.Msg) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp.Answer = r.Answer
|
||||
if r != nil {
|
||||
resp.Rcode = r.Rcode
|
||||
resp.Answer = r.Answer
|
||||
resp.Ns = r.Ns
|
||||
resp.Extra = stripECSFromExtra(r.Extra)
|
||||
}
|
||||
} else if len(config.Hosts) > 0 {
|
||||
var host = config.Hosts[rands.Int(0, len(config.Hosts)-1)]
|
||||
if host.Port <= 0 {
|
||||
@@ -216,7 +292,12 @@ func (this *Server) lookupRecursionDNS(req *dns.Msg, resp *dns.Msg) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp.Answer = r.Answer
|
||||
if r != nil {
|
||||
resp.Rcode = r.Rcode
|
||||
resp.Answer = r.Answer
|
||||
resp.Ns = r.Ns
|
||||
resp.Extra = stripECSFromExtra(r.Extra)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -270,7 +351,8 @@ func (this *Server) parseAction(questionName string, remoteAddr *string) (string
|
||||
// 记录日志
|
||||
func (this *Server) addLog(networking string, question dns.Question, domainId int64, routeCode string, record *models.NSRecord, isRecursive bool, writer dns.ResponseWriter, remoteAddr string, err error) {
|
||||
// 访问日志
|
||||
var accessLogRef = sharedNodeConfig.AccessLogRef
|
||||
var nodeConfig = dnsNodeConfig()
|
||||
var accessLogRef = nodeConfig.AccessLogRef
|
||||
if accessLogRef != nil && accessLogRef.IsOn {
|
||||
if domainId == 0 && !accessLogRef.LogMissingDomains {
|
||||
return
|
||||
@@ -282,7 +364,7 @@ func (this *Server) addLog(networking string, question dns.Question, domainId in
|
||||
|
||||
var now = time.Now()
|
||||
var pbAccessLog = &pb.NSAccessLog{
|
||||
NsNodeId: sharedNodeConfig.Id,
|
||||
NsNodeId: nodeConfig.Id,
|
||||
RemoteAddr: remoteAddr,
|
||||
NsDomainId: domainId,
|
||||
QuestionName: question.Name,
|
||||
@@ -428,8 +510,14 @@ func (this *Server) handleDNSMessage(writer dns.ResponseWriter, req *dns.Msg) {
|
||||
domain, recordName = sharedDomainManager.SplitDomain(fullName)
|
||||
if domain == nil {
|
||||
// 检查递归DNS
|
||||
if sharedNodeConfig.RecursionConfig != nil && sharedNodeConfig.RecursionConfig.IsOn {
|
||||
err := this.lookupRecursionDNS(req, resp)
|
||||
var recursionConfig = dnsNodeConfig().RecursionConfig
|
||||
if recursionConfig != nil && recursionConfig.IsOn {
|
||||
// 提取客户端 IP 用于 ECS
|
||||
var clientIP = remoteAddr
|
||||
if clientHost, _, splitErr := net.SplitHostPort(clientIP); splitErr == nil && len(clientHost) > 0 {
|
||||
clientIP = clientHost
|
||||
}
|
||||
err := this.lookupRecursionDNS(req, resp, clientIP)
|
||||
if err != nil {
|
||||
this.addLog(networking, question, 0, "", nil, true, writer, remoteAddr, err)
|
||||
} else {
|
||||
@@ -459,7 +547,7 @@ func (this *Server) handleDNSMessage(writer dns.ResponseWriter, req *dns.Msg) {
|
||||
|
||||
// 是否为NS记录,用于验证域名所有权
|
||||
if question.Qtype == dns.TypeNS {
|
||||
var hosts = sharedNodeConfig.Hosts
|
||||
var hosts = dnsNodeConfig().Hosts
|
||||
var l = len(hosts)
|
||||
var record = &models.NSRecord{
|
||||
Id: 0,
|
||||
@@ -518,7 +606,7 @@ func (this *Server) handleDNSMessage(writer dns.ResponseWriter, req *dns.Msg) {
|
||||
}
|
||||
|
||||
// 解析Agent
|
||||
if sharedNodeConfig.DetectAgents {
|
||||
if dnsNodeConfig().DetectAgents {
|
||||
agents.SharedQueue.Push(clientIP)
|
||||
}
|
||||
|
||||
@@ -569,7 +657,7 @@ func (this *Server) handleDNSMessage(writer dns.ResponseWriter, req *dns.Msg) {
|
||||
}
|
||||
|
||||
// 对 NS.example.com NS|SOA 处理
|
||||
if (question.Qtype == dns.TypeNS || (question.Qtype == dns.TypeSOA && len(records) == 0)) && lists.ContainsString(sharedNodeConfig.Hosts, fullName) {
|
||||
if (question.Qtype == dns.TypeNS || (question.Qtype == dns.TypeSOA && len(records) == 0)) && lists.ContainsString(dnsNodeConfig().Hosts, fullName) {
|
||||
var recordDNSType string
|
||||
switch question.Qtype {
|
||||
case dns.TypeNS:
|
||||
@@ -663,7 +751,7 @@ func (this *Server) handleDNSMessage(writer dns.ResponseWriter, req *dns.Msg) {
|
||||
}
|
||||
case dnsconfigs.RecordTypeNS:
|
||||
if record.Id == 0 {
|
||||
var hosts = sharedNodeConfig.Hosts
|
||||
var hosts = dnsNodeConfig().Hosts
|
||||
var l = len(hosts)
|
||||
if l > 0 {
|
||||
// 随机
|
||||
@@ -900,8 +988,9 @@ func (this *Server) handleHTTPJSONAPI(writer http.ResponseWriter, req *http.Requ
|
||||
|
||||
// 组合SOA回复信息
|
||||
func (this *Server) composeSOAAnswer(question dns.Question, record *models.NSRecord, resp *dns.Msg) {
|
||||
var config = sharedNodeConfig.SOA
|
||||
var serial = sharedNodeConfig.SOASerial
|
||||
var nodeCfg = dnsNodeConfig()
|
||||
var config = nodeCfg.SOA
|
||||
var serial = nodeCfg.SOASerial
|
||||
|
||||
if config == nil {
|
||||
config = dnsconfigs.DefaultNSSOAConfig()
|
||||
@@ -909,7 +998,7 @@ func (this *Server) composeSOAAnswer(question dns.Question, record *models.NSRec
|
||||
|
||||
var mName = config.MName
|
||||
if len(mName) == 0 {
|
||||
var hosts = sharedNodeConfig.Hosts
|
||||
var hosts = nodeCfg.Hosts
|
||||
var l = len(hosts)
|
||||
if l > 0 {
|
||||
var index = rands.Int(0, l-1)
|
||||
@@ -919,7 +1008,7 @@ func (this *Server) composeSOAAnswer(question dns.Question, record *models.NSRec
|
||||
|
||||
var rName = config.RName
|
||||
if len(rName) == 0 {
|
||||
rName = sharedNodeConfig.Email
|
||||
rName = nodeCfg.Email
|
||||
}
|
||||
rName = strings.ReplaceAll(rName, "@", ".")
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/configs"
|
||||
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/events"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/utils"
|
||||
@@ -20,7 +21,9 @@ func init() {
|
||||
|
||||
events.On(events.EventStart, func() {
|
||||
task := NewSyncAPINodesTask()
|
||||
go task.Start()
|
||||
goman.New(func() {
|
||||
task.Start()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -52,7 +55,8 @@ func (this *SyncAPINodesTask) Start() {
|
||||
|
||||
func (this *SyncAPINodesTask) Loop() error {
|
||||
// 如果有节点定制的API节点地址
|
||||
var hasCustomizedAPINodeAddrs = sharedNodeConfig != nil && len(sharedNodeConfig.APINodeAddrs) > 0
|
||||
var nodeConfig = dnsNodeConfig()
|
||||
var hasCustomizedAPINodeAddrs = nodeConfig != nil && len(nodeConfig.APINodeAddrs) > 0
|
||||
|
||||
config, err := configs.LoadAPIConfig()
|
||||
if err != nil {
|
||||
|
||||
@@ -29,6 +29,12 @@ func init() {
|
||||
|
||||
events.On(events.EventStart, func() {
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
remotelogs.Error("UPGRADE_MANAGER", fmt.Sprintf("goroutine panic: %v", r))
|
||||
}
|
||||
}()
|
||||
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
remotelogs.Error("UPGRADE_MANAGER", err.Error())
|
||||
|
||||
@@ -20,6 +20,12 @@ func init() {
|
||||
// 定期上传日志
|
||||
var ticker = time.NewTicker(60 * time.Second)
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logs.Println("[LOG]goroutine panic:", r)
|
||||
}
|
||||
}()
|
||||
|
||||
for range ticker.C {
|
||||
err := uploadLogs()
|
||||
if err != nil {
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package stats
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/events"
|
||||
@@ -45,6 +46,12 @@ func (this *StatManager) Start() {
|
||||
this.ticker = time.NewTicker(1 * time.Minute)
|
||||
}
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
remotelogs.Error("STAT", fmt.Sprintf("goroutine panic: %v", r))
|
||||
}
|
||||
}()
|
||||
|
||||
for range this.ticker.C {
|
||||
err := this.Loop()
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user