Initial commit (code only without large binaries)

This commit is contained in:
robin
2026-02-15 18:58:44 +08:00
commit 35df75498f
9442 changed files with 1495866 additions and 0 deletions

View File

@@ -0,0 +1,256 @@
package nodes
import (
"bytes"
"context"
"encoding/json"
"github.com/TeaOSLab/EdgeCommon/pkg/errors"
"github.com/TeaOSLab/EdgeCommon/pkg/messageconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
"github.com/TeaOSLab/EdgeDNS/internal/events"
"github.com/TeaOSLab/EdgeDNS/internal/firewalls"
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
"github.com/TeaOSLab/EdgeDNS/internal/utils"
executils "github.com/TeaOSLab/EdgeDNS/internal/utils/exec"
"github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/maps"
"os/exec"
"regexp"
"runtime"
"strconv"
"time"
)
type APIStream struct {
stream pb.NSNodeService_NsNodeStreamClient
isQuiting bool
cancelFunc context.CancelFunc
}
func NewAPIStream() *APIStream {
return &APIStream{}
}
func (this *APIStream) Start() {
events.On(events.EventQuit, func() {
this.isQuiting = true
if this.cancelFunc != nil {
this.cancelFunc()
}
})
for {
if this.isQuiting {
return
}
err := this.loop()
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("API_STREAM", err.Error())
} else {
remotelogs.Error("API_STREAM", err.Error())
}
time.Sleep(10 * time.Second)
continue
}
time.Sleep(1 * time.Second)
}
}
func (this *APIStream) loop() error {
rpcClient, err := rpc.SharedRPC()
if err != nil {
return errors.Wrap(err)
}
ctx, cancelFunc := context.WithCancel(rpcClient.Context())
this.cancelFunc = cancelFunc
defer func() {
cancelFunc()
}()
nodeStream, err := rpcClient.NSNodeRPC.NsNodeStream(ctx)
if err != nil {
if this.isQuiting {
return nil
}
return errors.Wrap(err)
}
this.stream = nodeStream
for {
if this.isQuiting {
logs.Println("API_STREAM", "quit")
break
}
message, err := nodeStream.Recv()
if err != nil {
if this.isQuiting {
remotelogs.Println("API_STREAM", "quit")
return nil
}
return errors.Wrap(err)
}
// 处理消息
switch message.Code {
case messageconfigs.NSMessageCodeConnectedAPINode: // 连接API节点成功
err = this.handleConnectedAPINode(message)
case messageconfigs.NSMessageCodeNewNodeTask: // 有新的任务
err = this.handleNewNodeTask(message)
case messageconfigs.NSMessageCodeCheckSystemdService: // 检查Systemd服务
err = this.handleCheckSystemdService(message)
case messageconfigs.MessageCodeCheckLocalFirewall: // 检查本地防火墙
err = this.handleCheckLocalFirewall(message)
default:
err = this.handleUnknownMessage(message)
}
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("API_STREAM", "handle message failed: "+err.Error())
} else {
remotelogs.Error("API_STREAM", "handle message failed: "+err.Error())
}
}
}
return nil
}
// 连接API节点成功
func (this *APIStream) handleConnectedAPINode(message *pb.NSNodeStreamMessage) error {
// 更改连接的APINode信息
if len(message.DataJSON) == 0 {
return nil
}
msg := &messageconfigs.ConnectedAPINodeMessage{}
err := json.Unmarshal(message.DataJSON, msg)
if err != nil {
return errors.Wrap(err)
}
_, err = rpc.SharedRPC()
if err != nil {
return errors.Wrap(err)
}
remotelogs.Println("API_STREAM", "connected to api node '"+strconv.FormatInt(msg.APINodeId, 10)+"'")
return nil
}
// 处理配置变化
func (this *APIStream) handleNewNodeTask(message *pb.NSNodeStreamMessage) error {
select {
case nodeTaskNotify <- true:
default:
}
this.replyOk(message.RequestId, "ok")
return nil
}
// 检查Systemd服务
func (this *APIStream) handleCheckSystemdService(message *pb.NSNodeStreamMessage) error {
systemctl, err := executils.LookPath("systemctl")
if err != nil {
this.replyFail(message.RequestId, "'systemctl' not found")
return nil
}
if len(systemctl) == 0 {
this.replyFail(message.RequestId, "'systemctl' not found")
return nil
}
cmd := utils.NewCommandExecutor()
shortName := teaconst.SystemdServiceName
cmd.Add(systemctl, "is-enabled", shortName)
output, err := cmd.Run()
if err != nil {
this.replyFail(message.RequestId, "'systemctl' command error: "+err.Error())
return nil
}
if output == "enabled" {
this.replyOk(message.RequestId, "ok")
} else {
this.replyFail(message.RequestId, "not installed")
}
return nil
}
// 检查本地防火墙
func (this *APIStream) handleCheckLocalFirewall(message *pb.NSNodeStreamMessage) error {
var dataMessage = &messageconfigs.CheckLocalFirewallMessage{}
err := json.Unmarshal(message.DataJSON, dataMessage)
if err != nil {
this.replyFail(message.RequestId, "decode message data failed: "+err.Error())
return nil
}
// nft
if dataMessage.Name == "nftables" {
if runtime.GOOS != "linux" {
this.replyFail(message.RequestId, "not Linux system")
return nil
}
nft, err := executils.LookPath("nft")
if err != nil {
this.replyFail(message.RequestId, "'nft' not found: "+err.Error())
return nil
}
var cmd = exec.Command(nft, "--version")
var output = &bytes.Buffer{}
cmd.Stdout = output
err = cmd.Run()
if err != nil {
this.replyFail(message.RequestId, "get version failed: "+err.Error())
return nil
}
var outputString = output.String()
var versionMatches = regexp.MustCompile(`nftables v([\d.]+)`).FindStringSubmatch(outputString)
if len(versionMatches) <= 1 {
this.replyFail(message.RequestId, "can not get nft version")
return nil
}
var version = versionMatches[1]
var result = maps.Map{
"version": version,
}
var protectionConfig = sharedNodeConfig.DDoSProtection
err = firewalls.SharedDDoSProtectionManager.Apply(protectionConfig)
if err != nil {
this.replyFail(message.RequestId, dataMessage.Name+"was installed, but apply DDoS protection config failed: "+err.Error())
} else {
this.replyOk(message.RequestId, string(result.AsJSON()))
}
} else {
this.replyFail(message.RequestId, "invalid firewall name '"+dataMessage.Name+"'")
}
return nil
}
// 处理未知消息
func (this *APIStream) handleUnknownMessage(message *pb.NSNodeStreamMessage) error {
this.replyFail(message.RequestId, "unknown message code '"+message.Code+"'")
return nil
}
// 回复失败
func (this *APIStream) replyFail(requestId int64, message string) {
_ = this.stream.Send(&pb.NSNodeStreamMessage{RequestId: requestId, IsOk: false, Message: message})
}
// 回复成功
func (this *APIStream) replyOk(requestId int64, message string) {
_ = this.stream.Send(&pb.NSNodeStreamMessage{RequestId: requestId, IsOk: true, Message: message})
}

View File

@@ -0,0 +1,455 @@
//go:build plus
package nodes
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/iplibrary"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
"github.com/TeaOSLab/EdgeDNS/internal/agents"
"github.com/TeaOSLab/EdgeDNS/internal/configs"
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
"github.com/TeaOSLab/EdgeDNS/internal/dbs"
"github.com/TeaOSLab/EdgeDNS/internal/events"
"github.com/TeaOSLab/EdgeDNS/internal/firewalls"
"github.com/TeaOSLab/EdgeDNS/internal/goman"
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
"github.com/TeaOSLab/EdgeDNS/internal/utils"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
"github.com/iwind/gosock/pkg/gosock"
"log"
"os"
"os/exec"
"os/signal"
"runtime"
"runtime/debug"
"syscall"
"time"
)
var DaemonIsOn = false
var DaemonPid = 0
var nodeTaskNotify = make(chan bool, 8)
var sharedDomainManager *DomainManager
var sharedRecordManager *RecordManager
var sharedRouteManager *RouteManager
var sharedKeyManager *KeyManager
var sharedNodeConfig = &dnsconfigs.NSNodeConfig{}
func NewDNSNode() *DNSNode {
return &DNSNode{
sock: gosock.NewTmpSock(teaconst.ProcessName),
}
}
type DNSNode struct {
sock *gosock.Sock
RPC *rpc.RPCClient
}
func (this *DNSNode) Start() {
// 设置netdns
// 这个需要放在所有网络访问的最前面
_ = os.Setenv("GODEBUG", "netdns=go")
// 判断是否在守护进程下
_, ok := os.LookupEnv("EdgeDaemon")
if ok {
remotelogs.Println("DNS_NODE", "start from daemon")
DaemonIsOn = true
DaemonPid = os.Getppid()
}
// 设置DNS解析库
err := os.Setenv("GODEBUG", "netdns=go")
if err != nil {
remotelogs.Error("DNS_NODE", "[DNS_RESOLVER]set env failed: "+err.Error())
}
// 处理异常
this.handlePanic()
// 监听signal
this.listenSignals()
// 本地Sock
err = this.listenSock()
if err != nil {
logs.Println("[DNS_NODE]" + err.Error())
return
}
// 启动IP库
remotelogs.Println("DNS_NODE", "initializing ip library ...")
err = iplibrary.InitPlus()
if err != nil {
remotelogs.Error("DNS_NODE", "initialize ip library failed: "+err.Error())
}
runtime.GC()
debug.FreeOSMemory()
// 触发事件
events.Notify(events.EventStart)
// 监控状态
go NewNodeStatusExecutor().Listen()
// 连接API
go NewAPIStream().Start()
// 启动
go this.start()
// Hold住进程
logs.Println("[DNS_NODE]started")
select {}
}
// Daemon 实现守护进程
func (this *DNSNode) Daemon() {
var isDebug = lists.ContainsString(os.Args, "debug")
for {
conn, err := this.sock.Dial()
if err != nil {
if isDebug {
log.Println("[DAEMON]starting ...")
}
// 尝试启动
err = func() error {
exe, err := os.Executable()
if err != nil {
return err
}
// 可以标记当前是从守护进程启动的
_ = os.Setenv("EdgeDaemon", "on")
_ = os.Setenv("EdgeBackground", "on")
var cmd = exec.Command(exe)
err = cmd.Start()
if err != nil {
return err
}
err = cmd.Wait()
if err != nil {
return err
}
return nil
}()
if err != nil {
if isDebug {
log.Println("[DAEMON]", err)
}
time.Sleep(1 * time.Second)
} else {
time.Sleep(5 * time.Second)
}
} else {
_ = conn.Close()
time.Sleep(5 * time.Second)
}
}
}
// Test 测试配置
func (this *DNSNode) Test() error {
// 检查是否能连接API
rpcClient, err := rpc.SharedRPC()
if err != nil {
return fmt.Errorf("test rpc failed: %w", err)
}
_, err = rpcClient.APINodeRPC.FindCurrentAPINodeVersion(rpcClient.Context(), &pb.FindCurrentAPINodeVersionRequest{})
if err != nil {
return fmt.Errorf("test rpc failed: %w", err)
}
return nil
}
// InstallSystemService 安装系统服务
func (this *DNSNode) InstallSystemService() error {
shortName := teaconst.SystemdServiceName
exe, err := os.Executable()
if err != nil {
return err
}
manager := utils.NewServiceManager(shortName, teaconst.ProductName)
err = manager.Install(exe, []string{})
if err != nil {
return err
}
return nil
}
// 监听一些信号
func (this *DNSNode) listenSignals() {
var queue = make(chan os.Signal, 8)
signal.Notify(queue, syscall.SIGTERM, syscall.SIGINT, syscall.SIGKILL, syscall.SIGQUIT)
goman.New(func() {
for range queue {
time.Sleep(100 * time.Millisecond)
utils.Exit()
return
}
})
}
// 监听本地sock
func (this *DNSNode) listenSock() error {
// 检查是否在运行
if this.sock.IsListening() {
reply, err := this.sock.Send(&gosock.Command{Code: "pid"})
if err == nil {
return errors.New("error: the process is already running, pid: " + maps.NewMap(reply.Params).GetString("pid"))
} else {
return errors.New("error: the process is already running")
}
}
// 启动监听
go func() {
this.sock.OnCommand(func(cmd *gosock.Command) {
switch cmd.Code {
case "pid":
_ = cmd.Reply(&gosock.Command{
Code: "pid",
Params: map[string]interface{}{
"pid": os.Getpid(),
},
})
case "info":
exePath, _ := os.Executable()
_ = cmd.Reply(&gosock.Command{
Code: "info",
Params: map[string]interface{}{
"pid": os.Getpid(),
"version": teaconst.Version,
"path": exePath,
},
})
case "stop":
_ = cmd.ReplyOk()
// 退出主进程
events.Notify(events.EventQuit)
time.Sleep(100 * time.Millisecond)
os.Exit(0)
case "gc":
runtime.GC()
debug.FreeOSMemory()
_ = cmd.ReplyOk()
}
})
err := this.sock.Listen()
if err != nil {
logs.Println("NODE", err.Error())
}
}()
events.On(events.EventQuit, func() {
logs.Println("[DNS_NODE]", "quit unix sock")
_ = this.sock.Close()
})
return nil
}
// 启动
func (this *DNSNode) start() {
client, err := rpc.SharedRPC()
if err != nil {
remotelogs.Error("DNS_NODE", err.Error())
return
}
this.RPC = client
tryTimes := 0
var configJSON []byte
for {
resp, err := client.NSNodeRPC.FindCurrentNSNodeConfig(client.Context(), &pb.FindCurrentNSNodeConfigRequest{})
if err != nil {
tryTimes++
if tryTimes%10 == 0 {
remotelogs.Error("NODE", "read config from API failed: "+err.Error())
}
time.Sleep(1 * time.Second)
// 不做长时间的无意义的重试
if tryTimes > 1000 {
remotelogs.Error("NODE", "load failed: unable to read config from API")
return
}
} else {
configJSON = resp.NsNodeJSON
break
}
}
if len(configJSON) == 0 {
remotelogs.Error("NODE", "can not find node config")
return
}
var config = &dnsconfigs.NSNodeConfig{}
err = json.Unmarshal(configJSON, config)
if err != nil {
remotelogs.Error("NODE", "decode config failed: "+err.Error())
return
}
err = config.Init(context.TODO())
if err != nil {
remotelogs.Error("NODE", "init config failed: "+err.Error())
return
}
sharedNodeConfig = config
configs.SharedNodeConfig = config
events.Notify(events.EventReload)
sharedNodeConfigManager.reload(config)
apiConfig, _ := configs.SharedAPIConfig()
if apiConfig != nil {
apiConfig.NumberId = config.Id
}
var db = dbs.NewDB(Tea.Root + "/data/data-" + types.String(config.Id) + "-" + config.NodeId + "-v0.1.0.db")
err = db.Init()
if err != nil {
remotelogs.Error("NODE", "init database failed: "+err.Error())
return
}
go sharedNodeConfigManager.Start()
sharedDomainManager = NewDomainManager(db, config.ClusterId)
go sharedDomainManager.Start()
sharedRecordManager = NewRecordManager(db)
go sharedRecordManager.Start()
sharedRouteManager = NewRouteManager(db)
go sharedRouteManager.Start()
sharedKeyManager = NewKeyManager(db)
go sharedKeyManager.Start()
agents.SharedManager = agents.NewManager(db)
go agents.SharedManager.Start()
// 发送通知这里发送通知需要在DomainManager、RecordeManager等加载完成之后
time.Sleep(1 * time.Second)
events.Notify(events.EventLoaded)
// 启动循环
go this.loop()
}
// 更新配置Loop
func (this *DNSNode) loop() {
var ticker = time.NewTicker(60 * time.Second)
for {
select {
case <-ticker.C:
case <-nodeTaskNotify:
}
err := this.processTasks()
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("DNS_NODE", "process tasks: "+err.Error())
} else {
remotelogs.Error("DNS_NODE", "process tasks: "+err.Error())
}
}
}
}
// 处理任务
func (this *DNSNode) processTasks() error {
rpcClient, err := rpc.SharedRPC()
if err != nil {
return err
}
// 所有的任务
tasksResp, err := rpcClient.NodeTaskRPC.FindNodeTasks(rpcClient.Context(), &pb.FindNodeTasksRequest{})
if err != nil {
return err
}
for _, task := range tasksResp.NodeTasks {
switch task.Type {
case "nsConfigChanged":
sharedNodeConfigManager.NotifyChange()
case "nsDomainChanged":
sharedDomainManager.NotifyUpdate()
case "nsRecordChanged":
sharedRecordManager.NotifyUpdate()
case "nsRouteChanged":
sharedRouteManager.NotifyUpdate()
case "nsKeyChanged":
sharedKeyManager.NotifyUpdate()
case "nsDDoSProtectionChanged":
err := this.updateDDoS(rpcClient)
if err != nil {
remotelogs.Error("DNS_NODE", "apply DDoS config failed: "+err.Error())
}
}
_, err = rpcClient.NodeTaskRPC.ReportNodeTaskDone(rpcClient.Context(), &pb.ReportNodeTaskDoneRequest{
NodeTaskId: task.Id,
IsOk: true,
Error: "",
})
if err != nil {
return err
}
}
return nil
}
func (this *DNSNode) updateDDoS(rpcClient *rpc.RPCClient) error {
resp, err := rpcClient.NSNodeRPC.FindNSNodeDDoSProtection(rpcClient.Context(), &pb.FindNSNodeDDoSProtectionRequest{})
if err != nil {
return err
}
if len(resp.DdosProtectionJSON) == 0 {
if sharedNodeConfig != nil {
sharedNodeConfig.DDoSProtection = nil
}
} else {
var ddosProtectionConfig = &ddosconfigs.ProtectionConfig{}
err = json.Unmarshal(resp.DdosProtectionJSON, ddosProtectionConfig)
if err != nil {
return fmt.Errorf("decode DDoS protection config failed: %w", err)
}
if sharedNodeConfig != nil {
sharedNodeConfig.DDoSProtection = ddosProtectionConfig
}
err = firewalls.SharedDDoSProtectionManager.Apply(ddosProtectionConfig)
if err != nil {
// 不阻塞
remotelogs.Error("NODE", "apply DDoS protection failed: "+err.Error())
}
}
return nil
}

