1.4.5.2
This commit is contained in:
256
EdgeDNS/internal/nodes/api_stream.go
Normal file
256
EdgeDNS/internal/nodes/api_stream.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/messageconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/events"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/firewalls"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/utils"
|
||||
executils "github.com/TeaOSLab/EdgeDNS/internal/utils/exec"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
type APIStream struct {
|
||||
stream pb.NSNodeService_NsNodeStreamClient
|
||||
|
||||
isQuiting bool
|
||||
cancelFunc context.CancelFunc
|
||||
}
|
||||
|
||||
func NewAPIStream() *APIStream {
|
||||
return &APIStream{}
|
||||
}
|
||||
|
||||
func (this *APIStream) Start() {
|
||||
events.On(events.EventQuit, func() {
|
||||
this.isQuiting = true
|
||||
if this.cancelFunc != nil {
|
||||
this.cancelFunc()
|
||||
}
|
||||
})
|
||||
for {
|
||||
if this.isQuiting {
|
||||
return
|
||||
}
|
||||
err := this.loop()
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("API_STREAM", err.Error())
|
||||
} else {
|
||||
remotelogs.Error("API_STREAM", err.Error())
|
||||
}
|
||||
time.Sleep(10 * time.Second)
|
||||
continue
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func (this *APIStream) loop() error {
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return errors.Wrap(err)
|
||||
}
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(rpcClient.Context())
|
||||
this.cancelFunc = cancelFunc
|
||||
|
||||
defer func() {
|
||||
cancelFunc()
|
||||
}()
|
||||
|
||||
nodeStream, err := rpcClient.NSNodeRPC.NsNodeStream(ctx)
|
||||
if err != nil {
|
||||
if this.isQuiting {
|
||||
return nil
|
||||
}
|
||||
return errors.Wrap(err)
|
||||
}
|
||||
|
||||
this.stream = nodeStream
|
||||
|
||||
for {
|
||||
if this.isQuiting {
|
||||
logs.Println("API_STREAM", "quit")
|
||||
break
|
||||
}
|
||||
|
||||
message, err := nodeStream.Recv()
|
||||
if err != nil {
|
||||
if this.isQuiting {
|
||||
remotelogs.Println("API_STREAM", "quit")
|
||||
return nil
|
||||
}
|
||||
return errors.Wrap(err)
|
||||
}
|
||||
|
||||
// 处理消息
|
||||
switch message.Code {
|
||||
case messageconfigs.NSMessageCodeConnectedAPINode: // 连接API节点成功
|
||||
err = this.handleConnectedAPINode(message)
|
||||
case messageconfigs.NSMessageCodeNewNodeTask: // 有新的任务
|
||||
err = this.handleNewNodeTask(message)
|
||||
case messageconfigs.NSMessageCodeCheckSystemdService: // 检查Systemd服务
|
||||
err = this.handleCheckSystemdService(message)
|
||||
case messageconfigs.MessageCodeCheckLocalFirewall: // 检查本地防火墙
|
||||
err = this.handleCheckLocalFirewall(message)
|
||||
default:
|
||||
err = this.handleUnknownMessage(message)
|
||||
}
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("API_STREAM", "handle message failed: "+err.Error())
|
||||
} else {
|
||||
remotelogs.Error("API_STREAM", "handle message failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 连接API节点成功
|
||||
func (this *APIStream) handleConnectedAPINode(message *pb.NSNodeStreamMessage) error {
|
||||
// 更改连接的APINode信息
|
||||
if len(message.DataJSON) == 0 {
|
||||
return nil
|
||||
}
|
||||
msg := &messageconfigs.ConnectedAPINodeMessage{}
|
||||
err := json.Unmarshal(message.DataJSON, msg)
|
||||
if err != nil {
|
||||
return errors.Wrap(err)
|
||||
}
|
||||
|
||||
_, err = rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return errors.Wrap(err)
|
||||
}
|
||||
|
||||
remotelogs.Println("API_STREAM", "connected to api node '"+strconv.FormatInt(msg.APINodeId, 10)+"'")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 处理配置变化
|
||||
func (this *APIStream) handleNewNodeTask(message *pb.NSNodeStreamMessage) error {
|
||||
select {
|
||||
case nodeTaskNotify <- true:
|
||||
default:
|
||||
|
||||
}
|
||||
this.replyOk(message.RequestId, "ok")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查Systemd服务
|
||||
func (this *APIStream) handleCheckSystemdService(message *pb.NSNodeStreamMessage) error {
|
||||
systemctl, err := executils.LookPath("systemctl")
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "'systemctl' not found")
|
||||
return nil
|
||||
}
|
||||
if len(systemctl) == 0 {
|
||||
this.replyFail(message.RequestId, "'systemctl' not found")
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd := utils.NewCommandExecutor()
|
||||
shortName := teaconst.SystemdServiceName
|
||||
cmd.Add(systemctl, "is-enabled", shortName)
|
||||
output, err := cmd.Run()
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "'systemctl' command error: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
if output == "enabled" {
|
||||
this.replyOk(message.RequestId, "ok")
|
||||
} else {
|
||||
this.replyFail(message.RequestId, "not installed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查本地防火墙
|
||||
func (this *APIStream) handleCheckLocalFirewall(message *pb.NSNodeStreamMessage) error {
|
||||
var dataMessage = &messageconfigs.CheckLocalFirewallMessage{}
|
||||
err := json.Unmarshal(message.DataJSON, dataMessage)
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "decode message data failed: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
// nft
|
||||
if dataMessage.Name == "nftables" {
|
||||
if runtime.GOOS != "linux" {
|
||||
this.replyFail(message.RequestId, "not Linux system")
|
||||
return nil
|
||||
}
|
||||
|
||||
nft, err := executils.LookPath("nft")
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "'nft' not found: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
var cmd = exec.Command(nft, "--version")
|
||||
var output = &bytes.Buffer{}
|
||||
cmd.Stdout = output
|
||||
err = cmd.Run()
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "get version failed: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
var outputString = output.String()
|
||||
var versionMatches = regexp.MustCompile(`nftables v([\d.]+)`).FindStringSubmatch(outputString)
|
||||
if len(versionMatches) <= 1 {
|
||||
this.replyFail(message.RequestId, "can not get nft version")
|
||||
return nil
|
||||
}
|
||||
var version = versionMatches[1]
|
||||
|
||||
var result = maps.Map{
|
||||
"version": version,
|
||||
}
|
||||
|
||||
var protectionConfig = sharedNodeConfig.DDoSProtection
|
||||
err = firewalls.SharedDDoSProtectionManager.Apply(protectionConfig)
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, dataMessage.Name+"was installed, but apply DDoS protection config failed: "+err.Error())
|
||||
} else {
|
||||
this.replyOk(message.RequestId, string(result.AsJSON()))
|
||||
}
|
||||
} else {
|
||||
this.replyFail(message.RequestId, "invalid firewall name '"+dataMessage.Name+"'")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 处理未知消息
|
||||
func (this *APIStream) handleUnknownMessage(message *pb.NSNodeStreamMessage) error {
|
||||
this.replyFail(message.RequestId, "unknown message code '"+message.Code+"'")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 回复失败
|
||||
func (this *APIStream) replyFail(requestId int64, message string) {
|
||||
_ = this.stream.Send(&pb.NSNodeStreamMessage{RequestId: requestId, IsOk: false, Message: message})
|
||||
}
|
||||
|
||||
// 回复成功
|
||||
func (this *APIStream) replyOk(requestId int64, message string) {
|
||||
_ = this.stream.Send(&pb.NSNodeStreamMessage{RequestId: requestId, IsOk: true, Message: message})
|
||||
}
|
||||
Reference in New Issue
Block a user