Initial commit (code only without large binaries)
This commit is contained in:
256
EdgeDNS/internal/nodes/api_stream.go
Normal file
256
EdgeDNS/internal/nodes/api_stream.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/messageconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/events"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/firewalls"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/utils"
|
||||
executils "github.com/TeaOSLab/EdgeDNS/internal/utils/exec"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
type APIStream struct {
|
||||
stream pb.NSNodeService_NsNodeStreamClient
|
||||
|
||||
isQuiting bool
|
||||
cancelFunc context.CancelFunc
|
||||
}
|
||||
|
||||
func NewAPIStream() *APIStream {
|
||||
return &APIStream{}
|
||||
}
|
||||
|
||||
func (this *APIStream) Start() {
|
||||
events.On(events.EventQuit, func() {
|
||||
this.isQuiting = true
|
||||
if this.cancelFunc != nil {
|
||||
this.cancelFunc()
|
||||
}
|
||||
})
|
||||
for {
|
||||
if this.isQuiting {
|
||||
return
|
||||
}
|
||||
err := this.loop()
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("API_STREAM", err.Error())
|
||||
} else {
|
||||
remotelogs.Error("API_STREAM", err.Error())
|
||||
}
|
||||
time.Sleep(10 * time.Second)
|
||||
continue
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func (this *APIStream) loop() error {
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return errors.Wrap(err)
|
||||
}
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(rpcClient.Context())
|
||||
this.cancelFunc = cancelFunc
|
||||
|
||||
defer func() {
|
||||
cancelFunc()
|
||||
}()
|
||||
|
||||
nodeStream, err := rpcClient.NSNodeRPC.NsNodeStream(ctx)
|
||||
if err != nil {
|
||||
if this.isQuiting {
|
||||
return nil
|
||||
}
|
||||
return errors.Wrap(err)
|
||||
}
|
||||
|
||||
this.stream = nodeStream
|
||||
|
||||
for {
|
||||
if this.isQuiting {
|
||||
logs.Println("API_STREAM", "quit")
|
||||
break
|
||||
}
|
||||
|
||||
message, err := nodeStream.Recv()
|
||||
if err != nil {
|
||||
if this.isQuiting {
|
||||
remotelogs.Println("API_STREAM", "quit")
|
||||
return nil
|
||||
}
|
||||
return errors.Wrap(err)
|
||||
}
|
||||
|
||||
// 处理消息
|
||||
switch message.Code {
|
||||
case messageconfigs.NSMessageCodeConnectedAPINode: // 连接API节点成功
|
||||
err = this.handleConnectedAPINode(message)
|
||||
case messageconfigs.NSMessageCodeNewNodeTask: // 有新的任务
|
||||
err = this.handleNewNodeTask(message)
|
||||
case messageconfigs.NSMessageCodeCheckSystemdService: // 检查Systemd服务
|
||||
err = this.handleCheckSystemdService(message)
|
||||
case messageconfigs.MessageCodeCheckLocalFirewall: // 检查本地防火墙
|
||||
err = this.handleCheckLocalFirewall(message)
|
||||
default:
|
||||
err = this.handleUnknownMessage(message)
|
||||
}
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("API_STREAM", "handle message failed: "+err.Error())
|
||||
} else {
|
||||
remotelogs.Error("API_STREAM", "handle message failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 连接API节点成功
|
||||
func (this *APIStream) handleConnectedAPINode(message *pb.NSNodeStreamMessage) error {
|
||||
// 更改连接的APINode信息
|
||||
if len(message.DataJSON) == 0 {
|
||||
return nil
|
||||
}
|
||||
msg := &messageconfigs.ConnectedAPINodeMessage{}
|
||||
err := json.Unmarshal(message.DataJSON, msg)
|
||||
if err != nil {
|
||||
return errors.Wrap(err)
|
||||
}
|
||||
|
||||
_, err = rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return errors.Wrap(err)
|
||||
}
|
||||
|
||||
remotelogs.Println("API_STREAM", "connected to api node '"+strconv.FormatInt(msg.APINodeId, 10)+"'")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 处理配置变化
|
||||
func (this *APIStream) handleNewNodeTask(message *pb.NSNodeStreamMessage) error {
|
||||
select {
|
||||
case nodeTaskNotify <- true:
|
||||
default:
|
||||
|
||||
}
|
||||
this.replyOk(message.RequestId, "ok")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查Systemd服务
|
||||
func (this *APIStream) handleCheckSystemdService(message *pb.NSNodeStreamMessage) error {
|
||||
systemctl, err := executils.LookPath("systemctl")
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "'systemctl' not found")
|
||||
return nil
|
||||
}
|
||||
if len(systemctl) == 0 {
|
||||
this.replyFail(message.RequestId, "'systemctl' not found")
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd := utils.NewCommandExecutor()
|
||||
shortName := teaconst.SystemdServiceName
|
||||
cmd.Add(systemctl, "is-enabled", shortName)
|
||||
output, err := cmd.Run()
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "'systemctl' command error: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
if output == "enabled" {
|
||||
this.replyOk(message.RequestId, "ok")
|
||||
} else {
|
||||
this.replyFail(message.RequestId, "not installed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查本地防火墙
|
||||
func (this *APIStream) handleCheckLocalFirewall(message *pb.NSNodeStreamMessage) error {
|
||||
var dataMessage = &messageconfigs.CheckLocalFirewallMessage{}
|
||||
err := json.Unmarshal(message.DataJSON, dataMessage)
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "decode message data failed: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
// nft
|
||||
if dataMessage.Name == "nftables" {
|
||||
if runtime.GOOS != "linux" {
|
||||
this.replyFail(message.RequestId, "not Linux system")
|
||||
return nil
|
||||
}
|
||||
|
||||
nft, err := executils.LookPath("nft")
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "'nft' not found: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
var cmd = exec.Command(nft, "--version")
|
||||
var output = &bytes.Buffer{}
|
||||
cmd.Stdout = output
|
||||
err = cmd.Run()
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "get version failed: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
var outputString = output.String()
|
||||
var versionMatches = regexp.MustCompile(`nftables v([\d.]+)`).FindStringSubmatch(outputString)
|
||||
if len(versionMatches) <= 1 {
|
||||
this.replyFail(message.RequestId, "can not get nft version")
|
||||
return nil
|
||||
}
|
||||
var version = versionMatches[1]
|
||||
|
||||
var result = maps.Map{
|
||||
"version": version,
|
||||
}
|
||||
|
||||
var protectionConfig = sharedNodeConfig.DDoSProtection
|
||||
err = firewalls.SharedDDoSProtectionManager.Apply(protectionConfig)
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, dataMessage.Name+"was installed, but apply DDoS protection config failed: "+err.Error())
|
||||
} else {
|
||||
this.replyOk(message.RequestId, string(result.AsJSON()))
|
||||
}
|
||||
} else {
|
||||
this.replyFail(message.RequestId, "invalid firewall name '"+dataMessage.Name+"'")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 处理未知消息
|
||||
func (this *APIStream) handleUnknownMessage(message *pb.NSNodeStreamMessage) error {
|
||||
this.replyFail(message.RequestId, "unknown message code '"+message.Code+"'")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 回复失败
|
||||
func (this *APIStream) replyFail(requestId int64, message string) {
|
||||
_ = this.stream.Send(&pb.NSNodeStreamMessage{RequestId: requestId, IsOk: false, Message: message})
|
||||
}
|
||||
|
||||
// 回复成功
|
||||
func (this *APIStream) replyOk(requestId int64, message string) {
|
||||
_ = this.stream.Send(&pb.NSNodeStreamMessage{RequestId: requestId, IsOk: true, Message: message})
|
||||
}
|
||||
455
EdgeDNS/internal/nodes/dns_node.go
Normal file
455
EdgeDNS/internal/nodes/dns_node.go
Normal file
@@ -0,0 +1,455 @@
|
||||
//go:build plus
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/agents"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/configs"
|
||||
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/dbs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/events"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/firewalls"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/utils"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"github.com/iwind/gosock/pkg/gosock"
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
var DaemonIsOn = false
|
||||
var DaemonPid = 0
|
||||
var nodeTaskNotify = make(chan bool, 8)
|
||||
|
||||
var sharedDomainManager *DomainManager
|
||||
var sharedRecordManager *RecordManager
|
||||
var sharedRouteManager *RouteManager
|
||||
var sharedKeyManager *KeyManager
|
||||
var sharedNodeConfig = &dnsconfigs.NSNodeConfig{}
|
||||
|
||||
func NewDNSNode() *DNSNode {
|
||||
return &DNSNode{
|
||||
sock: gosock.NewTmpSock(teaconst.ProcessName),
|
||||
}
|
||||
}
|
||||
|
||||
type DNSNode struct {
|
||||
sock *gosock.Sock
|
||||
|
||||
RPC *rpc.RPCClient
|
||||
}
|
||||
|
||||
func (this *DNSNode) Start() {
|
||||
// 设置netdns
|
||||
// 这个需要放在所有网络访问的最前面
|
||||
_ = os.Setenv("GODEBUG", "netdns=go")
|
||||
|
||||
// 判断是否在守护进程下
|
||||
_, ok := os.LookupEnv("EdgeDaemon")
|
||||
if ok {
|
||||
remotelogs.Println("DNS_NODE", "start from daemon")
|
||||
DaemonIsOn = true
|
||||
DaemonPid = os.Getppid()
|
||||
}
|
||||
|
||||
// 设置DNS解析库
|
||||
err := os.Setenv("GODEBUG", "netdns=go")
|
||||
if err != nil {
|
||||
remotelogs.Error("DNS_NODE", "[DNS_RESOLVER]set env failed: "+err.Error())
|
||||
}
|
||||
|
||||
// 处理异常
|
||||
this.handlePanic()
|
||||
|
||||
// 监听signal
|
||||
this.listenSignals()
|
||||
|
||||
// 本地Sock
|
||||
err = this.listenSock()
|
||||
if err != nil {
|
||||
logs.Println("[DNS_NODE]" + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 启动IP库
|
||||
remotelogs.Println("DNS_NODE", "initializing ip library ...")
|
||||
err = iplibrary.InitPlus()
|
||||
if err != nil {
|
||||
remotelogs.Error("DNS_NODE", "initialize ip library failed: "+err.Error())
|
||||
}
|
||||
runtime.GC()
|
||||
debug.FreeOSMemory()
|
||||
|
||||
// 触发事件
|
||||
events.Notify(events.EventStart)
|
||||
|
||||
// 监控状态
|
||||
go NewNodeStatusExecutor().Listen()
|
||||
|
||||
// 连接API
|
||||
go NewAPIStream().Start()
|
||||
|
||||
// 启动
|
||||
go this.start()
|
||||
|
||||
// Hold住进程
|
||||
logs.Println("[DNS_NODE]started")
|
||||
select {}
|
||||
}
|
||||
|
||||
// Daemon 实现守护进程
|
||||
func (this *DNSNode) Daemon() {
|
||||
var isDebug = lists.ContainsString(os.Args, "debug")
|
||||
for {
|
||||
conn, err := this.sock.Dial()
|
||||
if err != nil {
|
||||
if isDebug {
|
||||
log.Println("[DAEMON]starting ...")
|
||||
}
|
||||
|
||||
// 尝试启动
|
||||
err = func() error {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 可以标记当前是从守护进程启动的
|
||||
_ = os.Setenv("EdgeDaemon", "on")
|
||||
_ = os.Setenv("EdgeBackground", "on")
|
||||
|
||||
var cmd = exec.Command(exe)
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = cmd.Wait()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
if isDebug {
|
||||
log.Println("[DAEMON]", err)
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
} else {
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
||||
} else {
|
||||
_ = conn.Close()
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test 测试配置
|
||||
func (this *DNSNode) Test() error {
|
||||
// 检查是否能连接API
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return fmt.Errorf("test rpc failed: %w", err)
|
||||
}
|
||||
_, err = rpcClient.APINodeRPC.FindCurrentAPINodeVersion(rpcClient.Context(), &pb.FindCurrentAPINodeVersionRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("test rpc failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// InstallSystemService 安装系统服务
|
||||
func (this *DNSNode) InstallSystemService() error {
|
||||
shortName := teaconst.SystemdServiceName
|
||||
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
manager := utils.NewServiceManager(shortName, teaconst.ProductName)
|
||||
err = manager.Install(exe, []string{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 监听一些信号
|
||||
func (this *DNSNode) listenSignals() {
|
||||
var queue = make(chan os.Signal, 8)
|
||||
signal.Notify(queue, syscall.SIGTERM, syscall.SIGINT, syscall.SIGKILL, syscall.SIGQUIT)
|
||||
goman.New(func() {
|
||||
for range queue {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
utils.Exit()
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// 监听本地sock
|
||||
func (this *DNSNode) listenSock() error {
|
||||
// 检查是否在运行
|
||||
if this.sock.IsListening() {
|
||||
reply, err := this.sock.Send(&gosock.Command{Code: "pid"})
|
||||
if err == nil {
|
||||
return errors.New("error: the process is already running, pid: " + maps.NewMap(reply.Params).GetString("pid"))
|
||||
} else {
|
||||
return errors.New("error: the process is already running")
|
||||
}
|
||||
}
|
||||
|
||||
// 启动监听
|
||||
go func() {
|
||||
this.sock.OnCommand(func(cmd *gosock.Command) {
|
||||
switch cmd.Code {
|
||||
case "pid":
|
||||
_ = cmd.Reply(&gosock.Command{
|
||||
Code: "pid",
|
||||
Params: map[string]interface{}{
|
||||
"pid": os.Getpid(),
|
||||
},
|
||||
})
|
||||
case "info":
|
||||
exePath, _ := os.Executable()
|
||||
_ = cmd.Reply(&gosock.Command{
|
||||
Code: "info",
|
||||
Params: map[string]interface{}{
|
||||
"pid": os.Getpid(),
|
||||
"version": teaconst.Version,
|
||||
"path": exePath,
|
||||
},
|
||||
})
|
||||
case "stop":
|
||||
_ = cmd.ReplyOk()
|
||||
|
||||
// 退出主进程
|
||||
events.Notify(events.EventQuit)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
os.Exit(0)
|
||||
case "gc":
|
||||
runtime.GC()
|
||||
debug.FreeOSMemory()
|
||||
_ = cmd.ReplyOk()
|
||||
}
|
||||
})
|
||||
|
||||
err := this.sock.Listen()
|
||||
if err != nil {
|
||||
logs.Println("NODE", err.Error())
|
||||
}
|
||||
}()
|
||||
|
||||
events.On(events.EventQuit, func() {
|
||||
logs.Println("[DNS_NODE]", "quit unix sock")
|
||||
_ = this.sock.Close()
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 启动
|
||||
func (this *DNSNode) start() {
|
||||
client, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
remotelogs.Error("DNS_NODE", err.Error())
|
||||
return
|
||||
}
|
||||
this.RPC = client
|
||||
tryTimes := 0
|
||||
var configJSON []byte
|
||||
for {
|
||||
resp, err := client.NSNodeRPC.FindCurrentNSNodeConfig(client.Context(), &pb.FindCurrentNSNodeConfigRequest{})
|
||||
if err != nil {
|
||||
tryTimes++
|
||||
if tryTimes%10 == 0 {
|
||||
remotelogs.Error("NODE", "read config from API failed: "+err.Error())
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
// 不做长时间的无意义的重试
|
||||
if tryTimes > 1000 {
|
||||
remotelogs.Error("NODE", "load failed: unable to read config from API")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
configJSON = resp.NsNodeJSON
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(configJSON) == 0 {
|
||||
remotelogs.Error("NODE", "can not find node config")
|
||||
return
|
||||
}
|
||||
var config = &dnsconfigs.NSNodeConfig{}
|
||||
err = json.Unmarshal(configJSON, config)
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "decode config failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
err = config.Init(context.TODO())
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "init config failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sharedNodeConfig = config
|
||||
|
||||
configs.SharedNodeConfig = config
|
||||
events.Notify(events.EventReload)
|
||||
|
||||
sharedNodeConfigManager.reload(config)
|
||||
|
||||
apiConfig, _ := configs.SharedAPIConfig()
|
||||
if apiConfig != nil {
|
||||
apiConfig.NumberId = config.Id
|
||||
}
|
||||
|
||||
var db = dbs.NewDB(Tea.Root + "/data/data-" + types.String(config.Id) + "-" + config.NodeId + "-v0.1.0.db")
|
||||
err = db.Init()
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "init database failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
go sharedNodeConfigManager.Start()
|
||||
|
||||
sharedDomainManager = NewDomainManager(db, config.ClusterId)
|
||||
go sharedDomainManager.Start()
|
||||
|
||||
sharedRecordManager = NewRecordManager(db)
|
||||
go sharedRecordManager.Start()
|
||||
|
||||
sharedRouteManager = NewRouteManager(db)
|
||||
go sharedRouteManager.Start()
|
||||
|
||||
sharedKeyManager = NewKeyManager(db)
|
||||
go sharedKeyManager.Start()
|
||||
|
||||
agents.SharedManager = agents.NewManager(db)
|
||||
go agents.SharedManager.Start()
|
||||
|
||||
// 发送通知,这里发送通知需要在DomainManager、RecordeManager等加载完成之后
|
||||
time.Sleep(1 * time.Second)
|
||||
events.Notify(events.EventLoaded)
|
||||
|
||||
// 启动循环
|
||||
go this.loop()
|
||||
}
|
||||
|
||||
// 更新配置Loop
|
||||
func (this *DNSNode) loop() {
|
||||
var ticker = time.NewTicker(60 * time.Second)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
case <-nodeTaskNotify:
|
||||
}
|
||||
|
||||
err := this.processTasks()
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("DNS_NODE", "process tasks: "+err.Error())
|
||||
} else {
|
||||
remotelogs.Error("DNS_NODE", "process tasks: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理任务
|
||||
func (this *DNSNode) processTasks() error {
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 所有的任务
|
||||
tasksResp, err := rpcClient.NodeTaskRPC.FindNodeTasks(rpcClient.Context(), &pb.FindNodeTasksRequest{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, task := range tasksResp.NodeTasks {
|
||||
switch task.Type {
|
||||
case "nsConfigChanged":
|
||||
sharedNodeConfigManager.NotifyChange()
|
||||
case "nsDomainChanged":
|
||||
sharedDomainManager.NotifyUpdate()
|
||||
case "nsRecordChanged":
|
||||
sharedRecordManager.NotifyUpdate()
|
||||
case "nsRouteChanged":
|
||||
sharedRouteManager.NotifyUpdate()
|
||||
case "nsKeyChanged":
|
||||
sharedKeyManager.NotifyUpdate()
|
||||
case "nsDDoSProtectionChanged":
|
||||
err := this.updateDDoS(rpcClient)
|
||||
if err != nil {
|
||||
remotelogs.Error("DNS_NODE", "apply DDoS config failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
_, err = rpcClient.NodeTaskRPC.ReportNodeTaskDone(rpcClient.Context(), &pb.ReportNodeTaskDoneRequest{
|
||||
NodeTaskId: task.Id,
|
||||
IsOk: true,
|
||||
Error: "",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *DNSNode) updateDDoS(rpcClient *rpc.RPCClient) error {
|
||||
resp, err := rpcClient.NSNodeRPC.FindNSNodeDDoSProtection(rpcClient.Context(), &pb.FindNSNodeDDoSProtectionRequest{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(resp.DdosProtectionJSON) == 0 {
|
||||
if sharedNodeConfig != nil {
|
||||
sharedNodeConfig.DDoSProtection = nil
|
||||
}
|
||||
} else {
|
||||
var ddosProtectionConfig = &ddosconfigs.ProtectionConfig{}
|
||||
err = json.Unmarshal(resp.DdosProtectionJSON, ddosProtectionConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decode DDoS protection config failed: %w", err)
|
||||
}
|
||||
|
||||
if sharedNodeConfig != nil {
|
||||
sharedNodeConfig.DDoSProtection = ddosProtectionConfig
|
||||
}
|
||||
|
||||
err = firewalls.SharedDDoSProtectionManager.Apply(ddosProtectionConfig)
|
||||
if err != nil {
|
||||
// 不阻塞
|
||||
remotelogs.Error("NODE", "apply DDoS protection failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
306
EdgeDNS/internal/nodes/dns_node_test.go
Normal file
306
EdgeDNS/internal/nodes/dns_node_test.go
Normal file
@@ -0,0 +1,306 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build plus
|
||||
|
||||
package nodes_test
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"github.com/miekg/dns"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDNS_Query_A_UDP(t *testing.T) {
|
||||
type query struct {
|
||||
Domain string
|
||||
Type uint16
|
||||
}
|
||||
|
||||
c := new(dns.Client)
|
||||
|
||||
for _, query := range []query{
|
||||
{"hello.world.teaos.cn", dns.TypeA},
|
||||
{"cdn.teaos.cn", dns.TypeA},
|
||||
{"hello.teaos.cn", dns.TypeCNAME},
|
||||
{"hello.teaos.cn", dns.TypeA},
|
||||
{"edgecdn.teaos.cn", dns.TypeA},
|
||||
} {
|
||||
var m = new(dns.Msg)
|
||||
m.SetQuestion(query.Domain+".", query.Type)
|
||||
r, _, err := c.Exchange(m, "127.0.0.1:53")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to exchange: %v", err)
|
||||
}
|
||||
for _, a := range r.Answer {
|
||||
t.Log(a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNS_Query_A_Many(t *testing.T) {
|
||||
type queryDef struct {
|
||||
Domain string
|
||||
Type uint16
|
||||
}
|
||||
|
||||
var c = new(dns.Client)
|
||||
|
||||
for i := 0; i < 10000; i++ {
|
||||
for _, query := range []queryDef{
|
||||
{"hello.goedge.cn", dns.TypeA},
|
||||
} {
|
||||
var m = new(dns.Msg)
|
||||
m.SetQuestion(query.Domain+".", query.Type)
|
||||
r, _, err := c.Exchange(m, "127.0.0.1:53")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to exchange: %v", err)
|
||||
}
|
||||
_ = r
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNS_Query_CNAME(t *testing.T) {
|
||||
type query struct {
|
||||
Domain string
|
||||
Type uint16
|
||||
}
|
||||
|
||||
c := new(dns.Client)
|
||||
|
||||
for _, query := range []query{
|
||||
{"hello.teaos.cn", dns.TypeA},
|
||||
} {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(query.Domain+".", query.Type)
|
||||
r, _, err := c.Exchange(m, "127.0.0.1:53")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to exchange: %v", err)
|
||||
}
|
||||
for _, a := range r.Answer {
|
||||
t.Log(a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNS_Query_A_TCP(t *testing.T) {
|
||||
type query struct {
|
||||
Domain string
|
||||
Type uint16
|
||||
}
|
||||
|
||||
c := new(dns.Client)
|
||||
|
||||
for _, query := range []query{
|
||||
{"hello.world.teaos.cn", dns.TypeA},
|
||||
{"cdn.teaos.cn", dns.TypeA},
|
||||
{"hello.teaos.cn", dns.TypeCNAME},
|
||||
{"goedge.cn", dns.TypeA},
|
||||
{"www.goedge.cn", dns.TypeA},
|
||||
} {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(query.Domain+".", query.Type)
|
||||
//r, _, err := c.Exchange(m, "127.0.0.1:54")
|
||||
conn, err := dns.Dial("tcp", "127.0.0.1:53")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r, _, err := c.ExchangeWithConn(m, conn)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to exchange: %v", err)
|
||||
}
|
||||
for _, a := range r.Answer {
|
||||
t.Log(a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNS_Query_A_TLS(t *testing.T) {
|
||||
type query struct {
|
||||
Domain string
|
||||
Type uint16
|
||||
}
|
||||
|
||||
c := new(dns.Client)
|
||||
|
||||
for _, query := range []query{
|
||||
{"hello.world.teaos.cn", dns.TypeA},
|
||||
{"cdn.teaos.cn", dns.TypeA},
|
||||
{"hello.teaos.cn", dns.TypeCNAME},
|
||||
} {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(query.Domain+".", query.Type)
|
||||
//r, _, err := c.Exchange(m, "127.0.0.1:54")
|
||||
conn, err := dns.DialWithTLS("tcp", "127.0.0.1:853", &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r, _, err := c.ExchangeWithConn(m, conn)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to exchange: %v", err)
|
||||
}
|
||||
for _, a := range r.Answer {
|
||||
t.Log(a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNS_Query_Internet(t *testing.T) {
|
||||
type query struct {
|
||||
Domain string
|
||||
Type uint16
|
||||
}
|
||||
|
||||
c := new(dns.Client)
|
||||
for _, query := range []query{
|
||||
{"1.goedge.cn", dns.TypeA},
|
||||
} {
|
||||
m := new(dns.Msg)
|
||||
//m.RecursionDesired = true
|
||||
m.SetQuestion(query.Domain+".", query.Type)
|
||||
|
||||
conn, err := dns.Dial("udp", "ns1.teaos.cn:53")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r, _, err := c.ExchangeWithConn(m, conn)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to exchange: %v", err)
|
||||
}
|
||||
for _, a := range r.Answer {
|
||||
t.Log(a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNS_Query_TSIG(t *testing.T) {
|
||||
type query struct {
|
||||
Domain string
|
||||
Type uint16
|
||||
}
|
||||
|
||||
c := new(dns.Client)
|
||||
c.TsigSecret = map[string]string{"teaos.": "NzhhZDExMzM5NWMwN2Q5OWM5YTFhMzgxZWNkZGMwMDA2ODUzODdiYTM2ODA5N2I2YjYwZWRlNmNlNjlhMzdmM2JmNjcxZmQ4NzVjMjI1Y2QwOTQ2Njk5OWY0MzRkMTJkNTczNjFlZDgwYmQxZWZjZDM4ZjAxNDNmM2Y2NTU1YjE="}
|
||||
|
||||
for _, query := range []query{
|
||||
{"hello.cdn.teaos.cn", dns.TypeA},
|
||||
{"cdn.teaos.cn", dns.TypeA},
|
||||
{"hello.teaos.cn", dns.TypeCNAME},
|
||||
{"hello.teaos.cn", dns.TypeA},
|
||||
{"edgecdn.teaos.cn", dns.TypeA},
|
||||
} {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(query.Domain+".", query.Type)
|
||||
m.SetTsig("teaos.", dns.HmacSHA512, 300, time.Now().Unix())
|
||||
r, _, err := c.Exchange(m, "127.0.0.1:53")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to exchange: %v", err)
|
||||
}
|
||||
for _, a := range r.Answer {
|
||||
t.Log(a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNS_Query_A_Route(t *testing.T) {
|
||||
type query struct {
|
||||
Domain string
|
||||
Type uint16
|
||||
}
|
||||
|
||||
c := new(dns.Client)
|
||||
|
||||
for _, query := range []query{
|
||||
{"route.com", dns.TypeA},
|
||||
} {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(query.Domain+".", query.Type)
|
||||
r, _, err := c.Exchange(m, "127.0.0.1:53")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to exchange: %v", err)
|
||||
}
|
||||
for _, a := range r.Answer {
|
||||
t.Log(a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNS_Query_Recursion(t *testing.T) {
|
||||
type query struct {
|
||||
Domain string
|
||||
Type uint16
|
||||
}
|
||||
|
||||
c := new(dns.Client)
|
||||
|
||||
for _, query := range []query{
|
||||
{"example.org", dns.TypeA},
|
||||
} {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(query.Domain+".", query.Type)
|
||||
r, _, err := c.Exchange(m, "127.0.0.1:53")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to exchange: %v", err)
|
||||
}
|
||||
for _, a := range r.Answer {
|
||||
t.Log(a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNS_Flood(t *testing.T) {
|
||||
type query struct {
|
||||
Domain string
|
||||
Type uint16
|
||||
}
|
||||
|
||||
var c = new(dns.Client)
|
||||
|
||||
for i := 0; i < 1_000_000; i++ {
|
||||
for _, query := range []query{
|
||||
{"hello.world.teaos.cn", dns.TypeA},
|
||||
{"cdn.teaos.cn", dns.TypeA},
|
||||
{"hello.teaos.cn", dns.TypeCNAME},
|
||||
{"hello.teaos.cn", dns.TypeA},
|
||||
{"edgecdn.teaos.cn", dns.TypeA},
|
||||
} {
|
||||
var m = new(dns.Msg)
|
||||
m.SetQuestion(query.Domain+".", query.Type)
|
||||
r, _, err := c.Exchange(m, "192.168.2.60:58")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to exchange: %v", err)
|
||||
}
|
||||
_ = r
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDNSNode(b *testing.B) {
|
||||
var c = new(dns.Client)
|
||||
conn, err := c.Dial("192.168.2.60:58")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
type query struct {
|
||||
Domain string
|
||||
Type uint16
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, query := range []query{
|
||||
{"cdn.teaos.cn", dns.TypeA},
|
||||
} {
|
||||
var m = new(dns.Msg)
|
||||
m.SetQuestion(query.Domain+".", query.Type)
|
||||
_, _, err := c.ExchangeWithConn(m, conn)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
179
EdgeDNS/internal/nodes/http_writer.go
Normal file
179
EdgeDNS/internal/nodes/http_writer.go
Normal file
@@ -0,0 +1,179 @@
|
||||
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"github.com/miekg/dns"
|
||||
"net"
|
||||
"net/http"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type HTTPWriter struct {
|
||||
rawConn net.Conn
|
||||
rawWriter http.ResponseWriter
|
||||
contentType string
|
||||
}
|
||||
|
||||
func NewHTTPWriter(rawWriter http.ResponseWriter, rawConn net.Conn, contentType string) *HTTPWriter {
|
||||
return &HTTPWriter{
|
||||
rawWriter: rawWriter,
|
||||
rawConn: rawConn,
|
||||
contentType: contentType,
|
||||
}
|
||||
}
|
||||
|
||||
func (this *HTTPWriter) LocalAddr() net.Addr {
|
||||
return this.rawConn.LocalAddr()
|
||||
}
|
||||
|
||||
func (this *HTTPWriter) RemoteAddr() net.Addr {
|
||||
return this.rawConn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (this *HTTPWriter) WriteMsg(msg *dns.Msg) error {
|
||||
if msg == nil {
|
||||
return errors.New("'msg' should not be nil")
|
||||
}
|
||||
|
||||
msgData, err := this.encodeMsg(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.rawWriter.Header().Set("Content-Length", types.String(len(msgData)))
|
||||
this.rawWriter.Header().Set("Content-Type", this.contentType)
|
||||
|
||||
// cache-control
|
||||
if len(msg.Answer) > 0 {
|
||||
var minTtl uint32
|
||||
for _, answer := range msg.Answer {
|
||||
var header = answer.Header()
|
||||
if header != nil && header.Ttl > 0 && (minTtl == 0 || header.Ttl < minTtl) {
|
||||
minTtl = header.Ttl
|
||||
}
|
||||
}
|
||||
if minTtl > 0 {
|
||||
this.rawWriter.Header().Set("Cache-Control", "max-age="+types.String(minTtl))
|
||||
}
|
||||
}
|
||||
|
||||
this.rawWriter.WriteHeader(http.StatusOK)
|
||||
_, err = this.rawWriter.Write(msgData)
|
||||
return err
|
||||
}
|
||||
|
||||
func (this *HTTPWriter) Write(p []byte) (int, error) {
|
||||
this.rawWriter.Header().Set("Content-Length", types.String(len(p)))
|
||||
this.rawWriter.WriteHeader(http.StatusOK)
|
||||
return this.rawWriter.Write(p)
|
||||
}
|
||||
|
||||
func (this *HTTPWriter) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *HTTPWriter) TsigStatus() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *HTTPWriter) TsigTimersOnly(timersOnly bool) {
|
||||
}
|
||||
|
||||
func (this *HTTPWriter) Hijack() {
|
||||
hijacker, ok := this.rawWriter.(http.Hijacker)
|
||||
if ok {
|
||||
_, _, _ = hijacker.Hijack()
|
||||
}
|
||||
}
|
||||
|
||||
func (this *HTTPWriter) encodeMsg(msg *dns.Msg) ([]byte, error) {
|
||||
if this.contentType == "application/x-javascript" || this.contentType == "application/json" {
|
||||
var result = map[string]any{
|
||||
"Status": 0,
|
||||
"TC": msg.Truncated,
|
||||
"RD": msg.RecursionDesired,
|
||||
"RA": msg.RecursionAvailable,
|
||||
"AD": msg.AuthenticatedData,
|
||||
"CD": msg.CheckingDisabled,
|
||||
}
|
||||
|
||||
// questions
|
||||
var questionMaps = []map[string]any{}
|
||||
for _, question := range msg.Question {
|
||||
questionMaps = append(questionMaps, map[string]any{
|
||||
"name": question.Name,
|
||||
"type": question.Qtype,
|
||||
})
|
||||
}
|
||||
result["Question"] = questionMaps
|
||||
|
||||
// answers
|
||||
var answerMaps = []map[string]any{}
|
||||
for _, answer := range msg.Answer {
|
||||
var answerMap = map[string]any{
|
||||
"name": answer.Header().Name,
|
||||
"type": answer.Header().Rrtype,
|
||||
"TTL": answer.Header().Ttl,
|
||||
}
|
||||
|
||||
switch x := answer.(type) {
|
||||
case *dns.A:
|
||||
answerMap["data"] = x.A.String()
|
||||
case *dns.AAAA:
|
||||
answerMap["data"] = x.AAAA.String()
|
||||
case *dns.CNAME:
|
||||
answerMap["data"] = x.Target
|
||||
case *dns.TXT:
|
||||
answerMap["data"] = x.Txt
|
||||
case *dns.NS:
|
||||
answerMap["data"] = x.Ns
|
||||
case *dns.MX:
|
||||
answerMap["data"] = x.Mx
|
||||
answerMap["preference"] = x.Preference
|
||||
default:
|
||||
var answerValue = reflect.ValueOf(answer).Elem()
|
||||
var answerType = answerValue.Type()
|
||||
|
||||
var countFields = answerType.NumField()
|
||||
for i := 0; i < countFields; i++ {
|
||||
var fieldName = answerType.Field(i).Name
|
||||
var fieldValue = answerValue.FieldByName(fieldName)
|
||||
if !fieldValue.IsValid() {
|
||||
continue
|
||||
}
|
||||
var fieldInterface = fieldValue.Interface()
|
||||
if fieldInterface == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
_, ok := fieldInterface.(dns.RR_Header)
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if countFields == 2 {
|
||||
answerMap["data"] = fieldValue.Interface()
|
||||
} else {
|
||||
answerMap[fieldName] = fieldValue.Interface()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
answerMaps = append(answerMaps, answerMap)
|
||||
}
|
||||
result["Answer"] = answerMaps
|
||||
|
||||
if Tea.IsTesting() {
|
||||
return json.MarshalIndent(result, "", " ")
|
||||
} else {
|
||||
return json.Marshal(result)
|
||||
}
|
||||
} else {
|
||||
return msg.Pack()
|
||||
}
|
||||
}
|
||||
306
EdgeDNS/internal/nodes/listen_manager.go
Normal file
306
EdgeDNS/internal/nodes/listen_manager.go
Normal file
@@ -0,0 +1,306 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/events"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/firewalls"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/utils"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var sharedListenManager *ListenManager = nil
|
||||
|
||||
func init() {
|
||||
if !teaconst.IsMain {
|
||||
return
|
||||
}
|
||||
|
||||
sharedListenManager = NewListenManager()
|
||||
|
||||
events.On(events.EventReload, func() {
|
||||
sharedListenManager.Update(sharedNodeConfig)
|
||||
})
|
||||
events.On(events.EventQuit, func() {
|
||||
_ = sharedListenManager.ShutdownAll()
|
||||
})
|
||||
}
|
||||
|
||||
// ListenManager 端口监听管理器
|
||||
type ListenManager struct {
|
||||
locker sync.Mutex
|
||||
serverMap map[string]*Server // addr => *Server
|
||||
|
||||
firewalld *firewalls.Firewalld
|
||||
lastPortStrings string
|
||||
lastTCPPortRanges [][2]int
|
||||
lastUDPPortRanges [][2]int
|
||||
}
|
||||
|
||||
// NewListenManager 获取新对象
|
||||
func NewListenManager() *ListenManager {
|
||||
return &ListenManager{
|
||||
serverMap: map[string]*Server{},
|
||||
firewalld: firewalls.NewFirewalld(),
|
||||
}
|
||||
}
|
||||
|
||||
// Update 修改配置
|
||||
func (this *ListenManager) Update(config *dnsconfigs.NSNodeConfig) {
|
||||
// 构造服务配置
|
||||
var serverConfigs = []*ServerConfig{}
|
||||
var serverAddrs = []string{}
|
||||
|
||||
// 如果没有配置,则配置一些默认的端口
|
||||
if config.TCP == nil && config.TLS == nil && config.UDP == nil {
|
||||
config.TCP = &serverconfigs.TCPProtocolConfig{}
|
||||
config.TCP.IsOn = true
|
||||
config.TCP.Listen = []*serverconfigs.NetworkAddressConfig{
|
||||
{
|
||||
Protocol: serverconfigs.ProtocolTCP,
|
||||
MinPort: 53,
|
||||
MaxPort: 53,
|
||||
},
|
||||
}
|
||||
|
||||
config.UDP = &serverconfigs.UDPProtocolConfig{}
|
||||
config.UDP.IsOn = true
|
||||
config.UDP.Listen = []*serverconfigs.NetworkAddressConfig{
|
||||
{
|
||||
Protocol: serverconfigs.ProtocolUDP,
|
||||
MinPort: 53,
|
||||
MaxPort: 53,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// 读取配置
|
||||
if config.TCP != nil && config.TCP.IsOn {
|
||||
for _, listen := range config.TCP.Listen {
|
||||
for port := listen.MinPort; port <= listen.MaxPort; port++ {
|
||||
serverConfigs = append(serverConfigs, &ServerConfig{
|
||||
Protocol: listen.Protocol,
|
||||
Host: listen.Host,
|
||||
Port: port,
|
||||
SSLPolicy: nil,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
if config.TLS != nil && config.TLS.IsOn {
|
||||
for _, listen := range config.TLS.Listen {
|
||||
if config.TLS.SSLPolicy == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for port := listen.MinPort; port <= listen.MaxPort; port++ {
|
||||
serverConfigs = append(serverConfigs, &ServerConfig{
|
||||
Protocol: listen.Protocol,
|
||||
Host: listen.Host,
|
||||
Port: port,
|
||||
SSLPolicy: config.TLS.SSLPolicy,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
if config.DoH != nil && config.DoH.IsOn {
|
||||
for _, listen := range config.DoH.Listen {
|
||||
if config.DoH.SSLPolicy == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for port := listen.MinPort; port <= listen.MaxPort; port++ {
|
||||
serverConfigs = append(serverConfigs, &ServerConfig{
|
||||
Protocol: listen.Protocol,
|
||||
Host: listen.Host,
|
||||
Port: port,
|
||||
SSLPolicy: config.DoH.SSLPolicy,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
if config.UDP != nil && config.UDP.IsOn {
|
||||
for _, listen := range config.UDP.Listen {
|
||||
for port := listen.MinPort; port <= listen.MaxPort; port++ {
|
||||
serverConfigs = append(serverConfigs, &ServerConfig{
|
||||
Protocol: listen.Protocol,
|
||||
Host: listen.Host,
|
||||
Port: port,
|
||||
SSLPolicy: nil,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 启动新的
|
||||
var addrMap = map[string]bool{} // addr => bool
|
||||
for _, serverConfig := range serverConfigs {
|
||||
var fullAddr = serverConfig.FullAddr()
|
||||
serverAddrs = append(serverAddrs, fullAddr)
|
||||
addrMap[fullAddr] = true
|
||||
this.locker.Lock()
|
||||
server, ok := this.serverMap[fullAddr]
|
||||
this.locker.Unlock()
|
||||
if !ok {
|
||||
// 启动新的
|
||||
var err error
|
||||
server, err = NewServer(serverConfig)
|
||||
if err != nil {
|
||||
remotelogs.Error("LISTEN_MANAGER", "create listener '"+fullAddr+"' failed: "+err.Error())
|
||||
continue
|
||||
}
|
||||
this.locker.Lock()
|
||||
this.serverMap[fullAddr] = server
|
||||
this.locker.Unlock()
|
||||
go func() {
|
||||
remotelogs.Println("LISTEN_MANAGER", "listen '"+fullAddr+"'")
|
||||
err = server.ListenAndServe()
|
||||
if err != nil {
|
||||
this.locker.Lock()
|
||||
delete(this.serverMap, fullAddr)
|
||||
this.locker.Unlock()
|
||||
remotelogs.Error("LISTEN_MANAGER", "listen '"+fullAddr+"' failed: "+err.Error())
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
// 更新配置
|
||||
server.Reload(serverConfig)
|
||||
}
|
||||
}
|
||||
|
||||
// 停止老的
|
||||
this.locker.Lock()
|
||||
for fullAddr, server := range this.serverMap {
|
||||
_, ok := addrMap[fullAddr]
|
||||
if !ok {
|
||||
delete(this.serverMap, fullAddr)
|
||||
|
||||
remotelogs.Println("LISTEN_MANAGER", "shutdown "+fullAddr)
|
||||
err := server.Shutdown()
|
||||
if err != nil {
|
||||
remotelogs.Error("LISTEN_MANAGER", "shutdown listener '"+fullAddr+"' failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
this.locker.Unlock()
|
||||
|
||||
// 添加端口到firewalld
|
||||
go func() {
|
||||
this.addToFirewalld(serverAddrs)
|
||||
}()
|
||||
}
|
||||
|
||||
// ShutdownAll 关闭所有的监听端口
|
||||
func (this *ListenManager) ShutdownAll() error {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
var lastErr error
|
||||
for _, server := range this.serverMap {
|
||||
err := server.Shutdown()
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
|
||||
func (this *ListenManager) addToFirewalld(serverAddrs []string) {
|
||||
if runtime.GOOS != "linux" {
|
||||
return
|
||||
}
|
||||
|
||||
if this.firewalld == nil || !this.firewalld.IsReady() {
|
||||
return
|
||||
}
|
||||
|
||||
// 组合端口号
|
||||
var portStrings = []string{}
|
||||
var udpPorts = []int{}
|
||||
var tcpPorts = []int{}
|
||||
for _, addr := range serverAddrs {
|
||||
var protocol = "tcp"
|
||||
if strings.HasPrefix(addr, "udp") {
|
||||
protocol = "udp"
|
||||
}
|
||||
|
||||
var lastIndex = strings.LastIndex(addr, ":")
|
||||
if lastIndex > 0 {
|
||||
var portString = addr[lastIndex+1:]
|
||||
portStrings = append(portStrings, portString+"/"+protocol)
|
||||
|
||||
switch protocol {
|
||||
case "tcp":
|
||||
tcpPorts = append(tcpPorts, types.Int(portString))
|
||||
case "udp":
|
||||
udpPorts = append(udpPorts, types.Int(portString))
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(portStrings) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否有变化
|
||||
sort.Strings(portStrings)
|
||||
var newPortStrings = strings.Join(portStrings, ",")
|
||||
if newPortStrings == this.lastPortStrings {
|
||||
return
|
||||
}
|
||||
this.lastPortStrings = newPortStrings
|
||||
|
||||
remotelogs.Println("FIREWALLD", "opening ports automatically ...")
|
||||
defer func() {
|
||||
remotelogs.Println("FIREWALLD", "open ports successfully")
|
||||
}()
|
||||
|
||||
// 合并端口
|
||||
var tcpPortRanges = utils.MergePorts(tcpPorts)
|
||||
var udpPortRanges = utils.MergePorts(udpPorts)
|
||||
|
||||
defer func() {
|
||||
this.lastTCPPortRanges = tcpPortRanges
|
||||
this.lastUDPPortRanges = udpPortRanges
|
||||
}()
|
||||
|
||||
// 删除老的不存在的端口
|
||||
var tcpPortRangesMap = map[string]bool{}
|
||||
var udpPortRangesMap = map[string]bool{}
|
||||
for _, portRange := range tcpPortRanges {
|
||||
tcpPortRangesMap[this.firewalld.PortRangeString(portRange, "tcp")] = true
|
||||
}
|
||||
for _, portRange := range udpPortRanges {
|
||||
udpPortRangesMap[this.firewalld.PortRangeString(portRange, "udp")] = true
|
||||
}
|
||||
|
||||
for _, portRange := range this.lastTCPPortRanges {
|
||||
var s = this.firewalld.PortRangeString(portRange, "tcp")
|
||||
_, ok := tcpPortRangesMap[s]
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
remotelogs.Println("FIREWALLD", "remove port '"+s+"'")
|
||||
_ = this.firewalld.RemovePortRangePermanently(portRange, "tcp")
|
||||
}
|
||||
for _, portRange := range this.lastUDPPortRanges {
|
||||
var s = this.firewalld.PortRangeString(portRange, "udp")
|
||||
_, ok := udpPortRangesMap[s]
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
remotelogs.Println("FIREWALLD", "remove port '"+s+"'")
|
||||
_ = this.firewalld.RemovePortRangePermanently(portRange, "udp")
|
||||
}
|
||||
|
||||
// 添加新的
|
||||
_ = this.firewalld.AllowPortRangesPermanently(tcpPortRanges, "tcp")
|
||||
_ = this.firewalld.AllowPortRangesPermanently(udpPortRanges, "udp")
|
||||
}
|
||||
319
EdgeDNS/internal/nodes/manager_domain.go
Normal file
319
EdgeDNS/internal/nodes/manager_domain.go
Normal file
@@ -0,0 +1,319 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/dbs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/models"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DomainManager 域名管理器
|
||||
type DomainManager struct {
|
||||
domainMap map[int64]*models.NSDomain // domainId => domain
|
||||
namesMap map[string]map[int64]*models.NSDomain // domain name => { domainId => domain }
|
||||
clusterId int64
|
||||
|
||||
db *dbs.DB
|
||||
version int64
|
||||
locker *sync.RWMutex
|
||||
|
||||
notifier chan bool
|
||||
}
|
||||
|
||||
// NewDomainManager 获取域名管理器对象
|
||||
func NewDomainManager(db *dbs.DB, clusterId int64) *DomainManager {
|
||||
return &DomainManager{
|
||||
db: db,
|
||||
domainMap: map[int64]*models.NSDomain{},
|
||||
namesMap: map[string]map[int64]*models.NSDomain{},
|
||||
clusterId: clusterId,
|
||||
notifier: make(chan bool, 8),
|
||||
locker: &sync.RWMutex{},
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动自动任务
|
||||
func (this *DomainManager) Start() {
|
||||
remotelogs.Println("DOMAIN_MANAGER", "starting ...")
|
||||
|
||||
// 从本地数据库中加载数据
|
||||
err := this.Load()
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("DOMAIN_MANAGER", "load failed: "+err.Error())
|
||||
} else {
|
||||
remotelogs.Error("DOMAIN_MANAGER", "load failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// 初始化运行
|
||||
err = this.LoopAll()
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("DOMAIN_MANAGER", "loop failed: "+err.Error())
|
||||
} else {
|
||||
remotelogs.Error("DOMAIN_MANAGER", "loop failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// 更新
|
||||
var ticker = time.NewTicker(20 * time.Second)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
case <-this.notifier:
|
||||
}
|
||||
|
||||
err = this.LoopAll()
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("DOMAIN_MANAGER", "loop failed: "+err.Error())
|
||||
} else {
|
||||
remotelogs.Error("DOMAIN_MANAGER", "loop failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (this *DomainManager) LoopAll() error {
|
||||
for {
|
||||
hasNext, err := this.Loop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !hasNext {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load 从数据库中加载数据
|
||||
func (this *DomainManager) Load() error {
|
||||
var offset = 0
|
||||
var size = 10000
|
||||
for {
|
||||
domains, err := this.db.ListDomains(this.clusterId, offset, size)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(domains) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
this.locker.Lock()
|
||||
for _, domain := range domains {
|
||||
this.domainMap[domain.Id] = domain
|
||||
|
||||
nameMap, ok := this.namesMap[domain.Name]
|
||||
if ok {
|
||||
nameMap[domain.Id] = domain
|
||||
} else {
|
||||
this.namesMap[domain.Name] = map[int64]*models.NSDomain{
|
||||
domain.Id: domain,
|
||||
}
|
||||
}
|
||||
|
||||
if domain.Version > this.version {
|
||||
this.version = domain.Version
|
||||
}
|
||||
}
|
||||
this.locker.Unlock()
|
||||
|
||||
offset += size
|
||||
}
|
||||
|
||||
if this.version > 0 {
|
||||
this.version++
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Loop 单次循环任务
|
||||
func (this *DomainManager) Loop() (hasNext bool, err error) {
|
||||
client, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
resp, err := client.NSDomainRPC.ListNSDomainsAfterVersion(client.Context(), &pb.ListNSDomainsAfterVersionRequest{
|
||||
Version: this.version,
|
||||
Size: 20000,
|
||||
})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
var domains = resp.NsDomains
|
||||
if len(domains) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
for _, domain := range domains {
|
||||
this.processDomain(domain)
|
||||
if domain.Version > this.version {
|
||||
this.version = domain.Version
|
||||
}
|
||||
}
|
||||
this.version++
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// FindDomain 根据名称查找域名
|
||||
func (this *DomainManager) FindDomain(name string) (domain *models.NSDomain, ok bool) {
|
||||
this.locker.RLock()
|
||||
defer this.locker.RUnlock()
|
||||
|
||||
nameMap, ok := this.namesMap[name]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
for _, domain2 := range nameMap {
|
||||
return domain2, true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// FindDomainWithId 根据域名ID查询域名
|
||||
func (this *DomainManager) FindDomainWithId(domainId int64) (domain *models.NSDomain) {
|
||||
this.locker.RLock()
|
||||
defer this.locker.RUnlock()
|
||||
|
||||
return this.domainMap[domainId]
|
||||
}
|
||||
|
||||
// NotifyUpdate 通知更新
|
||||
func (this *DomainManager) NotifyUpdate() {
|
||||
select {
|
||||
case this.notifier <- true:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// SplitDomain 分解域名
|
||||
func (this *DomainManager) SplitDomain(fullDomainName string) (rootDomain *models.NSDomain, recordName string) {
|
||||
if len(fullDomainName) == 0 {
|
||||
return
|
||||
}
|
||||
fullDomainName = strings.TrimSuffix(fullDomainName, ".") // 去除尾部的点(.)
|
||||
fullDomainName = strings.ToLower(fullDomainName) // 转换为小写
|
||||
var domainName = fullDomainName
|
||||
var domain, ok = this.FindDomain(domainName)
|
||||
if !ok {
|
||||
for {
|
||||
var index = strings.Index(domainName, ".")
|
||||
if index < 0 {
|
||||
break
|
||||
}
|
||||
domainName = domainName[index+1:]
|
||||
domain, ok = this.FindDomain(domainName)
|
||||
if ok {
|
||||
recordName = fullDomainName[:len(fullDomainName)-len(domainName)-1]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return domain, recordName
|
||||
}
|
||||
|
||||
// 处理域名
|
||||
func (this *DomainManager) processDomain(domain *pb.NSDomain) {
|
||||
if !domain.IsOn || domain.IsDeleted || domain.Status != dnsconfigs.NSDomainStatusVerified {
|
||||
this.locker.Lock()
|
||||
delete(this.domainMap, domain.Id)
|
||||
|
||||
nameMap, ok := this.namesMap[domain.Name]
|
||||
if ok {
|
||||
delete(nameMap, domain.Id)
|
||||
if len(nameMap) == 0 {
|
||||
delete(this.namesMap, domain.Name)
|
||||
}
|
||||
}
|
||||
|
||||
this.locker.Unlock()
|
||||
|
||||
// 从数据库中删除
|
||||
if this.db != nil {
|
||||
err := this.db.DeleteDomain(domain.Id)
|
||||
if err != nil {
|
||||
remotelogs.Error("DOMAIN_MANAGER", "delete domain from db failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// 存入数据库
|
||||
if this.db != nil {
|
||||
exists, err := this.db.ExistsDomain(domain.Id)
|
||||
if err != nil {
|
||||
remotelogs.Error("DOMAIN_MANAGER", "query failed: "+err.Error())
|
||||
} else {
|
||||
if exists {
|
||||
err = this.db.UpdateDomain(domain.Id, domain.NsCluster.Id, domain.UserId, domain.Name, domain.TsigJSON, domain.Version)
|
||||
if err != nil {
|
||||
remotelogs.Error("DOMAIN_MANAGER", "update failed: "+err.Error())
|
||||
}
|
||||
} else {
|
||||
err = this.db.InsertDomain(domain.Id, domain.NsCluster.Id, domain.UserId, domain.Name, domain.TsigJSON, domain.Version)
|
||||
if err != nil {
|
||||
remotelogs.Error("DOMAIN_MANAGER", "insert failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 同集群的才需要加载
|
||||
if this.clusterId == domain.NsCluster.Id {
|
||||
this.locker.Lock()
|
||||
var tsigConfig = &dnsconfigs.NSTSIGConfig{}
|
||||
if len(domain.TsigJSON) > 0 {
|
||||
err := json.Unmarshal(domain.TsigJSON, tsigConfig)
|
||||
if err != nil {
|
||||
remotelogs.Error("DOMAIN_MANAGER", "decode TSIG json failed: "+err.Error()+", domain: "+domain.Name+", domainId: "+types.String(domain.Id)+", JSON: "+string(domain.TsigJSON))
|
||||
}
|
||||
}
|
||||
|
||||
var nsDomain = &models.NSDomain{
|
||||
Id: domain.Id,
|
||||
ClusterId: domain.NsCluster.Id,
|
||||
UserId: domain.UserId,
|
||||
Name: domain.Name,
|
||||
TSIG: tsigConfig,
|
||||
Version: domain.Version,
|
||||
}
|
||||
this.domainMap[domain.Id] = nsDomain
|
||||
|
||||
nameMap, ok := this.namesMap[domain.Name]
|
||||
if ok {
|
||||
nameMap[nsDomain.Id] = nsDomain
|
||||
} else {
|
||||
this.namesMap[domain.Name] = map[int64]*models.NSDomain{
|
||||
nsDomain.Id: nsDomain,
|
||||
}
|
||||
}
|
||||
this.locker.Unlock()
|
||||
} else {
|
||||
// 不同集群的删除域名
|
||||
this.locker.Lock()
|
||||
delete(this.domainMap, domain.Id)
|
||||
|
||||
nameMap, ok := this.namesMap[domain.Name]
|
||||
if ok {
|
||||
delete(nameMap, domain.Id)
|
||||
if len(nameMap) == 0 {
|
||||
delete(this.namesMap, domain.Name)
|
||||
}
|
||||
}
|
||||
|
||||
this.locker.Unlock()
|
||||
}
|
||||
}
|
||||
63
EdgeDNS/internal/nodes/manager_domain_test.go
Normal file
63
EdgeDNS/internal/nodes/manager_domain_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/dbs"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDomainManager_Loop(t *testing.T) {
|
||||
var db = dbs.NewDB(Tea.Root + "/data/data.db")
|
||||
err := db.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var manager = NewDomainManager(db, 1)
|
||||
for i := 0; i < 10; i++ {
|
||||
_, err := manager.Loop()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
logs.PrintAsJSON(manager.domainMap, t)
|
||||
}
|
||||
|
||||
func TestDomainManager_Load(t *testing.T) {
|
||||
var db = dbs.NewDB(Tea.Root + "/data/data.db")
|
||||
err := db.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
manager := NewDomainManager(db, 2)
|
||||
err = manager.Load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
logs.PrintAsJSON(manager.domainMap, t)
|
||||
t.Log("version:", manager.version)
|
||||
}
|
||||
|
||||
func TestDomainManager_FindDomain(t *testing.T) {
|
||||
var db = dbs.NewDB(Tea.Root + "/data/data.db")
|
||||
err := db.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var manager = NewDomainManager(db, 2)
|
||||
err = manager.Load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, name := range []string{"hello.com", "teaos.cn"} {
|
||||
domain, ok := manager.FindDomain(name)
|
||||
t.Log(name, ok, domain)
|
||||
}
|
||||
}
|
||||
284
EdgeDNS/internal/nodes/manager_key.go
Normal file
284
EdgeDNS/internal/nodes/manager_key.go
Normal file
@@ -0,0 +1,284 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/dbs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/models"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// KeyManager 密钥管理器
|
||||
type KeyManager struct {
|
||||
domainKeyMap map[int64]*models.NSKeys // domainId => *NSKeys
|
||||
zoneKeyMap map[int64]*models.NSKeys // zoneId => *NSKeys
|
||||
|
||||
db *dbs.DB
|
||||
locker sync.RWMutex
|
||||
version int64
|
||||
|
||||
notifier chan bool
|
||||
}
|
||||
|
||||
// NewKeyManager 获取密钥管理器
|
||||
func NewKeyManager(db *dbs.DB) *KeyManager {
|
||||
return &KeyManager{
|
||||
domainKeyMap: map[int64]*models.NSKeys{},
|
||||
zoneKeyMap: map[int64]*models.NSKeys{},
|
||||
db: db,
|
||||
notifier: make(chan bool, 8),
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动自动任务
|
||||
func (this *KeyManager) Start() {
|
||||
remotelogs.Println("KEY_MANAGER", "starting ...")
|
||||
|
||||
// 从本地数据库中加载数据
|
||||
err := this.Load()
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("KEY_MANAGER", "load failed: "+err.Error())
|
||||
} else {
|
||||
remotelogs.Error("KEY_MANAGER", "load failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// 初始化运行
|
||||
err = this.LoopAll()
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("KEY_MANAGER", "loop failed: "+err.Error())
|
||||
} else {
|
||||
remotelogs.Error("KEY_MANAGER", "loop failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// 更新
|
||||
var ticker = time.NewTicker(1 * time.Minute)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
case <-this.notifier:
|
||||
}
|
||||
|
||||
err := this.LoopAll()
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("KEY_MANAGER", "loop failed: "+err.Error())
|
||||
} else {
|
||||
remotelogs.Error("KEY_MANAGER", "loop failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Load 从数据库中加载数据
|
||||
func (this *KeyManager) Load() error {
|
||||
var offset = 0
|
||||
var size = 10000
|
||||
for {
|
||||
keys, err := this.db.ListKeys(offset, size)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(keys) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
this.locker.Lock()
|
||||
for _, key := range keys {
|
||||
if key.ZoneId > 0 {
|
||||
keyList, ok := this.zoneKeyMap[key.ZoneId]
|
||||
if ok {
|
||||
keyList.Add(key)
|
||||
} else {
|
||||
keyList = models.NewNSKeys()
|
||||
keyList.Add(key)
|
||||
this.zoneKeyMap[key.ZoneId] = keyList
|
||||
}
|
||||
} else if key.DomainId > 0 {
|
||||
keyList, ok := this.domainKeyMap[key.DomainId]
|
||||
if ok {
|
||||
keyList.Add(key)
|
||||
} else {
|
||||
keyList = models.NewNSKeys()
|
||||
keyList.Add(key)
|
||||
this.domainKeyMap[key.DomainId] = keyList
|
||||
}
|
||||
}
|
||||
|
||||
if key.Version > this.version {
|
||||
this.version = key.Version
|
||||
}
|
||||
}
|
||||
this.locker.Unlock()
|
||||
|
||||
offset += size
|
||||
}
|
||||
|
||||
if this.version > 0 {
|
||||
this.version++
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *KeyManager) LoopAll() error {
|
||||
for {
|
||||
hasNext, err := this.Loop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !hasNext {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Loop 单次循环任务
|
||||
func (this *KeyManager) Loop() (hasNext bool, err error) {
|
||||
client, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
resp, err := client.NSKeyRPC.ListNSKeysAfterVersion(client.Context(), &pb.ListNSKeysAfterVersionRequest{
|
||||
Version: this.version,
|
||||
Size: 20000,
|
||||
})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
var keys = resp.NsKeys
|
||||
if len(keys) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
for _, key := range keys {
|
||||
this.processKey(key)
|
||||
|
||||
if key.Version > this.version {
|
||||
this.version = key.Version
|
||||
}
|
||||
}
|
||||
this.version++
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (this *KeyManager) FindKeysWithDomain(domainId int64) []*models.NSKey {
|
||||
this.locker.RLock()
|
||||
defer this.locker.RUnlock()
|
||||
|
||||
keys, ok := this.domainKeyMap[domainId]
|
||||
if ok {
|
||||
return keys.All()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NotifyUpdate 通知更新
|
||||
func (this *KeyManager) NotifyUpdate() {
|
||||
select {
|
||||
case this.notifier <- true:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// 处理Key
|
||||
func (this *KeyManager) processKey(key *pb.NSKey) {
|
||||
if key.NsDomain == nil && key.NsZone == nil {
|
||||
return
|
||||
}
|
||||
if !key.IsOn || key.IsDeleted {
|
||||
this.locker.Lock()
|
||||
if key.NsDomain != nil {
|
||||
list, ok := this.domainKeyMap[key.NsDomain.Id]
|
||||
if ok {
|
||||
list.Remove(key.Id)
|
||||
}
|
||||
}
|
||||
if key.NsZone != nil {
|
||||
list, ok := this.zoneKeyMap[key.NsZone.Id]
|
||||
if ok {
|
||||
list.Remove(key.Id)
|
||||
}
|
||||
}
|
||||
this.locker.Unlock()
|
||||
|
||||
// 从数据库中删除
|
||||
if this.db != nil {
|
||||
err := this.db.DeleteKey(key.Id)
|
||||
if err != nil {
|
||||
remotelogs.Error("KEY_MANAGER", "delete key from db failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var domainId int64
|
||||
var zoneId int64
|
||||
if key.NsDomain != nil {
|
||||
domainId = key.NsDomain.Id
|
||||
}
|
||||
if key.NsZone != nil {
|
||||
zoneId = key.NsZone.Id
|
||||
}
|
||||
|
||||
// 存入数据库
|
||||
if this.db != nil {
|
||||
exists, err := this.db.ExistsKey(key.Id)
|
||||
if err != nil {
|
||||
remotelogs.Error("KEY_MANAGER", "query failed: "+err.Error())
|
||||
} else {
|
||||
if exists {
|
||||
err = this.db.UpdateKey(key.Id, domainId, zoneId, key.Algo, key.Secret, key.SecretType, key.Version)
|
||||
if err != nil {
|
||||
remotelogs.Error("KEY_MANAGER", "update failed: "+err.Error())
|
||||
}
|
||||
} else {
|
||||
err = this.db.InsertKey(key.Id, domainId, zoneId, key.Algo, key.Secret, key.SecretType, key.Version)
|
||||
if err != nil {
|
||||
remotelogs.Error("KEY_MANAGER", "insert failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 加入缓存Map
|
||||
this.locker.Lock()
|
||||
var nsKey = &models.NSKey{
|
||||
Id: key.Id,
|
||||
DomainId: domainId,
|
||||
ZoneId: zoneId,
|
||||
Algo: key.Algo,
|
||||
Secret: key.Secret,
|
||||
SecretType: key.SecretType,
|
||||
Version: key.Version,
|
||||
}
|
||||
|
||||
if zoneId > 0 {
|
||||
keyList, ok := this.zoneKeyMap[zoneId]
|
||||
if ok {
|
||||
keyList.Add(nsKey)
|
||||
} else {
|
||||
keyList = models.NewNSKeys()
|
||||
keyList.Add(nsKey)
|
||||
this.zoneKeyMap[zoneId] = keyList
|
||||
}
|
||||
} else if domainId > 0 {
|
||||
keyList, ok := this.domainKeyMap[domainId]
|
||||
if ok {
|
||||
keyList.Add(nsKey)
|
||||
} else {
|
||||
keyList = models.NewNSKeys()
|
||||
keyList.Add(nsKey)
|
||||
this.domainKeyMap[domainId] = keyList
|
||||
}
|
||||
}
|
||||
this.locker.Unlock()
|
||||
}
|
||||
208
EdgeDNS/internal/nodes/manager_node_config.go
Normal file
208
EdgeDNS/internal/nodes/manager_node_config.go
Normal file
@@ -0,0 +1,208 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/accesslogs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/configs"
|
||||
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/events"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/utils"
|
||||
"sort"
|
||||
"time"
|
||||
)
|
||||
|
||||
var sharedNodeConfigManager = NewNodeConfigManager()
|
||||
|
||||
type NodeConfigManager struct {
|
||||
notifyChan chan bool
|
||||
ticker *time.Ticker
|
||||
timezone string
|
||||
|
||||
lastConfigJSON []byte
|
||||
|
||||
lastAPINodeVersion int64
|
||||
lastAPINodeAddrs []string // 以前的API节点地址
|
||||
}
|
||||
|
||||
func NewNodeConfigManager() *NodeConfigManager {
|
||||
return &NodeConfigManager{
|
||||
notifyChan: make(chan bool, 2),
|
||||
ticker: time.NewTicker(3 * time.Minute),
|
||||
}
|
||||
}
|
||||
|
||||
func (this *NodeConfigManager) Start() {
|
||||
for {
|
||||
select {
|
||||
case <-this.ticker.C:
|
||||
case <-this.notifyChan:
|
||||
}
|
||||
err := this.Loop()
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("NODE_CONFIG_MANAGER", err.Error())
|
||||
} else {
|
||||
remotelogs.Error("NODE_CONFIG_MANAGER", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (this *NodeConfigManager) Loop() error {
|
||||
client, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := client.NSNodeRPC.FindCurrentNSNodeConfig(client.Context(), &pb.FindCurrentNSNodeConfigRequest{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var configJSON = resp.NsNodeJSON
|
||||
if len(configJSON) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查是否有变化
|
||||
if bytes.Equal(this.lastConfigJSON, configJSON) {
|
||||
return nil
|
||||
}
|
||||
this.lastConfigJSON = configJSON
|
||||
|
||||
var config = &dnsconfigs.NSNodeConfig{}
|
||||
err = json.Unmarshal(configJSON, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = config.Init(context.TODO())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sharedNodeConfig = config
|
||||
configs.SharedNodeConfig = config
|
||||
|
||||
this.reload(config)
|
||||
|
||||
events.Notify(events.EventReload)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *NodeConfigManager) NotifyChange() {
|
||||
select {
|
||||
case this.notifyChan <- true:
|
||||
default:
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// 刷新配置
|
||||
func (this *NodeConfigManager) reload(config *dnsconfigs.NSNodeConfig) {
|
||||
teaconst.IsPlus = config.IsPlus
|
||||
|
||||
accesslogs.SharedDNSFileWriter().SetDirByPolicyPath(config.AccessLogFilePath)
|
||||
accesslogs.SharedDNSFileWriter().SetRotateConfig(config.AccessLogRotate)
|
||||
|
||||
needWriteFile := config.AccessLogWriteTargets == nil || config.AccessLogWriteTargets.File || config.AccessLogWriteTargets.ClickHouse
|
||||
if needWriteFile {
|
||||
_ = accesslogs.SharedDNSFileWriter().EnsureInit()
|
||||
} else {
|
||||
_ = accesslogs.SharedDNSFileWriter().Close()
|
||||
}
|
||||
|
||||
// timezone
|
||||
var timeZone = config.TimeZone
|
||||
if len(timeZone) == 0 {
|
||||
timeZone = "Asia/Shanghai"
|
||||
}
|
||||
|
||||
if this.timezone != timeZone {
|
||||
location, err := time.LoadLocation(timeZone)
|
||||
if err != nil {
|
||||
remotelogs.Error("TIMEZONE", "change time zone failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
remotelogs.Println("TIMEZONE", "change time zone to '"+timeZone+"'")
|
||||
time.Local = location
|
||||
this.timezone = timeZone
|
||||
}
|
||||
|
||||
// API Node地址,这里不限制是否为空,因为在为空时仍然要有对应的处理
|
||||
this.changeAPINodeAddrs(config.APINodeAddrs)
|
||||
}
|
||||
|
||||
// 检查API节点地址
|
||||
func (this *NodeConfigManager) changeAPINodeAddrs(apiNodeAddrs []*serverconfigs.NetworkAddressConfig) {
|
||||
var addrs = []string{}
|
||||
for _, addr := range apiNodeAddrs {
|
||||
err := addr.Init()
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "changeAPINodeAddrs: validate api node address '"+configutils.QuoteIP(addr.Host)+":"+addr.PortRange+"' failed: "+err.Error())
|
||||
} else {
|
||||
addrs = append(addrs, addr.FullAddresses()...)
|
||||
}
|
||||
}
|
||||
sort.Strings(addrs)
|
||||
|
||||
if utils.EqualStrings(this.lastAPINodeAddrs, addrs) {
|
||||
return
|
||||
}
|
||||
|
||||
this.lastAPINodeAddrs = addrs
|
||||
|
||||
config, err := configs.LoadAPIConfig()
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "changeAPINodeAddrs: "+err.Error())
|
||||
return
|
||||
}
|
||||
if config == nil {
|
||||
return
|
||||
}
|
||||
var oldEndpoints = config.RPCEndpoints
|
||||
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if len(addrs) > 0 {
|
||||
this.lastAPINodeVersion++
|
||||
var v = this.lastAPINodeVersion
|
||||
|
||||
// 异步检测,防止阻塞
|
||||
go func(v int64) {
|
||||
// 测试新的API节点地址
|
||||
if rpcClient.TestEndpoints(addrs) {
|
||||
config.RPCEndpoints = addrs
|
||||
} else {
|
||||
config.RPCEndpoints = oldEndpoints
|
||||
this.lastAPINodeAddrs = nil // 恢复为空,以便于下次更新重试
|
||||
}
|
||||
|
||||
// 检查测试中间有无新的变更
|
||||
if v != this.lastAPINodeVersion {
|
||||
return
|
||||
}
|
||||
|
||||
err = rpcClient.UpdateConfig(config)
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "changeAPINodeAddrs: update rpc config failed: "+err.Error())
|
||||
}
|
||||
}(v)
|
||||
return
|
||||
}
|
||||
|
||||
err = rpcClient.UpdateConfig(config)
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "changeAPINodeAddrs: update rpc config failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
254
EdgeDNS/internal/nodes/manager_record.go
Normal file
254
EdgeDNS/internal/nodes/manager_record.go
Normal file
@@ -0,0 +1,254 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/dbs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/models"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RecordManager 记录管理器
|
||||
type RecordManager struct {
|
||||
recordsMap map[int64]*models.DomainRecords // domainId => RecordsMap
|
||||
|
||||
db *dbs.DB
|
||||
locker sync.RWMutex
|
||||
version int64
|
||||
|
||||
notifier chan bool
|
||||
}
|
||||
|
||||
// NewRecordManager 获取新记录管理器对象
|
||||
func NewRecordManager(db *dbs.DB) *RecordManager {
|
||||
return &RecordManager{
|
||||
db: db,
|
||||
recordsMap: map[int64]*models.DomainRecords{},
|
||||
notifier: make(chan bool, 8),
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动自动任务
|
||||
func (this *RecordManager) Start() {
|
||||
remotelogs.Println("RECORD_MANAGER", "starting ...")
|
||||
|
||||
// 从本地数据库中加载数据
|
||||
err := this.Load()
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("RECORD_MANAGER", "load failed: "+err.Error())
|
||||
} else {
|
||||
remotelogs.Error("RECORD_MANAGER", "load failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// 初始化运行
|
||||
err = this.LoopAll()
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("RECORD_MANAGER", "loop failed: "+err.Error())
|
||||
} else {
|
||||
remotelogs.Error("RECORD_MANAGER", "loop failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// 更新
|
||||
var ticker = time.NewTicker(30 * time.Second)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
case <-this.notifier:
|
||||
}
|
||||
|
||||
err := this.LoopAll()
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("RECORD_MANAGER", "loop failed: "+err.Error())
|
||||
} else {
|
||||
remotelogs.Error("RECORD_MANAGER", "loop failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Load 从数据库中加载数据
|
||||
func (this *RecordManager) Load() error {
|
||||
var offset = 0
|
||||
var size = 10000
|
||||
for {
|
||||
records, err := this.db.ListRecords(offset, size)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(records) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
this.locker.Lock()
|
||||
for _, record := range records {
|
||||
domainRecords, ok := this.recordsMap[record.DomainId]
|
||||
if !ok {
|
||||
domainRecords = models.NewDomainRecords()
|
||||
this.recordsMap[record.DomainId] = domainRecords
|
||||
}
|
||||
domainRecords.Add(record)
|
||||
|
||||
if record.Version > this.version {
|
||||
this.version = record.Version
|
||||
}
|
||||
}
|
||||
this.locker.Unlock()
|
||||
|
||||
offset += size
|
||||
}
|
||||
|
||||
if this.version > 0 {
|
||||
this.version++
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *RecordManager) LoopAll() error {
|
||||
for {
|
||||
hasNext, err := this.Loop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !hasNext {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Loop 单次循环任务
|
||||
func (this *RecordManager) Loop() (hasNext bool, err error) {
|
||||
client, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
resp, err := client.NSRecordRPC.ListNSRecordsAfterVersion(client.Context(), &pb.ListNSRecordsAfterVersionRequest{
|
||||
Version: this.version,
|
||||
Size: 20000,
|
||||
})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
var records = resp.NsRecords
|
||||
if len(records) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
for _, record := range records {
|
||||
this.processRecord(record)
|
||||
if record.Version > this.version {
|
||||
this.version = record.Version
|
||||
}
|
||||
}
|
||||
this.version++
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (this *RecordManager) FindRecords(domainId int64, routeCodes []string, recordName string, recordType dnsconfigs.RecordType, strictMode bool) (records []*models.NSRecord, routeCode string) {
|
||||
this.locker.RLock()
|
||||
domainRecords, ok := this.recordsMap[domainId]
|
||||
if ok {
|
||||
records, routeCode = domainRecords.Find(routeCodes, recordName, recordType, sharedNodeConfig.Answer, strictMode)
|
||||
}
|
||||
this.locker.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
// NotifyUpdate 通知更新
|
||||
func (this *RecordManager) NotifyUpdate() {
|
||||
select {
|
||||
case this.notifier <- true:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// 处理单条记录
|
||||
func (this *RecordManager) processRecord(record *pb.NSRecord) {
|
||||
if record.NsDomain == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !record.IsOn || record.IsDeleted {
|
||||
this.locker.Lock()
|
||||
domainRecords, ok := this.recordsMap[record.NsDomain.Id]
|
||||
if ok {
|
||||
domainRecords.Remove(record.Id)
|
||||
}
|
||||
this.locker.Unlock()
|
||||
|
||||
// 从数据库中删除
|
||||
if this.db != nil {
|
||||
err := this.db.DeleteRecord(record.Id)
|
||||
if err != nil {
|
||||
remotelogs.Error("RECORD_MANAGER", "delete record from db failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// 存入数据库
|
||||
if this.db != nil {
|
||||
exists, err := this.db.ExistsRecord(record.Id)
|
||||
if err != nil {
|
||||
remotelogs.Error("RECORD_MANAGER", "query failed: "+err.Error())
|
||||
} else {
|
||||
var routeIds = []string{}
|
||||
for _, route := range record.NsRoutes {
|
||||
routeIds = append(routeIds, route.Code)
|
||||
}
|
||||
|
||||
if exists {
|
||||
err = this.db.UpdateRecord(record.Id, record.NsDomain.Id, record.Name, record.Type, record.Value, record.MxPriority, record.SrvPriority, record.SrvWeight, record.SrvPort, record.CaaFlag, record.CaaTag, record.Ttl, record.Weight, routeIds, record.Version)
|
||||
if err != nil {
|
||||
remotelogs.Error("RECORD_MANAGER", "update failed: "+err.Error())
|
||||
}
|
||||
} else {
|
||||
err = this.db.InsertRecord(record.Id, record.NsDomain.Id, record.Name, record.Type, record.Value, record.MxPriority, record.SrvPriority, record.SrvWeight, record.SrvPort, record.CaaFlag, record.CaaTag, record.Ttl, record.Weight, routeIds, record.Version)
|
||||
if err != nil {
|
||||
remotelogs.Error("RECORD_MANAGER", "insert failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 加入缓存Map
|
||||
this.locker.Lock()
|
||||
domainRecords, ok := this.recordsMap[record.NsDomain.Id]
|
||||
if !ok {
|
||||
domainRecords = models.NewDomainRecords()
|
||||
this.recordsMap[record.NsDomain.Id] = domainRecords
|
||||
}
|
||||
var routeIds = []string{}
|
||||
for _, r := range record.NsRoutes {
|
||||
routeIds = append(routeIds, r.Code)
|
||||
}
|
||||
domainRecords.Add(&models.NSRecord{
|
||||
Id: record.Id,
|
||||
Name: record.Name,
|
||||
Type: record.Type,
|
||||
Value: record.Value,
|
||||
MXPriority: record.MxPriority,
|
||||
SRVPriority: record.SrvPriority,
|
||||
SRVWeight: record.SrvWeight,
|
||||
SRVPort: record.SrvPort,
|
||||
CAAFlag: record.CaaFlag,
|
||||
CAATag: record.CaaTag,
|
||||
Ttl: record.Ttl,
|
||||
Weight: record.Weight,
|
||||
Version: record.Version,
|
||||
RouteIds: routeIds,
|
||||
DomainId: record.NsDomain.Id,
|
||||
})
|
||||
this.locker.Unlock()
|
||||
}
|
||||
45
EdgeDNS/internal/nodes/manager_record_test.go
Normal file
45
EdgeDNS/internal/nodes/manager_record_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/dbs"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRecordManager_Loop(t *testing.T) {
|
||||
var db = dbs.NewDB(Tea.Root + "/data/data.db")
|
||||
err := db.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
manager := NewRecordManager(db)
|
||||
for i := 0; i < 10; i++ {
|
||||
_, err := manager.Loop()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
logs.PrintAsJSON(manager.recordsMap, t)
|
||||
t.Log("version:", manager.version)
|
||||
}
|
||||
|
||||
func TestRecordManager_Load(t *testing.T) {
|
||||
db := dbs.NewDB(Tea.Root + "/data/data.db")
|
||||
err := db.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
manager := NewRecordManager(db)
|
||||
err = manager.Load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
logs.PrintAsJSON(manager.recordsMap, t)
|
||||
t.Log("version:", manager.version)
|
||||
}
|
||||
481
EdgeDNS/internal/nodes/manager_route.go
Normal file
481
EdgeDNS/internal/nodes/manager_route.go
Normal file
@@ -0,0 +1,481 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/agents"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/dbs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/models"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RouteManager 线路管理器
|
||||
type RouteManager struct {
|
||||
allRouteMap map[int64]*models.NSRoute // routeId => route
|
||||
userRouteMap map[int64][]int64 // userId => sorted routeIds
|
||||
|
||||
db *dbs.DB
|
||||
version int64
|
||||
locker sync.RWMutex
|
||||
|
||||
notifier chan bool
|
||||
|
||||
ispRouteMap map[string]string // name => code
|
||||
chinaRouteMap map[string]string // name => code
|
||||
worldRouteMap map[string]string // name => code
|
||||
}
|
||||
|
||||
// NewRouteManager 获取新线路管理器对象
|
||||
func NewRouteManager(db *dbs.DB) *RouteManager {
|
||||
return &RouteManager{
|
||||
db: db,
|
||||
allRouteMap: map[int64]*models.NSRoute{},
|
||||
userRouteMap: map[int64][]int64{},
|
||||
|
||||
notifier: make(chan bool, 8),
|
||||
|
||||
ispRouteMap: map[string]string{},
|
||||
chinaRouteMap: map[string]string{},
|
||||
worldRouteMap: map[string]string{},
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动自动任务
|
||||
func (this *RouteManager) Start() {
|
||||
remotelogs.Println("ROUTE_MANAGER", "starting ...")
|
||||
|
||||
// 初始化公共线路
|
||||
this.loadDefaultRoutes()
|
||||
|
||||
// 从本地数据库中加载数据
|
||||
err := this.Load()
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("ROUTE_MANAGER", "load failed: "+err.Error())
|
||||
} else {
|
||||
remotelogs.Error("ROUTE_MANAGER", "load failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// 初始化运行
|
||||
err = this.LoopAll()
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("ROUTE_MANAGER", "loop failed: "+err.Error())
|
||||
} else {
|
||||
remotelogs.Error("ROUTE_MANAGER", "loop failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// 更新
|
||||
var ticker = time.NewTicker(1 * time.Minute)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
case <-this.notifier:
|
||||
}
|
||||
|
||||
err := this.LoopAll()
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("ROUTE_MANAGER", "loop failed: "+err.Error())
|
||||
} else {
|
||||
remotelogs.Error("ROUTE_MANAGER", "loop failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Load 从数据库中加载数据
|
||||
func (this *RouteManager) Load() error {
|
||||
var offset int64 = 0
|
||||
var size int64 = 10000
|
||||
for {
|
||||
routes, err := this.db.ListRoutes(offset, size)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(routes) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
this.locker.Lock()
|
||||
for _, route := range routes {
|
||||
this.addRoute(route)
|
||||
|
||||
if route.Version > this.version {
|
||||
this.version = route.Version
|
||||
}
|
||||
}
|
||||
this.locker.Unlock()
|
||||
|
||||
offset += size
|
||||
}
|
||||
|
||||
if this.version > 0 {
|
||||
this.version++
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *RouteManager) LoopAll() error {
|
||||
for {
|
||||
hasNext, err := this.Loop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !hasNext {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Loop 单次循环任务
|
||||
func (this *RouteManager) Loop() (hasNext bool, err error) {
|
||||
client, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
resp, err := client.NSRouteRPC.ListNSRoutesAfterVersion(client.Context(), &pb.ListNSRoutesAfterVersionRequest{
|
||||
Version: this.version,
|
||||
Size: 20000,
|
||||
})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
var routes = resp.NsRoutes
|
||||
if len(routes) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
for _, route := range routes {
|
||||
this.processRoute(route)
|
||||
|
||||
if route.Version > this.version {
|
||||
this.version = route.Version
|
||||
}
|
||||
}
|
||||
this.version++
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// FindRouteCodes 查找一个地址对应的线路
|
||||
func (this *RouteManager) FindRouteCodes(ip string, domainUserId int64) (result []string) {
|
||||
var netIP = net.ParseIP(ip)
|
||||
if len(netIP) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 自定义route
|
||||
this.locker.RLock()
|
||||
|
||||
// 先查找用户自定义的
|
||||
if domainUserId > 0 {
|
||||
var userRouteIds = this.userRouteMap[domainUserId]
|
||||
for _, routeId := range userRouteIds {
|
||||
route, ok := this.allRouteMap[routeId]
|
||||
if ok && route.Contains(netIP) {
|
||||
result = append(result, route.RealCode())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 再查找公共的
|
||||
var publicRouteIds = this.userRouteMap[0]
|
||||
for _, routeId := range publicRouteIds {
|
||||
route, ok := this.allRouteMap[routeId]
|
||||
if ok && route.Contains(netIP) {
|
||||
result = append(result, route.RealCode())
|
||||
}
|
||||
}
|
||||
|
||||
this.locker.RUnlock()
|
||||
|
||||
// 解析公用线路
|
||||
var ipResult = iplibrary.LookupIP(ip)
|
||||
if ipResult != nil && ipResult.IsOk() {
|
||||
// 运营商
|
||||
for _, providerCode := range ipResult.ProviderCodes() {
|
||||
code, ok := this.ispRouteMap[providerCode]
|
||||
if ok {
|
||||
result = append(result, code)
|
||||
|
||||
// 单次只能有一个匹配
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 省|州|城市
|
||||
if ipResult.ProvinceId() > 0 {
|
||||
result = append(result, "region:province:"+types.String(ipResult.ProvinceId()))
|
||||
}
|
||||
if ipResult.CityId() > 0 {
|
||||
result = append(result, "region:city:"+types.String(ipResult.CityId()))
|
||||
}
|
||||
if ipResult.TownId() > 0 {
|
||||
result = append(result, "region:town:"+types.String(ipResult.TownId()))
|
||||
}
|
||||
|
||||
// 中国省市
|
||||
for _, provinceCode := range ipResult.ProvinceCodes() {
|
||||
// 中国
|
||||
code, ok := this.chinaRouteMap[provinceCode]
|
||||
if ok {
|
||||
result = append(result, code)
|
||||
|
||||
// 兼容以前的拼写错误
|
||||
switch code {
|
||||
case "china:province:hebei":
|
||||
result = append(result, "china:province:heibei")
|
||||
case "china:province:heibei":
|
||||
result = append(result, "china:province:hebei")
|
||||
case "china:jilin":
|
||||
result = append(result, "china:province:jilin")
|
||||
case "china:province:jilin":
|
||||
result = append(result, "china:jilin")
|
||||
}
|
||||
|
||||
// 香港
|
||||
switch code {
|
||||
case dnsconfigs.ChinaProvinceCodeHK:
|
||||
result = append(result, dnsconfigs.WorldRegionCodeHK, dnsconfigs.WorldRegionCodeChinaAbroad)
|
||||
|
||||
// 澳门
|
||||
case dnsconfigs.ChinaProvinceCodeMO:
|
||||
result = append(result, dnsconfigs.WorldRegionCodeMO, dnsconfigs.WorldRegionCodeChinaAbroad)
|
||||
|
||||
// 台湾
|
||||
case dnsconfigs.ChinaProvinceCodeTW:
|
||||
result = append(result, dnsconfigs.WorldRegionCodeTW, dnsconfigs.WorldRegionCodeChinaAbroad)
|
||||
default:
|
||||
result = append(result, dnsconfigs.WorldRegionCodeChinaMainland)
|
||||
}
|
||||
|
||||
// 单次只能有一个匹配
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 国家/地区
|
||||
for _, countryCode := range ipResult.CountryCodes() {
|
||||
code, ok := this.worldRouteMap[countryCode]
|
||||
if ok {
|
||||
// 中国全境(world:CN)必须优先于「海外」匹配,否则前面中国省市若误判为 HK/MO/TW 已加入 world:CN:abroad 时会先命中海外线
|
||||
if code == dnsconfigs.WorldRegionCodeChina {
|
||||
result = append([]string{code}, result...)
|
||||
} else {
|
||||
result = append(result, code)
|
||||
// 中国海外
|
||||
if code != dnsconfigs.WorldRegionCodeChina {
|
||||
result = append(result, dnsconfigs.WorldRegionCodeChinaAbroad)
|
||||
}
|
||||
}
|
||||
|
||||
// 单次只能有一个匹配
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 搜索引擎线路
|
||||
if agents.SharedManager != nil {
|
||||
var agentCode = agents.SharedManager.LookupIP(ip)
|
||||
if len(agentCode) > 0 {
|
||||
result = append(result, "agent:"+agentCode, "agent" /** 所有搜索引擎 **/)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// NotifyUpdate 通知更新
|
||||
func (this *RouteManager) NotifyUpdate() {
|
||||
select {
|
||||
case this.notifier <- true:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (this *RouteManager) loadDefaultRoutes() {
|
||||
for _, route := range dnsconfigs.AllDefaultISPRoutes {
|
||||
for _, name := range route.AliasNames {
|
||||
this.ispRouteMap[name] = route.Code
|
||||
}
|
||||
}
|
||||
for _, route := range dnsconfigs.AllDefaultChinaProvinceRoutes {
|
||||
for _, name := range route.AliasNames {
|
||||
this.chinaRouteMap[name] = route.Code
|
||||
}
|
||||
}
|
||||
for _, route := range dnsconfigs.AllDefaultWorldRegionRoutes {
|
||||
for _, name := range route.AliasNames {
|
||||
this.worldRouteMap[name] = route.Code
|
||||
}
|
||||
// 用线路 Code 中的国家/地区 ISO 码做映射,使 IP 库返回的 ISO 码(如 US、CN、HK)能命中对应线路
|
||||
if strings.HasPrefix(route.Code, "world:") {
|
||||
parts := strings.Split(route.Code, ":")
|
||||
if len(parts) == 2 && len(parts[1]) == 2 {
|
||||
this.worldRouteMap[parts[1]] = route.Code
|
||||
}
|
||||
if len(parts) == 3 && len(parts[2]) == 2 {
|
||||
this.worldRouteMap[strings.ToUpper(parts[2])] = route.Code
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 添加线路
|
||||
func (this *RouteManager) addRoute(route *models.NSRoute) {
|
||||
// 不需要加锁,因为此函数均在锁内调用
|
||||
|
||||
// 从老的用户中删除
|
||||
oldRoute, ok := this.allRouteMap[route.Id]
|
||||
if ok {
|
||||
var oldUserId = oldRoute.UserId
|
||||
if oldUserId != route.UserId {
|
||||
userRouteIds, ok := this.userRouteMap[oldUserId]
|
||||
if ok {
|
||||
this.userRouteMap[oldUserId] = this.removeId(userRouteIds, route.Id)
|
||||
if len(userRouteIds) == 0 {
|
||||
delete(this.userRouteMap, oldUserId)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 添加
|
||||
this.allRouteMap[route.Id] = route
|
||||
|
||||
userRouteIds, ok := this.userRouteMap[route.UserId]
|
||||
if ok {
|
||||
// 重新按优先级、排序、ID排序
|
||||
var userRoutes = []*models.NSRoute{}
|
||||
for _, userRouteId := range userRouteIds {
|
||||
userRoute, ok := this.allRouteMap[userRouteId]
|
||||
if ok {
|
||||
userRoutes = append(userRoutes, userRoute)
|
||||
}
|
||||
}
|
||||
if !lists.ContainsInt64(userRouteIds, route.Id) {
|
||||
userRoutes = append(userRoutes, route)
|
||||
}
|
||||
|
||||
sort.Slice(userRoutes, func(i, j int) bool {
|
||||
var userRoute1 = userRoutes[i]
|
||||
var userRoute2 = userRoutes[j]
|
||||
if userRoute1.Priority != userRoute2.Priority {
|
||||
return userRoute1.Priority > userRoute2.Priority
|
||||
}
|
||||
if userRoute1.Order != userRoute2.Order {
|
||||
return userRoute1.Order > userRoute2.Order
|
||||
}
|
||||
return userRoute1.Id < userRoute2.Id
|
||||
})
|
||||
|
||||
var newUserRouteIds = []int64{}
|
||||
for _, userRoute := range userRoutes {
|
||||
newUserRouteIds = append(newUserRouteIds, userRoute.Id)
|
||||
}
|
||||
this.userRouteMap[route.UserId] = newUserRouteIds
|
||||
} else {
|
||||
this.userRouteMap[route.UserId] = []int64{route.Id}
|
||||
}
|
||||
}
|
||||
|
||||
// 删除线路
|
||||
func (this *RouteManager) removePBRoute(route *pb.NSRoute) {
|
||||
delete(this.allRouteMap, route.Id)
|
||||
userRouteIds, ok := this.userRouteMap[route.UserId]
|
||||
if ok {
|
||||
userRouteIds = this.removeId(userRouteIds, route.Id)
|
||||
if len(userRouteIds) == 0 {
|
||||
delete(this.userRouteMap, route.UserId)
|
||||
} else {
|
||||
this.userRouteMap[route.UserId] = userRouteIds
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理线路
|
||||
func (this *RouteManager) processRoute(route *pb.NSRoute) {
|
||||
if !route.IsOn || route.IsDeleted {
|
||||
this.locker.Lock()
|
||||
this.removePBRoute(route)
|
||||
this.locker.Unlock()
|
||||
|
||||
// 从数据库中删除
|
||||
if this.db != nil {
|
||||
err := this.db.DeleteRoute(route.Id)
|
||||
if err != nil {
|
||||
remotelogs.Error("ROUTE_MANAGER", "delete route from db failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// 存入数据库
|
||||
if this.db != nil {
|
||||
exists, err := this.db.ExistsRoute(route.Id)
|
||||
if err != nil {
|
||||
remotelogs.Error("ROUTE_MANAGER", "query failed: "+err.Error())
|
||||
} else {
|
||||
if exists {
|
||||
err = this.db.UpdateRoute(route.Id, route.UserId, route.RangesJSON, route.Order, route.Priority, route.Version)
|
||||
if err != nil {
|
||||
remotelogs.Error("ROUTE_MANAGER", "update failed: "+err.Error())
|
||||
}
|
||||
} else {
|
||||
err = this.db.InsertRoute(route.Id, route.UserId, route.RangesJSON, route.Order, route.Priority, route.Version)
|
||||
if err != nil {
|
||||
remotelogs.Error("ROUTE_MANAGER", "insert failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ranges, err := models.InitRangesFromJSON(route.RangesJSON)
|
||||
if err != nil {
|
||||
remotelogs.Error("ROUTE_MANAGER", "decode routes '"+strconv.FormatInt(route.Id, 10)+"' failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var nsRoute = &models.NSRoute{
|
||||
Id: route.Id,
|
||||
Ranges: ranges,
|
||||
Priority: route.Priority,
|
||||
Order: route.Order,
|
||||
UserId: route.UserId,
|
||||
Version: route.Version,
|
||||
}
|
||||
|
||||
this.locker.Lock()
|
||||
this.addRoute(nsRoute)
|
||||
this.locker.Unlock()
|
||||
}
|
||||
|
||||
// 从一组ID中删除某个ID
|
||||
func (this *RouteManager) removeId(ids []int64, id int64) []int64 {
|
||||
var result = []int64{}
|
||||
for _, id2 := range ids {
|
||||
if id2 == id {
|
||||
continue
|
||||
}
|
||||
result = append(result, id2)
|
||||
}
|
||||
return result
|
||||
}
|
||||
138
EdgeDNS/internal/nodes/manager_route_test.go
Normal file
138
EdgeDNS/internal/nodes/manager_route_test.go
Normal file
@@ -0,0 +1,138 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build plus
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/dbs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/events"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/models"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRouteManager_Loop(t *testing.T) {
|
||||
var db = dbs.NewDB(Tea.Root + "/data/data.db")
|
||||
err := db.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var manager = NewRouteManager(db)
|
||||
for i := 0; i < 10; i++ {
|
||||
_, err := manager.Loop()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
logs.PrintAsJSON(manager.allRouteMap, t)
|
||||
}
|
||||
|
||||
func TestRouteManager_Load(t *testing.T) {
|
||||
var db = dbs.NewDB(Tea.Root + "/data/data.db")
|
||||
err := db.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var manager = NewRouteManager(db)
|
||||
err = manager.Load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
logs.PrintAsJSON(manager.allRouteMap, t)
|
||||
t.Log("version:", manager.version)
|
||||
}
|
||||
|
||||
func TestRouteManager_AddRoute(t *testing.T) {
|
||||
var manager = NewRouteManager(nil)
|
||||
manager.addRoute(&models.NSRoute{
|
||||
Id: 1,
|
||||
UserId: 0,
|
||||
Priority: 0,
|
||||
})
|
||||
manager.addRoute(&models.NSRoute{
|
||||
Id: 2,
|
||||
UserId: 0,
|
||||
Priority: 1,
|
||||
})
|
||||
manager.addRoute(&models.NSRoute{
|
||||
Id: 3,
|
||||
UserId: 0,
|
||||
Priority: 2,
|
||||
Order: 1,
|
||||
})
|
||||
manager.addRoute(&models.NSRoute{
|
||||
Id: 4,
|
||||
UserId: 0,
|
||||
Priority: 2,
|
||||
Order: 1,
|
||||
})
|
||||
manager.addRoute(&models.NSRoute{
|
||||
Id: 4,
|
||||
UserId: 0,
|
||||
Priority: 2,
|
||||
Order: 1,
|
||||
})
|
||||
manager.addRoute(&models.NSRoute{
|
||||
Id: 4,
|
||||
UserId: 1,
|
||||
Priority: 2,
|
||||
Order: 1,
|
||||
})
|
||||
manager.addRoute(&models.NSRoute{
|
||||
Id: 5,
|
||||
UserId: 1,
|
||||
Priority: 2,
|
||||
Order: 1,
|
||||
})
|
||||
logs.PrintAsJSON(manager.allRouteMap, t)
|
||||
logs.PrintAsJSON(manager.userRouteMap, t)
|
||||
}
|
||||
|
||||
func TestRouteManager_FindRouteCodes(t *testing.T) {
|
||||
events.Notify(events.EventLoaded)
|
||||
|
||||
var manager = NewRouteManager(nil)
|
||||
manager.loadDefaultRoutes()
|
||||
{
|
||||
var r = &models.NSRoute{
|
||||
Id: 1,
|
||||
Ranges: []dnsconfigs.NSRouteRangeInterface{
|
||||
&dnsconfigs.NSRouteRangeIPRange{
|
||||
IPFrom: "192.168.1.1",
|
||||
IPTo: "192.168.1.200",
|
||||
},
|
||||
&dnsconfigs.NSRouteRangeIPRange{
|
||||
IPFrom: "192.168.1.200",
|
||||
IPTo: "192.168.1.255",
|
||||
},
|
||||
&dnsconfigs.NSRouteRangeIPRange{
|
||||
IPFrom: "127.0.0.1",
|
||||
IPTo: "127.0.0.1",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, rr := range r.Ranges {
|
||||
err := rr.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
manager.addRoute(r)
|
||||
}
|
||||
for _, ip := range []string{
|
||||
"192.168.1.100",
|
||||
"192.168.1.201",
|
||||
"192.168.2.1",
|
||||
"127.0.0.1",
|
||||
"111.197.174.111",
|
||||
"202.109.116.116",
|
||||
"8.8.8.8",
|
||||
} {
|
||||
t.Log(ip+": ", manager.FindRouteCodes(ip, 0))
|
||||
}
|
||||
}
|
||||
43
EdgeDNS/internal/nodes/node_panic.go
Normal file
43
EdgeDNS/internal/nodes/node_panic.go
Normal file
@@ -0,0 +1,43 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build !arm64 && plus
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// 处理异常
|
||||
func (this *DNSNode) handlePanic() {
|
||||
// 如果是在前台运行就直接返回
|
||||
backgroundEnv, _ := os.LookupEnv("EdgeBackground")
|
||||
if backgroundEnv != "on" {
|
||||
return
|
||||
}
|
||||
|
||||
var panicFile = Tea.Root + "/logs/panic.log"
|
||||
|
||||
// 分析panic
|
||||
data, err := os.ReadFile(panicFile)
|
||||
if err == nil {
|
||||
var index = bytes.Index(data, []byte("panic:"))
|
||||
if index >= 0 {
|
||||
remotelogs.Error("NODE", "系统错误,请上报给开发者: "+string(data[index:]))
|
||||
}
|
||||
}
|
||||
|
||||
fp, err := os.OpenFile(panicFile, os.O_CREATE|os.O_TRUNC|os.O_WRONLY|os.O_APPEND, 0777)
|
||||
if err != nil {
|
||||
logs.Println("NODE", "open 'panic.log' failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
err = syscall.Dup2(int(fp.Fd()), int(os.Stderr.Fd()))
|
||||
if err != nil {
|
||||
logs.Println("NODE", "write to 'panic.log' failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
9
EdgeDNS/internal/nodes/node_panic_arm64.go
Normal file
9
EdgeDNS/internal/nodes/node_panic_arm64.go
Normal file
@@ -0,0 +1,9 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build arm64 && plus
|
||||
|
||||
package nodes
|
||||
|
||||
// 处理异常
|
||||
func (this *DNSNode) handlePanic() {
|
||||
|
||||
}
|
||||
226
EdgeDNS/internal/nodes/node_status_executor.go
Normal file
226
EdgeDNS/internal/nodes/node_status_executor.go
Normal file
@@ -0,0 +1,226 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/configs"
|
||||
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/events"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/monitor"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/utils"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/shirou/gopsutil/v3/cpu"
|
||||
"github.com/shirou/gopsutil/v3/disk"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type NodeStatusExecutor struct {
|
||||
isFirstTime bool
|
||||
|
||||
cpuUpdatedTime time.Time
|
||||
cpuLogicalCount int
|
||||
cpuPhysicalCount int
|
||||
|
||||
apiCallStat *rpc.CallStat
|
||||
|
||||
ticker *time.Ticker
|
||||
}
|
||||
|
||||
func NewNodeStatusExecutor() *NodeStatusExecutor {
|
||||
return &NodeStatusExecutor{
|
||||
ticker: time.NewTicker(30 * time.Second),
|
||||
apiCallStat: rpc.NewCallStat(10),
|
||||
}
|
||||
}
|
||||
|
||||
func (this *NodeStatusExecutor) Listen() {
|
||||
this.isFirstTime = true
|
||||
this.cpuUpdatedTime = time.Now()
|
||||
this.update()
|
||||
|
||||
events.On(events.EventQuit, func() {
|
||||
remotelogs.Println("NODE_STATUS", "quit executor")
|
||||
this.ticker.Stop()
|
||||
})
|
||||
|
||||
for range this.ticker.C {
|
||||
this.isFirstTime = false
|
||||
this.update()
|
||||
}
|
||||
}
|
||||
|
||||
func (this *NodeStatusExecutor) update() {
|
||||
var status = &nodeconfigs.NodeStatus{}
|
||||
status.BuildVersion = teaconst.Version
|
||||
status.BuildVersionCode = utils.VersionToLong(teaconst.Version)
|
||||
status.OS = runtime.GOOS
|
||||
status.Arch = runtime.GOARCH
|
||||
status.ConfigVersion = 0
|
||||
status.IsActive = true
|
||||
status.ConnectionCount = 0 // TODO 将来显示连接数
|
||||
|
||||
apiSuccessPercent, apiAvgCostSeconds := this.apiCallStat.Sum()
|
||||
status.APISuccessPercent = apiSuccessPercent
|
||||
status.APIAvgCostSeconds = apiAvgCostSeconds
|
||||
|
||||
exe, _ := os.Executable()
|
||||
status.ExePath = exe
|
||||
|
||||
hostname, _ := os.Hostname()
|
||||
status.Hostname = hostname
|
||||
|
||||
this.updateCPU(status)
|
||||
this.updateMem(status)
|
||||
this.updateLoad(status)
|
||||
this.updateDisk(status)
|
||||
|
||||
status.UpdatedAt = time.Now().Unix()
|
||||
status.Timestamp = status.UpdatedAt
|
||||
|
||||
// 发送数据
|
||||
jsonData, err := json.Marshal(status)
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE_STATUS", "serial NodeStatus fail: "+err.Error())
|
||||
return
|
||||
}
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE_STATUS", "failed to open rpc: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var nodeId = int64(0)
|
||||
sharedAPIConfig, _ := configs.SharedAPIConfig()
|
||||
if sharedAPIConfig != nil {
|
||||
nodeId = sharedAPIConfig.NumberId
|
||||
}
|
||||
var before = time.Now()
|
||||
_, err = rpcClient.NSNodeRPC.UpdateNSNodeStatus(rpcClient.Context(), &pb.UpdateNSNodeStatusRequest{
|
||||
NodeId: nodeId,
|
||||
StatusJSON: jsonData,
|
||||
})
|
||||
var costSeconds = time.Since(before).Seconds()
|
||||
this.apiCallStat.Add(err == nil, costSeconds)
|
||||
if err != nil {
|
||||
if !rpc.IsConnError(err) {
|
||||
remotelogs.Error("NODE_STATUS", "rpc UpdateNSNodeStatus() failed: "+err.Error())
|
||||
} else {
|
||||
remotelogs.Debug("NODE_STATUS", "rpc UpdateNSNodeStatus() failed: "+err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 更新CPU
|
||||
func (this *NodeStatusExecutor) updateCPU(status *nodeconfigs.NodeStatus) {
|
||||
duration := time.Duration(0)
|
||||
if this.isFirstTime {
|
||||
duration = 100 * time.Millisecond
|
||||
}
|
||||
|
||||
percents, err := cpu.Percent(duration, false)
|
||||
if err != nil {
|
||||
status.Error = "cpu.Percent(): " + err.Error()
|
||||
return
|
||||
}
|
||||
if len(percents) == 0 {
|
||||
return
|
||||
}
|
||||
status.CPUUsage = percents[0] / 100
|
||||
|
||||
if time.Since(this.cpuUpdatedTime) > 300*time.Second { // 每隔5分钟才会更新一次
|
||||
this.cpuUpdatedTime = time.Now()
|
||||
|
||||
status.CPULogicalCount, err = cpu.Counts(true)
|
||||
if err != nil {
|
||||
status.Error = "cpu.Counts(): " + err.Error()
|
||||
return
|
||||
}
|
||||
status.CPUPhysicalCount, err = cpu.Counts(false)
|
||||
if err != nil {
|
||||
status.Error = "cpu.Counts(): " + err.Error()
|
||||
return
|
||||
}
|
||||
this.cpuLogicalCount = status.CPULogicalCount
|
||||
this.cpuPhysicalCount = status.CPUPhysicalCount
|
||||
} else {
|
||||
status.CPULogicalCount = this.cpuLogicalCount
|
||||
status.CPUPhysicalCount = this.cpuPhysicalCount
|
||||
}
|
||||
|
||||
// 记录监控数据
|
||||
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemCPU, maps.Map{
|
||||
"usage": status.CPUUsage,
|
||||
"cores": runtime.NumCPU(),
|
||||
})
|
||||
}
|
||||
|
||||
// 更新硬盘
|
||||
func (this *NodeStatusExecutor) updateDisk(status *nodeconfigs.NodeStatus) {
|
||||
partitions, err := disk.Partitions(false)
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE_STATUS", err.Error())
|
||||
return
|
||||
}
|
||||
lists.Sort(partitions, func(i int, j int) bool {
|
||||
p1 := partitions[i]
|
||||
p2 := partitions[j]
|
||||
return p1.Mountpoint > p2.Mountpoint
|
||||
})
|
||||
|
||||
// 当前所在的fs
|
||||
var rootFS = ""
|
||||
var rootTotal = uint64(0)
|
||||
if lists.ContainsString([]string{"darwin", "linux", "freebsd"}, runtime.GOOS) {
|
||||
for _, p := range partitions {
|
||||
if p.Mountpoint == "/" {
|
||||
rootFS = p.Fstype
|
||||
usage, _ := disk.Usage(p.Mountpoint)
|
||||
if usage != nil {
|
||||
rootTotal = usage.Total
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var total = rootTotal
|
||||
var totalUsage = uint64(0)
|
||||
var maxUsage = float64(0)
|
||||
for _, partition := range partitions {
|
||||
if runtime.GOOS != "windows" && !strings.Contains(partition.Device, "/") && !strings.Contains(partition.Device, "\\") {
|
||||
continue
|
||||
}
|
||||
|
||||
// 跳过不同fs的
|
||||
if len(rootFS) > 0 && rootFS != partition.Fstype {
|
||||
continue
|
||||
}
|
||||
|
||||
usage, err := disk.Usage(partition.Mountpoint)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if partition.Mountpoint != "/" && (usage.Total != rootTotal || total == 0) {
|
||||
total += usage.Total
|
||||
}
|
||||
totalUsage += usage.Used
|
||||
if usage.UsedPercent >= maxUsage {
|
||||
maxUsage = usage.UsedPercent
|
||||
status.DiskMaxUsagePartition = partition.Mountpoint
|
||||
}
|
||||
}
|
||||
status.DiskTotal = total
|
||||
if total > 0 {
|
||||
status.DiskUsage = float64(totalUsage) / float64(total)
|
||||
}
|
||||
status.DiskMaxUsage = maxUsage / 100
|
||||
}
|
||||
27
EdgeDNS/internal/nodes/node_status_executor_test.go
Normal file
27
EdgeDNS/internal/nodes/node_status_executor_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/shirou/gopsutil/v3/cpu"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNodeStatusExecutor_CPU(t *testing.T) {
|
||||
countLogicCPU, err := cpu.Counts(true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("logic count:", countLogicCPU)
|
||||
|
||||
countPhysicalCPU, err := cpu.Counts(false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("physical count:", countPhysicalCPU)
|
||||
|
||||
percents, err := cpu.Percent(100*time.Millisecond, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(percents)
|
||||
}
|
||||
58
EdgeDNS/internal/nodes/node_status_executor_unix.go
Normal file
58
EdgeDNS/internal/nodes/node_status_executor_unix.go
Normal file
@@ -0,0 +1,58 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/monitor"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/shirou/gopsutil/v3/load"
|
||||
"github.com/shirou/gopsutil/v3/mem"
|
||||
)
|
||||
|
||||
// 更新内存
|
||||
func (this *NodeStatusExecutor) updateMem(status *nodeconfigs.NodeStatus) {
|
||||
stat, err := mem.VirtualMemory()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 重新计算内存
|
||||
if stat.Total > 0 {
|
||||
stat.Used = stat.Total - stat.Free - stat.Buffers - stat.Cached
|
||||
status.MemoryUsage = float64(stat.Used) / float64(stat.Total)
|
||||
}
|
||||
|
||||
status.MemoryTotal = stat.Total
|
||||
|
||||
// 记录监控数据
|
||||
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemMemory, maps.Map{
|
||||
"usage": status.MemoryUsage,
|
||||
"total": status.MemoryTotal,
|
||||
"used": stat.Used,
|
||||
})
|
||||
}
|
||||
|
||||
// 更新负载
|
||||
func (this *NodeStatusExecutor) updateLoad(status *nodeconfigs.NodeStatus) {
|
||||
stat, err := load.Avg()
|
||||
if err != nil {
|
||||
status.Error = err.Error()
|
||||
return
|
||||
}
|
||||
if stat == nil {
|
||||
status.Error = "load is nil"
|
||||
return
|
||||
}
|
||||
status.Load1m = stat.Load1
|
||||
status.Load5m = stat.Load5
|
||||
status.Load15m = stat.Load15
|
||||
|
||||
// 记录监控数据
|
||||
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemLoad, maps.Map{
|
||||
"load1m": status.Load1m,
|
||||
"load5m": status.Load5m,
|
||||
"load15m": status.Load15m,
|
||||
})
|
||||
}
|
||||
102
EdgeDNS/internal/nodes/node_status_executor_windows.go
Normal file
102
EdgeDNS/internal/nodes/node_status_executor_windows.go
Normal file
@@ -0,0 +1,102 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/shirou/gopsutil/v3/cpu"
|
||||
"github.com/shirou/gopsutil/v3/mem"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type WindowsLoadValue struct {
|
||||
Timestamp int64
|
||||
Value int
|
||||
}
|
||||
|
||||
var windowsLoadValues = []*WindowsLoadValue{}
|
||||
var windowsLoadLocker = &sync.Mutex{}
|
||||
|
||||
// 更新内存
|
||||
func (this *NodeStatusExecutor) updateMem(status *NodeStatus) {
|
||||
stat, err := mem.VirtualMemory()
|
||||
if err != nil {
|
||||
status.Error = err.Error()
|
||||
return
|
||||
}
|
||||
status.MemoryUsage = stat.UsedPercent
|
||||
status.MemoryTotal = stat.Total
|
||||
}
|
||||
|
||||
// 更新负载
|
||||
func (this *NodeStatusExecutor) updateLoad(status *NodeStatus) {
|
||||
timestamp := time.Now().Unix()
|
||||
|
||||
currentLoad := 0
|
||||
info, err := cpu.ProcInfo()
|
||||
if err == nil && len(info) > 0 && info[0].ProcessorQueueLength < 1000 {
|
||||
currentLoad = int(info[0].ProcessorQueueLength)
|
||||
}
|
||||
|
||||
// 删除15分钟之前的数据
|
||||
windowsLoadLocker.Lock()
|
||||
result := []*WindowsLoadValue{}
|
||||
for _, v := range windowsLoadValues {
|
||||
if timestamp-v.Timestamp > 15*60 {
|
||||
continue
|
||||
}
|
||||
result = append(result, v)
|
||||
}
|
||||
result = append(result, &WindowsLoadValue{
|
||||
Timestamp: timestamp,
|
||||
Value: currentLoad,
|
||||
})
|
||||
windowsLoadValues = result
|
||||
|
||||
total1 := 0
|
||||
count1 := 0
|
||||
total5 := 0
|
||||
count5 := 0
|
||||
total15 := 0
|
||||
count15 := 0
|
||||
for _, v := range result {
|
||||
if timestamp-v.Timestamp <= 60 {
|
||||
total1 += v.Value
|
||||
count1++
|
||||
}
|
||||
|
||||
if timestamp-v.Timestamp <= 300 {
|
||||
total5 += v.Value
|
||||
count5++
|
||||
}
|
||||
|
||||
total15 += v.Value
|
||||
count15++
|
||||
}
|
||||
|
||||
load1 := float64(0)
|
||||
load5 := float64(0)
|
||||
load15 := float64(0)
|
||||
if count1 > 0 {
|
||||
load1 = math.Round(float64(total1*100)/float64(count1)) / 100
|
||||
}
|
||||
if count5 > 0 {
|
||||
load5 = math.Round(float64(total5*100)/float64(count5)) / 100
|
||||
}
|
||||
if count15 > 0 {
|
||||
load15 = math.Round(float64(total15*100)/float64(count15)) / 100
|
||||
}
|
||||
|
||||
windowsLoadLocker.Unlock()
|
||||
|
||||
// 在老Windows上不显示错误
|
||||
if err == context.DeadlineExceeded {
|
||||
err = nil
|
||||
}
|
||||
status.Load1m = load1
|
||||
status.Load5m = load5
|
||||
status.Load15m = load15
|
||||
}
|
||||
125
EdgeDNS/internal/nodes/ns_access_log_queue.go
Normal file
125
EdgeDNS/internal/nodes/ns_access_log_queue.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/accesslogs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
var sharedNSAccessLogQueue = NewNSAccessLogQueue()
|
||||
|
||||
// NSAccessLogQueue NS访问日志队列
|
||||
type NSAccessLogQueue struct {
|
||||
queue chan *pb.NSAccessLog
|
||||
}
|
||||
|
||||
// NewNSAccessLogQueue 获取新对象
|
||||
func NewNSAccessLogQueue() *NSAccessLogQueue {
|
||||
// 队列中最大的值,超出此数量的访问日志会被抛弃
|
||||
// TODO 需要可以在界面中设置
|
||||
var maxSize = 20000
|
||||
queue := &NSAccessLogQueue{
|
||||
queue: make(chan *pb.NSAccessLog, maxSize),
|
||||
}
|
||||
go queue.Start()
|
||||
|
||||
return queue
|
||||
}
|
||||
|
||||
// Start 开始处理访问日志
|
||||
func (this *NSAccessLogQueue) Start() {
|
||||
var ticker = time.NewTicker(1 * time.Second)
|
||||
for range ticker.C {
|
||||
err := this.loop()
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("ACCESS_LOG_QUEUE", err.Error())
|
||||
} else {
|
||||
remotelogs.Error("ACCESS_LOG_QUEUE", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Push 加入新访问日志
|
||||
func (this *NSAccessLogQueue) Push(accessLog *pb.NSAccessLog) {
|
||||
select {
|
||||
case this.queue <- accessLog:
|
||||
default:
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// 上传访问日志
|
||||
func (this *NSAccessLogQueue) loop() error {
|
||||
var accessLogs = []*pb.NSAccessLog{}
|
||||
var count = 0
|
||||
var timestamp int64
|
||||
var requestId = 10_000_000
|
||||
|
||||
Loop:
|
||||
for {
|
||||
select {
|
||||
case accessLog := <-this.queue:
|
||||
if accessLog.Timestamp > timestamp {
|
||||
requestId = 10_000_000
|
||||
timestamp = accessLog.Timestamp
|
||||
} else {
|
||||
requestId++
|
||||
}
|
||||
|
||||
// timestamp + nodeId + requestId
|
||||
accessLog.RequestId = strconv.FormatInt(accessLog.Timestamp, 10) + strconv.FormatInt(accessLog.NsNodeId, 10) + strconv.Itoa(requestId)
|
||||
|
||||
accessLogs = append(accessLogs, accessLog)
|
||||
count++
|
||||
|
||||
// 每次只提交 N 条访问日志,防止网络拥堵
|
||||
if count > 2000 {
|
||||
break Loop
|
||||
}
|
||||
default:
|
||||
break Loop
|
||||
}
|
||||
}
|
||||
|
||||
if len(accessLogs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var clusterId int64
|
||||
var needWriteFile = true
|
||||
var needReportAPI = true
|
||||
if sharedNodeConfig != nil {
|
||||
clusterId = sharedNodeConfig.ClusterId
|
||||
if sharedNodeConfig.AccessLogWriteTargets != nil {
|
||||
targets := sharedNodeConfig.AccessLogWriteTargets
|
||||
needWriteFile = targets.File || targets.ClickHouse
|
||||
needReportAPI = targets.MySQL
|
||||
}
|
||||
}
|
||||
|
||||
if needWriteFile {
|
||||
accesslogs.SharedDNSFileWriter().WriteBatch(accessLogs, clusterId)
|
||||
}
|
||||
|
||||
if !needReportAPI {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 发送到API
|
||||
client, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = client.NSAccessLogRPC.CreateNSAccessLogs(client.Context(), &pb.CreateNSAccessLogsRequest{NsAccessLogs: accessLogs})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
932
EdgeDNS/internal/nodes/server.go
Normal file
932
EdgeDNS/internal/nodes/server.go
Normal file
@@ -0,0 +1,932 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/agents"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/models"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/stats"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/iwind/TeaGo/rands"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"github.com/miekg/dns"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var sharedRecursionDNSClient = &dns.Client{}
|
||||
|
||||
type httpContextKey struct {
|
||||
key string
|
||||
}
|
||||
|
||||
var HTTPConnContextKey = &httpContextKey{key: "http-conn"}
|
||||
|
||||
const PingDomain = "ping."
|
||||
|
||||
// Server 服务
|
||||
type Server struct {
|
||||
config *ServerConfig
|
||||
|
||||
rawServer *dns.Server
|
||||
httpsServer *http.Server
|
||||
}
|
||||
|
||||
// NewServer 构造新服务
|
||||
func NewServer(config *ServerConfig) (*Server, error) {
|
||||
var server = &Server{
|
||||
config: config,
|
||||
}
|
||||
err := server.init()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return server, nil
|
||||
}
|
||||
|
||||
// ListenAndServe 监听
|
||||
func (this *Server) ListenAndServe() error {
|
||||
if this.rawServer != nil {
|
||||
return this.rawServer.ListenAndServe()
|
||||
}
|
||||
if this.httpsServer != nil {
|
||||
listener, err := net.Listen("tcp", this.httpsServer.Addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = this.httpsServer.ServeTLS(listener, "", "")
|
||||
if err == http.ErrServerClosed {
|
||||
err = nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return errors.New("the server is not initialized")
|
||||
}
|
||||
|
||||
// Shutdown 关闭
|
||||
func (this *Server) Shutdown() error {
|
||||
if this.rawServer != nil {
|
||||
return this.rawServer.Shutdown()
|
||||
}
|
||||
if this.httpsServer != nil {
|
||||
return this.httpsServer.Shutdown(context.Background())
|
||||
}
|
||||
|
||||
return errors.New("the server is not initialized")
|
||||
}
|
||||
|
||||
// Reload 重载配置
|
||||
func (this *Server) Reload(config *ServerConfig) {
|
||||
this.config = config
|
||||
}
|
||||
|
||||
// 初始化
|
||||
func (this *Server) init() error {
|
||||
var rawServer = &dns.Server{
|
||||
ReadTimeout: 3 * time.Second,
|
||||
WriteTimeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
rawServer.Handler = dns.HandlerFunc(this.handleDNSMessage)
|
||||
var addr = ""
|
||||
if len(this.config.Host) > 0 {
|
||||
addr += configutils.QuoteIP(this.config.Host)
|
||||
}
|
||||
addr += ":" + types.String(this.config.Port)
|
||||
rawServer.Addr = addr
|
||||
|
||||
switch this.config.Protocol {
|
||||
case serverconfigs.ProtocolTCP:
|
||||
rawServer.Net = "tcp"
|
||||
case serverconfigs.ProtocolTLS:
|
||||
rawServer.Net = "tcp-tls"
|
||||
rawServer.TLSConfig = &tls.Config{
|
||||
Certificates: nil,
|
||||
GetConfigForClient: func(clientInfo *tls.ClientHelloInfo) (config *tls.Config, e error) {
|
||||
return this.config.SSLPolicy.TLSConfig(), nil
|
||||
},
|
||||
GetCertificate: func(clientInfo *tls.ClientHelloInfo) (certificate *tls.Certificate, e error) {
|
||||
return this.config.SSLPolicy.FirstCert(), nil
|
||||
},
|
||||
}
|
||||
case serverconfigs.ProtocolHTTPS: // DoH
|
||||
rawServer = nil
|
||||
this.httpsServer = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: http.HandlerFunc(this.handleHTTP),
|
||||
TLSConfig: &tls.Config{
|
||||
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
if this.config == nil {
|
||||
return nil, errors.New("invalid 'ServerConfig.config'")
|
||||
}
|
||||
if this.config.SSLPolicy == nil {
|
||||
return nil, errors.New("invalid 'ServerConfig.config.SSLPolicy'")
|
||||
}
|
||||
return this.config.SSLPolicy.TLSConfig(), nil
|
||||
},
|
||||
GetCertificate: func(clientInfo *tls.ClientHelloInfo) (certificate *tls.Certificate, e error) {
|
||||
if this.config == nil {
|
||||
return nil, errors.New("invalid 'ServerConfig.config'")
|
||||
}
|
||||
if this.config.SSLPolicy == nil {
|
||||
return nil, errors.New("invalid 'ServerConfig.config.SSLPolicy'")
|
||||
}
|
||||
return this.config.SSLPolicy.FirstCert(), nil
|
||||
},
|
||||
},
|
||||
ConnContext: func(ctx context.Context, c net.Conn) context.Context {
|
||||
return context.WithValue(ctx, HTTPConnContextKey, c)
|
||||
},
|
||||
ReadTimeout: 5 * time.Second,
|
||||
ReadHeaderTimeout: 3 * time.Second,
|
||||
WriteTimeout: 3 * time.Second,
|
||||
IdleTimeout: 75 * time.Second,
|
||||
MaxHeaderBytes: 4096,
|
||||
}
|
||||
case serverconfigs.ProtocolUDP:
|
||||
rawServer.Net = "udp"
|
||||
}
|
||||
|
||||
this.rawServer = rawServer
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 查询递归DNS
|
||||
func (this *Server) lookupRecursionDNS(req *dns.Msg, resp *dns.Msg) error {
|
||||
var config = sharedNodeConfig.RecursionConfig
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
if !config.IsOn {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 是否允许
|
||||
var domain = strings.TrimSuffix(req.Question[0].Name, ".")
|
||||
if len(config.DenyDomains) > 0 && configutils.MatchDomains(config.DenyDomains, domain) {
|
||||
return nil
|
||||
}
|
||||
if len(config.AllowDomains) > 0 && !configutils.MatchDomains(config.AllowDomains, domain) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if config.UseLocalHosts {
|
||||
// TODO 需要缓存文件内容
|
||||
resolveConfig, err := dns.ClientConfigFromFile("/etc/resolv.conf")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(resolveConfig.Servers) == 0 {
|
||||
return errors.New("no dns servers found in config file")
|
||||
}
|
||||
if len(resolveConfig.Port) == 0 {
|
||||
resolveConfig.Port = "53"
|
||||
}
|
||||
r, _, err := sharedRecursionDNSClient.Exchange(req, configutils.QuoteIP(resolveConfig.Servers[rands.Int(0, len(resolveConfig.Servers)-1)])+":"+resolveConfig.Port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp.Answer = r.Answer
|
||||
} else if len(config.Hosts) > 0 {
|
||||
var host = config.Hosts[rands.Int(0, len(config.Hosts)-1)]
|
||||
if host.Port <= 0 {
|
||||
host.Port = 53
|
||||
}
|
||||
r, _, err := sharedRecursionDNSClient.Exchange(req, configutils.QuoteIP(host.Host)+":"+types.String(host.Port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp.Answer = r.Answer
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 分析查询中的动作
|
||||
func (this *Server) parseAction(questionName string, remoteAddr *string) (string, error) {
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// TODO 需要防止恶意攻击
|
||||
var optionIndex = strings.Index(questionName, "-")
|
||||
if optionIndex > 0 {
|
||||
optionId := types.Int64(questionName[1:optionIndex])
|
||||
optionResp, err := rpcClient.NSQuestionOptionRPC.FindNSQuestionOption(rpcClient.Context(), &pb.FindNSQuestionOptionRequest{NsQuestionOptionId: optionId})
|
||||
if err != nil {
|
||||
return "", errors.New("query question option failed: " + err.Error())
|
||||
} else {
|
||||
var option = optionResp.NsQuestionOption
|
||||
if option != nil {
|
||||
switch option.Name {
|
||||
case "setRemoteAddr":
|
||||
var m = maps.Map{}
|
||||
err = json.Unmarshal(option.ValuesJSON, &m)
|
||||
if err != nil {
|
||||
return "", errors.New("decode question option failed: " + err.Error())
|
||||
} else {
|
||||
var ip = m.GetString("ip")
|
||||
*remoteAddr = ip
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
questionName = questionName[optionIndex+1:]
|
||||
}
|
||||
return questionName, nil
|
||||
}
|
||||
|
||||
// 记录日志
|
||||
func (this *Server) addLog(networking string, question dns.Question, domainId int64, routeCode string, record *models.NSRecord, isRecursive bool, writer dns.ResponseWriter, remoteAddr string, err error) {
|
||||
// 访问日志
|
||||
var accessLogRef = sharedNodeConfig.AccessLogRef
|
||||
if accessLogRef != nil && accessLogRef.IsOn {
|
||||
if domainId == 0 && !accessLogRef.LogMissingDomains {
|
||||
return
|
||||
}
|
||||
|
||||
if accessLogRef.MissingRecordsOnly && record != nil && len(record.Value) > 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var now = time.Now()
|
||||
var pbAccessLog = &pb.NSAccessLog{
|
||||
NsNodeId: sharedNodeConfig.Id,
|
||||
RemoteAddr: remoteAddr,
|
||||
NsDomainId: domainId,
|
||||
QuestionName: question.Name,
|
||||
QuestionType: dns.Type(question.Qtype).String(),
|
||||
IsRecursive: isRecursive,
|
||||
Networking: networking,
|
||||
ServerAddr: writer.LocalAddr().String(),
|
||||
Timestamp: now.Unix(),
|
||||
TimeLocal: now.Format("2/Jan/2006:15:04:05 -0700"),
|
||||
RequestId: "",
|
||||
}
|
||||
if record != nil {
|
||||
pbAccessLog.NsRecordId = record.Id
|
||||
if len(routeCode) > 0 {
|
||||
pbAccessLog.NsRouteCodes = []string{routeCode}
|
||||
}
|
||||
pbAccessLog.RecordName = record.Name
|
||||
pbAccessLog.RecordType = record.Type
|
||||
pbAccessLog.RecordValue = record.Value
|
||||
}
|
||||
if err != nil {
|
||||
pbAccessLog.Error = err.Error()
|
||||
}
|
||||
sharedNSAccessLogQueue.Push(pbAccessLog)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证TSIG
|
||||
func (this *Server) checkTSIG(msg *dns.Msg, domainId int64) error {
|
||||
var tsig = msg.IsTsig()
|
||||
if tsig == nil {
|
||||
return errors.New("tsig: tsig required")
|
||||
}
|
||||
|
||||
var keys = sharedKeyManager.FindKeysWithDomain(domainId)
|
||||
if len(keys) == 0 {
|
||||
return errors.New("tsig: no keys defined")
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
if key.Algo != tsig.Algorithm {
|
||||
continue
|
||||
}
|
||||
|
||||
// 需要重新Pack,每次Pack结果只能校验一次
|
||||
msgData, err := msg.Pack()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(msgData) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var base64Secret = key.Secret
|
||||
if key.SecretType == dnsconfigs.NSKeySecretTypeClear {
|
||||
base64Secret = base64.StdEncoding.EncodeToString([]byte(key.Secret))
|
||||
}
|
||||
err = dns.TsigVerify(msgData, base64Secret, "", false)
|
||||
if err != nil {
|
||||
continue
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return dns.ErrSig
|
||||
}
|
||||
|
||||
// 处理DNS请求
|
||||
func (this *Server) handleDNSMessage(writer dns.ResponseWriter, req *dns.Msg) {
|
||||
if len(req.Question) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if sharedDomainManager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var networking = ""
|
||||
if this.config != nil {
|
||||
networking = this.config.Protocol.String()
|
||||
}
|
||||
|
||||
var resultDomainId int64
|
||||
var resultRecordIds [][2]int64 // [] { domainId, recordId}
|
||||
|
||||
var resp = new(dns.Msg)
|
||||
resp.RecursionDesired = true
|
||||
resp.RecursionAvailable = true
|
||||
resp.SetReply(req)
|
||||
|
||||
resp.Answer = []dns.RR{}
|
||||
|
||||
var tsigIsChecked = false
|
||||
var remoteAddr = writer.RemoteAddr().String()
|
||||
|
||||
for _, question := range req.Question {
|
||||
if len(question.Name) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// PING
|
||||
if question.Name == PingDomain {
|
||||
resp.Answer = append(resp.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: question.Name,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: question.Qclass,
|
||||
Ttl: 60,
|
||||
},
|
||||
A: net.ParseIP("127.0.0.1"),
|
||||
})
|
||||
resp.Rcode = dns.RcodeSuccess
|
||||
err := writer.WriteMsg(resp)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 查询选项
|
||||
if question.Name[0] == '$' {
|
||||
_, port, _ := net.SplitHostPort(remoteAddr)
|
||||
|
||||
questionName, err := this.parseAction(question.Name, &remoteAddr)
|
||||
if err != nil {
|
||||
remotelogs.Error("SERVER", "invalid query option '"+question.Name+"'")
|
||||
continue
|
||||
}
|
||||
question.Name = questionName
|
||||
if len(port) > 0 {
|
||||
if strings.Contains(remoteAddr, ":") { // IPv6
|
||||
remoteAddr = "[" + remoteAddr + "]:" + port
|
||||
} else {
|
||||
remoteAddr += ":" + port
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var fullName = strings.TrimSuffix(question.Name, ".")
|
||||
var recordName string
|
||||
var recordType = dns.Type(question.Qtype).String()
|
||||
var domain *models.NSDomain
|
||||
domain, recordName = sharedDomainManager.SplitDomain(fullName)
|
||||
if domain == nil {
|
||||
// 检查递归DNS
|
||||
if sharedNodeConfig.RecursionConfig != nil && sharedNodeConfig.RecursionConfig.IsOn {
|
||||
err := this.lookupRecursionDNS(req, resp)
|
||||
if err != nil {
|
||||
this.addLog(networking, question, 0, "", nil, true, writer, remoteAddr, err)
|
||||
} else {
|
||||
var recordValue = ""
|
||||
if len(resp.Answer) > 0 {
|
||||
pieces := regexp.MustCompile(`\s+`).Split(resp.Answer[0].String(), 6)
|
||||
if len(pieces) >= 5 {
|
||||
recordValue = pieces[4]
|
||||
}
|
||||
}
|
||||
this.addLog(networking, question, 0, "", &models.NSRecord{
|
||||
Id: 0,
|
||||
Name: recordName,
|
||||
Type: recordType,
|
||||
Value: recordValue,
|
||||
Ttl: 0,
|
||||
}, true, writer, remoteAddr, nil)
|
||||
}
|
||||
|
||||
err = writer.WriteMsg(resp)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// 是否为NS记录,用于验证域名所有权
|
||||
if question.Qtype == dns.TypeNS {
|
||||
var hosts = sharedNodeConfig.Hosts
|
||||
var l = len(hosts)
|
||||
var record = &models.NSRecord{
|
||||
Id: 0,
|
||||
Type: dnsconfigs.RecordTypeNS,
|
||||
Ttl: 600, // TODO 可以设置
|
||||
}
|
||||
if l > 0 {
|
||||
l = 1 // 目前只返回一个
|
||||
|
||||
// 随机
|
||||
var indexes = []int{}
|
||||
for i := 0; i < l; i++ {
|
||||
indexes = append(indexes, i)
|
||||
}
|
||||
|
||||
rand.Shuffle(l, func(i, j int) {
|
||||
indexes[i], indexes[j] = indexes[j], indexes[i]
|
||||
})
|
||||
|
||||
record.Value = hosts[0] + "."
|
||||
for _, index := range indexes {
|
||||
resp.Answer = append(resp.Answer, &dns.NS{
|
||||
Hdr: record.ToRRHeader(question.Name, dns.TypeNS, question.Qclass),
|
||||
Ns: hosts[index] + ".",
|
||||
})
|
||||
}
|
||||
|
||||
this.addLog(networking, question, 0, "", record, false, writer, remoteAddr, nil)
|
||||
continue
|
||||
}
|
||||
|
||||
this.addLog(networking, question, 0, "", nil, false, writer, remoteAddr, nil)
|
||||
continue
|
||||
}
|
||||
|
||||
this.addLog(networking, question, 0, "", nil, false, writer, remoteAddr, nil)
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查TSIG
|
||||
if domain.TSIG != nil && domain.TSIG.IsOn && !tsigIsChecked {
|
||||
err := this.checkTSIG(req, domain.Id)
|
||||
if err != nil {
|
||||
this.addLog(networking, question, domain.Id, "", nil, false, writer, remoteAddr, err)
|
||||
continue
|
||||
}
|
||||
tsigIsChecked = true
|
||||
}
|
||||
|
||||
resultDomainId = domain.Id
|
||||
|
||||
var clientIP = remoteAddr
|
||||
clientHost, _, err := net.SplitHostPort(clientIP)
|
||||
if err == nil && len(clientHost) > 0 {
|
||||
clientIP = clientHost
|
||||
}
|
||||
|
||||
// 解析Agent
|
||||
if sharedNodeConfig.DetectAgents {
|
||||
agents.SharedQueue.Push(clientIP)
|
||||
}
|
||||
|
||||
var routeCodes = sharedRouteManager.FindRouteCodes(clientIP, domain.UserId)
|
||||
|
||||
var records []*models.NSRecord
|
||||
var matchedRouteCode string
|
||||
if question.Qtype == dns.TypeSOA { // SOA
|
||||
if len(recordName) == 0 { // 只有顶级域名才有SOA记录
|
||||
records = []*models.NSRecord{
|
||||
{
|
||||
Id: 0,
|
||||
Type: dnsconfigs.RecordTypeSOA,
|
||||
Ttl: 600, // TODO 可以设置
|
||||
},
|
||||
}
|
||||
}
|
||||
} else if question.Qtype == dns.TypeNS { // NS
|
||||
if len(recordName) == 0 { // 只有顶级域名才有NS记录
|
||||
records = []*models.NSRecord{
|
||||
{
|
||||
Id: 0,
|
||||
Type: dnsconfigs.RecordTypeNS,
|
||||
Ttl: 600, // TODO 可以设置
|
||||
},
|
||||
}
|
||||
}
|
||||
} else if question.Qtype != dns.TypeCNAME {
|
||||
// 是否有直接的设置
|
||||
records, matchedRouteCode = sharedRecordManager.FindRecords(domain.Id, routeCodes, recordName, recordType, true)
|
||||
|
||||
// 检查CNAME
|
||||
if len(records) == 0 {
|
||||
records, matchedRouteCode = sharedRecordManager.FindRecords(domain.Id, routeCodes, recordName, dnsconfigs.RecordTypeCNAME, false)
|
||||
if len(records) > 0 {
|
||||
question.Qtype = dns.TypeCNAME
|
||||
}
|
||||
}
|
||||
|
||||
// 再次尝试查找默认设置
|
||||
if len(records) == 0 {
|
||||
records, matchedRouteCode = sharedRecordManager.FindRecords(domain.Id, routeCodes, recordName, recordType, false)
|
||||
}
|
||||
}
|
||||
|
||||
if len(records) == 0 {
|
||||
records, matchedRouteCode = sharedRecordManager.FindRecords(domain.Id, routeCodes, recordName, recordType, false)
|
||||
}
|
||||
|
||||
// 对 NS.example.com NS|SOA 处理
|
||||
if (question.Qtype == dns.TypeNS || (question.Qtype == dns.TypeSOA && len(records) == 0)) && lists.ContainsString(sharedNodeConfig.Hosts, fullName) {
|
||||
var recordDNSType string
|
||||
switch question.Qtype {
|
||||
case dns.TypeNS:
|
||||
recordDNSType = dnsconfigs.RecordTypeNS
|
||||
case dns.TypeSOA:
|
||||
recordDNSType = dnsconfigs.RecordTypeSOA
|
||||
}
|
||||
this.composeSOAAnswer(question, &models.NSRecord{
|
||||
Type: recordDNSType,
|
||||
Ttl: 600,
|
||||
}, resp)
|
||||
}
|
||||
|
||||
if len(records) > 0 {
|
||||
var firstRecord = records[0]
|
||||
|
||||
for _, record := range records {
|
||||
resultRecordIds = append(resultRecordIds, [2]int64{record.DomainId, record.Id})
|
||||
|
||||
switch record.Type {
|
||||
case dnsconfigs.RecordTypeA:
|
||||
var answer = record.ToRRAnswer(question.Name, question.Qclass)
|
||||
if answer != nil {
|
||||
resp.Answer = append(resp.Answer, answer)
|
||||
}
|
||||
case dnsconfigs.RecordTypeCNAME:
|
||||
var value = record.Value
|
||||
if !strings.HasSuffix(value, ".") {
|
||||
value += "."
|
||||
}
|
||||
var lastRecordValue = value
|
||||
resp.Answer = append(resp.Answer, &dns.CNAME{
|
||||
Hdr: record.ToRRHeader(question.Name, dns.TypeCNAME, question.Qclass),
|
||||
Target: value,
|
||||
})
|
||||
|
||||
// 继续查询CNAME
|
||||
var allCNAMEValues = []string{lastRecordValue}
|
||||
for {
|
||||
// 限制最深32层
|
||||
if len(allCNAMEValues) > 32 {
|
||||
break
|
||||
}
|
||||
|
||||
cnameDomain, cnameRecordName := sharedDomainManager.SplitDomain(lastRecordValue)
|
||||
if cnameDomain == nil {
|
||||
break
|
||||
}
|
||||
cnameRecords, _ := sharedRecordManager.FindRecords(cnameDomain.Id, sharedRouteManager.FindRouteCodes(clientIP, cnameDomain.UserId), cnameRecordName, dnsconfigs.RecordTypeCNAME, false)
|
||||
if len(cnameRecords) == 0 {
|
||||
break
|
||||
}
|
||||
var cnameRecord = cnameRecords[0]
|
||||
if !lists.ContainsString(allCNAMEValues, cnameRecord.Value) {
|
||||
resultRecordIds = append(resultRecordIds, [2]int64{cnameRecord.DomainId, cnameRecord.Id}) // 统计
|
||||
var answer = cnameRecord.ToRRAnswer(lastRecordValue, question.Qclass)
|
||||
if answer == nil {
|
||||
break
|
||||
}
|
||||
resp.Answer = append(resp.Answer, answer)
|
||||
lastRecordValue = cnameRecord.Value
|
||||
allCNAMEValues = append(allCNAMEValues, lastRecordValue)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 再次查询原始问题
|
||||
if len(req.Question) > 0 {
|
||||
var firstQuestion = req.Question[0]
|
||||
if firstQuestion.Qtype != dns.TypeCNAME {
|
||||
finalDomain, finalRecordName := sharedDomainManager.SplitDomain(lastRecordValue)
|
||||
if finalDomain != nil {
|
||||
var realRecords, _ = sharedRecordManager.FindRecords(finalDomain.Id, sharedRouteManager.FindRouteCodes(clientIP, finalDomain.UserId), finalRecordName, dns.Type(firstQuestion.Qtype).String(), false)
|
||||
if len(realRecords) > 0 {
|
||||
for _, realRecord := range realRecords {
|
||||
resultRecordIds = append(resultRecordIds, [2]int64{realRecord.DomainId, realRecord.Id}) // 统计
|
||||
var answer = realRecord.ToRRAnswer(lastRecordValue, question.Qclass)
|
||||
if answer != nil {
|
||||
resp.Answer = append(resp.Answer, answer)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case dnsconfigs.RecordTypeAAAA:
|
||||
var answer = record.ToRRAnswer(question.Name, question.Qclass)
|
||||
if answer != nil {
|
||||
resp.Answer = append(resp.Answer, answer)
|
||||
}
|
||||
case dnsconfigs.RecordTypeNS:
|
||||
if record.Id == 0 {
|
||||
var hosts = sharedNodeConfig.Hosts
|
||||
var l = len(hosts)
|
||||
if l > 0 {
|
||||
// 随机
|
||||
var indexes = []int{}
|
||||
for i := 0; i < l; i++ {
|
||||
indexes = append(indexes, i)
|
||||
}
|
||||
|
||||
rand.Shuffle(l, func(i, j int) {
|
||||
indexes[i], indexes[j] = indexes[j], indexes[i]
|
||||
})
|
||||
|
||||
record.Value = hosts[0] + "."
|
||||
for _, index := range indexes {
|
||||
resp.Answer = append(resp.Answer, &dns.NS{
|
||||
Hdr: record.ToRRHeader(question.Name, dns.TypeNS, question.Qclass),
|
||||
Ns: hosts[index] + ".",
|
||||
})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
var value = record.Value
|
||||
if !strings.HasSuffix(value, ".") {
|
||||
value += "."
|
||||
}
|
||||
resp.Answer = append(resp.Answer, &dns.NS{
|
||||
Hdr: record.ToRRHeader(question.Name, dns.TypeNS, question.Qclass),
|
||||
Ns: value,
|
||||
})
|
||||
}
|
||||
case dnsconfigs.RecordTypeMX:
|
||||
var answer = record.ToRRAnswer(question.Name, question.Qclass)
|
||||
if answer != nil {
|
||||
resp.Answer = append(resp.Answer, answer)
|
||||
}
|
||||
case dnsconfigs.RecordTypeSRV:
|
||||
var answer = record.ToRRAnswer(question.Name, question.Qclass)
|
||||
if answer != nil {
|
||||
resp.Answer = append(resp.Answer, answer)
|
||||
}
|
||||
case dnsconfigs.RecordTypeTXT:
|
||||
var answer = record.ToRRAnswer(question.Name, question.Qclass)
|
||||
if answer != nil {
|
||||
resp.Answer = append(resp.Answer, answer)
|
||||
}
|
||||
case dnsconfigs.RecordTypeCAA:
|
||||
var answer = record.ToRRAnswer(question.Name, question.Qclass)
|
||||
if answer != nil {
|
||||
resp.Answer = append(resp.Answer, answer)
|
||||
}
|
||||
case dnsconfigs.RecordTypeSOA:
|
||||
this.composeSOAAnswer(question, record, resp)
|
||||
}
|
||||
}
|
||||
|
||||
// 访问日志
|
||||
this.addLog(networking, question, resultDomainId, matchedRouteCode, firstRecord, false, writer, remoteAddr, nil)
|
||||
} else {
|
||||
this.addLog(networking, question, resultDomainId, "", nil, false, writer, remoteAddr, nil)
|
||||
}
|
||||
}
|
||||
|
||||
resp.Rcode = dns.RcodeSuccess
|
||||
err := writer.WriteMsg(resp)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 统计
|
||||
for _, resultRecordId := range resultRecordIds {
|
||||
stats.SharedManager.Add(resultRecordId[0], resultRecordId[1], int64(resp.Len()))
|
||||
}
|
||||
}
|
||||
|
||||
// 处理HTTP请求
|
||||
// 参考:https://datatracker.ietf.org/doc/html/rfc8484
|
||||
// 参考:https://developers.google.com/speed/public-dns/docs/doh
|
||||
func (this *Server) handleHTTP(writer http.ResponseWriter, req *http.Request) {
|
||||
if req.URL.Path == "/dns-query" {
|
||||
this.handleHTTPDNSMessage(writer, req)
|
||||
return
|
||||
}
|
||||
|
||||
if req.URL.Path == "/resolve" {
|
||||
this.handleHTTPJSONAPI(writer, req)
|
||||
return
|
||||
}
|
||||
|
||||
writer.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
|
||||
func (this *Server) handleHTTPDNSMessage(writer http.ResponseWriter, req *http.Request) {
|
||||
const maxMessageSize = 512
|
||||
|
||||
writer.Header().Set("Accept", "application/dns-message")
|
||||
|
||||
if req.Method != http.MethodGet && req.Method != http.MethodPost {
|
||||
writer.WriteHeader(http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
|
||||
if req.ContentLength > maxMessageSize {
|
||||
writer.WriteHeader(http.StatusRequestEntityTooLarge)
|
||||
return
|
||||
}
|
||||
|
||||
var messageData []byte
|
||||
|
||||
switch req.Method {
|
||||
case http.MethodGet:
|
||||
if len(req.URL.RawQuery) > maxMessageSize {
|
||||
writer.WriteHeader(http.StatusRequestURITooLong)
|
||||
return
|
||||
}
|
||||
var encodedMessage = req.URL.Query().Get("dns")
|
||||
var err error
|
||||
messageData, err = base64.StdEncoding.DecodeString(encodedMessage)
|
||||
if err != nil {
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
case http.MethodPost:
|
||||
var contentType = req.Header.Get("Content-Type")
|
||||
if contentType != "application/dns-message" {
|
||||
writer.WriteHeader(http.StatusUnsupportedMediaType)
|
||||
return
|
||||
}
|
||||
data, err := io.ReadAll(io.LimitReader(req.Body, maxMessageSize))
|
||||
if err != nil {
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
messageData = data
|
||||
}
|
||||
|
||||
if len(messageData) == 0 {
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var msg = &dns.Msg{}
|
||||
err := msg.Unpack(messageData)
|
||||
if err != nil {
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var connValue = req.Context().Value(HTTPConnContextKey)
|
||||
if connValue == nil {
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
conn, ok := connValue.(net.Conn)
|
||||
if !ok {
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
this.handleDNSMessage(NewHTTPWriter(writer, conn, "application/dns-message"), msg)
|
||||
}
|
||||
|
||||
func (this *Server) handleHTTPJSONAPI(writer http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != http.MethodGet {
|
||||
writer.WriteHeader(http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
|
||||
var query = req.URL.Query()
|
||||
var name = strings.TrimSpace(query.Get("name"))
|
||||
var recordTypeString = strings.ToUpper(strings.TrimSpace(query.Get("type")))
|
||||
|
||||
if len(name) == 0 {
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = writer.Write([]byte("invalid 'name' parameter"))
|
||||
return
|
||||
}
|
||||
|
||||
// add '.' to name
|
||||
if !strings.HasSuffix(name, ".") {
|
||||
name += "."
|
||||
}
|
||||
|
||||
if len(recordTypeString) == 0 {
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = writer.Write([]byte("invalid 'type' parameter"))
|
||||
return
|
||||
}
|
||||
|
||||
var recordType uint16
|
||||
if regexp.MustCompile(`^\d{1,4}$`).MatchString(recordTypeString) {
|
||||
recordType = types.Uint16(recordTypeString)
|
||||
} else {
|
||||
recordType = dns.StringToType[recordTypeString]
|
||||
}
|
||||
if recordType == 0 {
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = writer.Write([]byte("invalid 'type' parameter"))
|
||||
return
|
||||
}
|
||||
|
||||
_, ok := dns.TypeToString[recordType]
|
||||
if !ok {
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = writer.Write([]byte("invalid 'type' parameter"))
|
||||
return
|
||||
}
|
||||
|
||||
var msg = &dns.Msg{}
|
||||
msg.Question = []dns.Question{
|
||||
{
|
||||
Name: name,
|
||||
Qtype: recordType,
|
||||
Qclass: dns.ClassINET,
|
||||
},
|
||||
}
|
||||
|
||||
var connValue = req.Context().Value(HTTPConnContextKey)
|
||||
if connValue == nil {
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// conn
|
||||
conn, ok := connValue.(net.Conn)
|
||||
if !ok {
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
this.handleDNSMessage(NewHTTPWriter(writer, conn, "application/x-javascript"), msg)
|
||||
}
|
||||
|
||||
// 组合SOA回复信息
|
||||
func (this *Server) composeSOAAnswer(question dns.Question, record *models.NSRecord, resp *dns.Msg) {
|
||||
var config = sharedNodeConfig.SOA
|
||||
var serial = sharedNodeConfig.SOASerial
|
||||
|
||||
if config == nil {
|
||||
config = dnsconfigs.DefaultNSSOAConfig()
|
||||
}
|
||||
|
||||
var mName = config.MName
|
||||
if len(mName) == 0 {
|
||||
var hosts = sharedNodeConfig.Hosts
|
||||
var l = len(hosts)
|
||||
if l > 0 {
|
||||
var index = rands.Int(0, l-1)
|
||||
mName = hosts[index]
|
||||
}
|
||||
}
|
||||
|
||||
var rName = config.RName
|
||||
if len(rName) == 0 {
|
||||
rName = sharedNodeConfig.Email
|
||||
}
|
||||
rName = strings.ReplaceAll(rName, "@", ".")
|
||||
|
||||
if len(mName) > 0 && len(rName) > 0 {
|
||||
// 设置记录值
|
||||
record.Value = mName + "."
|
||||
|
||||
resp.Answer = append(resp.Answer, &dns.SOA{
|
||||
Hdr: record.ToRRHeader(question.Name, dns.TypeSOA, question.Qclass),
|
||||
Ns: mName + ".",
|
||||
Mbox: rName + ".",
|
||||
Serial: serial,
|
||||
Refresh: config.RefreshSeconds,
|
||||
Retry: config.RetrySeconds,
|
||||
Expire: config.ExpireSeconds,
|
||||
Minttl: config.MinimumTTL,
|
||||
})
|
||||
}
|
||||
}
|
||||
22
EdgeDNS/internal/nodes/server_config.go
Normal file
22
EdgeDNS/internal/nodes/server_config.go
Normal file
@@ -0,0 +1,22 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
)
|
||||
|
||||
// ServerConfig 服务配置
|
||||
type ServerConfig struct {
|
||||
Protocol serverconfigs.Protocol
|
||||
Host string
|
||||
Port int
|
||||
SSLPolicy *sslconfigs.SSLPolicy
|
||||
}
|
||||
|
||||
// FullAddr 服务地址
|
||||
func (this *ServerConfig) FullAddr() string {
|
||||
return this.Protocol.String() + "://" + this.Host + ":" + types.String(this.Port)
|
||||
}
|
||||
114
EdgeDNS/internal/nodes/task_sync_api_nodes.go
Normal file
114
EdgeDNS/internal/nodes/task_sync_api_nodes.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/configs"
|
||||
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/events"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/utils"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"time"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if !teaconst.IsMain {
|
||||
return
|
||||
}
|
||||
|
||||
events.On(events.EventStart, func() {
|
||||
task := NewSyncAPINodesTask()
|
||||
go task.Start()
|
||||
})
|
||||
}
|
||||
|
||||
// SyncAPINodesTask API节点同步任务
|
||||
type SyncAPINodesTask struct {
|
||||
}
|
||||
|
||||
func NewSyncAPINodesTask() *SyncAPINodesTask {
|
||||
return &SyncAPINodesTask{}
|
||||
}
|
||||
|
||||
func (this *SyncAPINodesTask) Start() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
if Tea.IsTesting() {
|
||||
// 快速测试
|
||||
ticker = time.NewTicker(1 * time.Minute)
|
||||
}
|
||||
events.On(events.EventQuit, func() {
|
||||
remotelogs.Println("SYNC_API_NODES_TASK", "quit task")
|
||||
ticker.Stop()
|
||||
})
|
||||
for range ticker.C {
|
||||
err := this.Loop()
|
||||
if err != nil {
|
||||
logs.Println("[TASK][SYNC_API_NODES_TASK]" + err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (this *SyncAPINodesTask) Loop() error {
|
||||
// 如果有节点定制的API节点地址
|
||||
var hasCustomizedAPINodeAddrs = sharedNodeConfig != nil && len(sharedNodeConfig.APINodeAddrs) > 0
|
||||
|
||||
config, err := configs.LoadAPIConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 是否禁止自动升级
|
||||
if config.RPCDisableUpdate {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 获取所有可用的节点
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := rpcClient.APINodeRPC.FindAllEnabledAPINodes(rpcClient.Context(), &pb.FindAllEnabledAPINodesRequest{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var newEndpoints = []string{}
|
||||
for _, node := range resp.ApiNodes {
|
||||
if !node.IsOn {
|
||||
continue
|
||||
}
|
||||
newEndpoints = append(newEndpoints, node.AccessAddrs...)
|
||||
}
|
||||
|
||||
// 和现有的对比
|
||||
if utils.EqualStrings(newEndpoints, config.RPCEndpoints) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 测试是否有API节点可用
|
||||
var hasOk = rpcClient.TestEndpoints(newEndpoints)
|
||||
if !hasOk {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 修改RPC对象配置
|
||||
config.RPCEndpoints = newEndpoints
|
||||
|
||||
// 更新当前RPC
|
||||
if !hasCustomizedAPINodeAddrs {
|
||||
err = rpcClient.UpdateConfig(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 保存到文件
|
||||
err = config.WriteFile(Tea.ConfigFile(configs.ConfigFileName))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
300
EdgeDNS/internal/nodes/upgrade_manager.go
Normal file
300
EdgeDNS/internal/nodes/upgrade_manager.go
Normal file
@@ -0,0 +1,300 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/events"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/utils"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
stringutil "github.com/iwind/TeaGo/utils/string"
|
||||
"github.com/iwind/gosock/pkg/gosock"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if !teaconst.IsMain {
|
||||
return
|
||||
}
|
||||
|
||||
events.On(events.EventStart, func() {
|
||||
go func() {
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
remotelogs.Error("UPGRADE_MANAGER", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var ticker = time.NewTicker(1 * time.Minute)
|
||||
for range ticker.C {
|
||||
resp, err := rpcClient.NSNodeRPC.CheckNSNodeLatestVersion(rpcClient.Context(), &pb.CheckNSNodeLatestVersionRequest{
|
||||
Os: runtime.GOOS,
|
||||
Arch: runtime.GOARCH,
|
||||
CurrentVersion: teaconst.Version,
|
||||
})
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("UPGRADE_MANAGER", err.Error())
|
||||
} else {
|
||||
remotelogs.Error("UPGRADE_MANAGER", err.Error())
|
||||
}
|
||||
continue
|
||||
}
|
||||
if resp.HasNewVersion {
|
||||
sharedUpgradeManager.Start()
|
||||
}
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
var sharedUpgradeManager = NewUpgradeManager()
|
||||
|
||||
// UpgradeManager 节点升级管理器
|
||||
// TODO 需要在集群中设置是否自动更新
|
||||
type UpgradeManager struct {
|
||||
isInstalling bool
|
||||
lastFile string
|
||||
exe string
|
||||
}
|
||||
|
||||
// NewUpgradeManager 获取新对象
|
||||
func NewUpgradeManager() *UpgradeManager {
|
||||
return &UpgradeManager{}
|
||||
}
|
||||
|
||||
// Start 启动升级
|
||||
func (this *UpgradeManager) Start() {
|
||||
// 必须放在文件解压之前读取可执行文件路径,防止解析之后,当前的可执行文件路径发生改变
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
remotelogs.Error("UPGRADE_MANAGER", "can not find current executable file name")
|
||||
return
|
||||
}
|
||||
this.exe = exe
|
||||
|
||||
// 测试环境下不更新
|
||||
if Tea.IsTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
if this.isInstalling {
|
||||
return
|
||||
}
|
||||
this.isInstalling = true
|
||||
|
||||
// 还原安装状态
|
||||
defer func() {
|
||||
this.isInstalling = false
|
||||
}()
|
||||
|
||||
remotelogs.Println("UPGRADE_MANAGER", "upgrading dns node ...")
|
||||
err = this.install()
|
||||
if err != nil {
|
||||
remotelogs.Error("UPGRADE_MANAGER", "download failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
remotelogs.Println("UPGRADE_MANAGER", "upgrade successfully")
|
||||
|
||||
go func() {
|
||||
err = this.restart()
|
||||
if err != nil {
|
||||
logs.Println("UPGRADE_MANAGER", err.Error())
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (this *UpgradeManager) install() error {
|
||||
// 检查是否有已下载但未安装成功的
|
||||
if len(this.lastFile) > 0 {
|
||||
_, err := os.Stat(this.lastFile)
|
||||
if err == nil {
|
||||
err = this.unzip(this.lastFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
this.lastFile = ""
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// 创建临时文件
|
||||
dir := Tea.Root + "/tmp"
|
||||
_, err := os.Stat(dir)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
err = os.Mkdir(dir, 0777)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
remotelogs.Println("UPGRADE_MANAGER", "downloading new node ...")
|
||||
|
||||
path := dir + "/" + teaconst.ProcessName + ".tmp"
|
||||
fp, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0777)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
isClosed := false
|
||||
defer func() {
|
||||
if !isClosed {
|
||||
_ = fp.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
client, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var offset int64
|
||||
var h = md5.New()
|
||||
var sum = ""
|
||||
var filename = ""
|
||||
for {
|
||||
resp, err := client.NSNodeRPC.DownloadNSNodeInstallationFile(client.Context(), &pb.DownloadNSNodeInstallationFileRequest{
|
||||
Os: runtime.GOOS,
|
||||
Arch: runtime.GOARCH,
|
||||
ChunkOffset: offset,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(resp.Sum) == 0 {
|
||||
return nil
|
||||
}
|
||||
sum = resp.Sum
|
||||
filename = resp.Filename
|
||||
if stringutil.VersionCompare(resp.Version, teaconst.Version) <= 0 {
|
||||
return nil
|
||||
}
|
||||
if len(resp.ChunkData) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// 写入文件
|
||||
_, err = fp.Write(resp.ChunkData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = h.Write(resp.ChunkData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
offset = resp.Offset
|
||||
}
|
||||
|
||||
if len(filename) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
isClosed = true
|
||||
err = fp.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if fmt.Sprintf("%x", h.Sum(nil)) != sum {
|
||||
_ = os.Remove(path)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 改成zip
|
||||
zipPath := dir + "/" + filename
|
||||
err = os.Rename(path, zipPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
this.lastFile = zipPath
|
||||
|
||||
// 解压
|
||||
err = this.unzip(zipPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 解压
|
||||
func (this *UpgradeManager) unzip(zipPath string) error {
|
||||
var isOk = false
|
||||
defer func() {
|
||||
if isOk {
|
||||
// 只有解压并覆盖成功后才会删除
|
||||
_ = os.Remove(zipPath)
|
||||
}
|
||||
}()
|
||||
|
||||
// 解压
|
||||
var target = Tea.Root
|
||||
if Tea.IsTesting() {
|
||||
// 测试环境下只解压在tmp目录
|
||||
target = Tea.Root + "/tmp"
|
||||
}
|
||||
|
||||
// 先改先前的可执行文件
|
||||
err := os.Rename(target+"/bin/"+teaconst.ProcessName, target+"/bin/."+teaconst.ProcessName+".dist")
|
||||
hasBackup := err == nil
|
||||
defer func() {
|
||||
if !isOk && hasBackup {
|
||||
// 失败时还原
|
||||
_ = os.Rename(target+"/bin/."+teaconst.ProcessName+".dist", target+"/bin/"+teaconst.ProcessName)
|
||||
}
|
||||
}()
|
||||
|
||||
unzip := utils.NewUnzip(zipPath, target, teaconst.ProcessName+"/")
|
||||
err = unzip.Run()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
isOk = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 重启
|
||||
func (this *UpgradeManager) restart() error {
|
||||
// 关闭当前sock,防止无法重启
|
||||
_ = gosock.NewTmpSock(teaconst.ProcessName).Close()
|
||||
|
||||
// 重新启动
|
||||
if DaemonIsOn && DaemonPid == os.Getppid() {
|
||||
os.Exit(0) // TODO 试着更优雅重启
|
||||
} else {
|
||||
// quit
|
||||
events.Notify(events.EventQuit)
|
||||
|
||||
// 启动
|
||||
var exe = filepath.Dir(this.exe) + "/" + teaconst.ProcessName
|
||||
|
||||
logs.Println("restarting ...", exe)
|
||||
cmd := exec.Command(exe, "start")
|
||||
err := cmd.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 退出当前进程
|
||||
time.Sleep(1 * time.Second)
|
||||
os.Exit(0)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
16
EdgeDNS/internal/nodes/upgrade_manager_test.go
Normal file
16
EdgeDNS/internal/nodes/upgrade_manager_test.go
Normal file
@@ -0,0 +1,16 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUpgradeManager_install(t *testing.T) {
|
||||
err := NewUpgradeManager().install()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
Reference in New Issue
Block a user