View File

@@ -0,0 +1,306 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus
package nodes_test
import (
"crypto/tls"
"github.com/miekg/dns"
"testing"
"time"
)
func TestDNS_Query_A_UDP(t *testing.T) {
type query struct {
Domain string
Type uint16
}
c := new(dns.Client)
for _, query := range []query{
{"hello.world.teaos.cn", dns.TypeA},
{"cdn.teaos.cn", dns.TypeA},
{"hello.teaos.cn", dns.TypeCNAME},
{"hello.teaos.cn", dns.TypeA},
{"edgecdn.teaos.cn", dns.TypeA},
} {
var m = new(dns.Msg)
m.SetQuestion(query.Domain+".", query.Type)
r, _, err := c.Exchange(m, "127.0.0.1:53")
if err != nil {
t.Fatalf("failed to exchange: %v", err)
}
for _, a := range r.Answer {
t.Log(a)
}
}
}
func TestDNS_Query_A_Many(t *testing.T) {
type queryDef struct {
Domain string
Type uint16
}
var c = new(dns.Client)
for i := 0; i < 10000; i++ {
for _, query := range []queryDef{
{"hello.goedge.cn", dns.TypeA},
} {
var m = new(dns.Msg)
m.SetQuestion(query.Domain+".", query.Type)
r, _, err := c.Exchange(m, "127.0.0.1:53")
if err != nil {
t.Fatalf("failed to exchange: %v", err)
}
_ = r
}
}
}
func TestDNS_Query_CNAME(t *testing.T) {
type query struct {
Domain string
Type uint16
}
c := new(dns.Client)
for _, query := range []query{
{"hello.teaos.cn", dns.TypeA},
} {
m := new(dns.Msg)
m.SetQuestion(query.Domain+".", query.Type)
r, _, err := c.Exchange(m, "127.0.0.1:53")
if err != nil {
t.Fatalf("failed to exchange: %v", err)
}
for _, a := range r.Answer {
t.Log(a)
}
}
}
func TestDNS_Query_A_TCP(t *testing.T) {
type query struct {
Domain string
Type uint16
}
c := new(dns.Client)
for _, query := range []query{
{"hello.world.teaos.cn", dns.TypeA},
{"cdn.teaos.cn", dns.TypeA},
{"hello.teaos.cn", dns.TypeCNAME},
{"goedge.cn", dns.TypeA},
{"www.goedge.cn", dns.TypeA},
} {
m := new(dns.Msg)
m.SetQuestion(query.Domain+".", query.Type)
//r, _, err := c.Exchange(m, "127.0.0.1:54")
conn, err := dns.Dial("tcp", "127.0.0.1:53")
if err != nil {
t.Fatal(err)
}
r, _, err := c.ExchangeWithConn(m, conn)
if err != nil {
t.Fatalf("failed to exchange: %v", err)
}
for _, a := range r.Answer {
t.Log(a)
}
}
}
func TestDNS_Query_A_TLS(t *testing.T) {
type query struct {
Domain string
Type uint16
}
c := new(dns.Client)
for _, query := range []query{
{"hello.world.teaos.cn", dns.TypeA},
{"cdn.teaos.cn", dns.TypeA},
{"hello.teaos.cn", dns.TypeCNAME},
} {
m := new(dns.Msg)
m.SetQuestion(query.Domain+".", query.Type)
//r, _, err := c.Exchange(m, "127.0.0.1:54")
conn, err := dns.DialWithTLS("tcp", "127.0.0.1:853", &tls.Config{
InsecureSkipVerify: true,
})
if err != nil {
t.Fatal(err)
}
r, _, err := c.ExchangeWithConn(m, conn)
if err != nil {
t.Fatalf("failed to exchange: %v", err)
}
for _, a := range r.Answer {
t.Log(a)
}
}
}
func TestDNS_Query_Internet(t *testing.T) {
type query struct {
Domain string
Type uint16
}
c := new(dns.Client)
for _, query := range []query{
{"1.goedge.cn", dns.TypeA},
} {
m := new(dns.Msg)
//m.RecursionDesired = true
m.SetQuestion(query.Domain+".", query.Type)
conn, err := dns.Dial("udp", "ns1.teaos.cn:53")
if err != nil {
t.Fatal(err)
}
r, _, err := c.ExchangeWithConn(m, conn)
if err != nil {
t.Fatalf("failed to exchange: %v", err)
}
for _, a := range r.Answer {
t.Log(a)
}
}
}
func TestDNS_Query_TSIG(t *testing.T) {
type query struct {
Domain string
Type uint16
}
c := new(dns.Client)
c.TsigSecret = map[string]string{"teaos.": "NzhhZDExMzM5NWMwN2Q5OWM5YTFhMzgxZWNkZGMwMDA2ODUzODdiYTM2ODA5N2I2YjYwZWRlNmNlNjlhMzdmM2JmNjcxZmQ4NzVjMjI1Y2QwOTQ2Njk5OWY0MzRkMTJkNTczNjFlZDgwYmQxZWZjZDM4ZjAxNDNmM2Y2NTU1YjE="}
for _, query := range []query{
{"hello.cdn.teaos.cn", dns.TypeA},
{"cdn.teaos.cn", dns.TypeA},
{"hello.teaos.cn", dns.TypeCNAME},
{"hello.teaos.cn", dns.TypeA},
{"edgecdn.teaos.cn", dns.TypeA},
} {
m := new(dns.Msg)
m.SetQuestion(query.Domain+".", query.Type)
m.SetTsig("teaos.", dns.HmacSHA512, 300, time.Now().Unix())
r, _, err := c.Exchange(m, "127.0.0.1:53")
if err != nil {
t.Fatalf("failed to exchange: %v", err)
}
for _, a := range r.Answer {
t.Log(a)
}
}
}
func TestDNS_Query_A_Route(t *testing.T) {
type query struct {
Domain string
Type uint16
}
c := new(dns.Client)
for _, query := range []query{
{"route.com", dns.TypeA},
} {
m := new(dns.Msg)
m.SetQuestion(query.Domain+".", query.Type)
r, _, err := c.Exchange(m, "127.0.0.1:53")
if err != nil {
t.Fatalf("failed to exchange: %v", err)
}
for _, a := range r.Answer {
t.Log(a)
}
}
}
func TestDNS_Query_Recursion(t *testing.T) {
type query struct {
Domain string
Type uint16
}
c := new(dns.Client)
for _, query := range []query{
{"example.org", dns.TypeA},
} {
m := new(dns.Msg)
m.SetQuestion(query.Domain+".", query.Type)
r, _, err := c.Exchange(m, "127.0.0.1:53")
if err != nil {
t.Fatalf("failed to exchange: %v", err)
}
for _, a := range r.Answer {
t.Log(a)
}
}
}
func TestDNS_Flood(t *testing.T) {
type query struct {
Domain string
Type uint16
}
var c = new(dns.Client)
for i := 0; i < 1_000_000; i++ {
for _, query := range []query{
{"hello.world.teaos.cn", dns.TypeA},
{"cdn.teaos.cn", dns.TypeA},
{"hello.teaos.cn", dns.TypeCNAME},
{"hello.teaos.cn", dns.TypeA},
{"edgecdn.teaos.cn", dns.TypeA},
} {
var m = new(dns.Msg)
m.SetQuestion(query.Domain+".", query.Type)
r, _, err := c.Exchange(m, "192.168.2.60:58")
if err != nil {
t.Fatalf("failed to exchange: %v", err)
}
_ = r
}
}
}
func BenchmarkDNSNode(b *testing.B) {
var c = new(dns.Client)
conn, err := c.Dial("192.168.2.60:58")
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
type query struct {
Domain string
Type uint16
}
for i := 0; i < b.N; i++ {
for _, query := range []query{
{"cdn.teaos.cn", dns.TypeA},
} {
var m = new(dns.Msg)
m.SetQuestion(query.Domain+".", query.Type)
_, _, err := c.ExchangeWithConn(m, conn)
if err != nil {
b.Fatal(err)
}
}
}
}

View File

@@ -0,0 +1,179 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes
import (
"encoding/json"
"errors"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/types"
"github.com/miekg/dns"
"net"
"net/http"
"reflect"
)
type HTTPWriter struct {
rawConn net.Conn
rawWriter http.ResponseWriter
contentType string
}
func NewHTTPWriter(rawWriter http.ResponseWriter, rawConn net.Conn, contentType string) *HTTPWriter {
return &HTTPWriter{
rawWriter: rawWriter,
rawConn: rawConn,
contentType: contentType,
}
}
func (this *HTTPWriter) LocalAddr() net.Addr {
return this.rawConn.LocalAddr()
}
func (this *HTTPWriter) RemoteAddr() net.Addr {
return this.rawConn.RemoteAddr()
}
func (this *HTTPWriter) WriteMsg(msg *dns.Msg) error {
if msg == nil {
return errors.New("'msg' should not be nil")
}
msgData, err := this.encodeMsg(msg)
if err != nil {
return err
}
this.rawWriter.Header().Set("Content-Length", types.String(len(msgData)))
this.rawWriter.Header().Set("Content-Type", this.contentType)
// cache-control
if len(msg.Answer) > 0 {
var minTtl uint32
for _, answer := range msg.Answer {
var header = answer.Header()
if header != nil && header.Ttl > 0 && (minTtl == 0 || header.Ttl < minTtl) {
minTtl = header.Ttl
}
}
if minTtl > 0 {
this.rawWriter.Header().Set("Cache-Control", "max-age="+types.String(minTtl))
}
}
this.rawWriter.WriteHeader(http.StatusOK)
_, err = this.rawWriter.Write(msgData)
return err
}
func (this *HTTPWriter) Write(p []byte) (int, error) {
this.rawWriter.Header().Set("Content-Length", types.String(len(p)))
this.rawWriter.WriteHeader(http.StatusOK)
return this.rawWriter.Write(p)
}
func (this *HTTPWriter) Close() error {
return nil
}
func (this *HTTPWriter) TsigStatus() error {
return nil
}
func (this *HTTPWriter) TsigTimersOnly(timersOnly bool) {
}
func (this *HTTPWriter) Hijack() {
hijacker, ok := this.rawWriter.(http.Hijacker)
if ok {
_, _, _ = hijacker.Hijack()
}
}
func (this *HTTPWriter) encodeMsg(msg *dns.Msg) ([]byte, error) {
if this.contentType == "application/x-javascript" || this.contentType == "application/json" {
var result = map[string]any{
"Status": 0,
"TC": msg.Truncated,
"RD": msg.RecursionDesired,
"RA": msg.RecursionAvailable,
"AD": msg.AuthenticatedData,
"CD": msg.CheckingDisabled,
}
// questions
var questionMaps = []map[string]any{}
for _, question := range msg.Question {
questionMaps = append(questionMaps, map[string]any{
"name": question.Name,
"type": question.Qtype,
})
}
result["Question"] = questionMaps
// answers
var answerMaps = []map[string]any{}
for _, answer := range msg.Answer {
var answerMap = map[string]any{
"name": answer.Header().Name,
"type": answer.Header().Rrtype,
"TTL": answer.Header().Ttl,
}
switch x := answer.(type) {
case *dns.A:
answerMap["data"] = x.A.String()
case *dns.AAAA:
answerMap["data"] = x.AAAA.String()
case *dns.CNAME:
answerMap["data"] = x.Target
case *dns.TXT:
answerMap["data"] = x.Txt
case *dns.NS:
answerMap["data"] = x.Ns
case *dns.MX:
answerMap["data"] = x.Mx
answerMap["preference"] = x.Preference
default:
var answerValue = reflect.ValueOf(answer).Elem()
var answerType = answerValue.Type()
var countFields = answerType.NumField()
for i := 0; i < countFields; i++ {
var fieldName = answerType.Field(i).Name
var fieldValue = answerValue.FieldByName(fieldName)
if !fieldValue.IsValid() {
continue
}
var fieldInterface = fieldValue.Interface()
if fieldInterface == nil {
continue
}
_, ok := fieldInterface.(dns.RR_Header)
if ok {
continue
}
if countFields == 2 {
answerMap["data"] = fieldValue.Interface()
} else {
answerMap[fieldName] = fieldValue.Interface()
}
}
}
answerMaps = append(answerMaps, answerMap)
}
result["Answer"] = answerMaps
if Tea.IsTesting() {
return json.MarshalIndent(result, "", " ")
} else {
return json.Marshal(result)
}
} else {
return msg.Pack()
}
}

View File

@@ -0,0 +1,306 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
"github.com/TeaOSLab/EdgeDNS/internal/events"
"github.com/TeaOSLab/EdgeDNS/internal/firewalls"
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
"github.com/TeaOSLab/EdgeDNS/internal/utils"
"github.com/iwind/TeaGo/types"
"runtime"
"sort"
"strings"
"sync"
)
var sharedListenManager *ListenManager = nil
func init() {
if !teaconst.IsMain {
return
}
sharedListenManager = NewListenManager()
events.On(events.EventReload, func() {
sharedListenManager.Update(sharedNodeConfig)
})
events.On(events.EventQuit, func() {
_ = sharedListenManager.ShutdownAll()
})
}
// ListenManager 端口监听管理器
type ListenManager struct {
locker sync.Mutex
serverMap map[string]*Server // addr => *Server
firewalld *firewalls.Firewalld
lastPortStrings string
lastTCPPortRanges [][2]int
lastUDPPortRanges [][2]int
}
// NewListenManager 获取新对象
func NewListenManager() *ListenManager {
return &ListenManager{
serverMap: map[string]*Server{},
firewalld: firewalls.NewFirewalld(),
}
}
// Update 修改配置
func (this *ListenManager) Update(config *dnsconfigs.NSNodeConfig) {
// 构造服务配置
var serverConfigs = []*ServerConfig{}
var serverAddrs = []string{}
// 如果没有配置,则配置一些默认的端口
if config.TCP == nil && config.TLS == nil && config.UDP == nil {
config.TCP = &serverconfigs.TCPProtocolConfig{}
config.TCP.IsOn = true
config.TCP.Listen = []*serverconfigs.NetworkAddressConfig{
{
Protocol: serverconfigs.ProtocolTCP,
MinPort: 53,
MaxPort: 53,
},
}
config.UDP = &serverconfigs.UDPProtocolConfig{}
config.UDP.IsOn = true
config.UDP.Listen = []*serverconfigs.NetworkAddressConfig{
{
Protocol: serverconfigs.ProtocolUDP,
MinPort: 53,
MaxPort: 53,
},
}
}
// 读取配置
if config.TCP != nil && config.TCP.IsOn {
for _, listen := range config.TCP.Listen {
for port := listen.MinPort; port <= listen.MaxPort; port++ {
serverConfigs = append(serverConfigs, &ServerConfig{
Protocol: listen.Protocol,
Host: listen.Host,
Port: port,
SSLPolicy: nil,
})
}
}
}
if config.TLS != nil && config.TLS.IsOn {
for _, listen := range config.TLS.Listen {
if config.TLS.SSLPolicy == nil {
continue
}
for port := listen.MinPort; port <= listen.MaxPort; port++ {
serverConfigs = append(serverConfigs, &ServerConfig{
Protocol: listen.Protocol,
Host: listen.Host,
Port: port,
SSLPolicy: config.TLS.SSLPolicy,
})
}
}
}
if config.DoH != nil && config.DoH.IsOn {
for _, listen := range config.DoH.Listen {
if config.DoH.SSLPolicy == nil {
continue
}
for port := listen.MinPort; port <= listen.MaxPort; port++ {
serverConfigs = append(serverConfigs, &ServerConfig{
Protocol: listen.Protocol,
Host: listen.Host,
Port: port,
SSLPolicy: config.DoH.SSLPolicy,
})
}
}
}
if config.UDP != nil && config.UDP.IsOn {
for _, listen := range config.UDP.Listen {
for port := listen.MinPort; port <= listen.MaxPort; port++ {
serverConfigs = append(serverConfigs, &ServerConfig{
Protocol: listen.Protocol,
Host: listen.Host,
Port: port,
SSLPolicy: nil,
})
}
}
}
// 启动新的
var addrMap = map[string]bool{} // addr => bool
for _, serverConfig := range serverConfigs {
var fullAddr = serverConfig.FullAddr()
serverAddrs = append(serverAddrs, fullAddr)
addrMap[fullAddr] = true
this.locker.Lock()
server, ok := this.serverMap[fullAddr]
this.locker.Unlock()
if !ok {
// 启动新的
var err error
server, err = NewServer(serverConfig)
if err != nil {
remotelogs.Error("LISTEN_MANAGER", "create listener '"+fullAddr+"' failed: "+err.Error())
continue
}
this.locker.Lock()
this.serverMap[fullAddr] = server
this.locker.Unlock()
go func() {
remotelogs.Println("LISTEN_MANAGER", "listen '"+fullAddr+"'")
err = server.ListenAndServe()
if err != nil {
this.locker.Lock()
delete(this.serverMap, fullAddr)
this.locker.Unlock()
remotelogs.Error("LISTEN_MANAGER", "listen '"+fullAddr+"' failed: "+err.Error())
}
}()
} else {
// 更新配置
server.Reload(serverConfig)
}
}
// 停止老的
this.locker.Lock()
for fullAddr, server := range this.serverMap {
_, ok := addrMap[fullAddr]
if !ok {
delete(this.serverMap, fullAddr)
remotelogs.Println("LISTEN_MANAGER", "shutdown "+fullAddr)
err := server.Shutdown()
if err != nil {
remotelogs.Error("LISTEN_MANAGER", "shutdown listener '"+fullAddr+"' failed: "+err.Error())
}
}
}
this.locker.Unlock()
// 添加端口到firewalld
go func() {
this.addToFirewalld(serverAddrs)
}()
}
// ShutdownAll 关闭所有的监听端口
func (this *ListenManager) ShutdownAll() error {
this.locker.Lock()
defer this.locker.Unlock()
var lastErr error
for _, server := range this.serverMap {
err := server.Shutdown()
if err != nil {
lastErr = err
}
}
return lastErr
}
func (this *ListenManager) addToFirewalld(serverAddrs []string) {
if runtime.GOOS != "linux" {
return
}
if this.firewalld == nil || !this.firewalld.IsReady() {
return
}
// 组合端口号
var portStrings = []string{}
var udpPorts = []int{}
var tcpPorts = []int{}
for _, addr := range serverAddrs {
var protocol = "tcp"
if strings.HasPrefix(addr, "udp") {
protocol = "udp"
}
var lastIndex = strings.LastIndex(addr, ":")
if lastIndex > 0 {
var portString = addr[lastIndex+1:]
portStrings = append(portStrings, portString+"/"+protocol)
switch protocol {
case "tcp":
tcpPorts = append(tcpPorts, types.Int(portString))
case "udp":
udpPorts = append(udpPorts, types.Int(portString))
}
}
}
if len(portStrings) == 0 {
return
}
// 检查是否有变化
sort.Strings(portStrings)
var newPortStrings = strings.Join(portStrings, ",")
if newPortStrings == this.lastPortStrings {
return
}
this.lastPortStrings = newPortStrings
remotelogs.Println("FIREWALLD", "opening ports automatically ...")
defer func() {
remotelogs.Println("FIREWALLD", "open ports successfully")
}()
// 合并端口
var tcpPortRanges = utils.MergePorts(tcpPorts)
var udpPortRanges = utils.MergePorts(udpPorts)
defer func() {
this.lastTCPPortRanges = tcpPortRanges
this.lastUDPPortRanges = udpPortRanges
}()
// 删除老的不存在的端口
var tcpPortRangesMap = map[string]bool{}
var udpPortRangesMap = map[string]bool{}
for _, portRange := range tcpPortRanges {
tcpPortRangesMap[this.firewalld.PortRangeString(portRange, "tcp")] = true
}
for _, portRange := range udpPortRanges {
udpPortRangesMap[this.firewalld.PortRangeString(portRange, "udp")] = true
}
for _, portRange := range this.lastTCPPortRanges {
var s = this.firewalld.PortRangeString(portRange, "tcp")
_, ok := tcpPortRangesMap[s]
if ok {
continue
}
remotelogs.Println("FIREWALLD", "remove port '"+s+"'")
_ = this.firewalld.RemovePortRangePermanently(portRange, "tcp")
}
for _, portRange := range this.lastUDPPortRanges {
var s = this.firewalld.PortRangeString(portRange, "udp")
_, ok := udpPortRangesMap[s]
if ok {
continue
}
remotelogs.Println("FIREWALLD", "remove port '"+s+"'")
_ = this.firewalld.RemovePortRangePermanently(portRange, "udp")
}
// 添加新的
_ = this.firewalld.AllowPortRangesPermanently(tcpPortRanges, "tcp")
_ = this.firewalld.AllowPortRangesPermanently(udpPortRanges, "udp")
}

View File

@@ -0,0 +1,319 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"encoding/json"
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeDNS/internal/dbs"
"github.com/TeaOSLab/EdgeDNS/internal/models"
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
"github.com/iwind/TeaGo/types"
"strings"
"sync"
"time"
)
// DomainManager 域名管理器
type DomainManager struct {
domainMap map[int64]*models.NSDomain // domainId => domain
namesMap map[string]map[int64]*models.NSDomain // domain name => { domainId => domain }
clusterId int64
db *dbs.DB
version int64
locker *sync.RWMutex
notifier chan bool
}
// NewDomainManager 获取域名管理器对象
func NewDomainManager(db *dbs.DB, clusterId int64) *DomainManager {
return &DomainManager{
db: db,
domainMap: map[int64]*models.NSDomain{},
namesMap: map[string]map[int64]*models.NSDomain{},
clusterId: clusterId,
notifier: make(chan bool, 8),
locker: &sync.RWMutex{},
}
}
// Start 启动自动任务
func (this *DomainManager) Start() {
remotelogs.Println("DOMAIN_MANAGER", "starting ...")
// 从本地数据库中加载数据
err := this.Load()
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("DOMAIN_MANAGER", "load failed: "+err.Error())
} else {
remotelogs.Error("DOMAIN_MANAGER", "load failed: "+err.Error())
}
}
// 初始化运行
err = this.LoopAll()
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("DOMAIN_MANAGER", "loop failed: "+err.Error())
} else {
remotelogs.Error("DOMAIN_MANAGER", "loop failed: "+err.Error())
}
}
// 更新
var ticker = time.NewTicker(20 * time.Second)
for {
select {
case <-ticker.C:
case <-this.notifier:
}
err = this.LoopAll()
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("DOMAIN_MANAGER", "loop failed: "+err.Error())
} else {
remotelogs.Error("DOMAIN_MANAGER", "loop failed: "+err.Error())
}
}
}
}
func (this *DomainManager) LoopAll() error {
for {
hasNext, err := this.Loop()
if err != nil {
return err
}
if !hasNext {
break
}
}
return nil
}
// Load 从数据库中加载数据
func (this *DomainManager) Load() error {
var offset = 0
var size = 10000
for {
domains, err := this.db.ListDomains(this.clusterId, offset, size)
if err != nil {
return err
}
if len(domains) == 0 {
break
}
this.locker.Lock()
for _, domain := range domains {
this.domainMap[domain.Id] = domain
nameMap, ok := this.namesMap[domain.Name]
if ok {
nameMap[domain.Id] = domain
} else {
this.namesMap[domain.Name] = map[int64]*models.NSDomain{
domain.Id: domain,
}
}
if domain.Version > this.version {
this.version = domain.Version
}
}
this.locker.Unlock()
offset += size
}
if this.version > 0 {
this.version++
}
return nil
}
// Loop 单次循环任务
func (this *DomainManager) Loop() (hasNext bool, err error) {
client, err := rpc.SharedRPC()
if err != nil {
return false, err
}
resp, err := client.NSDomainRPC.ListNSDomainsAfterVersion(client.Context(), &pb.ListNSDomainsAfterVersionRequest{
Version: this.version,
Size: 20000,
})
if err != nil {
return false, err
}
var domains = resp.NsDomains
if len(domains) == 0 {
return false, nil
}
for _, domain := range domains {
this.processDomain(domain)
if domain.Version > this.version {
this.version = domain.Version
}
}
this.version++
return true, nil
}
// FindDomain 根据名称查找域名
func (this *DomainManager) FindDomain(name string) (domain *models.NSDomain, ok bool) {
this.locker.RLock()
defer this.locker.RUnlock()
nameMap, ok := this.namesMap[name]
if !ok {
return nil, false
}
for _, domain2 := range nameMap {
return domain2, true
}
return
}
// FindDomainWithId 根据域名ID查询域名
func (this *DomainManager) FindDomainWithId(domainId int64) (domain *models.NSDomain) {
this.locker.RLock()
defer this.locker.RUnlock()
return this.domainMap[domainId]
}
// NotifyUpdate 通知更新
func (this *DomainManager) NotifyUpdate() {
select {
case this.notifier <- true:
default:
}
}
// SplitDomain 分解域名
func (this *DomainManager) SplitDomain(fullDomainName string) (rootDomain *models.NSDomain, recordName string) {
if len(fullDomainName) == 0 {
return
}
fullDomainName = strings.TrimSuffix(fullDomainName, ".") // 去除尾部的点(.
fullDomainName = strings.ToLower(fullDomainName) // 转换为小写
var domainName = fullDomainName
var domain, ok = this.FindDomain(domainName)
if !ok {
for {
var index = strings.Index(domainName, ".")
if index < 0 {
break
}
domainName = domainName[index+1:]
domain, ok = this.FindDomain(domainName)
if ok {
recordName = fullDomainName[:len(fullDomainName)-len(domainName)-1]
break
}
}
}
return domain, recordName
}
// 处理域名
func (this *DomainManager) processDomain(domain *pb.NSDomain) {
if !domain.IsOn || domain.IsDeleted || domain.Status != dnsconfigs.NSDomainStatusVerified {
this.locker.Lock()
delete(this.domainMap, domain.Id)
nameMap, ok := this.namesMap[domain.Name]
if ok {
delete(nameMap, domain.Id)
if len(nameMap) == 0 {
delete(this.namesMap, domain.Name)
}
}
this.locker.Unlock()
// 从数据库中删除
if this.db != nil {
err := this.db.DeleteDomain(domain.Id)
if err != nil {
remotelogs.Error("DOMAIN_MANAGER", "delete domain from db failed: "+err.Error())
}
}
return
}
// 存入数据库
if this.db != nil {
exists, err := this.db.ExistsDomain(domain.Id)
if err != nil {
remotelogs.Error("DOMAIN_MANAGER", "query failed: "+err.Error())
} else {
if exists {
err = this.db.UpdateDomain(domain.Id, domain.NsCluster.Id, domain.UserId, domain.Name, domain.TsigJSON, domain.Version)
if err != nil {
remotelogs.Error("DOMAIN_MANAGER", "update failed: "+err.Error())
}
} else {
err = this.db.InsertDomain(domain.Id, domain.NsCluster.Id, domain.UserId, domain.Name, domain.TsigJSON, domain.Version)
if err != nil {
remotelogs.Error("DOMAIN_MANAGER", "insert failed: "+err.Error())
}
}
}
}
// 同集群的才需要加载
if this.clusterId == domain.NsCluster.Id {
this.locker.Lock()
var tsigConfig = &dnsconfigs.NSTSIGConfig{}
if len(domain.TsigJSON) > 0 {
err := json.Unmarshal(domain.TsigJSON, tsigConfig)
if err != nil {
remotelogs.Error("DOMAIN_MANAGER", "decode TSIG json failed: "+err.Error()+", domain: "+domain.Name+", domainId: "+types.String(domain.Id)+", JSON: "+string(domain.TsigJSON))
}
}
var nsDomain = &models.NSDomain{
Id: domain.Id,
ClusterId: domain.NsCluster.Id,
UserId: domain.UserId,
Name: domain.Name,
TSIG: tsigConfig,
Version: domain.Version,
}
this.domainMap[domain.Id] = nsDomain
nameMap, ok := this.namesMap[domain.Name]
if ok {
nameMap[nsDomain.Id] = nsDomain
} else {
this.namesMap[domain.Name] = map[int64]*models.NSDomain{
nsDomain.Id: nsDomain,
}
}
this.locker.Unlock()
} else {
// 不同集群的删除域名
this.locker.Lock()
delete(this.domainMap, domain.Id)
nameMap, ok := this.namesMap[domain.Name]
if ok {
delete(nameMap, domain.Id)
if len(nameMap) == 0 {
delete(this.namesMap, domain.Name)
}
}
this.locker.Unlock()
}
}

View File

@@ -0,0 +1,63 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"github.com/TeaOSLab/EdgeDNS/internal/dbs"
"github.com/iwind/TeaGo/Tea"
_ "github.com/iwind/TeaGo/bootstrap"
"github.com/iwind/TeaGo/logs"
"testing"
)
func TestDomainManager_Loop(t *testing.T) {
var db = dbs.NewDB(Tea.Root + "/data/data.db")
err := db.Init()
if err != nil {
t.Fatal(err)
}
var manager = NewDomainManager(db, 1)
for i := 0; i < 10; i++ {
_, err := manager.Loop()
if err != nil {
t.Fatal(err)
}
}
logs.PrintAsJSON(manager.domainMap, t)
}
func TestDomainManager_Load(t *testing.T) {
var db = dbs.NewDB(Tea.Root + "/data/data.db")
err := db.Init()
if err != nil {
t.Fatal(err)
}
manager := NewDomainManager(db, 2)
err = manager.Load()
if err != nil {
t.Fatal(err)
}
logs.PrintAsJSON(manager.domainMap, t)
t.Log("version:", manager.version)
}
func TestDomainManager_FindDomain(t *testing.T) {
var db = dbs.NewDB(Tea.Root + "/data/data.db")
err := db.Init()
if err != nil {
t.Fatal(err)
}
var manager = NewDomainManager(db, 2)
err = manager.Load()
if err != nil {
t.Fatal(err)
}
for _, name := range []string{"hello.com", "teaos.cn"} {
domain, ok := manager.FindDomain(name)
t.Log(name, ok, domain)
}
}

View File

@@ -0,0 +1,284 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeDNS/internal/dbs"
"github.com/TeaOSLab/EdgeDNS/internal/models"
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
"sync"
"time"
)
// KeyManager 密钥管理器
type KeyManager struct {
domainKeyMap map[int64]*models.NSKeys // domainId => *NSKeys
zoneKeyMap map[int64]*models.NSKeys // zoneId => *NSKeys
db *dbs.DB
locker sync.RWMutex
version int64
notifier chan bool
}
// NewKeyManager 获取密钥管理器
func NewKeyManager(db *dbs.DB) *KeyManager {
return &KeyManager{
domainKeyMap: map[int64]*models.NSKeys{},
zoneKeyMap: map[int64]*models.NSKeys{},
db: db,
notifier: make(chan bool, 8),
}
}
// Start 启动自动任务
func (this *KeyManager) Start() {
remotelogs.Println("KEY_MANAGER", "starting ...")
// 从本地数据库中加载数据
err := this.Load()
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("KEY_MANAGER", "load failed: "+err.Error())
} else {
remotelogs.Error("KEY_MANAGER", "load failed: "+err.Error())
}
}
// 初始化运行
err = this.LoopAll()
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("KEY_MANAGER", "loop failed: "+err.Error())
} else {
remotelogs.Error("KEY_MANAGER", "loop failed: "+err.Error())
}
}
// 更新
var ticker = time.NewTicker(1 * time.Minute)
for {
select {
case <-ticker.C:
case <-this.notifier:
}
err := this.LoopAll()
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("KEY_MANAGER", "loop failed: "+err.Error())
} else {
remotelogs.Error("KEY_MANAGER", "loop failed: "+err.Error())
}
}
}
}
// Load 从数据库中加载数据
func (this *KeyManager) Load() error {
var offset = 0
var size = 10000
for {
keys, err := this.db.ListKeys(offset, size)
if err != nil {
return err
}
if len(keys) == 0 {
break
}
this.locker.Lock()
for _, key := range keys {
if key.ZoneId > 0 {
keyList, ok := this.zoneKeyMap[key.ZoneId]
if ok {
keyList.Add(key)
} else {
keyList = models.NewNSKeys()
keyList.Add(key)
this.zoneKeyMap[key.ZoneId] = keyList
}
} else if key.DomainId > 0 {
keyList, ok := this.domainKeyMap[key.DomainId]
if ok {
keyList.Add(key)
} else {
keyList = models.NewNSKeys()
keyList.Add(key)
this.domainKeyMap[key.DomainId] = keyList
}
}
if key.Version > this.version {
this.version = key.Version
}
}
this.locker.Unlock()
offset += size
}
if this.version > 0 {
this.version++
}
return nil
}
func (this *KeyManager) LoopAll() error {
for {
hasNext, err := this.Loop()
if err != nil {
return err
}
if !hasNext {
break
}
}
return nil
}
// Loop 单次循环任务
func (this *KeyManager) Loop() (hasNext bool, err error) {
client, err := rpc.SharedRPC()
if err != nil {
return false, err
}
resp, err := client.NSKeyRPC.ListNSKeysAfterVersion(client.Context(), &pb.ListNSKeysAfterVersionRequest{
Version: this.version,
Size: 20000,
})
if err != nil {
return false, err
}
var keys = resp.NsKeys
if len(keys) == 0 {
return false, nil
}
for _, key := range keys {
this.processKey(key)
if key.Version > this.version {
this.version = key.Version
}
}
this.version++
return true, nil
}
func (this *KeyManager) FindKeysWithDomain(domainId int64) []*models.NSKey {
this.locker.RLock()
defer this.locker.RUnlock()
keys, ok := this.domainKeyMap[domainId]
if ok {
return keys.All()
}
return nil
}
// NotifyUpdate 通知更新
func (this *KeyManager) NotifyUpdate() {
select {
case this.notifier <- true:
default:
}
}
// 处理Key
func (this *KeyManager) processKey(key *pb.NSKey) {
if key.NsDomain == nil && key.NsZone == nil {
return
}
if !key.IsOn || key.IsDeleted {
this.locker.Lock()
if key.NsDomain != nil {
list, ok := this.domainKeyMap[key.NsDomain.Id]
if ok {
list.Remove(key.Id)
}
}
if key.NsZone != nil {
list, ok := this.zoneKeyMap[key.NsZone.Id]
if ok {
list.Remove(key.Id)
}
}
this.locker.Unlock()
// 从数据库中删除
if this.db != nil {
err := this.db.DeleteKey(key.Id)
if err != nil {
remotelogs.Error("KEY_MANAGER", "delete key from db failed: "+err.Error())
}
}
return
}
var domainId int64
var zoneId int64
if key.NsDomain != nil {
domainId = key.NsDomain.Id
}
if key.NsZone != nil {
zoneId = key.NsZone.Id
}
// 存入数据库
if this.db != nil {
exists, err := this.db.ExistsKey(key.Id)
if err != nil {
remotelogs.Error("KEY_MANAGER", "query failed: "+err.Error())
} else {
if exists {
err = this.db.UpdateKey(key.Id, domainId, zoneId, key.Algo, key.Secret, key.SecretType, key.Version)
if err != nil {
remotelogs.Error("KEY_MANAGER", "update failed: "+err.Error())
}
} else {
err = this.db.InsertKey(key.Id, domainId, zoneId, key.Algo, key.Secret, key.SecretType, key.Version)
if err != nil {
remotelogs.Error("KEY_MANAGER", "insert failed: "+err.Error())
}
}
}
}
// 加入缓存Map
this.locker.Lock()
var nsKey = &models.NSKey{
Id: key.Id,
DomainId: domainId,
ZoneId: zoneId,
Algo: key.Algo,
Secret: key.Secret,
SecretType: key.SecretType,
Version: key.Version,
}
if zoneId > 0 {
keyList, ok := this.zoneKeyMap[zoneId]
if ok {
keyList.Add(nsKey)
} else {
keyList = models.NewNSKeys()
keyList.Add(nsKey)
this.zoneKeyMap[zoneId] = keyList
}
} else if domainId > 0 {
keyList, ok := this.domainKeyMap[domainId]
if ok {
keyList.Add(nsKey)
} else {
keyList = models.NewNSKeys()
keyList.Add(nsKey)
this.domainKeyMap[domainId] = keyList
}
}
this.locker.Unlock()
}

View File

@@ -0,0 +1,208 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"bytes"
"context"
"encoding/json"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeDNS/internal/accesslogs"
"github.com/TeaOSLab/EdgeDNS/internal/configs"
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
"github.com/TeaOSLab/EdgeDNS/internal/events"
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
"github.com/TeaOSLab/EdgeDNS/internal/utils"
"sort"
"time"
)
var sharedNodeConfigManager = NewNodeConfigManager()
type NodeConfigManager struct {
notifyChan chan bool
ticker *time.Ticker
timezone string
lastConfigJSON []byte
lastAPINodeVersion int64
lastAPINodeAddrs []string // 以前的API节点地址
}
func NewNodeConfigManager() *NodeConfigManager {
return &NodeConfigManager{
notifyChan: make(chan bool, 2),
ticker: time.NewTicker(3 * time.Minute),
}
}
func (this *NodeConfigManager) Start() {
for {
select {
case <-this.ticker.C:
case <-this.notifyChan:
}
err := this.Loop()
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("NODE_CONFIG_MANAGER", err.Error())
} else {
remotelogs.Error("NODE_CONFIG_MANAGER", err.Error())
}
}
}
}
func (this *NodeConfigManager) Loop() error {
client, err := rpc.SharedRPC()
if err != nil {
return err
}
resp, err := client.NSNodeRPC.FindCurrentNSNodeConfig(client.Context(), &pb.FindCurrentNSNodeConfigRequest{})
if err != nil {
return err
}
var configJSON = resp.NsNodeJSON
if len(configJSON) == 0 {
return nil
}
// 检查是否有变化
if bytes.Equal(this.lastConfigJSON, configJSON) {
return nil
}
this.lastConfigJSON = configJSON
var config = &dnsconfigs.NSNodeConfig{}
err = json.Unmarshal(configJSON, config)
if err != nil {
return err
}
err = config.Init(context.TODO())
if err != nil {
return err
}
sharedNodeConfig = config
configs.SharedNodeConfig = config
this.reload(config)
events.Notify(events.EventReload)
return nil
}
func (this *NodeConfigManager) NotifyChange() {
select {
case this.notifyChan <- true:
default:
}
}
// 刷新配置
func (this *NodeConfigManager) reload(config *dnsconfigs.NSNodeConfig) {
teaconst.IsPlus = config.IsPlus
accesslogs.SharedDNSFileWriter().SetDirByPolicyPath(config.AccessLogFilePath)
accesslogs.SharedDNSFileWriter().SetRotateConfig(config.AccessLogRotate)
needWriteFile := config.AccessLogWriteTargets == nil || config.AccessLogWriteTargets.File || config.AccessLogWriteTargets.ClickHouse
if needWriteFile {
_ = accesslogs.SharedDNSFileWriter().EnsureInit()
} else {
_ = accesslogs.SharedDNSFileWriter().Close()
}
// timezone
var timeZone = config.TimeZone
if len(timeZone) == 0 {
timeZone = "Asia/Shanghai"
}
if this.timezone != timeZone {
location, err := time.LoadLocation(timeZone)
if err != nil {
remotelogs.Error("TIMEZONE", "change time zone failed: "+err.Error())
return
}
remotelogs.Println("TIMEZONE", "change time zone to '"+timeZone+"'")
time.Local = location
this.timezone = timeZone
}
// API Node地址这里不限制是否为空因为在为空时仍然要有对应的处理
this.changeAPINodeAddrs(config.APINodeAddrs)
}
// 检查API节点地址
func (this *NodeConfigManager) changeAPINodeAddrs(apiNodeAddrs []*serverconfigs.NetworkAddressConfig) {
var addrs = []string{}
for _, addr := range apiNodeAddrs {
err := addr.Init()
if err != nil {
remotelogs.Error("NODE", "changeAPINodeAddrs: validate api node address '"+configutils.QuoteIP(addr.Host)+":"+addr.PortRange+"' failed: "+err.Error())
} else {
addrs = append(addrs, addr.FullAddresses()...)
}
}
sort.Strings(addrs)
if utils.EqualStrings(this.lastAPINodeAddrs, addrs) {
return
}
this.lastAPINodeAddrs = addrs
config, err := configs.LoadAPIConfig()
if err != nil {
remotelogs.Error("NODE", "changeAPINodeAddrs: "+err.Error())
return
}
if config == nil {
return
}
var oldEndpoints = config.RPCEndpoints
rpcClient, err := rpc.SharedRPC()
if err != nil {
return
}
if len(addrs) > 0 {
this.lastAPINodeVersion++
var v = this.lastAPINodeVersion
// 异步检测,防止阻塞
go func(v int64) {
// 测试新的API节点地址
if rpcClient.TestEndpoints(addrs) {
config.RPCEndpoints = addrs
} else {
config.RPCEndpoints = oldEndpoints
this.lastAPINodeAddrs = nil // 恢复为空,以便于下次更新重试
}
// 检查测试中间有无新的变更
if v != this.lastAPINodeVersion {
return
}
err = rpcClient.UpdateConfig(config)
if err != nil {
remotelogs.Error("NODE", "changeAPINodeAddrs: update rpc config failed: "+err.Error())
}
}(v)
return
}
err = rpcClient.UpdateConfig(config)
if err != nil {
remotelogs.Error("NODE", "changeAPINodeAddrs: update rpc config failed: "+err.Error())
}
}

View File

@@ -0,0 +1,254 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeDNS/internal/dbs"
"github.com/TeaOSLab/EdgeDNS/internal/models"
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
"sync"
"time"
)
// RecordManager 记录管理器
type RecordManager struct {
recordsMap map[int64]*models.DomainRecords // domainId => RecordsMap
db *dbs.DB
locker sync.RWMutex
version int64
notifier chan bool
}
// NewRecordManager 获取新记录管理器对象
func NewRecordManager(db *dbs.DB) *RecordManager {
return &RecordManager{
db: db,
recordsMap: map[int64]*models.DomainRecords{},
notifier: make(chan bool, 8),
}
}
// Start 启动自动任务
func (this *RecordManager) Start() {
remotelogs.Println("RECORD_MANAGER", "starting ...")
// 从本地数据库中加载数据
err := this.Load()
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("RECORD_MANAGER", "load failed: "+err.Error())
} else {
remotelogs.Error("RECORD_MANAGER", "load failed: "+err.Error())
}
}
// 初始化运行
err = this.LoopAll()
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("RECORD_MANAGER", "loop failed: "+err.Error())
} else {
remotelogs.Error("RECORD_MANAGER", "loop failed: "+err.Error())
}
}
// 更新
var ticker = time.NewTicker(30 * time.Second)
for {
select {
case <-ticker.C:
case <-this.notifier:
}
err := this.LoopAll()
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("RECORD_MANAGER", "loop failed: "+err.Error())
} else {
remotelogs.Error("RECORD_MANAGER", "loop failed: "+err.Error())
}
}
}
}
// Load 从数据库中加载数据
func (this *RecordManager) Load() error {
var offset = 0
var size = 10000
for {
records, err := this.db.ListRecords(offset, size)
if err != nil {
return err
}
if len(records) == 0 {
break
}
this.locker.Lock()
for _, record := range records {
domainRecords, ok := this.recordsMap[record.DomainId]
if !ok {
domainRecords = models.NewDomainRecords()
this.recordsMap[record.DomainId] = domainRecords
}
domainRecords.Add(record)
if record.Version > this.version {
this.version = record.Version
}
}
this.locker.Unlock()
offset += size
}
if this.version > 0 {
this.version++
}
return nil
}
func (this *RecordManager) LoopAll() error {
for {
hasNext, err := this.Loop()
if err != nil {
return err
}
if !hasNext {
break
}
}
return nil
}
// Loop 单次循环任务
func (this *RecordManager) Loop() (hasNext bool, err error) {
client, err := rpc.SharedRPC()
if err != nil {
return false, err
}
resp, err := client.NSRecordRPC.ListNSRecordsAfterVersion(client.Context(), &pb.ListNSRecordsAfterVersionRequest{
Version: this.version,
Size: 20000,
})
if err != nil {
return false, err
}
var records = resp.NsRecords
if len(records) == 0 {
return false, nil
}
for _, record := range records {
this.processRecord(record)
if record.Version > this.version {
this.version = record.Version
}
}
this.version++
return true, nil
}
func (this *RecordManager) FindRecords(domainId int64, routeCodes []string, recordName string, recordType dnsconfigs.RecordType, strictMode bool) (records []*models.NSRecord, routeCode string) {
this.locker.RLock()
domainRecords, ok := this.recordsMap[domainId]
if ok {
records, routeCode = domainRecords.Find(routeCodes, recordName, recordType, sharedNodeConfig.Answer, strictMode)
}
this.locker.RUnlock()
return
}
// NotifyUpdate 通知更新
func (this *RecordManager) NotifyUpdate() {
select {
case this.notifier <- true:
default:
}
}
// 处理单条记录
func (this *RecordManager) processRecord(record *pb.NSRecord) {
if record.NsDomain == nil {
return
}
if !record.IsOn || record.IsDeleted {
this.locker.Lock()
domainRecords, ok := this.recordsMap[record.NsDomain.Id]
if ok {
domainRecords.Remove(record.Id)
}
this.locker.Unlock()
// 从数据库中删除
if this.db != nil {
err := this.db.DeleteRecord(record.Id)
if err != nil {
remotelogs.Error("RECORD_MANAGER", "delete record from db failed: "+err.Error())
}
}
return
}
// 存入数据库
if this.db != nil {
exists, err := this.db.ExistsRecord(record.Id)
if err != nil {
remotelogs.Error("RECORD_MANAGER", "query failed: "+err.Error())
} else {
var routeIds = []string{}
for _, route := range record.NsRoutes {
routeIds = append(routeIds, route.Code)
}
if exists {
err = this.db.UpdateRecord(record.Id, record.NsDomain.Id, record.Name, record.Type, record.Value, record.MxPriority, record.SrvPriority, record.SrvWeight, record.SrvPort, record.CaaFlag, record.CaaTag, record.Ttl, record.Weight, routeIds, record.Version)
if err != nil {
remotelogs.Error("RECORD_MANAGER", "update failed: "+err.Error())
}
} else {
err = this.db.InsertRecord(record.Id, record.NsDomain.Id, record.Name, record.Type, record.Value, record.MxPriority, record.SrvPriority, record.SrvWeight, record.SrvPort, record.CaaFlag, record.CaaTag, record.Ttl, record.Weight, routeIds, record.Version)
if err != nil {
remotelogs.Error("RECORD_MANAGER", "insert failed: "+err.Error())
}
}
}
}
// 加入缓存Map
this.locker.Lock()
domainRecords, ok := this.recordsMap[record.NsDomain.Id]
if !ok {
domainRecords = models.NewDomainRecords()
this.recordsMap[record.NsDomain.Id] = domainRecords
}
var routeIds = []string{}
for _, r := range record.NsRoutes {
routeIds = append(routeIds, r.Code)
}
domainRecords.Add(&models.NSRecord{
Id: record.Id,
Name: record.Name,
Type: record.Type,
Value: record.Value,
MXPriority: record.MxPriority,
SRVPriority: record.SrvPriority,
SRVWeight: record.SrvWeight,
SRVPort: record.SrvPort,
CAAFlag: record.CaaFlag,
CAATag: record.CaaTag,
Ttl: record.Ttl,
Weight: record.Weight,
Version: record.Version,
RouteIds: routeIds,
DomainId: record.NsDomain.Id,
})
this.locker.Unlock()
}

View File

@@ -0,0 +1,45 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"github.com/TeaOSLab/EdgeDNS/internal/dbs"
"github.com/iwind/TeaGo/Tea"
_ "github.com/iwind/TeaGo/bootstrap"
"github.com/iwind/TeaGo/logs"
"testing"
)
func TestRecordManager_Loop(t *testing.T) {
var db = dbs.NewDB(Tea.Root + "/data/data.db")
err := db.Init()
if err != nil {
t.Fatal(err)
}
manager := NewRecordManager(db)
for i := 0; i < 10; i++ {
_, err := manager.Loop()
if err != nil {
t.Fatal(err)
}
}
logs.PrintAsJSON(manager.recordsMap, t)
t.Log("version:", manager.version)
}
func TestRecordManager_Load(t *testing.T) {
db := dbs.NewDB(Tea.Root + "/data/data.db")
err := db.Init()
if err != nil {
t.Fatal(err)
}
manager := NewRecordManager(db)
err = manager.Load()
if err != nil {
t.Fatal(err)
}
logs.PrintAsJSON(manager.recordsMap, t)
t.Log("version:", manager.version)
}

View File

@@ -0,0 +1,481 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/iplibrary"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeDNS/internal/agents"
"github.com/TeaOSLab/EdgeDNS/internal/dbs"
"github.com/TeaOSLab/EdgeDNS/internal/models"
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/types"
"net"
"sort"
"strconv"
"strings"
"sync"
"time"
)
// RouteManager 线路管理器
type RouteManager struct {
allRouteMap map[int64]*models.NSRoute // routeId => route
userRouteMap map[int64][]int64 // userId => sorted routeIds
db *dbs.DB
version int64
locker sync.RWMutex
notifier chan bool
ispRouteMap map[string]string // name => code
chinaRouteMap map[string]string // name => code
worldRouteMap map[string]string // name => code
}
// NewRouteManager 获取新线路管理器对象
func NewRouteManager(db *dbs.DB) *RouteManager {
return &RouteManager{
db: db,
allRouteMap: map[int64]*models.NSRoute{},
userRouteMap: map[int64][]int64{},
notifier: make(chan bool, 8),
ispRouteMap: map[string]string{},
chinaRouteMap: map[string]string{},
worldRouteMap: map[string]string{},
}
}
// Start 启动自动任务
func (this *RouteManager) Start() {
remotelogs.Println("ROUTE_MANAGER", "starting ...")
// 初始化公共线路
this.loadDefaultRoutes()
// 从本地数据库中加载数据
err := this.Load()
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("ROUTE_MANAGER", "load failed: "+err.Error())
} else {
remotelogs.Error("ROUTE_MANAGER", "load failed: "+err.Error())
}
}
// 初始化运行
err = this.LoopAll()
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("ROUTE_MANAGER", "loop failed: "+err.Error())
} else {
remotelogs.Error("ROUTE_MANAGER", "loop failed: "+err.Error())
}
}
// 更新
var ticker = time.NewTicker(1 * time.Minute)
for {
select {
case <-ticker.C:
case <-this.notifier:
}
err := this.LoopAll()
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("ROUTE_MANAGER", "loop failed: "+err.Error())
} else {
remotelogs.Error("ROUTE_MANAGER", "loop failed: "+err.Error())
}
}
}
}
// Load 从数据库中加载数据
func (this *RouteManager) Load() error {
var offset int64 = 0
var size int64 = 10000
for {
routes, err := this.db.ListRoutes(offset, size)
if err != nil {
return err
}
if len(routes) == 0 {
break
}
this.locker.Lock()
for _, route := range routes {
this.addRoute(route)
if route.Version > this.version {
this.version = route.Version
}
}
this.locker.Unlock()
offset += size
}
if this.version > 0 {
this.version++
}
return nil
}
func (this *RouteManager) LoopAll() error {
for {
hasNext, err := this.Loop()
if err != nil {
return err
}
if !hasNext {
break
}
}
return nil
}
// Loop 单次循环任务
func (this *RouteManager) Loop() (hasNext bool, err error) {
client, err := rpc.SharedRPC()
if err != nil {
return false, err
}
resp, err := client.NSRouteRPC.ListNSRoutesAfterVersion(client.Context(), &pb.ListNSRoutesAfterVersionRequest{
Version: this.version,
Size: 20000,
})
if err != nil {
return false, err
}
var routes = resp.NsRoutes
if len(routes) == 0 {
return false, nil
}
for _, route := range routes {
this.processRoute(route)
if route.Version > this.version {
this.version = route.Version
}
}
this.version++
return true, nil
}
// FindRouteCodes 查找一个地址对应的线路
func (this *RouteManager) FindRouteCodes(ip string, domainUserId int64) (result []string) {
var netIP = net.ParseIP(ip)
if len(netIP) == 0 {
return nil
}
// 自定义route
this.locker.RLock()
// 先查找用户自定义的
if domainUserId > 0 {
var userRouteIds = this.userRouteMap[domainUserId]
for _, routeId := range userRouteIds {
route, ok := this.allRouteMap[routeId]
if ok && route.Contains(netIP) {
result = append(result, route.RealCode())
}
}
}
// 再查找公共的
var publicRouteIds = this.userRouteMap[0]
for _, routeId := range publicRouteIds {
route, ok := this.allRouteMap[routeId]
if ok && route.Contains(netIP) {
result = append(result, route.RealCode())
}
}
this.locker.RUnlock()
// 解析公用线路
var ipResult = iplibrary.LookupIP(ip)
if ipResult != nil && ipResult.IsOk() {
// 运营商
for _, providerCode := range ipResult.ProviderCodes() {
code, ok := this.ispRouteMap[providerCode]
if ok {
result = append(result, code)
// 单次只能有一个匹配
break
}
}
// 省|州|城市
if ipResult.ProvinceId() > 0 {
result = append(result, "region:province:"+types.String(ipResult.ProvinceId()))
}
if ipResult.CityId() > 0 {
result = append(result, "region:city:"+types.String(ipResult.CityId()))
}
if ipResult.TownId() > 0 {
result = append(result, "region:town:"+types.String(ipResult.TownId()))
}
// 中国省市
for _, provinceCode := range ipResult.ProvinceCodes() {
// 中国
code, ok := this.chinaRouteMap[provinceCode]
if ok {
result = append(result, code)
// 兼容以前的拼写错误
switch code {
case "china:province:hebei":
result = append(result, "china:province:heibei")
case "china:province:heibei":
result = append(result, "china:province:hebei")
case "china:jilin":
result = append(result, "china:province:jilin")
case "china:province:jilin":
result = append(result, "china:jilin")
}
// 香港
switch code {
case dnsconfigs.ChinaProvinceCodeHK:
result = append(result, dnsconfigs.WorldRegionCodeHK, dnsconfigs.WorldRegionCodeChinaAbroad)
// 澳门
case dnsconfigs.ChinaProvinceCodeMO:
result = append(result, dnsconfigs.WorldRegionCodeMO, dnsconfigs.WorldRegionCodeChinaAbroad)
// 台湾
case dnsconfigs.ChinaProvinceCodeTW:
result = append(result, dnsconfigs.WorldRegionCodeTW, dnsconfigs.WorldRegionCodeChinaAbroad)
default:
result = append(result, dnsconfigs.WorldRegionCodeChinaMainland)
}
// 单次只能有一个匹配
break
}
}
// 国家/地区
for _, countryCode := range ipResult.CountryCodes() {
code, ok := this.worldRouteMap[countryCode]
if ok {
// 中国全境world:CN必须优先于「海外」匹配否则前面中国省市若误判为 HK/MO/TW 已加入 world:CN:abroad 时会先命中海外线
if code == dnsconfigs.WorldRegionCodeChina {
result = append([]string{code}, result...)
} else {
result = append(result, code)
// 中国海外
if code != dnsconfigs.WorldRegionCodeChina {
result = append(result, dnsconfigs.WorldRegionCodeChinaAbroad)
}
}
// 单次只能有一个匹配
break
}
}
}
// 搜索引擎线路
if agents.SharedManager != nil {
var agentCode = agents.SharedManager.LookupIP(ip)
if len(agentCode) > 0 {
result = append(result, "agent:"+agentCode, "agent" /** 所有搜索引擎 **/)
}
}
return
}
// NotifyUpdate 通知更新
func (this *RouteManager) NotifyUpdate() {
select {
case this.notifier <- true:
default:
}
}
func (this *RouteManager) loadDefaultRoutes() {
for _, route := range dnsconfigs.AllDefaultISPRoutes {
for _, name := range route.AliasNames {
this.ispRouteMap[name] = route.Code
}
}
for _, route := range dnsconfigs.AllDefaultChinaProvinceRoutes {
for _, name := range route.AliasNames {
this.chinaRouteMap[name] = route.Code
}
}
for _, route := range dnsconfigs.AllDefaultWorldRegionRoutes {
for _, name := range route.AliasNames {
this.worldRouteMap[name] = route.Code
}
// 用线路 Code 中的国家/地区 ISO 码做映射,使 IP 库返回的 ISO 码(如 US、CN、HK能命中对应线路
if strings.HasPrefix(route.Code, "world:") {
parts := strings.Split(route.Code, ":")
if len(parts) == 2 && len(parts[1]) == 2 {
this.worldRouteMap[parts[1]] = route.Code
}
if len(parts) == 3 && len(parts[2]) == 2 {
this.worldRouteMap[strings.ToUpper(parts[2])] = route.Code
}
}
}
}
// 添加线路
func (this *RouteManager) addRoute(route *models.NSRoute) {
// 不需要加锁,因为此函数均在锁内调用
// 从老的用户中删除
oldRoute, ok := this.allRouteMap[route.Id]
if ok {
var oldUserId = oldRoute.UserId
if oldUserId != route.UserId {
userRouteIds, ok := this.userRouteMap[oldUserId]
if ok {
this.userRouteMap[oldUserId] = this.removeId(userRouteIds, route.Id)
if len(userRouteIds) == 0 {
delete(this.userRouteMap, oldUserId)
}
}
}
}
// 添加
this.allRouteMap[route.Id] = route
userRouteIds, ok := this.userRouteMap[route.UserId]
if ok {
// 重新按优先级、排序、ID排序
var userRoutes = []*models.NSRoute{}
for _, userRouteId := range userRouteIds {
userRoute, ok := this.allRouteMap[userRouteId]
if ok {
userRoutes = append(userRoutes, userRoute)
}
}
if !lists.ContainsInt64(userRouteIds, route.Id) {
userRoutes = append(userRoutes, route)
}
sort.Slice(userRoutes, func(i, j int) bool {
var userRoute1 = userRoutes[i]
var userRoute2 = userRoutes[j]
if userRoute1.Priority != userRoute2.Priority {
return userRoute1.Priority > userRoute2.Priority
}
if userRoute1.Order != userRoute2.Order {
return userRoute1.Order > userRoute2.Order
}
return userRoute1.Id < userRoute2.Id
})
var newUserRouteIds = []int64{}
for _, userRoute := range userRoutes {
newUserRouteIds = append(newUserRouteIds, userRoute.Id)
}
this.userRouteMap[route.UserId] = newUserRouteIds
} else {
this.userRouteMap[route.UserId] = []int64{route.Id}
}
}
// 删除线路
func (this *RouteManager) removePBRoute(route *pb.NSRoute) {
delete(this.allRouteMap, route.Id)
userRouteIds, ok := this.userRouteMap[route.UserId]
if ok {
userRouteIds = this.removeId(userRouteIds, route.Id)
if len(userRouteIds) == 0 {
delete(this.userRouteMap, route.UserId)
} else {
this.userRouteMap[route.UserId] = userRouteIds
}
}
}
// 处理线路
func (this *RouteManager) processRoute(route *pb.NSRoute) {
if !route.IsOn || route.IsDeleted {
this.locker.Lock()
this.removePBRoute(route)
this.locker.Unlock()
// 从数据库中删除
if this.db != nil {
err := this.db.DeleteRoute(route.Id)
if err != nil {
remotelogs.Error("ROUTE_MANAGER", "delete route from db failed: "+err.Error())
}
}
return
}
// 存入数据库
if this.db != nil {
exists, err := this.db.ExistsRoute(route.Id)
if err != nil {
remotelogs.Error("ROUTE_MANAGER", "query failed: "+err.Error())
} else {
if exists {
err = this.db.UpdateRoute(route.Id, route.UserId, route.RangesJSON, route.Order, route.Priority, route.Version)
if err != nil {
remotelogs.Error("ROUTE_MANAGER", "update failed: "+err.Error())
}
} else {
err = this.db.InsertRoute(route.Id, route.UserId, route.RangesJSON, route.Order, route.Priority, route.Version)
if err != nil {
remotelogs.Error("ROUTE_MANAGER", "insert failed: "+err.Error())
}
}
}
}
ranges, err := models.InitRangesFromJSON(route.RangesJSON)
if err != nil {
remotelogs.Error("ROUTE_MANAGER", "decode routes '"+strconv.FormatInt(route.Id, 10)+"' failed: "+err.Error())
return
}
var nsRoute = &models.NSRoute{
Id: route.Id,
Ranges: ranges,
Priority: route.Priority,
Order: route.Order,
UserId: route.UserId,
Version: route.Version,
}
this.locker.Lock()
this.addRoute(nsRoute)
this.locker.Unlock()
}
// 从一组ID中删除某个ID
func (this *RouteManager) removeId(ids []int64, id int64) []int64 {
var result = []int64{}
for _, id2 := range ids {
if id2 == id {
continue
}
result = append(result, id2)
}
return result
}

View File

@@ -0,0 +1,138 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/TeaOSLab/EdgeDNS/internal/dbs"
"github.com/TeaOSLab/EdgeDNS/internal/events"
"github.com/TeaOSLab/EdgeDNS/internal/models"
"github.com/iwind/TeaGo/Tea"
_ "github.com/iwind/TeaGo/bootstrap"
"github.com/iwind/TeaGo/logs"
"testing"
)
func TestRouteManager_Loop(t *testing.T) {
var db = dbs.NewDB(Tea.Root + "/data/data.db")
err := db.Init()
if err != nil {
t.Fatal(err)
}
var manager = NewRouteManager(db)
for i := 0; i < 10; i++ {
_, err := manager.Loop()
if err != nil {
t.Fatal(err)
}
}
logs.PrintAsJSON(manager.allRouteMap, t)
}
func TestRouteManager_Load(t *testing.T) {
var db = dbs.NewDB(Tea.Root + "/data/data.db")
err := db.Init()
if err != nil {
t.Fatal(err)
}
var manager = NewRouteManager(db)
err = manager.Load()
if err != nil {
t.Fatal(err)
}
logs.PrintAsJSON(manager.allRouteMap, t)
t.Log("version:", manager.version)
}
func TestRouteManager_AddRoute(t *testing.T) {
var manager = NewRouteManager(nil)
manager.addRoute(&models.NSRoute{
Id: 1,
UserId: 0,
Priority: 0,
})
manager.addRoute(&models.NSRoute{
Id: 2,
UserId: 0,
Priority: 1,
})
manager.addRoute(&models.NSRoute{
Id: 3,
UserId: 0,
Priority: 2,
Order: 1,
})
manager.addRoute(&models.NSRoute{
Id: 4,
UserId: 0,
Priority: 2,
Order: 1,
})
manager.addRoute(&models.NSRoute{
Id: 4,
UserId: 0,
Priority: 2,
Order: 1,
})
manager.addRoute(&models.NSRoute{
Id: 4,
UserId: 1,
Priority: 2,
Order: 1,
})
manager.addRoute(&models.NSRoute{
Id: 5,
UserId: 1,
Priority: 2,
Order: 1,
})
logs.PrintAsJSON(manager.allRouteMap, t)
logs.PrintAsJSON(manager.userRouteMap, t)
}
func TestRouteManager_FindRouteCodes(t *testing.T) {
events.Notify(events.EventLoaded)
var manager = NewRouteManager(nil)
manager.loadDefaultRoutes()
{
var r = &models.NSRoute{
Id: 1,
Ranges: []dnsconfigs.NSRouteRangeInterface{
&dnsconfigs.NSRouteRangeIPRange{
IPFrom: "192.168.1.1",
IPTo: "192.168.1.200",
},
&dnsconfigs.NSRouteRangeIPRange{
IPFrom: "192.168.1.200",
IPTo: "192.168.1.255",
},
&dnsconfigs.NSRouteRangeIPRange{
IPFrom: "127.0.0.1",
IPTo: "127.0.0.1",
},
},
}
for _, rr := range r.Ranges {
err := rr.Init()
if err != nil {
t.Fatal(err)
}
}
manager.addRoute(r)
}
for _, ip := range []string{
"192.168.1.100",
"192.168.1.201",
"192.168.2.1",
"127.0.0.1",
"111.197.174.111",
"202.109.116.116",
"8.8.8.8",
} {
t.Log(ip+": ", manager.FindRouteCodes(ip, 0))
}
}

View File

@@ -0,0 +1,43 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build !arm64 && plus
package nodes
import (
"bytes"
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/logs"
"os"
"syscall"
)
// 处理异常
func (this *DNSNode) handlePanic() {
// 如果是在前台运行就直接返回
backgroundEnv, _ := os.LookupEnv("EdgeBackground")
if backgroundEnv != "on" {
return
}
var panicFile = Tea.Root + "/logs/panic.log"
// 分析panic
data, err := os.ReadFile(panicFile)
if err == nil {
var index = bytes.Index(data, []byte("panic:"))
if index >= 0 {
remotelogs.Error("NODE", "系统错误,请上报给开发者: "+string(data[index:]))
}
}
fp, err := os.OpenFile(panicFile, os.O_CREATE|os.O_TRUNC|os.O_WRONLY|os.O_APPEND, 0777)
if err != nil {
logs.Println("NODE", "open 'panic.log' failed: "+err.Error())
return
}
err = syscall.Dup2(int(fp.Fd()), int(os.Stderr.Fd()))
if err != nil {
logs.Println("NODE", "write to 'panic.log' failed: "+err.Error())
}
}

View File

@@ -0,0 +1,9 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build arm64 && plus
package nodes
// 处理异常
func (this *DNSNode) handlePanic() {
}

View File

@@ -0,0 +1,226 @@
package nodes
import (
"encoding/json"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeDNS/internal/configs"
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
"github.com/TeaOSLab/EdgeDNS/internal/events"
"github.com/TeaOSLab/EdgeDNS/internal/monitor"
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
"github.com/TeaOSLab/EdgeDNS/internal/utils"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/maps"
"github.com/shirou/gopsutil/v3/cpu"
"github.com/shirou/gopsutil/v3/disk"
"os"
"runtime"
"strings"
"time"
)
type NodeStatusExecutor struct {
isFirstTime bool
cpuUpdatedTime time.Time
cpuLogicalCount int
cpuPhysicalCount int
apiCallStat *rpc.CallStat
ticker *time.Ticker
}
func NewNodeStatusExecutor() *NodeStatusExecutor {
return &NodeStatusExecutor{
ticker: time.NewTicker(30 * time.Second),
apiCallStat: rpc.NewCallStat(10),
}
}
func (this *NodeStatusExecutor) Listen() {
this.isFirstTime = true
this.cpuUpdatedTime = time.Now()
this.update()
events.On(events.EventQuit, func() {
remotelogs.Println("NODE_STATUS", "quit executor")
this.ticker.Stop()
})
for range this.ticker.C {
this.isFirstTime = false
this.update()
}
}
func (this *NodeStatusExecutor) update() {
var status = &nodeconfigs.NodeStatus{}
status.BuildVersion = teaconst.Version
status.BuildVersionCode = utils.VersionToLong(teaconst.Version)
status.OS = runtime.GOOS
status.Arch = runtime.GOARCH
status.ConfigVersion = 0
status.IsActive = true
status.ConnectionCount = 0 // TODO 将来显示连接数
apiSuccessPercent, apiAvgCostSeconds := this.apiCallStat.Sum()
status.APISuccessPercent = apiSuccessPercent
status.APIAvgCostSeconds = apiAvgCostSeconds
exe, _ := os.Executable()
status.ExePath = exe
hostname, _ := os.Hostname()
status.Hostname = hostname
this.updateCPU(status)
this.updateMem(status)
this.updateLoad(status)
this.updateDisk(status)
status.UpdatedAt = time.Now().Unix()
status.Timestamp = status.UpdatedAt
// 发送数据
jsonData, err := json.Marshal(status)
if err != nil {
remotelogs.Error("NODE_STATUS", "serial NodeStatus fail: "+err.Error())
return
}
rpcClient, err := rpc.SharedRPC()
if err != nil {
remotelogs.Error("NODE_STATUS", "failed to open rpc: "+err.Error())
return
}
var nodeId = int64(0)
sharedAPIConfig, _ := configs.SharedAPIConfig()
if sharedAPIConfig != nil {
nodeId = sharedAPIConfig.NumberId
}
var before = time.Now()
_, err = rpcClient.NSNodeRPC.UpdateNSNodeStatus(rpcClient.Context(), &pb.UpdateNSNodeStatusRequest{
NodeId: nodeId,
StatusJSON: jsonData,
})
var costSeconds = time.Since(before).Seconds()
this.apiCallStat.Add(err == nil, costSeconds)
if err != nil {
if !rpc.IsConnError(err) {
remotelogs.Error("NODE_STATUS", "rpc UpdateNSNodeStatus() failed: "+err.Error())
} else {
remotelogs.Debug("NODE_STATUS", "rpc UpdateNSNodeStatus() failed: "+err.Error())
}
return
}
}
// 更新CPU
func (this *NodeStatusExecutor) updateCPU(status *nodeconfigs.NodeStatus) {
duration := time.Duration(0)
if this.isFirstTime {
duration = 100 * time.Millisecond
}
percents, err := cpu.Percent(duration, false)
if err != nil {
status.Error = "cpu.Percent(): " + err.Error()
return
}
if len(percents) == 0 {
return
}
status.CPUUsage = percents[0] / 100
if time.Since(this.cpuUpdatedTime) > 300*time.Second { // 每隔5分钟才会更新一次
this.cpuUpdatedTime = time.Now()
status.CPULogicalCount, err = cpu.Counts(true)
if err != nil {
status.Error = "cpu.Counts(): " + err.Error()
return
}
status.CPUPhysicalCount, err = cpu.Counts(false)
if err != nil {
status.Error = "cpu.Counts(): " + err.Error()
return
}
this.cpuLogicalCount = status.CPULogicalCount
this.cpuPhysicalCount = status.CPUPhysicalCount
} else {
status.CPULogicalCount = this.cpuLogicalCount
status.CPUPhysicalCount = this.cpuPhysicalCount
}
// 记录监控数据
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemCPU, maps.Map{
"usage": status.CPUUsage,
"cores": runtime.NumCPU(),
})
}
// 更新硬盘
func (this *NodeStatusExecutor) updateDisk(status *nodeconfigs.NodeStatus) {
partitions, err := disk.Partitions(false)
if err != nil {
remotelogs.Error("NODE_STATUS", err.Error())
return
}
lists.Sort(partitions, func(i int, j int) bool {
p1 := partitions[i]
p2 := partitions[j]
return p1.Mountpoint > p2.Mountpoint
})
// 当前所在的fs
var rootFS = ""
var rootTotal = uint64(0)
if lists.ContainsString([]string{"darwin", "linux", "freebsd"}, runtime.GOOS) {
for _, p := range partitions {
if p.Mountpoint == "/" {
rootFS = p.Fstype
usage, _ := disk.Usage(p.Mountpoint)
if usage != nil {
rootTotal = usage.Total
}
break
}
}
}
var total = rootTotal
var totalUsage = uint64(0)
var maxUsage = float64(0)
for _, partition := range partitions {
if runtime.GOOS != "windows" && !strings.Contains(partition.Device, "/") && !strings.Contains(partition.Device, "\\") {
continue
}
// 跳过不同fs的
if len(rootFS) > 0 && rootFS != partition.Fstype {
continue
}
usage, err := disk.Usage(partition.Mountpoint)
if err != nil {
continue
}
if partition.Mountpoint != "/" && (usage.Total != rootTotal || total == 0) {
total += usage.Total
}
totalUsage += usage.Used
if usage.UsedPercent >= maxUsage {
maxUsage = usage.UsedPercent
status.DiskMaxUsagePartition = partition.Mountpoint
}
}
status.DiskTotal = total
if total > 0 {
status.DiskUsage = float64(totalUsage) / float64(total)
}
status.DiskMaxUsage = maxUsage / 100
}

View File

@@ -0,0 +1,27 @@
package nodes
import (
"github.com/shirou/gopsutil/v3/cpu"
"testing"
"time"
)
func TestNodeStatusExecutor_CPU(t *testing.T) {
countLogicCPU, err := cpu.Counts(true)
if err != nil {
t.Fatal(err)
}
t.Log("logic count:", countLogicCPU)
countPhysicalCPU, err := cpu.Counts(false)
if err != nil {
t.Fatal(err)
}
t.Log("physical count:", countPhysicalCPU)
percents, err := cpu.Percent(100*time.Millisecond, false)
if err != nil {
t.Fatal(err)
}
t.Log(percents)
}

View File

@@ -0,0 +1,58 @@
//go:build !windows
// +build !windows
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeDNS/internal/monitor"
"github.com/iwind/TeaGo/maps"
"github.com/shirou/gopsutil/v3/load"
"github.com/shirou/gopsutil/v3/mem"
)
// 更新内存
func (this *NodeStatusExecutor) updateMem(status *nodeconfigs.NodeStatus) {
stat, err := mem.VirtualMemory()
if err != nil {
return
}
// 重新计算内存
if stat.Total > 0 {
stat.Used = stat.Total - stat.Free - stat.Buffers - stat.Cached
status.MemoryUsage = float64(stat.Used) / float64(stat.Total)
}
status.MemoryTotal = stat.Total
// 记录监控数据
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemMemory, maps.Map{
"usage": status.MemoryUsage,
"total": status.MemoryTotal,
"used": stat.Used,
})
}
// 更新负载
func (this *NodeStatusExecutor) updateLoad(status *nodeconfigs.NodeStatus) {
stat, err := load.Avg()
if err != nil {
status.Error = err.Error()
return
}
if stat == nil {
status.Error = "load is nil"
return
}
status.Load1m = stat.Load1
status.Load5m = stat.Load5
status.Load15m = stat.Load15
// 记录监控数据
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemLoad, maps.Map{
"load1m": status.Load1m,
"load5m": status.Load5m,
"load15m": status.Load15m,
})
}

View File

@@ -0,0 +1,102 @@
//go:build windows
// +build windows
package nodes
import (
"context"
"github.com/shirou/gopsutil/v3/cpu"
"github.com/shirou/gopsutil/v3/mem"
"math"
"sync"
"time"
)
type WindowsLoadValue struct {
Timestamp int64
Value int
}
var windowsLoadValues = []*WindowsLoadValue{}
var windowsLoadLocker = &sync.Mutex{}
// 更新内存
func (this *NodeStatusExecutor) updateMem(status *NodeStatus) {
stat, err := mem.VirtualMemory()
if err != nil {
status.Error = err.Error()
return
}
status.MemoryUsage = stat.UsedPercent
status.MemoryTotal = stat.Total
}
// 更新负载
func (this *NodeStatusExecutor) updateLoad(status *NodeStatus) {
timestamp := time.Now().Unix()
currentLoad := 0
info, err := cpu.ProcInfo()
if err == nil && len(info) > 0 && info[0].ProcessorQueueLength < 1000 {
currentLoad = int(info[0].ProcessorQueueLength)
}
// 删除15分钟之前的数据
windowsLoadLocker.Lock()
result := []*WindowsLoadValue{}
for _, v := range windowsLoadValues {
if timestamp-v.Timestamp > 15*60 {
continue
}
result = append(result, v)
}
result = append(result, &WindowsLoadValue{
Timestamp: timestamp,
Value: currentLoad,
})
windowsLoadValues = result
total1 := 0
count1 := 0
total5 := 0
count5 := 0
total15 := 0
count15 := 0
for _, v := range result {
if timestamp-v.Timestamp <= 60 {
total1 += v.Value
count1++
}
if timestamp-v.Timestamp <= 300 {
total5 += v.Value
count5++
}
total15 += v.Value
count15++
}
load1 := float64(0)
load5 := float64(0)
load15 := float64(0)
if count1 > 0 {
load1 = math.Round(float64(total1*100)/float64(count1)) / 100
}
if count5 > 0 {
load5 = math.Round(float64(total5*100)/float64(count5)) / 100
}
if count15 > 0 {
load15 = math.Round(float64(total15*100)/float64(count15)) / 100
}
windowsLoadLocker.Unlock()
// 在老Windows上不显示错误
if err == context.DeadlineExceeded {
err = nil
}
status.Load1m = load1
status.Load5m = load5
status.Load15m = load15
}

View File

@@ -0,0 +1,125 @@
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeDNS/internal/accesslogs"
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
"strconv"
"time"
)
var sharedNSAccessLogQueue = NewNSAccessLogQueue()
// NSAccessLogQueue NS访问日志队列
type NSAccessLogQueue struct {
queue chan *pb.NSAccessLog
}
// NewNSAccessLogQueue 获取新对象
func NewNSAccessLogQueue() *NSAccessLogQueue {
// 队列中最大的值,超出此数量的访问日志会被抛弃
// TODO 需要可以在界面中设置
var maxSize = 20000
queue := &NSAccessLogQueue{
queue: make(chan *pb.NSAccessLog, maxSize),
}
go queue.Start()
return queue
}
// Start 开始处理访问日志
func (this *NSAccessLogQueue) Start() {
var ticker = time.NewTicker(1 * time.Second)
for range ticker.C {
err := this.loop()
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("ACCESS_LOG_QUEUE", err.Error())
} else {
remotelogs.Error("ACCESS_LOG_QUEUE", err.Error())
}
}
}
}
// Push 加入新访问日志
func (this *NSAccessLogQueue) Push(accessLog *pb.NSAccessLog) {
select {
case this.queue <- accessLog:
default:
}
}
// 上传访问日志
func (this *NSAccessLogQueue) loop() error {
var accessLogs = []*pb.NSAccessLog{}
var count = 0
var timestamp int64
var requestId = 10_000_000
Loop:
for {
select {
case accessLog := <-this.queue:
if accessLog.Timestamp > timestamp {
requestId = 10_000_000
timestamp = accessLog.Timestamp
} else {
requestId++
}
// timestamp + nodeId + requestId
accessLog.RequestId = strconv.FormatInt(accessLog.Timestamp, 10) + strconv.FormatInt(accessLog.NsNodeId, 10) + strconv.Itoa(requestId)
accessLogs = append(accessLogs, accessLog)
count++
// 每次只提交 N 条访问日志,防止网络拥堵
if count > 2000 {
break Loop
}
default:
break Loop
}
}
if len(accessLogs) == 0 {
return nil
}
var clusterId int64
var needWriteFile = true
var needReportAPI = true
if sharedNodeConfig != nil {
clusterId = sharedNodeConfig.ClusterId
if sharedNodeConfig.AccessLogWriteTargets != nil {
targets := sharedNodeConfig.AccessLogWriteTargets
needWriteFile = targets.File || targets.ClickHouse
needReportAPI = targets.MySQL
}
}
if needWriteFile {
accesslogs.SharedDNSFileWriter().WriteBatch(accessLogs, clusterId)
}
if !needReportAPI {
return nil
}
// 发送到API
client, err := rpc.SharedRPC()
if err != nil {
return err
}
_, err = client.NSAccessLogRPC.CreateNSAccessLogs(client.Context(), &pb.CreateNSAccessLogsRequest{NsAccessLogs: accessLogs})
if err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,932 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes
import (
"context"
"crypto/tls"
"encoding/base64"
"encoding/json"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeDNS/internal/agents"
"github.com/TeaOSLab/EdgeDNS/internal/models"
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
"github.com/TeaOSLab/EdgeDNS/internal/stats"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types"
"github.com/miekg/dns"
"io"
"math/rand"
"net"
"net/http"
"regexp"
"strings"
"time"
)
var sharedRecursionDNSClient = &dns.Client{}
type httpContextKey struct {
key string
}
var HTTPConnContextKey = &httpContextKey{key: "http-conn"}
const PingDomain = "ping."
// Server 服务
type Server struct {
config *ServerConfig
rawServer *dns.Server
httpsServer *http.Server
}
// NewServer 构造新服务
func NewServer(config *ServerConfig) (*Server, error) {
var server = &Server{
config: config,
}
err := server.init()
if err != nil {
return nil, err
}
return server, nil
}
// ListenAndServe 监听
func (this *Server) ListenAndServe() error {
if this.rawServer != nil {
return this.rawServer.ListenAndServe()
}
if this.httpsServer != nil {
listener, err := net.Listen("tcp", this.httpsServer.Addr)
if err != nil {
return err
}
err = this.httpsServer.ServeTLS(listener, "", "")
if err == http.ErrServerClosed {
err = nil
}
return err
}
return errors.New("the server is not initialized")
}
// Shutdown 关闭
func (this *Server) Shutdown() error {
if this.rawServer != nil {
return this.rawServer.Shutdown()
}
if this.httpsServer != nil {
return this.httpsServer.Shutdown(context.Background())
}
return errors.New("the server is not initialized")
}
// Reload 重载配置
func (this *Server) Reload(config *ServerConfig) {
this.config = config
}
// 初始化
func (this *Server) init() error {
var rawServer = &dns.Server{
ReadTimeout: 3 * time.Second,
WriteTimeout: 3 * time.Second,
}
rawServer.Handler = dns.HandlerFunc(this.handleDNSMessage)
var addr = ""
if len(this.config.Host) > 0 {
addr += configutils.QuoteIP(this.config.Host)
}
addr += ":" + types.String(this.config.Port)
rawServer.Addr = addr
switch this.config.Protocol {
case serverconfigs.ProtocolTCP:
rawServer.Net = "tcp"
case serverconfigs.ProtocolTLS:
rawServer.Net = "tcp-tls"
rawServer.TLSConfig = &tls.Config{
Certificates: nil,
GetConfigForClient: func(clientInfo *tls.ClientHelloInfo) (config *tls.Config, e error) {
return this.config.SSLPolicy.TLSConfig(), nil
},
GetCertificate: func(clientInfo *tls.ClientHelloInfo) (certificate *tls.Certificate, e error) {
return this.config.SSLPolicy.FirstCert(), nil
},
}
case serverconfigs.ProtocolHTTPS: // DoH
rawServer = nil
this.httpsServer = &http.Server{
Addr: addr,
Handler: http.HandlerFunc(this.handleHTTP),
TLSConfig: &tls.Config{
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
if this.config == nil {
return nil, errors.New("invalid 'ServerConfig.config'")
}
if this.config.SSLPolicy == nil {
return nil, errors.New("invalid 'ServerConfig.config.SSLPolicy'")
}
return this.config.SSLPolicy.TLSConfig(), nil
},
GetCertificate: func(clientInfo *tls.ClientHelloInfo) (certificate *tls.Certificate, e error) {
if this.config == nil {
return nil, errors.New("invalid 'ServerConfig.config'")
}
if this.config.SSLPolicy == nil {
return nil, errors.New("invalid 'ServerConfig.config.SSLPolicy'")
}
return this.config.SSLPolicy.FirstCert(), nil
},
},
ConnContext: func(ctx context.Context, c net.Conn) context.Context {
return context.WithValue(ctx, HTTPConnContextKey, c)
},
ReadTimeout: 5 * time.Second,
ReadHeaderTimeout: 3 * time.Second,
WriteTimeout: 3 * time.Second,
IdleTimeout: 75 * time.Second,
MaxHeaderBytes: 4096,
}
case serverconfigs.ProtocolUDP:
rawServer.Net = "udp"
}
this.rawServer = rawServer
return nil
}
// 查询递归DNS
func (this *Server) lookupRecursionDNS(req *dns.Msg, resp *dns.Msg) error {
var config = sharedNodeConfig.RecursionConfig
if config == nil {
return nil
}
if !config.IsOn {
return nil
}
// 是否允许
var domain = strings.TrimSuffix(req.Question[0].Name, ".")
if len(config.DenyDomains) > 0 && configutils.MatchDomains(config.DenyDomains, domain) {
return nil
}
if len(config.AllowDomains) > 0 && !configutils.MatchDomains(config.AllowDomains, domain) {
return nil
}
if config.UseLocalHosts {
// TODO 需要缓存文件内容
resolveConfig, err := dns.ClientConfigFromFile("/etc/resolv.conf")
if err != nil {
return err
}
if len(resolveConfig.Servers) == 0 {
return errors.New("no dns servers found in config file")
}
if len(resolveConfig.Port) == 0 {
resolveConfig.Port = "53"
}
r, _, err := sharedRecursionDNSClient.Exchange(req, configutils.QuoteIP(resolveConfig.Servers[rands.Int(0, len(resolveConfig.Servers)-1)])+":"+resolveConfig.Port)
if err != nil {
return err
}
resp.Answer = r.Answer
} else if len(config.Hosts) > 0 {
var host = config.Hosts[rands.Int(0, len(config.Hosts)-1)]
if host.Port <= 0 {
host.Port = 53
}
r, _, err := sharedRecursionDNSClient.Exchange(req, configutils.QuoteIP(host.Host)+":"+types.String(host.Port))
if err != nil {
return err
}
resp.Answer = r.Answer
}
return nil
}
// 分析查询中的动作
func (this *Server) parseAction(questionName string, remoteAddr *string) (string, error) {
rpcClient, err := rpc.SharedRPC()
if err != nil {
return "", err
}
// TODO 需要防止恶意攻击
var optionIndex = strings.Index(questionName, "-")
if optionIndex > 0 {
optionId := types.Int64(questionName[1:optionIndex])
optionResp, err := rpcClient.NSQuestionOptionRPC.FindNSQuestionOption(rpcClient.Context(), &pb.FindNSQuestionOptionRequest{NsQuestionOptionId: optionId})
if err != nil {
return "", errors.New("query question option failed: " + err.Error())
} else {
var option = optionResp.NsQuestionOption
if option != nil {
switch option.Name {
case "setRemoteAddr":
var m = maps.Map{}
err = json.Unmarshal(option.ValuesJSON, &m)
if err != nil {
return "", errors.New("decode question option failed: " + err.Error())
} else {
var ip = m.GetString("ip")
*remoteAddr = ip
}
}
}
}
questionName = questionName[optionIndex+1:]
}
return questionName, nil
}
// 记录日志
func (this *Server) addLog(networking string, question dns.Question, domainId int64, routeCode string, record *models.NSRecord, isRecursive bool, writer dns.ResponseWriter, remoteAddr string, err error) {
// 访问日志
var accessLogRef = sharedNodeConfig.AccessLogRef
if accessLogRef != nil && accessLogRef.IsOn {
if domainId == 0 && !accessLogRef.LogMissingDomains {
return
}
if accessLogRef.MissingRecordsOnly && record != nil && len(record.Value) > 0 {
return
}
var now = time.Now()
var pbAccessLog = &pb.NSAccessLog{
NsNodeId: sharedNodeConfig.Id,
RemoteAddr: remoteAddr,
NsDomainId: domainId,
QuestionName: question.Name,
QuestionType: dns.Type(question.Qtype).String(),
IsRecursive: isRecursive,
Networking: networking,
ServerAddr: writer.LocalAddr().String(),
Timestamp: now.Unix(),
TimeLocal: now.Format("2/Jan/2006:15:04:05 -0700"),
RequestId: "",
}
if record != nil {
pbAccessLog.NsRecordId = record.Id
if len(routeCode) > 0 {
pbAccessLog.NsRouteCodes = []string{routeCode}
}
pbAccessLog.RecordName = record.Name
pbAccessLog.RecordType = record.Type
pbAccessLog.RecordValue = record.Value
}
if err != nil {
pbAccessLog.Error = err.Error()
}
sharedNSAccessLogQueue.Push(pbAccessLog)
}
}
// 验证TSIG
func (this *Server) checkTSIG(msg *dns.Msg, domainId int64) error {
var tsig = msg.IsTsig()
if tsig == nil {
return errors.New("tsig: tsig required")
}
var keys = sharedKeyManager.FindKeysWithDomain(domainId)
if len(keys) == 0 {
return errors.New("tsig: no keys defined")
}
for _, key := range keys {
if key.Algo != tsig.Algorithm {
continue
}
// 需要重新Pack每次Pack结果只能校验一次
msgData, err := msg.Pack()
if err != nil {
return err
}
if len(msgData) == 0 {
return nil
}
var base64Secret = key.Secret
if key.SecretType == dnsconfigs.NSKeySecretTypeClear {
base64Secret = base64.StdEncoding.EncodeToString([]byte(key.Secret))
}
err = dns.TsigVerify(msgData, base64Secret, "", false)
if err != nil {
continue
} else {
return nil
}
}
return dns.ErrSig
}
// 处理DNS请求
func (this *Server) handleDNSMessage(writer dns.ResponseWriter, req *dns.Msg) {
if len(req.Question) == 0 {
return
}
if sharedDomainManager == nil {
return
}
var networking = ""
if this.config != nil {
networking = this.config.Protocol.String()
}
var resultDomainId int64
var resultRecordIds [][2]int64 // [] { domainId, recordId}
var resp = new(dns.Msg)
resp.RecursionDesired = true
resp.RecursionAvailable = true
resp.SetReply(req)
resp.Answer = []dns.RR{}
var tsigIsChecked = false
var remoteAddr = writer.RemoteAddr().String()
for _, question := range req.Question {
if len(question.Name) == 0 {
continue
}
// PING
if question.Name == PingDomain {
resp.Answer = append(resp.Answer, &dns.A{
Hdr: dns.RR_Header{
Name: question.Name,
Rrtype: dns.TypeA,
Class: question.Qclass,
Ttl: 60,
},
A: net.ParseIP("127.0.0.1"),
})
resp.Rcode = dns.RcodeSuccess
err := writer.WriteMsg(resp)
if err != nil {
return
}
return
}
// 查询选项
if question.Name[0] == '$' {
_, port, _ := net.SplitHostPort(remoteAddr)
questionName, err := this.parseAction(question.Name, &remoteAddr)
if err != nil {
remotelogs.Error("SERVER", "invalid query option '"+question.Name+"'")
continue
}
question.Name = questionName
if len(port) > 0 {
if strings.Contains(remoteAddr, ":") { // IPv6
remoteAddr = "[" + remoteAddr + "]:" + port
} else {
remoteAddr += ":" + port
}
}
}
var fullName = strings.TrimSuffix(question.Name, ".")
var recordName string
var recordType = dns.Type(question.Qtype).String()
var domain *models.NSDomain
domain, recordName = sharedDomainManager.SplitDomain(fullName)
if domain == nil {
// 检查递归DNS
if sharedNodeConfig.RecursionConfig != nil && sharedNodeConfig.RecursionConfig.IsOn {
err := this.lookupRecursionDNS(req, resp)
if err != nil {
this.addLog(networking, question, 0, "", nil, true, writer, remoteAddr, err)
} else {
var recordValue = ""
if len(resp.Answer) > 0 {
pieces := regexp.MustCompile(`\s+`).Split(resp.Answer[0].String(), 6)
if len(pieces) >= 5 {
recordValue = pieces[4]
}
}
this.addLog(networking, question, 0, "", &models.NSRecord{
Id: 0,
Name: recordName,
Type: recordType,
Value: recordValue,
Ttl: 0,
}, true, writer, remoteAddr, nil)
}
err = writer.WriteMsg(resp)
if err != nil {
return
}
return
}
// 是否为NS记录用于验证域名所有权
if question.Qtype == dns.TypeNS {
var hosts = sharedNodeConfig.Hosts
var l = len(hosts)
var record = &models.NSRecord{
Id: 0,
Type: dnsconfigs.RecordTypeNS,
Ttl: 600, // TODO 可以设置
}
if l > 0 {
l = 1 // 目前只返回一个
// 随机
var indexes = []int{}
for i := 0; i < l; i++ {
indexes = append(indexes, i)
}
rand.Shuffle(l, func(i, j int) {
indexes[i], indexes[j] = indexes[j], indexes[i]
})
record.Value = hosts[0] + "."
for _, index := range indexes {
resp.Answer = append(resp.Answer, &dns.NS{
Hdr: record.ToRRHeader(question.Name, dns.TypeNS, question.Qclass),
Ns: hosts[index] + ".",
})
}
this.addLog(networking, question, 0, "", record, false, writer, remoteAddr, nil)
continue
}
this.addLog(networking, question, 0, "", nil, false, writer, remoteAddr, nil)
continue
}
this.addLog(networking, question, 0, "", nil, false, writer, remoteAddr, nil)
continue
}
// 检查TSIG
if domain.TSIG != nil && domain.TSIG.IsOn && !tsigIsChecked {
err := this.checkTSIG(req, domain.Id)
if err != nil {
this.addLog(networking, question, domain.Id, "", nil, false, writer, remoteAddr, err)
continue
}
tsigIsChecked = true
}
resultDomainId = domain.Id
var clientIP = remoteAddr
clientHost, _, err := net.SplitHostPort(clientIP)
if err == nil && len(clientHost) > 0 {
clientIP = clientHost
}
// 解析Agent
if sharedNodeConfig.DetectAgents {
agents.SharedQueue.Push(clientIP)
}
var routeCodes = sharedRouteManager.FindRouteCodes(clientIP, domain.UserId)
var records []*models.NSRecord
var matchedRouteCode string
if question.Qtype == dns.TypeSOA { // SOA
if len(recordName) == 0 { // 只有顶级域名才有SOA记录
records = []*models.NSRecord{
{
Id: 0,
Type: dnsconfigs.RecordTypeSOA,
Ttl: 600, // TODO 可以设置
},
}
}
} else if question.Qtype == dns.TypeNS { // NS
if len(recordName) == 0 { // 只有顶级域名才有NS记录
records = []*models.NSRecord{
{
Id: 0,
Type: dnsconfigs.RecordTypeNS,
Ttl: 600, // TODO 可以设置
},
}
}
} else if question.Qtype != dns.TypeCNAME {
// 是否有直接的设置
records, matchedRouteCode = sharedRecordManager.FindRecords(domain.Id, routeCodes, recordName, recordType, true)
// 检查CNAME
if len(records) == 0 {
records, matchedRouteCode = sharedRecordManager.FindRecords(domain.Id, routeCodes, recordName, dnsconfigs.RecordTypeCNAME, false)
if len(records) > 0 {
question.Qtype = dns.TypeCNAME
}
}
// 再次尝试查找默认设置
if len(records) == 0 {
records, matchedRouteCode = sharedRecordManager.FindRecords(domain.Id, routeCodes, recordName, recordType, false)
}
}
if len(records) == 0 {
records, matchedRouteCode = sharedRecordManager.FindRecords(domain.Id, routeCodes, recordName, recordType, false)
}
// 对 NS.example.com NS|SOA 处理
if (question.Qtype == dns.TypeNS || (question.Qtype == dns.TypeSOA && len(records) == 0)) && lists.ContainsString(sharedNodeConfig.Hosts, fullName) {
var recordDNSType string
switch question.Qtype {
case dns.TypeNS:
recordDNSType = dnsconfigs.RecordTypeNS
case dns.TypeSOA:
recordDNSType = dnsconfigs.RecordTypeSOA
}
this.composeSOAAnswer(question, &models.NSRecord{
Type: recordDNSType,
Ttl: 600,
}, resp)
}
if len(records) > 0 {
var firstRecord = records[0]
for _, record := range records {
resultRecordIds = append(resultRecordIds, [2]int64{record.DomainId, record.Id})
switch record.Type {
case dnsconfigs.RecordTypeA:
var answer = record.ToRRAnswer(question.Name, question.Qclass)
if answer != nil {
resp.Answer = append(resp.Answer, answer)
}
case dnsconfigs.RecordTypeCNAME:
var value = record.Value
if !strings.HasSuffix(value, ".") {
value += "."
}
var lastRecordValue = value
resp.Answer = append(resp.Answer, &dns.CNAME{
Hdr: record.ToRRHeader(question.Name, dns.TypeCNAME, question.Qclass),
Target: value,
})
// 继续查询CNAME
var allCNAMEValues = []string{lastRecordValue}
for {
// 限制最深32层
if len(allCNAMEValues) > 32 {
break
}
cnameDomain, cnameRecordName := sharedDomainManager.SplitDomain(lastRecordValue)
if cnameDomain == nil {
break
}
cnameRecords, _ := sharedRecordManager.FindRecords(cnameDomain.Id, sharedRouteManager.FindRouteCodes(clientIP, cnameDomain.UserId), cnameRecordName, dnsconfigs.RecordTypeCNAME, false)
if len(cnameRecords) == 0 {
break
}
var cnameRecord = cnameRecords[0]
if !lists.ContainsString(allCNAMEValues, cnameRecord.Value) {
resultRecordIds = append(resultRecordIds, [2]int64{cnameRecord.DomainId, cnameRecord.Id}) // 统计
var answer = cnameRecord.ToRRAnswer(lastRecordValue, question.Qclass)
if answer == nil {
break
}
resp.Answer = append(resp.Answer, answer)
lastRecordValue = cnameRecord.Value
allCNAMEValues = append(allCNAMEValues, lastRecordValue)
} else {
break
}
}
// 再次查询原始问题
if len(req.Question) > 0 {
var firstQuestion = req.Question[0]
if firstQuestion.Qtype != dns.TypeCNAME {
finalDomain, finalRecordName := sharedDomainManager.SplitDomain(lastRecordValue)
if finalDomain != nil {
var realRecords, _ = sharedRecordManager.FindRecords(finalDomain.Id, sharedRouteManager.FindRouteCodes(clientIP, finalDomain.UserId), finalRecordName, dns.Type(firstQuestion.Qtype).String(), false)
if len(realRecords) > 0 {
for _, realRecord := range realRecords {
resultRecordIds = append(resultRecordIds, [2]int64{realRecord.DomainId, realRecord.Id}) // 统计
var answer = realRecord.ToRRAnswer(lastRecordValue, question.Qclass)
if answer != nil {
resp.Answer = append(resp.Answer, answer)
}
}
}
}
}
}
case dnsconfigs.RecordTypeAAAA:
var answer = record.ToRRAnswer(question.Name, question.Qclass)
if answer != nil {
resp.Answer = append(resp.Answer, answer)
}
case dnsconfigs.RecordTypeNS:
if record.Id == 0 {
var hosts = sharedNodeConfig.Hosts
var l = len(hosts)
if l > 0 {
// 随机
var indexes = []int{}
for i := 0; i < l; i++ {
indexes = append(indexes, i)
}
rand.Shuffle(l, func(i, j int) {
indexes[i], indexes[j] = indexes[j], indexes[i]
})
record.Value = hosts[0] + "."
for _, index := range indexes {
resp.Answer = append(resp.Answer, &dns.NS{
Hdr: record.ToRRHeader(question.Name, dns.TypeNS, question.Qclass),
Ns: hosts[index] + ".",
})
}
}
} else {
var value = record.Value
if !strings.HasSuffix(value, ".") {
value += "."
}
resp.Answer = append(resp.Answer, &dns.NS{
Hdr: record.ToRRHeader(question.Name, dns.TypeNS, question.Qclass),
Ns: value,
})
}
case dnsconfigs.RecordTypeMX:
var answer = record.ToRRAnswer(question.Name, question.Qclass)
if answer != nil {
resp.Answer = append(resp.Answer, answer)
}
case dnsconfigs.RecordTypeSRV:
var answer = record.ToRRAnswer(question.Name, question.Qclass)
if answer != nil {
resp.Answer = append(resp.Answer, answer)
}
case dnsconfigs.RecordTypeTXT:
var answer = record.ToRRAnswer(question.Name, question.Qclass)
if answer != nil {
resp.Answer = append(resp.Answer, answer)
}
case dnsconfigs.RecordTypeCAA:
var answer = record.ToRRAnswer(question.Name, question.Qclass)
if answer != nil {
resp.Answer = append(resp.Answer, answer)
}
case dnsconfigs.RecordTypeSOA:
this.composeSOAAnswer(question, record, resp)
}
}
// 访问日志
this.addLog(networking, question, resultDomainId, matchedRouteCode, firstRecord, false, writer, remoteAddr, nil)
} else {
this.addLog(networking, question, resultDomainId, "", nil, false, writer, remoteAddr, nil)
}
}
resp.Rcode = dns.RcodeSuccess
err := writer.WriteMsg(resp)
if err != nil {
return
}
// 统计
for _, resultRecordId := range resultRecordIds {
stats.SharedManager.Add(resultRecordId[0], resultRecordId[1], int64(resp.Len()))
}
}
// 处理HTTP请求
// 参考https://datatracker.ietf.org/doc/html/rfc8484
// 参考https://developers.google.com/speed/public-dns/docs/doh
func (this *Server) handleHTTP(writer http.ResponseWriter, req *http.Request) {
if req.URL.Path == "/dns-query" {
this.handleHTTPDNSMessage(writer, req)
return
}
if req.URL.Path == "/resolve" {
this.handleHTTPJSONAPI(writer, req)
return
}
writer.WriteHeader(http.StatusNotFound)
}
func (this *Server) handleHTTPDNSMessage(writer http.ResponseWriter, req *http.Request) {
const maxMessageSize = 512
writer.Header().Set("Accept", "application/dns-message")
if req.Method != http.MethodGet && req.Method != http.MethodPost {
writer.WriteHeader(http.StatusNotImplemented)
return
}
if req.ContentLength > maxMessageSize {
writer.WriteHeader(http.StatusRequestEntityTooLarge)
return
}
var messageData []byte
switch req.Method {
case http.MethodGet:
if len(req.URL.RawQuery) > maxMessageSize {
writer.WriteHeader(http.StatusRequestURITooLong)
return
}
var encodedMessage = req.URL.Query().Get("dns")
var err error
messageData, err = base64.StdEncoding.DecodeString(encodedMessage)
if err != nil {
writer.WriteHeader(http.StatusBadRequest)
return
}
case http.MethodPost:
var contentType = req.Header.Get("Content-Type")
if contentType != "application/dns-message" {
writer.WriteHeader(http.StatusUnsupportedMediaType)
return
}
data, err := io.ReadAll(io.LimitReader(req.Body, maxMessageSize))
if err != nil {
writer.WriteHeader(http.StatusBadRequest)
return
}
messageData = data
}
if len(messageData) == 0 {
writer.WriteHeader(http.StatusBadRequest)
return
}
var msg = &dns.Msg{}
err := msg.Unpack(messageData)
if err != nil {
writer.WriteHeader(http.StatusBadRequest)
return
}
var connValue = req.Context().Value(HTTPConnContextKey)
if connValue == nil {
writer.WriteHeader(http.StatusInternalServerError)
return
}
conn, ok := connValue.(net.Conn)
if !ok {
writer.WriteHeader(http.StatusInternalServerError)
return
}
this.handleDNSMessage(NewHTTPWriter(writer, conn, "application/dns-message"), msg)
}
func (this *Server) handleHTTPJSONAPI(writer http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodGet {
writer.WriteHeader(http.StatusNotImplemented)
return
}
var query = req.URL.Query()
var name = strings.TrimSpace(query.Get("name"))
var recordTypeString = strings.ToUpper(strings.TrimSpace(query.Get("type")))
if len(name) == 0 {
writer.WriteHeader(http.StatusBadRequest)
_, _ = writer.Write([]byte("invalid 'name' parameter"))
return
}
// add '.' to name
if !strings.HasSuffix(name, ".") {
name += "."
}
if len(recordTypeString) == 0 {
writer.WriteHeader(http.StatusBadRequest)
_, _ = writer.Write([]byte("invalid 'type' parameter"))
return
}
var recordType uint16
if regexp.MustCompile(`^\d{1,4}$`).MatchString(recordTypeString) {
recordType = types.Uint16(recordTypeString)
} else {
recordType = dns.StringToType[recordTypeString]
}
if recordType == 0 {
writer.WriteHeader(http.StatusBadRequest)
_, _ = writer.Write([]byte("invalid 'type' parameter"))
return
}
_, ok := dns.TypeToString[recordType]
if !ok {
writer.WriteHeader(http.StatusBadRequest)
_, _ = writer.Write([]byte("invalid 'type' parameter"))
return
}
var msg = &dns.Msg{}
msg.Question = []dns.Question{
{
Name: name,
Qtype: recordType,
Qclass: dns.ClassINET,
},
}
var connValue = req.Context().Value(HTTPConnContextKey)
if connValue == nil {
writer.WriteHeader(http.StatusInternalServerError)
return
}
// conn
conn, ok := connValue.(net.Conn)
if !ok {
writer.WriteHeader(http.StatusInternalServerError)
return
}
this.handleDNSMessage(NewHTTPWriter(writer, conn, "application/x-javascript"), msg)
}
// 组合SOA回复信息
func (this *Server) composeSOAAnswer(question dns.Question, record *models.NSRecord, resp *dns.Msg) {
var config = sharedNodeConfig.SOA
var serial = sharedNodeConfig.SOASerial
if config == nil {
config = dnsconfigs.DefaultNSSOAConfig()
}
var mName = config.MName
if len(mName) == 0 {
var hosts = sharedNodeConfig.Hosts
var l = len(hosts)
if l > 0 {
var index = rands.Int(0, l-1)
mName = hosts[index]
}
}
var rName = config.RName
if len(rName) == 0 {
rName = sharedNodeConfig.Email
}
rName = strings.ReplaceAll(rName, "@", ".")
if len(mName) > 0 && len(rName) > 0 {
// 设置记录值
record.Value = mName + "."
resp.Answer = append(resp.Answer, &dns.SOA{
Hdr: record.ToRRHeader(question.Name, dns.TypeSOA, question.Qclass),
Ns: mName + ".",
Mbox: rName + ".",
Serial: serial,
Refresh: config.RefreshSeconds,
Retry: config.RetrySeconds,
Expire: config.ExpireSeconds,
Minttl: config.MinimumTTL,
})
}
}

View File

@@ -0,0 +1,22 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs"
"github.com/iwind/TeaGo/types"
)
// ServerConfig 服务配置
type ServerConfig struct {
Protocol serverconfigs.Protocol
Host string
Port int
SSLPolicy *sslconfigs.SSLPolicy
}
// FullAddr 服务地址
func (this *ServerConfig) FullAddr() string {
return this.Protocol.String() + "://" + this.Host + ":" + types.String(this.Port)
}

View File

@@ -0,0 +1,114 @@
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeDNS/internal/configs"
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
"github.com/TeaOSLab/EdgeDNS/internal/events"
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
"github.com/TeaOSLab/EdgeDNS/internal/utils"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/logs"
"time"
)
func init() {
if !teaconst.IsMain {
return
}
events.On(events.EventStart, func() {
task := NewSyncAPINodesTask()
go task.Start()
})
}
// SyncAPINodesTask API节点同步任务
type SyncAPINodesTask struct {
}
func NewSyncAPINodesTask() *SyncAPINodesTask {
return &SyncAPINodesTask{}
}
func (this *SyncAPINodesTask) Start() {
ticker := time.NewTicker(5 * time.Minute)
if Tea.IsTesting() {
// 快速测试
ticker = time.NewTicker(1 * time.Minute)
}
events.On(events.EventQuit, func() {
remotelogs.Println("SYNC_API_NODES_TASK", "quit task")
ticker.Stop()
})
for range ticker.C {
err := this.Loop()
if err != nil {
logs.Println("[TASK][SYNC_API_NODES_TASK]" + err.Error())
}
}
}
func (this *SyncAPINodesTask) Loop() error {
// 如果有节点定制的API节点地址
var hasCustomizedAPINodeAddrs = sharedNodeConfig != nil && len(sharedNodeConfig.APINodeAddrs) > 0
config, err := configs.LoadAPIConfig()
if err != nil {
return err
}
// 是否禁止自动升级
if config.RPCDisableUpdate {
return nil
}
// 获取所有可用的节点
rpcClient, err := rpc.SharedRPC()
if err != nil {
return err
}
resp, err := rpcClient.APINodeRPC.FindAllEnabledAPINodes(rpcClient.Context(), &pb.FindAllEnabledAPINodesRequest{})
if err != nil {
return err
}
var newEndpoints = []string{}
for _, node := range resp.ApiNodes {
if !node.IsOn {
continue
}
newEndpoints = append(newEndpoints, node.AccessAddrs...)
}
// 和现有的对比
if utils.EqualStrings(newEndpoints, config.RPCEndpoints) {
return nil
}
// 测试是否有API节点可用
var hasOk = rpcClient.TestEndpoints(newEndpoints)
if !hasOk {
return nil
}
// 修改RPC对象配置
config.RPCEndpoints = newEndpoints
// 更新当前RPC
if !hasCustomizedAPINodeAddrs {
err = rpcClient.UpdateConfig(config)
if err != nil {
return err
}
}
// 保存到文件
err = config.WriteFile(Tea.ConfigFile(configs.ConfigFileName))
if err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,300 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"crypto/md5"
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
"github.com/TeaOSLab/EdgeDNS/internal/events"
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
"github.com/TeaOSLab/EdgeDNS/internal/utils"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/logs"
stringutil "github.com/iwind/TeaGo/utils/string"
"github.com/iwind/gosock/pkg/gosock"
"os"
"os/exec"
"path/filepath"
"runtime"
"time"
)
func init() {
if !teaconst.IsMain {
return
}
events.On(events.EventStart, func() {
go func() {
rpcClient, err := rpc.SharedRPC()
if err != nil {
remotelogs.Error("UPGRADE_MANAGER", err.Error())
return
}
var ticker = time.NewTicker(1 * time.Minute)
for range ticker.C {
resp, err := rpcClient.NSNodeRPC.CheckNSNodeLatestVersion(rpcClient.Context(), &pb.CheckNSNodeLatestVersionRequest{
Os: runtime.GOOS,
Arch: runtime.GOARCH,
CurrentVersion: teaconst.Version,
})
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("UPGRADE_MANAGER", err.Error())
} else {
remotelogs.Error("UPGRADE_MANAGER", err.Error())
}
continue
}
if resp.HasNewVersion {
sharedUpgradeManager.Start()
}
}
}()
})
}
var sharedUpgradeManager = NewUpgradeManager()
// UpgradeManager 节点升级管理器
// TODO 需要在集群中设置是否自动更新
type UpgradeManager struct {
isInstalling bool
lastFile string
exe string
}
// NewUpgradeManager 获取新对象
func NewUpgradeManager() *UpgradeManager {
return &UpgradeManager{}
}
// Start 启动升级
func (this *UpgradeManager) Start() {
// 必须放在文件解压之前读取可执行文件路径,防止解析之后,当前的可执行文件路径发生改变
exe, err := os.Executable()
if err != nil {
remotelogs.Error("UPGRADE_MANAGER", "can not find current executable file name")
return
}
this.exe = exe
// 测试环境下不更新
if Tea.IsTesting() {
return
}
if this.isInstalling {
return
}
this.isInstalling = true
// 还原安装状态
defer func() {
this.isInstalling = false
}()
remotelogs.Println("UPGRADE_MANAGER", "upgrading dns node ...")
err = this.install()
if err != nil {
remotelogs.Error("UPGRADE_MANAGER", "download failed: "+err.Error())
return
}
remotelogs.Println("UPGRADE_MANAGER", "upgrade successfully")
go func() {
err = this.restart()
if err != nil {
logs.Println("UPGRADE_MANAGER", err.Error())
}
}()
}
func (this *UpgradeManager) install() error {
// 检查是否有已下载但未安装成功的
if len(this.lastFile) > 0 {
_, err := os.Stat(this.lastFile)
if err == nil {
err = this.unzip(this.lastFile)
if err != nil {
return err
}
this.lastFile = ""
return nil
}
}
// 创建临时文件
dir := Tea.Root + "/tmp"
_, err := os.Stat(dir)
if err != nil {
if os.IsNotExist(err) {
err = os.Mkdir(dir, 0777)
if err != nil {
return err
}
} else {
return err
}
}
remotelogs.Println("UPGRADE_MANAGER", "downloading new node ...")
path := dir + "/" + teaconst.ProcessName + ".tmp"
fp, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0777)
if err != nil {
return err
}
isClosed := false
defer func() {
if !isClosed {
_ = fp.Close()
}
}()
client, err := rpc.SharedRPC()
if err != nil {
return err
}
var offset int64
var h = md5.New()
var sum = ""
var filename = ""
for {
resp, err := client.NSNodeRPC.DownloadNSNodeInstallationFile(client.Context(), &pb.DownloadNSNodeInstallationFileRequest{
Os: runtime.GOOS,
Arch: runtime.GOARCH,
ChunkOffset: offset,
})
if err != nil {
return err
}
if len(resp.Sum) == 0 {
return nil
}
sum = resp.Sum
filename = resp.Filename
if stringutil.VersionCompare(resp.Version, teaconst.Version) <= 0 {
return nil
}
if len(resp.ChunkData) == 0 {
break
}
// 写入文件
_, err = fp.Write(resp.ChunkData)
if err != nil {
return err
}
_, err = h.Write(resp.ChunkData)
if err != nil {
return err
}
offset = resp.Offset
}
if len(filename) == 0 {
return nil
}
isClosed = true
err = fp.Close()
if err != nil {
return err
}
if fmt.Sprintf("%x", h.Sum(nil)) != sum {
_ = os.Remove(path)
return nil
}
// 改成zip
zipPath := dir + "/" + filename
err = os.Rename(path, zipPath)
if err != nil {
return err
}
this.lastFile = zipPath
// 解压
err = this.unzip(zipPath)
if err != nil {
return err
}
return nil
}
// 解压
func (this *UpgradeManager) unzip(zipPath string) error {
var isOk = false
defer func() {
if isOk {
// 只有解压并覆盖成功后才会删除
_ = os.Remove(zipPath)
}
}()
// 解压
var target = Tea.Root
if Tea.IsTesting() {
// 测试环境下只解压在tmp目录
target = Tea.Root + "/tmp"
}
// 先改先前的可执行文件
err := os.Rename(target+"/bin/"+teaconst.ProcessName, target+"/bin/."+teaconst.ProcessName+".dist")
hasBackup := err == nil
defer func() {
if !isOk && hasBackup {
// 失败时还原
_ = os.Rename(target+"/bin/."+teaconst.ProcessName+".dist", target+"/bin/"+teaconst.ProcessName)
}
}()
unzip := utils.NewUnzip(zipPath, target, teaconst.ProcessName+"/")
err = unzip.Run()
if err != nil {
return err
}
isOk = true
return nil
}
// 重启
func (this *UpgradeManager) restart() error {
// 关闭当前sock防止无法重启
_ = gosock.NewTmpSock(teaconst.ProcessName).Close()
// 重新启动
if DaemonIsOn && DaemonPid == os.Getppid() {
os.Exit(0) // TODO 试着更优雅重启
} else {
// quit
events.Notify(events.EventQuit)
// 启动
var exe = filepath.Dir(this.exe) + "/" + teaconst.ProcessName
logs.Println("restarting ...", exe)
cmd := exec.Command(exe, "start")
err := cmd.Start()
if err != nil {
return err
}
// 退出当前进程
time.Sleep(1 * time.Second)
os.Exit(0)
}
return nil
}

View File

@@ -0,0 +1,16 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
_ "github.com/iwind/TeaGo/bootstrap"
"testing"
)
func TestUpgradeManager_install(t *testing.T) {
err := NewUpgradeManager().install()
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}