1.4.5.2
This commit is contained in:
529
EdgeNode/internal/nodes/api_stream.go
Normal file
529
EdgeNode/internal/nodes/api_stream.go
Normal file
@@ -0,0 +1,529 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/messageconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/caches"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/configs"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/errors"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/goman"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
type APIStream struct {
|
||||
stream pb.NodeService_NodeStreamClient
|
||||
|
||||
isQuiting bool
|
||||
cancelFunc context.CancelFunc
|
||||
}
|
||||
|
||||
func NewAPIStream() *APIStream {
|
||||
return &APIStream{}
|
||||
}
|
||||
|
||||
func (this *APIStream) Start() {
|
||||
events.OnKey(events.EventQuit, this, func() {
|
||||
this.isQuiting = true
|
||||
if this.cancelFunc != nil {
|
||||
this.cancelFunc()
|
||||
}
|
||||
})
|
||||
for {
|
||||
if this.isQuiting {
|
||||
return
|
||||
}
|
||||
err := this.loop()
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("API_STREAM", err.Error())
|
||||
} else {
|
||||
remotelogs.Warn("API_STREAM", err.Error())
|
||||
}
|
||||
time.Sleep(10 * time.Second)
|
||||
continue
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func (this *APIStream) loop() error {
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return errors.Wrap(err)
|
||||
}
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(rpcClient.Context())
|
||||
this.cancelFunc = cancelFunc
|
||||
|
||||
defer func() {
|
||||
cancelFunc()
|
||||
}()
|
||||
|
||||
nodeStream, err := rpcClient.NodeRPC.NodeStream(ctx)
|
||||
if err != nil {
|
||||
if this.isQuiting {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
this.stream = nodeStream
|
||||
|
||||
for {
|
||||
if this.isQuiting {
|
||||
remotelogs.Println("API_STREAM", "quit")
|
||||
break
|
||||
}
|
||||
|
||||
message, streamErr := nodeStream.Recv()
|
||||
if streamErr != nil {
|
||||
if this.isQuiting {
|
||||
remotelogs.Println("API_STREAM", "quit")
|
||||
return nil
|
||||
}
|
||||
return streamErr
|
||||
}
|
||||
|
||||
// 处理消息
|
||||
switch message.Code {
|
||||
case messageconfigs.MessageCodeConnectedAPINode: // 连接API节点成功
|
||||
err = this.handleConnectedAPINode(message)
|
||||
case messageconfigs.MessageCodeWriteCache: // 写入缓存
|
||||
err = this.handleWriteCache(message)
|
||||
case messageconfigs.MessageCodeReadCache: // 读取缓存
|
||||
err = this.handleReadCache(message)
|
||||
case messageconfigs.MessageCodeStatCache: // 统计缓存
|
||||
err = this.handleStatCache(message)
|
||||
case messageconfigs.MessageCodeCleanCache: // 清理缓存
|
||||
err = this.handleCleanCache(message)
|
||||
case messageconfigs.MessageCodeNewNodeTask: // 有新的任务
|
||||
err = this.handleNewNodeTask(message)
|
||||
case messageconfigs.MessageCodeCheckSystemdService: // 检查Systemd服务
|
||||
err = this.handleCheckSystemdService(message)
|
||||
case messageconfigs.MessageCodeCheckLocalFirewall: // 检查本地防火墙
|
||||
err = this.handleCheckLocalFirewall(message)
|
||||
case messageconfigs.MessageCodeChangeAPINode: // 修改API节点地址
|
||||
err = this.handleChangeAPINode(message)
|
||||
default:
|
||||
err = this.handleUnknownMessage(message)
|
||||
}
|
||||
if err != nil {
|
||||
remotelogs.Error("API_STREAM", "handle message failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 连接API节点成功
|
||||
func (this *APIStream) handleConnectedAPINode(message *pb.NodeStreamMessage) error {
|
||||
// 更改连接的APINode信息
|
||||
if len(message.DataJSON) == 0 {
|
||||
return nil
|
||||
}
|
||||
msg := &messageconfigs.ConnectedAPINodeMessage{}
|
||||
err := json.Unmarshal(message.DataJSON, msg)
|
||||
if err != nil {
|
||||
return errors.Wrap(err)
|
||||
}
|
||||
|
||||
_, err = rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return errors.Wrap(err)
|
||||
}
|
||||
|
||||
remotelogs.Println("API_STREAM", "connected to api node '"+strconv.FormatInt(msg.APINodeId, 10)+"'")
|
||||
|
||||
// 重新读取配置
|
||||
if nodeConfigUpdatedAt == 0 {
|
||||
select {
|
||||
case nodeConfigChangedNotify <- true:
|
||||
default:
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 写入缓存
|
||||
func (this *APIStream) handleWriteCache(message *pb.NodeStreamMessage) error {
|
||||
msg := &messageconfigs.WriteCacheMessage{}
|
||||
err := json.Unmarshal(message.DataJSON, msg)
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "decode message data failed: "+err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
storage, shouldStop, err := this.cacheStorage(message, msg.CachePolicyJSON)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if shouldStop {
|
||||
defer func() {
|
||||
storage.Stop()
|
||||
}()
|
||||
}
|
||||
|
||||
expiredAt := time.Now().Unix() + msg.LifeSeconds
|
||||
writer, err := storage.OpenWriter(msg.Key, expiredAt, 200, -1, int64(len(msg.Value)), -1, false)
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "prepare writing failed: "+err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
// 写入一个空的Header
|
||||
_, err = writer.WriteHeader([]byte(":"))
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "write failed: "+err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
// 写入数据
|
||||
_, err = writer.Write(msg.Value)
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "write failed: "+err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
err = writer.Close()
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "write failed: "+err.Error())
|
||||
return err
|
||||
}
|
||||
storage.AddToList(&caches.Item{
|
||||
Type: writer.ItemType(),
|
||||
Key: msg.Key,
|
||||
ExpiresAt: expiredAt,
|
||||
HeaderSize: writer.HeaderSize(),
|
||||
BodySize: writer.BodySize(),
|
||||
})
|
||||
|
||||
this.replyOk(message.RequestId, "write ok")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 读取缓存
|
||||
func (this *APIStream) handleReadCache(message *pb.NodeStreamMessage) error {
|
||||
msg := &messageconfigs.ReadCacheMessage{}
|
||||
err := json.Unmarshal(message.DataJSON, msg)
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "decode message data failed: "+err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
storage, shouldStop, err := this.cacheStorage(message, msg.CachePolicyJSON)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if shouldStop {
|
||||
defer func() {
|
||||
storage.Stop()
|
||||
}()
|
||||
}
|
||||
|
||||
reader, err := storage.OpenReader(msg.Key, false, false)
|
||||
if err != nil {
|
||||
if err == caches.ErrNotFound {
|
||||
this.replyFail(message.RequestId, "key not found")
|
||||
return nil
|
||||
}
|
||||
this.replyFail(message.RequestId, "read key failed: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
defer func() {
|
||||
_ = reader.Close()
|
||||
}()
|
||||
|
||||
this.replyOk(message.RequestId, "value "+strconv.FormatInt(reader.BodySize(), 10)+" bytes")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 统计缓存
|
||||
func (this *APIStream) handleStatCache(message *pb.NodeStreamMessage) error {
|
||||
msg := &messageconfigs.ReadCacheMessage{}
|
||||
err := json.Unmarshal(message.DataJSON, msg)
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "decode message data failed: "+err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
storage, shouldStop, err := this.cacheStorage(message, msg.CachePolicyJSON)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if shouldStop {
|
||||
defer func() {
|
||||
storage.Stop()
|
||||
}()
|
||||
}
|
||||
|
||||
stat, err := storage.Stat()
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "stat failed: "+err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
sizeFormat := ""
|
||||
if stat.Size < (1 << 10) {
|
||||
sizeFormat = strconv.FormatInt(stat.Size, 10) + " Bytes"
|
||||
} else if stat.Size < (1 << 20) {
|
||||
sizeFormat = fmt.Sprintf("%.2f KiB", float64(stat.Size)/(1<<10))
|
||||
} else if stat.Size < (1 << 30) {
|
||||
sizeFormat = fmt.Sprintf("%.2f MiB", float64(stat.Size)/(1<<20))
|
||||
} else {
|
||||
sizeFormat = fmt.Sprintf("%.2f GiB", float64(stat.Size)/(1<<30))
|
||||
}
|
||||
this.replyOk(message.RequestId, "size:"+sizeFormat+", count:"+strconv.Itoa(stat.Count))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 清理缓存
|
||||
func (this *APIStream) handleCleanCache(message *pb.NodeStreamMessage) error {
|
||||
msg := &messageconfigs.ReadCacheMessage{}
|
||||
err := json.Unmarshal(message.DataJSON, msg)
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "decode message data failed: "+err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
storage, shouldStop, err := this.cacheStorage(message, msg.CachePolicyJSON)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if shouldStop {
|
||||
defer func() {
|
||||
storage.Stop()
|
||||
}()
|
||||
}
|
||||
|
||||
err = storage.CleanAll()
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "clean cache failed: "+err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
this.replyOk(message.RequestId, "ok")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 处理配置变化
|
||||
func (this *APIStream) handleNewNodeTask(message *pb.NodeStreamMessage) error {
|
||||
select {
|
||||
case nodeTaskNotify <- true:
|
||||
default:
|
||||
|
||||
}
|
||||
this.replyOk(message.RequestId, "ok")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查Systemd服务
|
||||
func (this *APIStream) handleCheckSystemdService(message *pb.NodeStreamMessage) error {
|
||||
systemctl, err := executils.LookPath("systemctl")
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "'systemctl' not found")
|
||||
return nil
|
||||
}
|
||||
if len(systemctl) == 0 {
|
||||
this.replyFail(message.RequestId, "'systemctl' not found")
|
||||
return nil
|
||||
}
|
||||
|
||||
var shortName = teaconst.SystemdServiceName
|
||||
var cmd = executils.NewTimeoutCmd(10*time.Second, systemctl, "is-enabled", shortName)
|
||||
cmd.WithStdout()
|
||||
err = cmd.Run()
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "'systemctl' command error: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
if cmd.Stdout() == "enabled" {
|
||||
this.replyOk(message.RequestId, "ok")
|
||||
} else {
|
||||
this.replyFail(message.RequestId, "not installed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查本地防火墙
|
||||
func (this *APIStream) handleCheckLocalFirewall(message *pb.NodeStreamMessage) error {
|
||||
var dataMessage = &messageconfigs.CheckLocalFirewallMessage{}
|
||||
err := json.Unmarshal(message.DataJSON, dataMessage)
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "decode message data failed: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
// nft
|
||||
if dataMessage.Name == "nftables" {
|
||||
if runtime.GOOS != "linux" {
|
||||
this.replyFail(message.RequestId, "not Linux system")
|
||||
return nil
|
||||
}
|
||||
|
||||
nft, err := executils.LookPath("nft")
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "'nft' not found: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
var cmd = executils.NewTimeoutCmd(10*time.Second, nft, "--version")
|
||||
cmd.WithStdout()
|
||||
err = cmd.Run()
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "get version failed: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
var outputString = cmd.Stdout()
|
||||
var versionMatches = regexp.MustCompile(`nftables v([\d.]+)`).FindStringSubmatch(outputString)
|
||||
if len(versionMatches) <= 1 {
|
||||
this.replyFail(message.RequestId, "can not get nft version")
|
||||
return nil
|
||||
}
|
||||
var version = versionMatches[1]
|
||||
|
||||
var result = maps.Map{
|
||||
"version": version,
|
||||
}
|
||||
|
||||
var protectionConfig = sharedNodeConfig.DDoSProtection
|
||||
err = firewalls.SharedDDoSProtectionManager.Apply(protectionConfig)
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, dataMessage.Name+" was installed, but apply DDoS protection config failed: "+err.Error())
|
||||
} else {
|
||||
this.replyOk(message.RequestId, string(result.AsJSON()))
|
||||
}
|
||||
} else {
|
||||
this.replyFail(message.RequestId, "invalid firewall name '"+dataMessage.Name+"'")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 修改API地址
|
||||
func (this *APIStream) handleChangeAPINode(message *pb.NodeStreamMessage) error {
|
||||
config, err := configs.LoadAPIConfig()
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "read config error: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
var messageData = &messageconfigs.ChangeAPINodeMessage{}
|
||||
err = json.Unmarshal(message.DataJSON, messageData)
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "unmarshal message failed: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err = url.Parse(messageData.Addr)
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "invalid new api node address: '"+messageData.Addr+"'")
|
||||
return nil
|
||||
}
|
||||
|
||||
config.RPCEndpoints = []string{messageData.Addr}
|
||||
|
||||
// 保存到文件
|
||||
err = config.WriteFile(Tea.ConfigFile(configs.ConfigFileName))
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "save config file failed: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
this.replyOk(message.RequestId, "")
|
||||
|
||||
goman.New(func() {
|
||||
// 延后生效,防止变更前的API无法读取到状态
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
remotelogs.Error("API_STREAM", "change rpc endpoint to '"+
|
||||
messageData.Addr+"' failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
rpcClient.Close()
|
||||
|
||||
err = rpcClient.UpdateConfig(config)
|
||||
if err != nil {
|
||||
remotelogs.Error("API_STREAM", "change rpc endpoint to '"+
|
||||
messageData.Addr+"' failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
remotelogs.Println("API_STREAM", "change rpc endpoint to '"+
|
||||
messageData.Addr+"' successfully")
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 处理未知消息
|
||||
func (this *APIStream) handleUnknownMessage(message *pb.NodeStreamMessage) error {
|
||||
this.replyFail(message.RequestId, "unknown message code '"+message.Code+"'")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 回复失败
|
||||
func (this *APIStream) replyFail(requestId int64, message string) {
|
||||
_ = this.stream.Send(&pb.NodeStreamMessage{RequestId: requestId, IsOk: false, Message: message})
|
||||
}
|
||||
|
||||
// 回复成功
|
||||
func (this *APIStream) replyOk(requestId int64, message string) {
|
||||
_ = this.stream.Send(&pb.NodeStreamMessage{RequestId: requestId, IsOk: true, Message: message})
|
||||
}
|
||||
|
||||
// 回复成功并包含数据
|
||||
func (this *APIStream) replyOkData(requestId int64, message string, dataJSON []byte) {
|
||||
_ = this.stream.Send(&pb.NodeStreamMessage{RequestId: requestId, IsOk: true, Message: message, DataJSON: dataJSON})
|
||||
}
|
||||
|
||||
// 获取缓存存取对象
|
||||
func (this *APIStream) cacheStorage(message *pb.NodeStreamMessage, cachePolicyJSON []byte) (storage caches.StorageInterface, shouldStop bool, err error) {
|
||||
cachePolicy := &serverconfigs.HTTPCachePolicy{}
|
||||
err = json.Unmarshal(cachePolicyJSON, cachePolicy)
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "decode cache policy config failed: "+err.Error())
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
storage = caches.SharedManager.FindStorageWithPolicy(cachePolicy.Id)
|
||||
if storage == nil {
|
||||
storage = caches.SharedManager.NewStorageWithPolicy(cachePolicy)
|
||||
if storage == nil {
|
||||
this.replyFail(message.RequestId, "invalid storage type '"+cachePolicy.Type+"'")
|
||||
return nil, false, errors.New("invalid storage type '" + cachePolicy.Type + "'")
|
||||
}
|
||||
err = storage.Init()
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "storage init failed: "+err.Error())
|
||||
return nil, false, err
|
||||
}
|
||||
shouldStop = true
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
15
EdgeNode/internal/nodes/api_stream_test.go
Normal file
15
EdgeNode/internal/nodes/api_stream_test.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAPIStream_Start(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
apiStream := NewAPIStream()
|
||||
apiStream.Start()
|
||||
}
|
||||
326
EdgeNode/internal/nodes/client_conn.go
Normal file
326
EdgeNode/internal/nodes/client_conn.go
Normal file
@@ -0,0 +1,326 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/conns"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/stats"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
connutils "github.com/TeaOSLab/EdgeNode/internal/utils/conns"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/counters"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ClientConn 客户端连接
|
||||
type ClientConn struct {
|
||||
BaseClientConn
|
||||
|
||||
createdAt int64
|
||||
|
||||
isTLS bool
|
||||
isHTTP bool
|
||||
hasRead bool
|
||||
|
||||
isLO bool // 是否为环路
|
||||
isNoStat bool // 是否不统计带宽
|
||||
isInAllowList bool
|
||||
|
||||
hasResetSYNFlood bool
|
||||
|
||||
lastReadAt int64
|
||||
lastWriteAt int64
|
||||
lastErr error
|
||||
|
||||
readDeadlineTime int64
|
||||
isShortReading bool // reading header or tls handshake
|
||||
|
||||
isDebugging bool
|
||||
autoReadTimeout bool
|
||||
autoWriteTimeout bool
|
||||
}
|
||||
|
||||
func NewClientConn(rawConn net.Conn, isHTTP bool, isTLS bool, isInAllowList bool) net.Conn {
|
||||
// 是否为环路
|
||||
var remoteAddr = rawConn.RemoteAddr().String()
|
||||
|
||||
var conn = &ClientConn{
|
||||
BaseClientConn: BaseClientConn{rawConn: rawConn},
|
||||
isTLS: isTLS,
|
||||
isHTTP: isHTTP,
|
||||
isLO: strings.HasPrefix(remoteAddr, "127.0.0.1:") || strings.HasPrefix(remoteAddr, "[::1]:"),
|
||||
isNoStat: connutils.IsNoStatConn(remoteAddr),
|
||||
isInAllowList: isInAllowList,
|
||||
createdAt: fasttime.Now().Unix(),
|
||||
}
|
||||
|
||||
if existsLnNodeIP(conn.RawIP()) {
|
||||
conn.SetIsPersistent(true)
|
||||
}
|
||||
|
||||
// 超时等设置
|
||||
var globalServerConfig = sharedNodeConfig.GlobalServerConfig
|
||||
if globalServerConfig != nil {
|
||||
var performanceConfig = globalServerConfig.Performance
|
||||
conn.isDebugging = performanceConfig.Debug
|
||||
conn.autoReadTimeout = performanceConfig.AutoReadTimeout
|
||||
conn.autoWriteTimeout = performanceConfig.AutoWriteTimeout
|
||||
}
|
||||
|
||||
if isHTTP {
|
||||
// TODO 可以在配置中设置此值
|
||||
_ = conn.SetLinger(nodeconfigs.DefaultTCPLinger)
|
||||
}
|
||||
|
||||
// 加入到Map
|
||||
conns.SharedMap.Add(conn)
|
||||
|
||||
return conn
|
||||
}
|
||||
|
||||
func (this *ClientConn) Read(b []byte) (n int, err error) {
|
||||
if this.isDebugging {
|
||||
this.lastReadAt = fasttime.Now().Unix()
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
this.lastErr = fmt.Errorf("read error: %w", err)
|
||||
} else {
|
||||
this.lastErr = nil
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// 环路直接读取
|
||||
if this.isLO {
|
||||
n, err = this.rawConn.Read(b)
|
||||
if n > 0 {
|
||||
atomic.AddUint64(&teaconst.InTrafficBytes, uint64(n))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 设置读超时时间
|
||||
if this.isHTTP && !this.isPersistent && !this.isShortReading && this.autoReadTimeout {
|
||||
this.setHTTPReadTimeout()
|
||||
}
|
||||
|
||||
// 开始读取
|
||||
n, err = this.rawConn.Read(b)
|
||||
if n > 0 {
|
||||
atomic.AddUint64(&teaconst.InTrafficBytes, uint64(n))
|
||||
this.hasRead = true
|
||||
}
|
||||
|
||||
// 检测是否为超时错误
|
||||
var isTimeout = err != nil && os.IsTimeout(err)
|
||||
var isHandshakeError = isTimeout && !this.hasRead
|
||||
|
||||
if err != nil {
|
||||
_ = this.SetLinger(nodeconfigs.DefaultTCPLinger)
|
||||
}
|
||||
|
||||
// 忽略白名单和局域网
|
||||
if !this.isPersistent && this.isHTTP && !this.isInAllowList && !utils.IsLocalIP(this.RawIP()) {
|
||||
// SYN Flood检测
|
||||
if this.serverId == 0 || !this.hasResetSYNFlood {
|
||||
var synFloodConfig = sharedNodeConfig.SYNFloodConfig()
|
||||
if synFloodConfig != nil && synFloodConfig.IsOn {
|
||||
if isHandshakeError {
|
||||
this.increaseSYNFlood(synFloodConfig)
|
||||
} else if err == nil && !this.hasResetSYNFlood {
|
||||
this.hasResetSYNFlood = true
|
||||
this.resetSYNFlood()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (this *ClientConn) Write(b []byte) (n int, err error) {
|
||||
if len(b) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
if this.isDebugging {
|
||||
this.lastWriteAt = fasttime.Now().Unix()
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
this.lastErr = fmt.Errorf("write error: %w", err)
|
||||
} else {
|
||||
this.lastErr = nil
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// 设置写超时时间
|
||||
if !this.isPersistent && this.autoWriteTimeout {
|
||||
var timeoutSeconds = len(b) / 1024
|
||||
if timeoutSeconds < 3 {
|
||||
timeoutSeconds = 3
|
||||
}
|
||||
_ = this.rawConn.SetWriteDeadline(time.Now().Add(time.Duration(timeoutSeconds) * time.Second)) // TODO 时间可以设置
|
||||
}
|
||||
|
||||
// 延长读超时时间
|
||||
if this.isHTTP && !this.isPersistent && this.autoReadTimeout {
|
||||
this.setHTTPReadTimeout()
|
||||
}
|
||||
|
||||
// 开始写入
|
||||
var before = time.Now()
|
||||
n, err = this.rawConn.Write(b)
|
||||
if n > 0 {
|
||||
atomic.AddInt64(&this.totalSentBytes, int64(n))
|
||||
|
||||
// 统计当前服务带宽
|
||||
if this.serverId > 0 {
|
||||
// TODO 需要加入在serverId绑定之前的带宽
|
||||
if !this.isNoStat || Tea.IsTesting() { // 环路不统计带宽,避免缓存预热等行为产生带宽
|
||||
atomic.AddUint64(&teaconst.OutTrafficBytes, uint64(n))
|
||||
|
||||
var cost = time.Since(before).Seconds()
|
||||
if cost > 1 {
|
||||
stats.SharedBandwidthStatManager.AddBandwidth(this.userId, this.userPlanId, this.serverId, int64(float64(n)/cost), int64(n))
|
||||
} else {
|
||||
stats.SharedBandwidthStatManager.AddBandwidth(this.userId, this.userPlanId, this.serverId, int64(n), int64(n))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果是写入超时,则立即关闭连接
|
||||
if err != nil && os.IsTimeout(err) {
|
||||
// TODO 考虑对多次慢连接的IP做出惩罚
|
||||
conn, ok := this.rawConn.(LingerConn)
|
||||
if ok {
|
||||
_ = conn.SetLinger(0)
|
||||
}
|
||||
|
||||
_ = this.Close()
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (this *ClientConn) Close() error {
|
||||
this.isClosed = true
|
||||
|
||||
err := this.rawConn.Close()
|
||||
|
||||
// 单个服务并发数限制
|
||||
// 不能加条件限制,因为服务配置随时有变化
|
||||
sharedClientConnLimiter.Remove(this.rawConn.RemoteAddr().String())
|
||||
|
||||
// 从conn map中移除
|
||||
conns.SharedMap.Remove(this)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (this *ClientConn) LocalAddr() net.Addr {
|
||||
return this.rawConn.LocalAddr()
|
||||
}
|
||||
|
||||
func (this *ClientConn) RemoteAddr() net.Addr {
|
||||
return this.rawConn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (this *ClientConn) SetDeadline(t time.Time) error {
|
||||
return this.rawConn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (this *ClientConn) SetReadDeadline(t time.Time) error {
|
||||
// 如果开启了HTTP自动读超时选项,则自动控制超时时间
|
||||
if this.isHTTP && !this.isPersistent && this.autoReadTimeout {
|
||||
this.isShortReading = false
|
||||
|
||||
var unixTime = t.Unix()
|
||||
if unixTime < 10 {
|
||||
return nil
|
||||
}
|
||||
if unixTime == this.readDeadlineTime {
|
||||
return nil
|
||||
}
|
||||
this.readDeadlineTime = unixTime
|
||||
var seconds = -time.Since(t)
|
||||
if seconds <= 0 || seconds > HTTPIdleTimeout {
|
||||
return nil
|
||||
}
|
||||
if seconds < HTTPIdleTimeout-1*time.Second {
|
||||
this.isShortReading = true
|
||||
}
|
||||
}
|
||||
return this.rawConn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (this *ClientConn) SetWriteDeadline(t time.Time) error {
|
||||
return this.rawConn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (this *ClientConn) CreatedAt() int64 {
|
||||
return this.createdAt
|
||||
}
|
||||
|
||||
func (this *ClientConn) LastReadAt() int64 {
|
||||
return this.lastReadAt
|
||||
}
|
||||
|
||||
func (this *ClientConn) LastWriteAt() int64 {
|
||||
return this.lastWriteAt
|
||||
}
|
||||
|
||||
func (this *ClientConn) LastErr() error {
|
||||
return this.lastErr
|
||||
}
|
||||
|
||||
func (this *ClientConn) resetSYNFlood() {
|
||||
counters.SharedCounter.ResetKey("SYN_FLOOD:" + this.RawIP())
|
||||
}
|
||||
|
||||
func (this *ClientConn) increaseSYNFlood(synFloodConfig *firewallconfigs.SYNFloodConfig) {
|
||||
var ip = this.RawIP()
|
||||
if len(ip) > 0 && !iplibrary.IsInWhiteList(ip) && (!synFloodConfig.IgnoreLocal || !utils.IsLocalIP(ip)) {
|
||||
var result = counters.SharedCounter.IncreaseKey("SYN_FLOOD:"+ip, 60)
|
||||
var minAttempts = synFloodConfig.MinAttempts
|
||||
if minAttempts < 5 {
|
||||
minAttempts = 5
|
||||
}
|
||||
if !this.isTLS {
|
||||
// 非TLS,设置为两倍,防止误封
|
||||
minAttempts = 2 * minAttempts
|
||||
}
|
||||
if result >= types.Uint32(minAttempts) {
|
||||
var timeout = synFloodConfig.TimeoutSeconds
|
||||
if timeout <= 0 {
|
||||
timeout = 600
|
||||
}
|
||||
|
||||
// 关闭当前连接
|
||||
_ = this.SetLinger(0)
|
||||
_ = this.Close()
|
||||
|
||||
waf.SharedIPBlackList.RecordIP(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip, fasttime.Now().Unix()+int64(timeout), 0, true, 0, 0, "疑似SYN Flood攻击,当前1分钟"+types.String(result)+"次空连接")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 设置读超时时间
|
||||
func (this *ClientConn) setHTTPReadTimeout() {
|
||||
_ = this.SetReadDeadline(time.Now().Add(HTTPIdleTimeout))
|
||||
}
|
||||
193
EdgeNode/internal/nodes/client_conn_base.go
Normal file
193
EdgeNode/internal/nodes/client_conn_base.go
Normal file
@@ -0,0 +1,193 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type BaseClientConn struct {
|
||||
rawConn net.Conn
|
||||
|
||||
isBound bool
|
||||
userId int64
|
||||
userPlanId int64
|
||||
serverId int64
|
||||
remoteAddr string
|
||||
hasLimit bool
|
||||
|
||||
isPersistent bool // 是否为持久化连接
|
||||
fingerprint []byte
|
||||
|
||||
isClosed bool
|
||||
|
||||
rawIP string
|
||||
|
||||
totalSentBytes int64
|
||||
}
|
||||
|
||||
func (this *BaseClientConn) IsClosed() bool {
|
||||
return this.isClosed
|
||||
}
|
||||
|
||||
// IsBound 是否已绑定服务
|
||||
func (this *BaseClientConn) IsBound() bool {
|
||||
return this.isBound
|
||||
}
|
||||
|
||||
// Bind 绑定服务
|
||||
func (this *BaseClientConn) Bind(serverId int64, remoteAddr string, maxConnsPerServer int, maxConnsPerIP int) bool {
|
||||
if this.isBound {
|
||||
return true
|
||||
}
|
||||
this.isBound = true
|
||||
this.serverId = serverId
|
||||
this.remoteAddr = remoteAddr
|
||||
this.hasLimit = true
|
||||
|
||||
// 检查是否可以连接
|
||||
return sharedClientConnLimiter.Add(this.rawConn.RemoteAddr().String(), serverId, remoteAddr, maxConnsPerServer, maxConnsPerIP)
|
||||
}
|
||||
|
||||
// SetServerId 设置服务ID
|
||||
func (this *BaseClientConn) SetServerId(serverId int64) (goNext bool) {
|
||||
goNext = true
|
||||
|
||||
// 检查服务相关IP黑名单
|
||||
var rawIP = this.RawIP()
|
||||
if serverId > 0 && len(rawIP) > 0 {
|
||||
// 是否在白名单中
|
||||
ok, _, expiresAt := iplibrary.AllowIP(rawIP, serverId)
|
||||
if !ok {
|
||||
_ = this.rawConn.Close()
|
||||
firewalls.DropTemporaryTo(rawIP, expiresAt)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
this.serverId = serverId
|
||||
|
||||
// 设置包装前连接
|
||||
switch conn := this.rawConn.(type) {
|
||||
case *tls.Conn:
|
||||
nativeConn, ok := conn.NetConn().(ClientConnInterface)
|
||||
if ok {
|
||||
nativeConn.SetServerId(serverId)
|
||||
}
|
||||
case *ClientConn:
|
||||
conn.SetServerId(serverId)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// ServerId 读取当前连接绑定的服务ID
|
||||
func (this *BaseClientConn) ServerId() int64 {
|
||||
return this.serverId
|
||||
}
|
||||
|
||||
// SetUserId 设置所属服务的用户ID
|
||||
func (this *BaseClientConn) SetUserId(userId int64) {
|
||||
this.userId = userId
|
||||
|
||||
// 设置包装前连接
|
||||
switch conn := this.rawConn.(type) {
|
||||
case *tls.Conn:
|
||||
nativeConn, ok := conn.NetConn().(ClientConnInterface)
|
||||
if ok {
|
||||
nativeConn.SetUserId(userId)
|
||||
}
|
||||
case *ClientConn:
|
||||
conn.SetUserId(userId)
|
||||
}
|
||||
}
|
||||
|
||||
func (this *BaseClientConn) SetUserPlanId(userPlanId int64) {
|
||||
this.userPlanId = userPlanId
|
||||
|
||||
// 设置包装前连接
|
||||
switch conn := this.rawConn.(type) {
|
||||
case *tls.Conn:
|
||||
nativeConn, ok := conn.NetConn().(ClientConnInterface)
|
||||
if ok {
|
||||
nativeConn.SetUserPlanId(userPlanId)
|
||||
}
|
||||
case *ClientConn:
|
||||
conn.SetUserPlanId(userPlanId)
|
||||
}
|
||||
}
|
||||
|
||||
// UserId 获取当前连接所属服务的用户ID
|
||||
func (this *BaseClientConn) UserId() int64 {
|
||||
return this.userId
|
||||
}
|
||||
|
||||
// UserPlanId 用户套餐ID
|
||||
func (this *BaseClientConn) UserPlanId() int64 {
|
||||
return this.userPlanId
|
||||
}
|
||||
|
||||
// RawIP 原本IP
|
||||
func (this *BaseClientConn) RawIP() string {
|
||||
if len(this.rawIP) > 0 {
|
||||
return this.rawIP
|
||||
}
|
||||
|
||||
ip, _, _ := net.SplitHostPort(this.rawConn.RemoteAddr().String())
|
||||
this.rawIP = ip
|
||||
return ip
|
||||
}
|
||||
|
||||
// TCPConn 转换为TCPConn
|
||||
func (this *BaseClientConn) TCPConn() (tcpConn *net.TCPConn, ok bool) {
|
||||
// 设置包装前连接
|
||||
switch conn := this.rawConn.(type) {
|
||||
case *tls.Conn:
|
||||
var internalConn = conn.NetConn()
|
||||
clientConn, isClientConn := internalConn.(*ClientConn)
|
||||
if isClientConn {
|
||||
return clientConn.TCPConn()
|
||||
}
|
||||
tcpConn, ok = internalConn.(*net.TCPConn)
|
||||
default:
|
||||
tcpConn, ok = this.rawConn.(*net.TCPConn)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// SetLinger 设置Linger
|
||||
func (this *BaseClientConn) SetLinger(seconds int) error {
|
||||
tcpConn, ok := this.TCPConn()
|
||||
if ok {
|
||||
return tcpConn.SetLinger(seconds)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *BaseClientConn) SetIsPersistent(isPersistent bool) {
|
||||
this.isPersistent = isPersistent
|
||||
|
||||
_ = this.rawConn.SetDeadline(time.Time{})
|
||||
}
|
||||
|
||||
// SetFingerprint 设置指纹信息
|
||||
func (this *BaseClientConn) SetFingerprint(fingerprint []byte) {
|
||||
this.fingerprint = fingerprint
|
||||
}
|
||||
|
||||
// Fingerprint 读取指纹信息
|
||||
func (this *BaseClientConn) Fingerprint() []byte {
|
||||
return this.fingerprint
|
||||
}
|
||||
|
||||
// LastRequestBytes 读取上一次请求发送的字节数
|
||||
func (this *BaseClientConn) LastRequestBytes() int64 {
|
||||
var result = atomic.LoadInt64(&this.totalSentBytes)
|
||||
atomic.StoreInt64(&this.totalSentBytes, 0)
|
||||
return result
|
||||
}
|
||||
41
EdgeNode/internal/nodes/client_conn_interface.go
Normal file
41
EdgeNode/internal/nodes/client_conn_interface.go
Normal file
@@ -0,0 +1,41 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
type ClientConnInterface interface {
|
||||
// IsClosed 是否已关闭
|
||||
IsClosed() bool
|
||||
|
||||
// IsBound 是否已绑定服务
|
||||
IsBound() bool
|
||||
|
||||
// Bind 绑定服务
|
||||
Bind(serverId int64, remoteAddr string, maxConnsPerServer int, maxConnsPerIP int) bool
|
||||
|
||||
// ServerId 获取服务ID
|
||||
ServerId() int64
|
||||
|
||||
// SetServerId 设置服务ID
|
||||
SetServerId(serverId int64) (goNext bool)
|
||||
|
||||
// SetUserId 设置所属网站的用户ID
|
||||
SetUserId(userId int64)
|
||||
|
||||
// SetUserPlanId 设置
|
||||
SetUserPlanId(userPlanId int64)
|
||||
|
||||
// UserId 获取当前连接所属服务的用户ID
|
||||
UserId() int64
|
||||
|
||||
// SetIsPersistent 设置是否为持久化
|
||||
SetIsPersistent(isPersistent bool)
|
||||
|
||||
// SetFingerprint 设置指纹信息
|
||||
SetFingerprint(fingerprint []byte)
|
||||
|
||||
// Fingerprint 读取指纹信息
|
||||
Fingerprint() []byte
|
||||
|
||||
// LastRequestBytes 读取上一次请求发送的字节数
|
||||
LastRequestBytes() int64
|
||||
}
|
||||
130
EdgeNode/internal/nodes/client_conn_limiter.go
Normal file
130
EdgeNode/internal/nodes/client_conn_limiter.go
Normal file
@@ -0,0 +1,130 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/zero"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var sharedClientConnLimiter = NewClientConnLimiter()
|
||||
|
||||
// ClientConnRemoteAddr 客户端地址定义
|
||||
type ClientConnRemoteAddr struct {
|
||||
remoteAddr string
|
||||
serverId int64
|
||||
}
|
||||
|
||||
// ClientConnLimiter 客户端连接数限制
|
||||
type ClientConnLimiter struct {
|
||||
remoteAddrMap map[string]*ClientConnRemoteAddr // raw remote addr => remoteAddr
|
||||
ipConns map[string]map[string]zero.Zero // remoteAddr => { raw remote addr => Zero }
|
||||
serverConns map[int64]map[string]zero.Zero // serverId => { remoteAddr => Zero }
|
||||
|
||||
locker sync.Mutex
|
||||
}
|
||||
|
||||
func NewClientConnLimiter() *ClientConnLimiter {
|
||||
return &ClientConnLimiter{
|
||||
remoteAddrMap: map[string]*ClientConnRemoteAddr{},
|
||||
ipConns: map[string]map[string]zero.Zero{},
|
||||
serverConns: map[int64]map[string]zero.Zero{},
|
||||
}
|
||||
}
|
||||
|
||||
// Add 添加新连接
|
||||
// 返回值为true的时候表示允许添加;否则表示不允许添加
|
||||
func (this *ClientConnLimiter) Add(rawRemoteAddr string, serverId int64, remoteAddr string, maxConnsPerServer int, maxConnsPerIP int) bool {
|
||||
if (maxConnsPerServer <= 0 && maxConnsPerIP <= 0) || len(remoteAddr) == 0 || serverId <= 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
// 检查服务连接数
|
||||
var serverMap = this.serverConns[serverId]
|
||||
if maxConnsPerServer > 0 {
|
||||
if serverMap == nil {
|
||||
serverMap = map[string]zero.Zero{}
|
||||
this.serverConns[serverId] = serverMap
|
||||
}
|
||||
|
||||
if maxConnsPerServer <= len(serverMap) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 检查IP连接数
|
||||
var ipMap = this.ipConns[remoteAddr]
|
||||
if maxConnsPerIP > 0 {
|
||||
if ipMap == nil {
|
||||
ipMap = map[string]zero.Zero{}
|
||||
this.ipConns[remoteAddr] = ipMap
|
||||
}
|
||||
if maxConnsPerIP > 0 && maxConnsPerIP <= len(ipMap) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
this.remoteAddrMap[rawRemoteAddr] = &ClientConnRemoteAddr{
|
||||
remoteAddr: remoteAddr,
|
||||
serverId: serverId,
|
||||
}
|
||||
|
||||
if maxConnsPerServer > 0 {
|
||||
serverMap[rawRemoteAddr] = zero.New()
|
||||
}
|
||||
if maxConnsPerIP > 0 {
|
||||
ipMap[rawRemoteAddr] = zero.New()
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Remove 删除连接
|
||||
func (this *ClientConnLimiter) Remove(rawRemoteAddr string) {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
addr, ok := this.remoteAddrMap[rawRemoteAddr]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
delete(this.remoteAddrMap, rawRemoteAddr)
|
||||
delete(this.ipConns[addr.remoteAddr], rawRemoteAddr)
|
||||
delete(this.serverConns[addr.serverId], rawRemoteAddr)
|
||||
|
||||
if len(this.ipConns[addr.remoteAddr]) == 0 {
|
||||
delete(this.ipConns, addr.remoteAddr)
|
||||
}
|
||||
|
||||
if len(this.serverConns[addr.serverId]) == 0 {
|
||||
delete(this.serverConns, addr.serverId)
|
||||
}
|
||||
}
|
||||
|
||||
// Conns 获取连接信息
|
||||
// 用于调试
|
||||
func (this *ClientConnLimiter) Conns() (ipConns map[string][]string, serverConns map[int64][]string) {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
ipConns = map[string][]string{} // ip => [addr1, addr2, ...]
|
||||
serverConns = map[int64][]string{} // serverId => [addr1, addr2, ...]
|
||||
|
||||
for ip, m := range this.ipConns {
|
||||
for addr := range m {
|
||||
ipConns[ip] = append(ipConns[ip], addr)
|
||||
}
|
||||
}
|
||||
|
||||
for serverId, m := range this.serverConns {
|
||||
for addr := range m {
|
||||
serverConns[serverId] = append(serverConns[serverId], addr)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
38
EdgeNode/internal/nodes/client_conn_limiter_test.go
Normal file
38
EdgeNode/internal/nodes/client_conn_limiter_test.go
Normal file
@@ -0,0 +1,38 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClientConnLimiter_Add(t *testing.T) {
|
||||
var limiter = NewClientConnLimiter()
|
||||
{
|
||||
b := limiter.Add("127.0.0.1:1234", 1, "192.168.1.100", 10, 5)
|
||||
t.Log(b)
|
||||
}
|
||||
{
|
||||
b := limiter.Add("127.0.0.1:1235", 1, "192.168.1.100", 10, 5)
|
||||
t.Log(b)
|
||||
}
|
||||
{
|
||||
b := limiter.Add("127.0.0.1:1236", 1, "192.168.1.100", 10, 5)
|
||||
t.Log(b)
|
||||
}
|
||||
{
|
||||
b := limiter.Add("127.0.0.1:1237", 1, "192.168.1.101", 10, 5)
|
||||
t.Log(b)
|
||||
}
|
||||
{
|
||||
b := limiter.Add("127.0.0.1:1238", 1, "192.168.1.100", 5, 5)
|
||||
t.Log(b)
|
||||
}
|
||||
limiter.Remove("127.0.0.1:1238")
|
||||
limiter.Remove("127.0.0.1:1239")
|
||||
limiter.Remove("127.0.0.1:1237")
|
||||
logs.PrintAsJSON(limiter.remoteAddrMap, t)
|
||||
logs.PrintAsJSON(limiter.ipConns, t)
|
||||
logs.PrintAsJSON(limiter.serverConns, t)
|
||||
}
|
||||
45
EdgeNode/internal/nodes/client_conn_traffic.go
Normal file
45
EdgeNode/internal/nodes/client_conn_traffic.go
Normal file
@@ -0,0 +1,45 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/monitor"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/goman"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 发送监控流量
|
||||
func init() {
|
||||
if !teaconst.IsMain {
|
||||
return
|
||||
}
|
||||
|
||||
events.On(events.EventStart, func() {
|
||||
var ticker = time.NewTicker(1 * time.Minute)
|
||||
goman.New(func() {
|
||||
for range ticker.C {
|
||||
// 加入到数据队列中
|
||||
var inBytes = atomic.LoadUint64(&teaconst.InTrafficBytes)
|
||||
atomic.StoreUint64(&teaconst.InTrafficBytes, 0) // 重置数据
|
||||
if inBytes > 0 {
|
||||
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemTrafficIn, maps.Map{
|
||||
"total": inBytes,
|
||||
})
|
||||
}
|
||||
|
||||
var outBytes = atomic.LoadUint64(&teaconst.OutTrafficBytes)
|
||||
atomic.StoreUint64(&teaconst.OutTrafficBytes, 0) // 重置数据
|
||||
if outBytes > 0 {
|
||||
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemTrafficOut, maps.Map{
|
||||
"total": outBytes,
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
20
EdgeNode/internal/nodes/client_conn_utils.go
Normal file
20
EdgeNode/internal/nodes/client_conn_utils.go
Normal file
@@ -0,0 +1,20 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// 判断客户端连接是否已关闭
|
||||
func isClientConnClosed(conn net.Conn) bool {
|
||||
if conn == nil {
|
||||
return true
|
||||
}
|
||||
clientConn, ok := conn.(ClientConnInterface)
|
||||
if ok {
|
||||
return clientConn.IsClosed()
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
81
EdgeNode/internal/nodes/client_listener.go
Normal file
81
EdgeNode/internal/nodes/client_listener.go
Normal file
@@ -0,0 +1,81 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||
"net"
|
||||
)
|
||||
|
||||
// ClientListener 客户端网络监听
|
||||
type ClientListener struct {
|
||||
rawListener net.Listener
|
||||
isHTTP bool
|
||||
isTLS bool
|
||||
}
|
||||
|
||||
func NewClientListener(listener net.Listener, isHTTP bool) *ClientListener {
|
||||
return &ClientListener{
|
||||
rawListener: listener,
|
||||
isHTTP: isHTTP,
|
||||
}
|
||||
}
|
||||
|
||||
func (this *ClientListener) SetIsTLS(isTLS bool) {
|
||||
this.isTLS = isTLS
|
||||
}
|
||||
|
||||
func (this *ClientListener) IsTLS() bool {
|
||||
return this.isTLS
|
||||
}
|
||||
|
||||
func (this *ClientListener) Accept() (net.Conn, error) {
|
||||
conn, err := this.rawListener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 是否在WAF名单中
|
||||
ip, _, err := net.SplitHostPort(conn.RemoteAddr().String())
|
||||
var isInAllowList = false
|
||||
if err == nil {
|
||||
canGoNext, inAllowList, expiresAt := iplibrary.AllowIP(ip, 0)
|
||||
isInAllowList = inAllowList
|
||||
if !canGoNext {
|
||||
firewalls.DropTemporaryTo(ip, expiresAt)
|
||||
} else {
|
||||
if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) {
|
||||
var ok bool
|
||||
expiresAt, ok = waf.SharedIPBlackList.ContainsExpires(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip)
|
||||
if ok {
|
||||
canGoNext = false
|
||||
firewalls.DropTemporaryTo(ip, expiresAt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !canGoNext {
|
||||
tcpConn, ok := conn.(*net.TCPConn)
|
||||
if ok {
|
||||
_ = tcpConn.SetLinger(0)
|
||||
}
|
||||
|
||||
_ = conn.Close()
|
||||
|
||||
return this.Accept()
|
||||
}
|
||||
}
|
||||
|
||||
return NewClientConn(conn, this.isHTTP, this.isTLS, isInAllowList), nil
|
||||
}
|
||||
|
||||
func (this *ClientListener) Close() error {
|
||||
return this.rawListener.Close()
|
||||
}
|
||||
|
||||
func (this *ClientListener) Addr() net.Addr {
|
||||
return this.rawListener.Addr()
|
||||
}
|
||||
99
EdgeNode/internal/nodes/client_tls_conn.go
Normal file
99
EdgeNode/internal/nodes/client_tls_conn.go
Normal file
@@ -0,0 +1,99 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ClientTLSConn TLS连接封装
|
||||
type ClientTLSConn struct {
|
||||
BaseClientConn
|
||||
}
|
||||
|
||||
func NewClientTLSConn(conn *tls.Conn) net.Conn {
|
||||
return &ClientTLSConn{BaseClientConn{rawConn: conn}}
|
||||
}
|
||||
|
||||
func (this *ClientTLSConn) Read(b []byte) (n int, err error) {
|
||||
n, err = this.rawConn.Read(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (this *ClientTLSConn) Write(b []byte) (n int, err error) {
|
||||
n, err = this.rawConn.Write(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (this *ClientTLSConn) Close() error {
|
||||
this.isClosed = true
|
||||
|
||||
// 单个服务并发数限制
|
||||
sharedClientConnLimiter.Remove(this.rawConn.RemoteAddr().String())
|
||||
|
||||
return this.rawConn.Close()
|
||||
}
|
||||
|
||||
func (this *ClientTLSConn) LocalAddr() net.Addr {
|
||||
return this.rawConn.LocalAddr()
|
||||
}
|
||||
|
||||
func (this *ClientTLSConn) RemoteAddr() net.Addr {
|
||||
return this.rawConn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (this *ClientTLSConn) SetDeadline(t time.Time) error {
|
||||
return this.rawConn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (this *ClientTLSConn) SetReadDeadline(t time.Time) error {
|
||||
return this.rawConn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (this *ClientTLSConn) SetWriteDeadline(t time.Time) error {
|
||||
return this.rawConn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (this *ClientTLSConn) SetIsPersistent(isPersistent bool) {
|
||||
tlsConn, ok := this.rawConn.(*tls.Conn)
|
||||
if ok {
|
||||
var rawConn = tlsConn.NetConn()
|
||||
if rawConn != nil {
|
||||
clientConn, ok := rawConn.(*ClientConn)
|
||||
if ok {
|
||||
clientConn.SetIsPersistent(isPersistent)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (this *ClientTLSConn) Fingerprint() []byte {
|
||||
tlsConn, ok := this.rawConn.(*tls.Conn)
|
||||
if ok {
|
||||
var rawConn = tlsConn.NetConn()
|
||||
if rawConn != nil {
|
||||
clientConn, ok := rawConn.(*ClientConn)
|
||||
if ok {
|
||||
return clientConn.fingerprint
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LastRequestBytes 读取上一次请求发送的字节数
|
||||
func (this *ClientTLSConn) LastRequestBytes() int64 {
|
||||
tlsConn, ok := this.rawConn.(*tls.Conn)
|
||||
if ok {
|
||||
var rawConn = tlsConn.NetConn()
|
||||
if rawConn != nil {
|
||||
clientConn, ok := rawConn.(*ClientConn)
|
||||
if ok {
|
||||
return clientConn.LastRequestBytes()
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
7
EdgeNode/internal/nodes/conn_linger.go
Normal file
7
EdgeNode/internal/nodes/conn_linger.go
Normal file
@@ -0,0 +1,7 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package nodes
|
||||
|
||||
type LingerConn interface {
|
||||
SetLinger(sec int) error
|
||||
}
|
||||
37
EdgeNode/internal/nodes/http3_conn_notifier_plus.go
Normal file
37
EdgeNode/internal/nodes/http3_conn_notifier_plus.go
Normal file
@@ -0,0 +1,37 @@
|
||||
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
//go:build plus
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/stats"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type HTTP3ConnNotifier struct {
|
||||
conn *HTTP3Conn
|
||||
}
|
||||
|
||||
func NewHTTP3ConnNotifier(conn *HTTP3Conn) *HTTP3ConnNotifier {
|
||||
return &HTTP3ConnNotifier{
|
||||
conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
func (this *HTTP3ConnNotifier) ReadBytes(bytes int) {
|
||||
|
||||
}
|
||||
|
||||
func (this *HTTP3ConnNotifier) WriteBytes(bytes int) {
|
||||
if this.conn != nil && this.conn.ServerId() > 0 {
|
||||
stats.SharedBandwidthStatManager.AddBandwidth(this.conn.UserId(), this.conn.UserPlanId(), this.conn.ServerId(), int64(bytes), int64(bytes))
|
||||
|
||||
if this.conn.BaseClientConn != nil {
|
||||
atomic.AddInt64(&this.conn.BaseClientConn.totalSentBytes, int64(bytes))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (this *HTTP3ConnNotifier) Close() {
|
||||
this.conn.NotifyClose()
|
||||
}
|
||||
34
EdgeNode/internal/nodes/http3_conn_plus.go
Normal file
34
EdgeNode/internal/nodes/http3_conn_plus.go
Normal file
@@ -0,0 +1,34 @@
|
||||
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
//go:build plus
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/conns"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/http3"
|
||||
)
|
||||
|
||||
type HTTP3Conn struct {
|
||||
*BaseClientConn
|
||||
http3.Conn
|
||||
}
|
||||
|
||||
func NewHTTP3Conn(http3Conn http3.Conn) *HTTP3Conn {
|
||||
// 添加新的
|
||||
var conn = &HTTP3Conn{
|
||||
BaseClientConn: &BaseClientConn{rawConn: http3Conn},
|
||||
Conn: http3Conn,
|
||||
}
|
||||
|
||||
http3Conn.SetParentConn(conn)
|
||||
http3Conn.SetNotifier(NewHTTP3ConnNotifier(conn))
|
||||
|
||||
// 添加到统计Map
|
||||
conns.SharedMap.Add(conn)
|
||||
|
||||
return conn
|
||||
}
|
||||
|
||||
func (this *HTTP3Conn) NotifyClose() {
|
||||
conns.SharedMap.Remove(this)
|
||||
}
|
||||
89
EdgeNode/internal/nodes/http3_listener_plus.go
Normal file
89
EdgeNode/internal/nodes/http3_listener_plus.go
Normal file
@@ -0,0 +1,89 @@
|
||||
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
//go:build plus
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/http3"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||
"github.com/quic-go/quic-go"
|
||||
"net"
|
||||
)
|
||||
|
||||
type HTTP3Listener struct {
|
||||
rawListener http3.Listener
|
||||
}
|
||||
|
||||
func ListenHTTP3(addr string, tlsConfig *tls.Config) (http3.Listener, error) {
|
||||
rawListener, err := http3.Listen(addr, tlsConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var listener = &HTTP3Listener{rawListener: rawListener}
|
||||
listener.init()
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
func (this *HTTP3Listener) init() {
|
||||
events.OnKey(events.EventQuit, fmt.Sprintf("http_listener_%p", this), func() {
|
||||
_ = this.Close()
|
||||
})
|
||||
events.OnKey(events.EventTerminated, fmt.Sprintf("http_listener_%p", this), func() {
|
||||
_ = this.Close()
|
||||
})
|
||||
}
|
||||
|
||||
func (this *HTTP3Listener) Addr() net.Addr {
|
||||
return this.rawListener.Addr()
|
||||
}
|
||||
|
||||
func (this *HTTP3Listener) Accept(ctx context.Context, ctxFunc http3.ContextFunc) (quic.EarlyConnection, error) {
|
||||
conn, err := this.rawListener.Accept(ctx, ctxFunc)
|
||||
if err != nil {
|
||||
return conn, err
|
||||
}
|
||||
|
||||
// 是否在WAF名单中
|
||||
ip, _, err := net.SplitHostPort(conn.RemoteAddr().String())
|
||||
var isInAllowList = false
|
||||
if err == nil {
|
||||
canGoNext, inAllowList, expiresAt := iplibrary.AllowIP(ip, 0)
|
||||
isInAllowList = inAllowList
|
||||
if !canGoNext {
|
||||
firewalls.DropTemporaryTo(ip, expiresAt)
|
||||
} else {
|
||||
if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) {
|
||||
var ok bool
|
||||
expiresAt, ok = waf.SharedIPBlackList.ContainsExpires(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip)
|
||||
if ok {
|
||||
canGoNext = false
|
||||
firewalls.DropTemporaryTo(ip, expiresAt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !canGoNext {
|
||||
_ = conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "")
|
||||
|
||||
return this.Accept(ctx, ctxFunc)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO 将isInAllowList传递到HTTP3Conn?
|
||||
_ = isInAllowList
|
||||
|
||||
return NewHTTP3Conn(conn.(*http3.BasicConn)), nil
|
||||
}
|
||||
|
||||
func (this *HTTP3Listener) Close() error {
|
||||
events.Remove(fmt.Sprintf("http_listener_%p", this))
|
||||
return this.rawListener.Close()
|
||||
}
|
||||
231
EdgeNode/internal/nodes/http3_manager_plus.go
Normal file
231
EdgeNode/internal/nodes/http3_manager_plus.go
Normal file
@@ -0,0 +1,231 @@
|
||||
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
//go:build plus
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/http3"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
var sharedHTTP3Manager = NewHTTP3Manager()
|
||||
|
||||
func init() {
|
||||
if !teaconst.IsMain {
|
||||
return
|
||||
}
|
||||
|
||||
var listener = &HTTPListener{
|
||||
isHTTPS: true,
|
||||
isHTTP3: true,
|
||||
}
|
||||
events.On(events.EventLoaded, func() {
|
||||
sharedListenerManager.http3Listener = listener // 注册到ListenerManager,以便统计用
|
||||
})
|
||||
|
||||
var eventLocker = sync.Mutex{}
|
||||
events.OnEvents([]events.Event{events.EventReload, events.EventReloadSomeServers}, func() {
|
||||
go func() {
|
||||
eventLocker.Lock()
|
||||
defer eventLocker.Unlock()
|
||||
|
||||
if sharedNodeConfig == nil {
|
||||
return
|
||||
}
|
||||
|
||||
_ = sharedHTTP3Manager.Update(sharedNodeConfig.HTTP3Policies)
|
||||
sharedHTTP3Manager.UpdateHTTPListener(listener)
|
||||
|
||||
listener.Reload(sharedNodeConfig.HTTP3Group())
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
// HTTP3Manager HTTP3管理器
|
||||
type HTTP3Manager struct {
|
||||
locker sync.RWMutex
|
||||
|
||||
hasHTTP3 bool
|
||||
|
||||
policies map[int64]*nodeconfigs.HTTP3Policy // clusterId => *HTTP3Policy
|
||||
serverMap map[int]*http3.Server // port => *Server
|
||||
mobileUserAgentReg *regexp.Regexp
|
||||
|
||||
httpListener *HTTPListener
|
||||
tlsConfig *tls.Config
|
||||
}
|
||||
|
||||
func NewHTTP3Manager() *HTTP3Manager {
|
||||
return &HTTP3Manager{
|
||||
policies: map[int64]*nodeconfigs.HTTP3Policy{},
|
||||
serverMap: map[int]*http3.Server{},
|
||||
mobileUserAgentReg: regexp.MustCompile(`(?i)(iPhone|Android)`),
|
||||
}
|
||||
}
|
||||
|
||||
// Update 更新配置
|
||||
// m: clusterId => *HTTP3Policy
|
||||
func (this *HTTP3Manager) Update(m map[int64]*nodeconfigs.HTTP3Policy) error {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
// 启动新的
|
||||
var newPolicyMap = map[int64]*nodeconfigs.HTTP3Policy{} // clusterId => *HTTP3Policy
|
||||
var newPorts = map[int]bool{} // port => bool
|
||||
for clusterId, policy := range m {
|
||||
if policy.IsOn && policy.Port > 0 {
|
||||
this.policies[clusterId] = policy
|
||||
newPolicyMap[clusterId] = policy
|
||||
|
||||
var port = policy.Port
|
||||
newPorts[port] = true
|
||||
|
||||
_, existPort := this.serverMap[port]
|
||||
if !existPort {
|
||||
server, err := this.createServer(port)
|
||||
if err != nil {
|
||||
remotelogs.Error("HTTP3_MANAGER", "start port '"+types.String(port)+"' failed: "+err.Error())
|
||||
continue
|
||||
}
|
||||
this.serverMap[port] = server
|
||||
remotelogs.Debug("HTTP3_MANAGER", "start port '"+types.String(port)+"'")
|
||||
}
|
||||
}
|
||||
}
|
||||
this.policies = newPolicyMap
|
||||
|
||||
// 关闭老的
|
||||
for port, server := range this.serverMap {
|
||||
if !newPorts[port] {
|
||||
_ = server.Close()
|
||||
delete(this.serverMap, port)
|
||||
remotelogs.Debug("HTTP3_MANAGER", "close port '"+types.String(port)+"'")
|
||||
}
|
||||
}
|
||||
|
||||
this.hasHTTP3 = len(this.serverMap) > 0
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateHTTPListener 更新Listener
|
||||
// 这里的Listener只是为了方便复用HTTPListener的相关方法
|
||||
func (this *HTTP3Manager) UpdateHTTPListener(listener *HTTPListener) {
|
||||
this.locker.Lock()
|
||||
this.httpListener = listener
|
||||
if listener != nil {
|
||||
this.tlsConfig = listener.buildTLSConfig()
|
||||
}
|
||||
this.locker.Unlock()
|
||||
}
|
||||
|
||||
// ProcessHTTP3Headers 处理HTTP3相关Headers
|
||||
func (this *HTTP3Manager) ProcessHTTP3Headers(userAgent string, headers http.Header, clusterId int64) {
|
||||
// 这里不要加锁,以便于提升性能
|
||||
if !this.hasHTTP3 {
|
||||
return
|
||||
}
|
||||
|
||||
this.locker.RLock()
|
||||
defer this.locker.RUnlock()
|
||||
|
||||
// 再次准确检查
|
||||
if !this.hasHTTP3 {
|
||||
return
|
||||
}
|
||||
|
||||
policy, ok := this.policies[clusterId]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if policy.IsOn && policy.Port > 0 && (policy.SupportMobileBrowsers || !this.mobileUserAgentReg.MatchString(userAgent)) {
|
||||
// TODO 版本好和有效期可以在策略里设置
|
||||
headers.Set("Alt-Svc", `h3=":`+types.String(policy.Port)+`"; ma=2592000,h3-29=":`+types.String(policy.Port)+`"; ma=2592000`)
|
||||
}
|
||||
}
|
||||
|
||||
// 创建server
|
||||
func (this *HTTP3Manager) createServer(port int) (*http3.Server, error) {
|
||||
var addr = ":" + types.String(port)
|
||||
listener, err := ListenHTTP3(addr, &tls.Config{
|
||||
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
this.locker.RLock()
|
||||
var tlsConfig = this.tlsConfig
|
||||
this.locker.RUnlock()
|
||||
|
||||
if tlsConfig != nil && tlsConfig.GetConfigForClient != nil {
|
||||
return tlsConfig.GetConfigForClient(info)
|
||||
}
|
||||
|
||||
return nil, errors.New("http3: no tls config")
|
||||
},
|
||||
GetCertificate: func(clientInfo *tls.ClientHelloInfo) (certificate *tls.Certificate, e error) {
|
||||
this.locker.RLock()
|
||||
var tlsConfig = this.tlsConfig
|
||||
this.locker.RUnlock()
|
||||
|
||||
if tlsConfig != nil && tlsConfig.GetCertificate != nil {
|
||||
return tlsConfig.GetCertificate(clientInfo)
|
||||
}
|
||||
|
||||
return nil, errors.New("http3: no tls config")
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var server = &http3.Server{
|
||||
Addr: ":" + types.String(port),
|
||||
Handler: http.HandlerFunc(func(writer http.ResponseWriter, req *http.Request) {
|
||||
if this.httpListener != nil {
|
||||
var servePortString = "443"
|
||||
if len(req.Host) > 0 {
|
||||
_, hostPortString, hostErr := net.SplitHostPort(req.Host)
|
||||
if hostErr == nil && len(hostPortString) > 0 {
|
||||
servePortString = hostPortString
|
||||
}
|
||||
}
|
||||
this.httpListener.ServeHTTPWithAddr(writer, req, ":"+servePortString)
|
||||
}
|
||||
}),
|
||||
ConnState: func(conn net.Conn, state http.ConnState) {
|
||||
if this.httpListener == nil {
|
||||
return
|
||||
}
|
||||
switch state {
|
||||
case http.StateNew:
|
||||
atomic.AddInt64(&this.httpListener.countActiveConnections, 1)
|
||||
case http.StateClosed:
|
||||
atomic.AddInt64(&this.httpListener.countActiveConnections, -1)
|
||||
default:
|
||||
// do nothing
|
||||
}
|
||||
},
|
||||
ConnContext: func(ctx context.Context, conn net.Conn) context.Context {
|
||||
return context.WithValue(ctx, HTTPConnContextKey, conn)
|
||||
},
|
||||
}
|
||||
go func() {
|
||||
err = server.Serve(listener)
|
||||
if err != nil {
|
||||
remotelogs.Error("HTTP3_MANAGER", "serve '"+addr+"' failed: "+err.Error())
|
||||
this.locker.Lock()
|
||||
delete(this.serverMap, port)
|
||||
this.locker.Unlock()
|
||||
}
|
||||
}()
|
||||
return server, nil
|
||||
}
|
||||
194
EdgeNode/internal/nodes/http_access_log_queue.go
Normal file
194
EdgeNode/internal/nodes/http_access_log_queue.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/goman"
|
||||
memutils "github.com/TeaOSLab/EdgeNode/internal/utils/mem"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
var sharedHTTPAccessLogQueue = NewHTTPAccessLogQueue()
|
||||
|
||||
// HTTPAccessLogQueue HTTP访问日志队列
|
||||
type HTTPAccessLogQueue struct {
|
||||
queue chan *pb.HTTPAccessLog
|
||||
|
||||
rpcClient *rpc.RPCClient
|
||||
}
|
||||
|
||||
// NewHTTPAccessLogQueue 获取新对象
|
||||
func NewHTTPAccessLogQueue() *HTTPAccessLogQueue {
|
||||
// 队列中最大的值,超出此数量的访问日志会被丢弃
|
||||
var maxSize = 2_000 * (1 + memutils.SystemMemoryGB()/2)
|
||||
if maxSize > 20_000 {
|
||||
maxSize = 20_000
|
||||
}
|
||||
|
||||
var queue = &HTTPAccessLogQueue{
|
||||
queue: make(chan *pb.HTTPAccessLog, maxSize),
|
||||
}
|
||||
goman.New(func() {
|
||||
queue.Start()
|
||||
})
|
||||
|
||||
return queue
|
||||
}
|
||||
|
||||
// Start 开始处理访问日志
|
||||
func (this *HTTPAccessLogQueue) Start() {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
for range ticker.C {
|
||||
err := this.loop()
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("ACCESS_LOG_QUEUE", err.Error())
|
||||
} else {
|
||||
remotelogs.Error("ACCESS_LOG_QUEUE", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Push 加入新访问日志
|
||||
func (this *HTTPAccessLogQueue) Push(accessLog *pb.HTTPAccessLog) {
|
||||
select {
|
||||
case this.queue <- accessLog:
|
||||
default:
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// 上传访问日志
|
||||
func (this *HTTPAccessLogQueue) loop() error {
|
||||
const maxLen = 2000
|
||||
var accessLogs = make([]*pb.HTTPAccessLog, 0, maxLen)
|
||||
var count = 0
|
||||
|
||||
Loop:
|
||||
for {
|
||||
select {
|
||||
case accessLog := <-this.queue:
|
||||
accessLogs = append(accessLogs, accessLog)
|
||||
count++
|
||||
|
||||
// 每次只提交 N 条访问日志,防止网络拥堵
|
||||
if count >= maxLen {
|
||||
break Loop
|
||||
}
|
||||
default:
|
||||
break Loop
|
||||
}
|
||||
}
|
||||
|
||||
if len(accessLogs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 发送到本地
|
||||
if sharedHTTPAccessLogViewer.HasConns() {
|
||||
for _, accessLog := range accessLogs {
|
||||
sharedHTTPAccessLogViewer.Send(accessLog)
|
||||
}
|
||||
}
|
||||
|
||||
// 发送到API
|
||||
if this.rpcClient == nil {
|
||||
client, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
this.rpcClient = client
|
||||
}
|
||||
|
||||
_, err := this.rpcClient.HTTPAccessLogRPC.CreateHTTPAccessLogs(this.rpcClient.Context(), &pb.CreateHTTPAccessLogsRequest{HttpAccessLogs: accessLogs})
|
||||
if err != nil {
|
||||
// 是否包含了invalid UTF-8
|
||||
if strings.Contains(err.Error(), "string field contains invalid UTF-8") {
|
||||
for _, accessLog := range accessLogs {
|
||||
this.ToValidUTF8(accessLog)
|
||||
}
|
||||
|
||||
// 重新提交
|
||||
_, err = this.rpcClient.HTTPAccessLogRPC.CreateHTTPAccessLogs(this.rpcClient.Context(), &pb.CreateHTTPAccessLogsRequest{HttpAccessLogs: accessLogs})
|
||||
return err
|
||||
}
|
||||
|
||||
// 是否请求内容过大
|
||||
statusCode, ok := status.FromError(err)
|
||||
if ok && statusCode.Code() == codes.ResourceExhausted {
|
||||
// 去除Body
|
||||
for _, accessLog := range accessLogs {
|
||||
accessLog.RequestBody = nil
|
||||
}
|
||||
|
||||
// 重新提交
|
||||
_, err = this.rpcClient.HTTPAccessLogRPC.CreateHTTPAccessLogs(this.rpcClient.Context(), &pb.CreateHTTPAccessLogsRequest{HttpAccessLogs: accessLogs})
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToValidUTF8 处理访问日志中的非UTF-8字节
|
||||
func (this *HTTPAccessLogQueue) ToValidUTF8(accessLog *pb.HTTPAccessLog) {
|
||||
accessLog.RemoteAddr = utils.ToValidUTF8string(accessLog.RemoteAddr)
|
||||
accessLog.RemoteUser = utils.ToValidUTF8string(accessLog.RemoteUser)
|
||||
accessLog.RequestURI = utils.ToValidUTF8string(accessLog.RequestURI)
|
||||
accessLog.RequestPath = utils.ToValidUTF8string(accessLog.RequestPath)
|
||||
accessLog.RequestFilename = utils.ToValidUTF8string(accessLog.RequestFilename)
|
||||
accessLog.RequestBody = bytes.ToValidUTF8(accessLog.RequestBody, []byte{})
|
||||
accessLog.Host = utils.ToValidUTF8string(accessLog.Host)
|
||||
accessLog.Hostname = utils.ToValidUTF8string(accessLog.Hostname)
|
||||
|
||||
for k, v := range accessLog.SentHeader {
|
||||
if !utf8.ValidString(k) {
|
||||
delete(accessLog.SentHeader, k)
|
||||
continue
|
||||
}
|
||||
|
||||
for index, s := range v.Values {
|
||||
v.Values[index] = utils.ToValidUTF8string(s)
|
||||
}
|
||||
}
|
||||
|
||||
accessLog.Referer = utils.ToValidUTF8string(accessLog.Referer)
|
||||
accessLog.UserAgent = utils.ToValidUTF8string(accessLog.UserAgent)
|
||||
accessLog.Request = utils.ToValidUTF8string(accessLog.Request)
|
||||
accessLog.ContentType = utils.ToValidUTF8string(accessLog.ContentType)
|
||||
|
||||
for k, c := range accessLog.Cookie {
|
||||
if !utf8.ValidString(k) {
|
||||
delete(accessLog.Cookie, k)
|
||||
continue
|
||||
}
|
||||
accessLog.Cookie[k] = utils.ToValidUTF8string(c)
|
||||
}
|
||||
|
||||
accessLog.Args = utils.ToValidUTF8string(accessLog.Args)
|
||||
accessLog.QueryString = utils.ToValidUTF8string(accessLog.QueryString)
|
||||
|
||||
for k, v := range accessLog.Header {
|
||||
if !utf8.ValidString(k) {
|
||||
delete(accessLog.Header, k)
|
||||
continue
|
||||
}
|
||||
for index, s := range v.Values {
|
||||
v.Values[index] = utils.ToValidUTF8string(s)
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range accessLog.Errors {
|
||||
accessLog.Errors[k] = utils.ToValidUTF8string(v)
|
||||
}
|
||||
}
|
||||
232
EdgeNode/internal/nodes/http_access_log_queue_test.go
Normal file
232
EdgeNode/internal/nodes/http_access_log_queue_test.go
Normal file
@@ -0,0 +1,232 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/nodes"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"google.golang.org/grpc/status"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
func TestHTTPAccessLogQueue_Push(t *testing.T) {
|
||||
// 发送到API
|
||||
client, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var requestId = 1_000_000
|
||||
|
||||
var utf8Bytes = []byte{}
|
||||
for i := 0; i < 254; i++ {
|
||||
utf8Bytes = append(utf8Bytes, uint8(i))
|
||||
}
|
||||
|
||||
//bytes = []byte("真不错")
|
||||
|
||||
var accessLog = &pb.HTTPAccessLog{
|
||||
ServerId: 23,
|
||||
RequestId: strconv.FormatInt(time.Now().Unix(), 10) + strconv.Itoa(requestId) + strconv.FormatInt(1, 10),
|
||||
NodeId: 48,
|
||||
Host: "www.hello.com",
|
||||
RequestURI: string(utf8Bytes),
|
||||
RequestPath: string(utf8Bytes),
|
||||
Timestamp: time.Now().Unix(),
|
||||
Cookie: map[string]string{"test": string(utf8Bytes)},
|
||||
|
||||
Header: map[string]*pb.Strings{
|
||||
"test": {Values: []string{string(utf8Bytes)}},
|
||||
},
|
||||
}
|
||||
|
||||
new(nodes.HTTPAccessLogQueue).ToValidUTF8(accessLog)
|
||||
|
||||
// logs.PrintAsJSON(accessLog)
|
||||
|
||||
//t.Log(strings.ToValidUTF8(string(utf8Bytes), ""))
|
||||
_, err = client.HTTPAccessLogRPC.CreateHTTPAccessLogs(client.Context(), &pb.CreateHTTPAccessLogsRequest{HttpAccessLogs: []*pb.HTTPAccessLog{
|
||||
accessLog,
|
||||
}})
|
||||
if err != nil {
|
||||
// 这里只是为了重现错误
|
||||
t.Logf("%#v, %s", err, err.Error())
|
||||
|
||||
statusErr, ok := status.FromError(err)
|
||||
if ok {
|
||||
t.Logf("%#v", statusErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestHTTPAccessLogQueue_Push2(t *testing.T) {
|
||||
var utf8Bytes = []byte{}
|
||||
for i := 0; i < 254; i++ {
|
||||
utf8Bytes = append(utf8Bytes, uint8(i))
|
||||
}
|
||||
|
||||
var accessLog = &pb.HTTPAccessLog{
|
||||
ServerId: 23,
|
||||
RequestId: strconv.FormatInt(time.Now().Unix(), 10) + strconv.Itoa(1) + strconv.FormatInt(1, 10),
|
||||
NodeId: 48,
|
||||
Host: "www.hello.com",
|
||||
RequestURI: string(utf8Bytes),
|
||||
RequestPath: string(utf8Bytes),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
var v = reflect.Indirect(reflect.ValueOf(accessLog))
|
||||
var countFields = v.NumField()
|
||||
for i := 0; i < countFields; i++ {
|
||||
var field = v.Field(i)
|
||||
if field.Kind() == reflect.String {
|
||||
field.SetString(strings.ToValidUTF8(field.String(), ""))
|
||||
}
|
||||
}
|
||||
|
||||
client, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = client.HTTPAccessLogRPC.CreateHTTPAccessLogs(client.Context(), &pb.CreateHTTPAccessLogsRequest{HttpAccessLogs: []*pb.HTTPAccessLog{
|
||||
accessLog,
|
||||
}})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestHTTPAccessLogQueue_Memory(t *testing.T) {
|
||||
if !testutils.IsSingleTesting() {
|
||||
return
|
||||
}
|
||||
|
||||
testutils.StartMemoryStats(t)
|
||||
|
||||
debug.SetGCPercent(10)
|
||||
|
||||
var accessLogs = []*pb.HTTPAccessLog{}
|
||||
for i := 0; i < 20_000; i++ {
|
||||
accessLogs = append(accessLogs, &pb.HTTPAccessLog{
|
||||
RequestPath: "https://goedge.cn/hello/world",
|
||||
})
|
||||
}
|
||||
|
||||
runtime.GC()
|
||||
_ = accessLogs
|
||||
|
||||
// will not release automatically
|
||||
func() {
|
||||
var accessLogs1 = []*pb.HTTPAccessLog{}
|
||||
for i := 0; i < 2_000_000; i++ {
|
||||
accessLogs1 = append(accessLogs1, &pb.HTTPAccessLog{
|
||||
RequestPath: "https://goedge.cn/hello/world",
|
||||
})
|
||||
}
|
||||
_ = accessLogs1
|
||||
}()
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
||||
|
||||
func TestUTF8_IsValid(t *testing.T) {
|
||||
t.Log(utf8.ValidString("abc"))
|
||||
|
||||
var noneUTF8Bytes = []byte{}
|
||||
for i := 0; i < 254; i++ {
|
||||
noneUTF8Bytes = append(noneUTF8Bytes, uint8(i))
|
||||
}
|
||||
t.Log(utf8.ValidString(string(noneUTF8Bytes)))
|
||||
}
|
||||
|
||||
func BenchmarkHTTPAccessLogQueue_ToValidUTF8(b *testing.B) {
|
||||
runtime.GOMAXPROCS(1)
|
||||
|
||||
var utf8Bytes = []byte{}
|
||||
for i := 0; i < 254; i++ {
|
||||
utf8Bytes = append(utf8Bytes, uint8(i))
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = bytes.ToValidUTF8(utf8Bytes, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHTTPAccessLogQueue_ToValidUTF8String(b *testing.B) {
|
||||
runtime.GOMAXPROCS(1)
|
||||
|
||||
var utf8Bytes = []byte{}
|
||||
for i := 0; i < 254; i++ {
|
||||
utf8Bytes = append(utf8Bytes, uint8(i))
|
||||
}
|
||||
|
||||
var s = string(utf8Bytes)
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = strings.ToValidUTF8(s, "")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAppendAccessLogs(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
var stat1 = &runtime.MemStats{}
|
||||
runtime.ReadMemStats(stat1)
|
||||
|
||||
const count = 20000
|
||||
var a = make([]*pb.HTTPAccessLog, 0, count)
|
||||
for i := 0; i < b.N; i++ {
|
||||
a = append(a, &pb.HTTPAccessLog{
|
||||
RequestPath: "/hello/world",
|
||||
Host: "example.com",
|
||||
RequestBody: bytes.Repeat([]byte{'A'}, 1024),
|
||||
})
|
||||
if len(a) == count {
|
||||
a = make([]*pb.HTTPAccessLog, 0, count)
|
||||
}
|
||||
}
|
||||
|
||||
_ = len(a)
|
||||
|
||||
var stat2 = &runtime.MemStats{}
|
||||
runtime.ReadMemStats(stat2)
|
||||
b.Log((stat2.TotalAlloc-stat1.TotalAlloc)>>20, "MB")
|
||||
}
|
||||
|
||||
func BenchmarkAppendAccessLogs2(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
var stat1 = &runtime.MemStats{}
|
||||
runtime.ReadMemStats(stat1)
|
||||
|
||||
const count = 20000
|
||||
var a = []*pb.HTTPAccessLog{}
|
||||
for i := 0; i < b.N; i++ {
|
||||
a = append(a, &pb.HTTPAccessLog{
|
||||
RequestPath: "/hello/world",
|
||||
Host: "example.com",
|
||||
RequestBody: bytes.Repeat([]byte{'A'}, 1024),
|
||||
})
|
||||
if len(a) == count {
|
||||
a = []*pb.HTTPAccessLog{}
|
||||
}
|
||||
}
|
||||
|
||||
_ = len(a)
|
||||
|
||||
var stat2 = &runtime.MemStats{}
|
||||
runtime.ReadMemStats(stat2)
|
||||
b.Log((stat2.TotalAlloc-stat1.TotalAlloc)>>20, "MB")
|
||||
}
|
||||
117
EdgeNode/internal/nodes/http_access_log_viewer.go
Normal file
117
EdgeNode/internal/nodes/http_access_log_viewer.go
Normal file
@@ -0,0 +1,117 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
var sharedHTTPAccessLogViewer = NewHTTPAccessLogViewer()
|
||||
|
||||
// HTTPAccessLogViewer 本地访问日志浏览器
|
||||
type HTTPAccessLogViewer struct {
|
||||
sockFile string
|
||||
|
||||
listener net.Listener
|
||||
connMap map[int64]net.Conn // connId => net.Conn
|
||||
connId int64
|
||||
locker sync.Mutex
|
||||
}
|
||||
|
||||
// NewHTTPAccessLogViewer 获取新对象
|
||||
func NewHTTPAccessLogViewer() *HTTPAccessLogViewer {
|
||||
return &HTTPAccessLogViewer{
|
||||
sockFile: os.TempDir() + "/" + teaconst.AccessLogSockName,
|
||||
connMap: map[int64]net.Conn{},
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动
|
||||
func (this *HTTPAccessLogViewer) Start() error {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
if this.listener == nil {
|
||||
// remove if exists
|
||||
_ = os.Remove(this.sockFile)
|
||||
|
||||
// start listening
|
||||
listener, err := net.Listen("unix", this.sockFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
this.listener = listener
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := this.listener.Accept()
|
||||
if err != nil {
|
||||
remotelogs.Error("ACCESS_LOG", "start local reading failed: "+err.Error())
|
||||
break
|
||||
}
|
||||
|
||||
this.locker.Lock()
|
||||
var connId = this.nextConnId()
|
||||
this.connMap[connId] = conn
|
||||
go func() {
|
||||
this.startReading(conn, connId)
|
||||
}()
|
||||
this.locker.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasConns 检查是否有连接
|
||||
func (this *HTTPAccessLogViewer) HasConns() bool {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
return len(this.connMap) > 0
|
||||
}
|
||||
|
||||
// Send 发送日志
|
||||
func (this *HTTPAccessLogViewer) Send(accessLog *pb.HTTPAccessLog) {
|
||||
var conns = []net.Conn{}
|
||||
this.locker.Lock()
|
||||
for _, conn := range this.connMap {
|
||||
conns = append(conns, conn)
|
||||
}
|
||||
this.locker.Unlock()
|
||||
|
||||
if len(conns) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, conn := range conns {
|
||||
// ignore error
|
||||
_, _ = conn.Write([]byte(accessLog.RemoteAddr + " [" + accessLog.TimeLocal + "] \"" + accessLog.RequestMethod + " " + accessLog.Scheme + "://" + accessLog.Host + accessLog.RequestURI + " " + accessLog.Proto + "\" " + types.String(accessLog.Status) + " " + types.String(accessLog.BytesSent) + " " + strconv.Quote(accessLog.Referer) + " " + strconv.Quote(accessLog.UserAgent) + " - " + fmt.Sprintf("%.2fms", accessLog.RequestTime*1000) + "\n"))
|
||||
}
|
||||
}
|
||||
|
||||
func (this *HTTPAccessLogViewer) nextConnId() int64 {
|
||||
return atomic.AddInt64(&this.connId, 1)
|
||||
}
|
||||
|
||||
func (this *HTTPAccessLogViewer) startReading(conn net.Conn, connId int64) {
|
||||
var buf = make([]byte, 1024)
|
||||
for {
|
||||
_, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
this.locker.Lock()
|
||||
delete(this.connMap, connId)
|
||||
this.locker.Unlock()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
327
EdgeNode/internal/nodes/http_cache_task_manager.go
Normal file
327
EdgeNode/internal/nodes/http_cache_task_manager.go
Normal file
@@ -0,0 +1,327 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/caches"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/compressions"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/bytepool"
|
||||
connutils "github.com/TeaOSLab/EdgeNode/internal/utils/conns"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/goman"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if !teaconst.IsMain {
|
||||
return
|
||||
}
|
||||
|
||||
events.On(events.EventStart, func() {
|
||||
goman.New(func() {
|
||||
SharedHTTPCacheTaskManager.Start()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
var SharedHTTPCacheTaskManager = NewHTTPCacheTaskManager()
|
||||
|
||||
// HTTPCacheTaskManager 缓存任务管理
|
||||
type HTTPCacheTaskManager struct {
|
||||
ticker *time.Ticker
|
||||
protocolReg *regexp.Regexp
|
||||
|
||||
timeoutClientMap map[time.Duration]*http.Client // timeout seconds=> *http.Client
|
||||
locker sync.Mutex
|
||||
|
||||
taskQueue chan *pb.PurgeServerCacheRequest
|
||||
}
|
||||
|
||||
func NewHTTPCacheTaskManager() *HTTPCacheTaskManager {
|
||||
var duration = 30 * time.Second
|
||||
if Tea.IsTesting() {
|
||||
duration = 10 * time.Second
|
||||
}
|
||||
|
||||
return &HTTPCacheTaskManager{
|
||||
ticker: time.NewTicker(duration),
|
||||
protocolReg: regexp.MustCompile(`^(?i)(http|https)://`),
|
||||
taskQueue: make(chan *pb.PurgeServerCacheRequest, 1024),
|
||||
timeoutClientMap: make(map[time.Duration]*http.Client),
|
||||
}
|
||||
}
|
||||
|
||||
func (this *HTTPCacheTaskManager) Start() {
|
||||
// task queue
|
||||
goman.New(func() {
|
||||
rpcClient, _ := rpc.SharedRPC()
|
||||
|
||||
if rpcClient != nil {
|
||||
for taskReq := range this.taskQueue {
|
||||
_, err := rpcClient.ServerRPC.PurgeServerCache(rpcClient.Context(), taskReq)
|
||||
if err != nil {
|
||||
remotelogs.Error("HTTP_CACHE_TASK_MANAGER", "create purge task failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Loop
|
||||
for range this.ticker.C {
|
||||
err := this.Loop()
|
||||
if err != nil {
|
||||
remotelogs.Error("HTTP_CACHE_TASK_MANAGER", "execute task failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (this *HTTPCacheTaskManager) Loop() error {
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := rpcClient.HTTPCacheTaskKeyRPC.FindDoingHTTPCacheTaskKeys(rpcClient.Context(), &pb.FindDoingHTTPCacheTaskKeysRequest{})
|
||||
if err != nil {
|
||||
// 忽略连接错误
|
||||
if rpc.IsConnError(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
var keys = resp.HttpCacheTaskKeys
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var pbResults = []*pb.UpdateHTTPCacheTaskKeysStatusRequest_KeyResult{}
|
||||
|
||||
var taskGroup = goman.NewTaskGroup()
|
||||
for _, key := range keys {
|
||||
var taskKey = key
|
||||
taskGroup.Run(func() {
|
||||
processErr := this.processKey(taskKey)
|
||||
var pbResult = &pb.UpdateHTTPCacheTaskKeysStatusRequest_KeyResult{
|
||||
Id: taskKey.Id,
|
||||
NodeClusterId: taskKey.NodeClusterId,
|
||||
Error: "",
|
||||
}
|
||||
|
||||
if processErr != nil {
|
||||
pbResult.Error = processErr.Error()
|
||||
}
|
||||
|
||||
taskGroup.Lock()
|
||||
pbResults = append(pbResults, pbResult)
|
||||
taskGroup.Unlock()
|
||||
})
|
||||
}
|
||||
|
||||
taskGroup.Wait()
|
||||
|
||||
_, err = rpcClient.HTTPCacheTaskKeyRPC.UpdateHTTPCacheTaskKeysStatus(rpcClient.Context(), &pb.UpdateHTTPCacheTaskKeysStatusRequest{KeyResults: pbResults})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *HTTPCacheTaskManager) PushTaskKeys(keys []string) {
|
||||
select {
|
||||
case this.taskQueue <- &pb.PurgeServerCacheRequest{
|
||||
Keys: keys,
|
||||
Prefixes: nil,
|
||||
}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (this *HTTPCacheTaskManager) processKey(key *pb.HTTPCacheTaskKey) error {
|
||||
switch key.Type {
|
||||
case "purge":
|
||||
var storages = caches.SharedManager.FindAllStorages()
|
||||
for _, storage := range storages {
|
||||
switch key.KeyType {
|
||||
case "key":
|
||||
var cacheKeys = []string{key.Key}
|
||||
if strings.HasPrefix(key.Key, "http://") {
|
||||
cacheKeys = append(cacheKeys, strings.Replace(key.Key, "http://", "https://", 1))
|
||||
} else if strings.HasPrefix(key.Key, "https://") {
|
||||
cacheKeys = append(cacheKeys, strings.Replace(key.Key, "https://", "http://", 1))
|
||||
}
|
||||
|
||||
// TODO 提升效率
|
||||
for _, cacheKey := range cacheKeys {
|
||||
var subKeys = []string{
|
||||
cacheKey,
|
||||
cacheKey + caches.SuffixMethod + "HEAD",
|
||||
cacheKey + caches.SuffixWebP,
|
||||
cacheKey + caches.SuffixPartial,
|
||||
}
|
||||
// TODO 根据实际缓存的内容进行组合
|
||||
for _, encoding := range compressions.AllEncodings() {
|
||||
subKeys = append(subKeys, cacheKey+caches.SuffixCompression+encoding)
|
||||
subKeys = append(subKeys, cacheKey+caches.SuffixWebP+caches.SuffixCompression+encoding)
|
||||
}
|
||||
|
||||
err := storage.Purge(subKeys, "file")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case "prefix":
|
||||
var prefixes = []string{key.Key}
|
||||
if strings.HasPrefix(key.Key, "http://") {
|
||||
prefixes = append(prefixes, strings.Replace(key.Key, "http://", "https://", 1))
|
||||
} else if strings.HasPrefix(key.Key, "https://") {
|
||||
prefixes = append(prefixes, strings.Replace(key.Key, "https://", "http://", 1))
|
||||
}
|
||||
|
||||
err := storage.Purge(prefixes, "dir")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
case "fetch":
|
||||
err := this.fetchKey(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return errors.New("invalid operation type '" + key.Type + "'")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO 增加失败重试
|
||||
func (this *HTTPCacheTaskManager) fetchKey(key *pb.HTTPCacheTaskKey) error {
|
||||
var fullKey = key.Key
|
||||
if !this.protocolReg.MatchString(fullKey) {
|
||||
fullKey = "https://" + fullKey
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, fullKey, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid url: '%s': %w", fullKey, err)
|
||||
}
|
||||
|
||||
// TODO 可以在管理界面自定义Header
|
||||
req.Header.Set("X-Edge-Cache-Action", "fetch")
|
||||
req.Header.Set("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/85.0.4183.121 Safari/537.36") // TODO 可以定义
|
||||
req.Header.Set("Accept-Encoding", "gzip, deflate, br")
|
||||
resp, err := this.httpClient().Do(req)
|
||||
if err != nil {
|
||||
err = this.simplifyErr(err)
|
||||
return fmt.Errorf("request failed: '%s': %w", fullKey, err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
// 处理502
|
||||
if resp.StatusCode == http.StatusBadGateway {
|
||||
return errors.New("read origin site timeout")
|
||||
}
|
||||
|
||||
// 读取内容,以便于生成缓存
|
||||
var buf = bytepool.Pool16k.Get()
|
||||
_, err = io.CopyBuffer(io.Discard, resp.Body, buf.Bytes)
|
||||
bytepool.Pool16k.Put(buf)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
err = this.simplifyErr(err)
|
||||
return fmt.Errorf("request failed: '%s': %w", fullKey, err)
|
||||
} else {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *HTTPCacheTaskManager) simplifyErr(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if os.IsTimeout(err) {
|
||||
return errors.New("timeout to read origin site")
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (this *HTTPCacheTaskManager) httpClient() *http.Client {
|
||||
var timeout = serverconfigs.DefaultHTTPCachePolicyFetchTimeout
|
||||
|
||||
var nodeConfig = sharedNodeConfig // copy
|
||||
if nodeConfig != nil {
|
||||
var cachePolicies = nodeConfig.HTTPCachePolicies // copy
|
||||
if len(cachePolicies) > 0 && cachePolicies[0].FetchTimeout != nil && cachePolicies[0].FetchTimeout.Count > 0 {
|
||||
var fetchTimeout = cachePolicies[0].FetchTimeout.Duration()
|
||||
if fetchTimeout > 0 {
|
||||
timeout = fetchTimeout
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
client, ok := this.timeoutClientMap[timeout]
|
||||
if ok {
|
||||
return client
|
||||
}
|
||||
|
||||
client = &http.Client{
|
||||
Timeout: timeout,
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
_, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := net.Dial(network, "127.0.0.1:"+port)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return connutils.NewNoStat(conn), nil
|
||||
},
|
||||
MaxIdleConns: 128,
|
||||
MaxIdleConnsPerHost: 32,
|
||||
MaxConnsPerHost: 32,
|
||||
IdleConnTimeout: 2 * time.Minute,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
TLSHandshakeTimeout: 0,
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
this.timeoutClientMap[timeout] = client
|
||||
|
||||
return client
|
||||
}
|
||||
25
EdgeNode/internal/nodes/http_cache_task_manager_test.go
Normal file
25
EdgeNode/internal/nodes/http_cache_task_manager_test.go
Normal file
@@ -0,0 +1,25 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package nodes_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/caches"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/nodes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHTTPCacheTaskManager_Loop(t *testing.T) {
|
||||
// initialize cache policies
|
||||
config, err := nodeconfigs.SharedNodeConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
caches.SharedManager.UpdatePolicies(config.HTTPCachePolicies)
|
||||
|
||||
var manager = nodes.NewHTTPCacheTaskManager()
|
||||
err = manager.Loop()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
47
EdgeNode/internal/nodes/http_client.go
Normal file
47
EdgeNode/internal/nodes/http_client.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// HTTPClient HTTP客户端
|
||||
type HTTPClient struct {
|
||||
rawClient *http.Client
|
||||
accessAt int64
|
||||
isProxyProtocol bool
|
||||
}
|
||||
|
||||
// NewHTTPClient 获取新客户端对象
|
||||
func NewHTTPClient(rawClient *http.Client, isProxyProtocol bool) *HTTPClient {
|
||||
return &HTTPClient{
|
||||
rawClient: rawClient,
|
||||
accessAt: fasttime.Now().Unix(),
|
||||
isProxyProtocol: isProxyProtocol,
|
||||
}
|
||||
}
|
||||
|
||||
// RawClient 获取原始客户端对象
|
||||
func (this *HTTPClient) RawClient() *http.Client {
|
||||
return this.rawClient
|
||||
}
|
||||
|
||||
// UpdateAccessTime 更新访问时间
|
||||
func (this *HTTPClient) UpdateAccessTime() {
|
||||
this.accessAt = fasttime.Now().Unix()
|
||||
}
|
||||
|
||||
// AccessTime 获取访问时间
|
||||
func (this *HTTPClient) AccessTime() int64 {
|
||||
return this.accessAt
|
||||
}
|
||||
|
||||
// IsProxyProtocol 判断是否为PROXY Protocol
|
||||
func (this *HTTPClient) IsProxyProtocol() bool {
|
||||
return this.isProxyProtocol
|
||||
}
|
||||
|
||||
// Close 关闭
|
||||
func (this *HTTPClient) Close() {
|
||||
this.rawClient.CloseIdleConnections()
|
||||
}
|
||||
301
EdgeNode/internal/nodes/http_client_pool.go
Normal file
301
EdgeNode/internal/nodes/http_client_pool.go
Normal file
@@ -0,0 +1,301 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/goman"
|
||||
"github.com/cespare/xxhash/v2"
|
||||
"github.com/pires/go-proxyproto"
|
||||
"golang.org/x/net/http2"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SharedHTTPClientPool HTTP客户端池单例
|
||||
var SharedHTTPClientPool = NewHTTPClientPool()
|
||||
|
||||
const httpClientProxyProtocolTag = "@ProxyProtocol@"
|
||||
const maxHTTPRedirects = 8
|
||||
|
||||
// HTTPClientPool 客户端池
|
||||
type HTTPClientPool struct {
|
||||
clientsMap map[uint64]*HTTPClient // origin key => client
|
||||
|
||||
cleanTicker *time.Ticker
|
||||
|
||||
locker sync.RWMutex
|
||||
}
|
||||
|
||||
// NewHTTPClientPool 获取新对象
|
||||
func NewHTTPClientPool() *HTTPClientPool {
|
||||
var pool = &HTTPClientPool{
|
||||
cleanTicker: time.NewTicker(1 * time.Hour),
|
||||
clientsMap: map[uint64]*HTTPClient{},
|
||||
}
|
||||
|
||||
goman.New(func() {
|
||||
pool.cleanClients()
|
||||
})
|
||||
|
||||
return pool
|
||||
}
|
||||
|
||||
// Client 根据地址获取客户端
|
||||
func (this *HTTPClientPool) Client(req *HTTPRequest,
|
||||
origin *serverconfigs.OriginConfig,
|
||||
originAddr string,
|
||||
proxyProtocol *serverconfigs.ProxyProtocolConfig,
|
||||
followRedirects bool) (rawClient *http.Client, err error) {
|
||||
if origin.Addr == nil {
|
||||
return nil, errors.New("origin addr should not be empty (originId:" + strconv.FormatInt(origin.Id, 10) + ")")
|
||||
}
|
||||
|
||||
if req == nil || req.RawReq == nil || req.RawReq.URL == nil {
|
||||
err = errors.New("invalid request url")
|
||||
return
|
||||
}
|
||||
var originHost = req.RawReq.URL.Host
|
||||
var urlPort = req.RawReq.URL.Port()
|
||||
if len(urlPort) == 0 {
|
||||
if req.RawReq.URL.Scheme == "http" {
|
||||
urlPort = "80"
|
||||
} else {
|
||||
urlPort = "443"
|
||||
}
|
||||
|
||||
originHost += ":" + urlPort
|
||||
}
|
||||
|
||||
var rawKey = origin.UniqueKey() + "@" + originAddr + "@" + originHost
|
||||
|
||||
// if we are under available ProxyProtocol, we add client ip to key to make every client unique
|
||||
var isProxyProtocol = false
|
||||
if proxyProtocol != nil && proxyProtocol.IsOn {
|
||||
rawKey += httpClientProxyProtocolTag + req.requestRemoteAddr(true)
|
||||
isProxyProtocol = true
|
||||
}
|
||||
|
||||
// follow redirects
|
||||
if followRedirects {
|
||||
rawKey += "@follow"
|
||||
}
|
||||
|
||||
var key = xxhash.Sum64String(rawKey)
|
||||
|
||||
var isLnRequest = origin.Id == 0
|
||||
|
||||
this.locker.RLock()
|
||||
client, found := this.clientsMap[key]
|
||||
this.locker.RUnlock()
|
||||
if found {
|
||||
client.UpdateAccessTime()
|
||||
return client.RawClient(), nil
|
||||
}
|
||||
|
||||
// 这里不能使用RLock,避免因为并发生成多个同样的client实例
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
// 再次查找
|
||||
client, found = this.clientsMap[key]
|
||||
if found {
|
||||
client.UpdateAccessTime()
|
||||
return client.RawClient(), nil
|
||||
}
|
||||
|
||||
var maxConnections = origin.MaxConns
|
||||
var connectionTimeout = origin.ConnTimeoutDuration()
|
||||
var readTimeout = origin.ReadTimeoutDuration()
|
||||
var idleTimeout = origin.IdleTimeoutDuration()
|
||||
var idleConns = origin.MaxIdleConns
|
||||
|
||||
// 超时时间
|
||||
if connectionTimeout <= 0 {
|
||||
connectionTimeout = 15 * time.Second
|
||||
}
|
||||
|
||||
if idleTimeout <= 0 {
|
||||
idleTimeout = 2 * time.Minute
|
||||
}
|
||||
|
||||
var numberCPU = runtime.NumCPU()
|
||||
if numberCPU < 8 {
|
||||
numberCPU = 8
|
||||
}
|
||||
if maxConnections <= 0 {
|
||||
maxConnections = numberCPU * 64
|
||||
}
|
||||
|
||||
if idleConns <= 0 {
|
||||
idleConns = numberCPU * 16
|
||||
}
|
||||
|
||||
if isProxyProtocol { // ProxyProtocol无需保持太多空闲连接
|
||||
idleConns = 3
|
||||
} else if isLnRequest { // 可以判断为Ln节点请求
|
||||
maxConnections *= 8
|
||||
idleConns *= 8
|
||||
idleTimeout *= 4
|
||||
} else if sharedNodeConfig != nil && sharedNodeConfig.Level > 1 {
|
||||
// Ln节点可以适当增加连接数
|
||||
maxConnections *= 2
|
||||
idleConns *= 2
|
||||
}
|
||||
|
||||
// TLS通讯
|
||||
var tlsConfig = &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
if origin.Cert != nil {
|
||||
var obj = origin.Cert.CertObject()
|
||||
if obj != nil {
|
||||
tlsConfig.InsecureSkipVerify = false
|
||||
tlsConfig.Certificates = []tls.Certificate{*obj}
|
||||
if len(origin.Cert.ServerName) > 0 {
|
||||
tlsConfig.ServerName = origin.Cert.ServerName
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var transport = &HTTPClientTransport{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network string, addr string) (net.Conn, error) {
|
||||
var realAddr = originAddr
|
||||
|
||||
// for redirections
|
||||
if followRedirects && originHost != addr {
|
||||
realAddr = addr
|
||||
}
|
||||
|
||||
// connect
|
||||
conn, dialErr := (&net.Dialer{
|
||||
Timeout: connectionTimeout,
|
||||
KeepAlive: 1 * time.Minute,
|
||||
}).DialContext(ctx, network, realAddr)
|
||||
if dialErr != nil {
|
||||
return nil, dialErr
|
||||
}
|
||||
|
||||
// handle PROXY protocol
|
||||
proxyErr := this.handlePROXYProtocol(conn, req, proxyProtocol)
|
||||
if proxyErr != nil {
|
||||
return nil, proxyErr
|
||||
}
|
||||
|
||||
return NewOriginConn(conn), nil
|
||||
},
|
||||
MaxIdleConns: 0,
|
||||
MaxIdleConnsPerHost: idleConns,
|
||||
MaxConnsPerHost: maxConnections,
|
||||
IdleConnTimeout: idleTimeout,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
TLSHandshakeTimeout: 5 * time.Second,
|
||||
TLSClientConfig: tlsConfig,
|
||||
ReadBufferSize: 8 * 1024,
|
||||
Proxy: nil,
|
||||
},
|
||||
}
|
||||
|
||||
// support http/2
|
||||
if origin.HTTP2Enabled && origin.Addr != nil && origin.Addr.Protocol == serverconfigs.ProtocolHTTPS {
|
||||
_ = http2.ConfigureTransport(transport.Transport)
|
||||
}
|
||||
|
||||
rawClient = &http.Client{
|
||||
Timeout: readTimeout,
|
||||
Transport: transport,
|
||||
CheckRedirect: func(targetReq *http.Request, via []*http.Request) error {
|
||||
// follow redirects
|
||||
if followRedirects && len(via) <= maxHTTPRedirects {
|
||||
return nil
|
||||
}
|
||||
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
|
||||
this.clientsMap[key] = NewHTTPClient(rawClient, isProxyProtocol)
|
||||
|
||||
return rawClient, nil
|
||||
}
|
||||
|
||||
// 清理不使用的Client
|
||||
func (this *HTTPClientPool) cleanClients() {
|
||||
for range this.cleanTicker.C {
|
||||
var nowTime = fasttime.Now().Unix()
|
||||
|
||||
var expiredKeys []uint64
|
||||
var expiredClients = []*HTTPClient{}
|
||||
|
||||
// lookup expired clients
|
||||
this.locker.RLock()
|
||||
for k, client := range this.clientsMap {
|
||||
if client.AccessTime() < nowTime-86400 ||
|
||||
(client.IsProxyProtocol() && client.AccessTime() < nowTime-3600) { // 超过 N 秒没有调用就关闭
|
||||
expiredKeys = append(expiredKeys, k)
|
||||
expiredClients = append(expiredClients, client)
|
||||
}
|
||||
}
|
||||
this.locker.RUnlock()
|
||||
|
||||
// remove expired keys
|
||||
if len(expiredKeys) > 0 {
|
||||
this.locker.Lock()
|
||||
for _, k := range expiredKeys {
|
||||
delete(this.clientsMap, k)
|
||||
}
|
||||
this.locker.Unlock()
|
||||
}
|
||||
|
||||
// close expired clients
|
||||
if len(expiredClients) > 0 {
|
||||
for _, client := range expiredClients {
|
||||
client.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 支持PROXY Protocol
|
||||
func (this *HTTPClientPool) handlePROXYProtocol(conn net.Conn, req *HTTPRequest, proxyProtocol *serverconfigs.ProxyProtocolConfig) error {
|
||||
if proxyProtocol != nil &&
|
||||
proxyProtocol.IsOn &&
|
||||
(proxyProtocol.Version == serverconfigs.ProxyProtocolVersion1 || proxyProtocol.Version == serverconfigs.ProxyProtocolVersion2) {
|
||||
var remoteAddr = req.requestRemoteAddr(true)
|
||||
var transportProtocol = proxyproto.TCPv4
|
||||
if strings.Contains(remoteAddr, ":") {
|
||||
transportProtocol = proxyproto.TCPv6
|
||||
}
|
||||
var destAddr = conn.RemoteAddr()
|
||||
var reqConn = req.RawReq.Context().Value(HTTPConnContextKey)
|
||||
if reqConn != nil {
|
||||
destAddr = reqConn.(net.Conn).LocalAddr()
|
||||
}
|
||||
var header = proxyproto.Header{
|
||||
Version: byte(proxyProtocol.Version),
|
||||
Command: proxyproto.PROXY,
|
||||
TransportProtocol: transportProtocol,
|
||||
SourceAddr: &net.TCPAddr{
|
||||
IP: net.ParseIP(remoteAddr),
|
||||
Port: req.requestRemotePort(),
|
||||
},
|
||||
DestinationAddr: destAddr,
|
||||
}
|
||||
_, err := header.WriteTo(conn)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
92
EdgeNode/internal/nodes/http_client_pool_test.go
Normal file
92
EdgeNode/internal/nodes/http_client_pool_test.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestHTTPClientPool_Client(t *testing.T) {
|
||||
var pool = NewHTTPClientPool()
|
||||
|
||||
rawReq, err := http.NewRequest(http.MethodGet, "https://example.com/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var req = &HTTPRequest{RawReq: rawReq}
|
||||
|
||||
{
|
||||
var origin = &serverconfigs.OriginConfig{
|
||||
Id: 1,
|
||||
Version: 2,
|
||||
Addr: &serverconfigs.NetworkAddressConfig{Host: "127.0.0.1", PortRange: "1234"},
|
||||
}
|
||||
err := origin.Init(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
{
|
||||
client, err := pool.Client(req, origin, origin.Addr.PickAddress(), nil, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("client:", client)
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
client, err := pool.Client(req, origin, origin.Addr.PickAddress(), nil, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("client:", client)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPClientPool_cleanClients(t *testing.T) {
|
||||
var origin = &serverconfigs.OriginConfig{
|
||||
Id: 1,
|
||||
Version: 2,
|
||||
Addr: &serverconfigs.NetworkAddressConfig{Host: "127.0.0.1", PortRange: "1234"},
|
||||
}
|
||||
err := origin.Init(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var pool = NewHTTPClientPool()
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
t.Log("get", i)
|
||||
_, _ = pool.Client(nil, origin, origin.Addr.PickAddress(), nil, false)
|
||||
|
||||
if testutils.IsSingleTesting() {
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHTTPClientPool_Client(b *testing.B) {
|
||||
runtime.GOMAXPROCS(1)
|
||||
|
||||
var origin = &serverconfigs.OriginConfig{
|
||||
Id: 1,
|
||||
Version: 2,
|
||||
Addr: &serverconfigs.NetworkAddressConfig{Host: "127.0.0.1", PortRange: "1234"},
|
||||
}
|
||||
err := origin.Init(context.Background())
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
var pool = NewHTTPClientPool()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = pool.Client(nil, origin, origin.Addr.PickAddress(), nil, false)
|
||||
}
|
||||
}
|
||||
26
EdgeNode/internal/nodes/http_client_transport.go
Normal file
26
EdgeNode/internal/nodes/http_client_transport.go
Normal file
@@ -0,0 +1,26 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const emptyHTTPLocation = "/$EmptyHTTPLocation$"
|
||||
|
||||
type HTTPClientTransport struct {
|
||||
*http.Transport
|
||||
}
|
||||
|
||||
func (this *HTTPClientTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
resp, err := this.Transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// 检查在跳转相关状态中Location是否存在
|
||||
if httpStatusIsRedirect(resp.StatusCode) && len(resp.Header.Get("Location")) == 0 {
|
||||
resp.Header.Set("Location", emptyHTTPLocation)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
2108
EdgeNode/internal/nodes/http_request.go
Normal file
2108
EdgeNode/internal/nodes/http_request.go
Normal file
File diff suppressed because it is too large
Load Diff
39
EdgeNode/internal/nodes/http_request_acme.go
Normal file
39
EdgeNode/internal/nodes/http_request_acme.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func (this *HTTPRequest) doACME() (shouldStop bool) {
|
||||
// TODO 对请求进行校验,防止恶意攻击
|
||||
|
||||
var token = filepath.Base(this.RawReq.URL.Path)
|
||||
if token == "acme-challenge" || len(token) <= 32 {
|
||||
return false
|
||||
}
|
||||
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
remotelogs.Error("RPC", "[ACME]rpc failed: "+err.Error())
|
||||
return false
|
||||
}
|
||||
|
||||
keyResp, err := rpcClient.ACMEAuthenticationRPC.FindACMEAuthenticationKeyWithToken(rpcClient.Context(), &pb.FindACMEAuthenticationKeyWithTokenRequest{Token: token})
|
||||
if err != nil {
|
||||
remotelogs.Error("RPC", "[ACME]read key for token failed: "+err.Error())
|
||||
return false
|
||||
}
|
||||
if len(keyResp.Key) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
this.tags = append(this.tags, "ACME")
|
||||
|
||||
this.writer.Header().Set("Content-Type", "text/plain")
|
||||
_, _ = this.writer.WriteString(keyResp.Key)
|
||||
|
||||
return true
|
||||
}
|
||||
72
EdgeNode/internal/nodes/http_request_auth.go
Normal file
72
EdgeNode/internal/nodes/http_request_auth.go
Normal file
@@ -0,0 +1,72 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// 执行认证
|
||||
func (this *HTTPRequest) doAuth() (shouldStop bool) {
|
||||
if this.web.Auth == nil || !this.web.Auth.IsOn {
|
||||
return
|
||||
}
|
||||
|
||||
for _, ref := range this.web.Auth.PolicyRefs {
|
||||
if !ref.IsOn || ref.AuthPolicy == nil || !ref.AuthPolicy.IsOn {
|
||||
continue
|
||||
}
|
||||
if !ref.AuthPolicy.MatchRequest(this.RawReq) {
|
||||
continue
|
||||
}
|
||||
ok, newURI, uriChanged, err := ref.AuthPolicy.Filter(this.RawReq, func(subReq *http.Request) (status int, err error) {
|
||||
subReq.TLS = this.RawReq.TLS
|
||||
subReq.RemoteAddr = this.RawReq.RemoteAddr
|
||||
subReq.Host = this.RawReq.Host
|
||||
subReq.Proto = this.RawReq.Proto
|
||||
subReq.ProtoMinor = this.RawReq.ProtoMinor
|
||||
subReq.ProtoMajor = this.RawReq.ProtoMajor
|
||||
subReq.Body = io.NopCloser(bytes.NewReader([]byte{}))
|
||||
subReq.Header.Set("Referer", this.URL())
|
||||
var writer = NewEmptyResponseWriter(this.writer)
|
||||
this.doSubRequest(writer, subReq)
|
||||
return writer.StatusCode(), nil
|
||||
}, this.Format)
|
||||
if err != nil {
|
||||
this.write50x(err, http.StatusInternalServerError, "Failed to execute the AuthPolicy", "认证策略执行失败", false)
|
||||
return
|
||||
}
|
||||
if ok {
|
||||
if uriChanged {
|
||||
this.uri = newURI
|
||||
}
|
||||
this.tags = append(this.tags, "auth:"+ref.AuthPolicy.Type)
|
||||
return
|
||||
} else {
|
||||
// Basic Auth比较特殊
|
||||
if ref.AuthPolicy.Type == serverconfigs.HTTPAuthTypeBasicAuth {
|
||||
method, ok := ref.AuthPolicy.Method().(*serverconfigs.HTTPAuthBasicMethod)
|
||||
if ok {
|
||||
var headerValue = "Basic realm=\""
|
||||
if len(method.Realm) > 0 {
|
||||
headerValue += method.Realm
|
||||
} else {
|
||||
headerValue += this.ReqHost
|
||||
}
|
||||
headerValue += "\""
|
||||
if len(method.Charset) > 0 {
|
||||
headerValue += ", charset=\"" + method.Charset + "\""
|
||||
}
|
||||
this.writer.Header()["WWW-Authenticate"] = []string{headerValue}
|
||||
}
|
||||
}
|
||||
this.writer.WriteHeader(http.StatusUnauthorized)
|
||||
this.tags = append(this.tags, "auth:"+ref.AuthPolicy.Type)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
760
EdgeNode/internal/nodes/http_request_cache.go
Normal file
760
EdgeNode/internal/nodes/http_request_cache.go
Normal file
@@ -0,0 +1,760 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/caches"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/compressions"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||||
rangeutils "github.com/TeaOSLab/EdgeNode/internal/utils/ranges"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 读取缓存
|
||||
func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
|
||||
// 需要动态Upgrade的不缓存
|
||||
if len(this.RawReq.Header.Get("Upgrade")) > 0 {
|
||||
return
|
||||
}
|
||||
|
||||
this.cacheCanTryStale = false
|
||||
|
||||
var cachePolicy = this.ReqServer.HTTPCachePolicy
|
||||
if cachePolicy == nil || !cachePolicy.IsOn {
|
||||
return
|
||||
}
|
||||
|
||||
if this.web.Cache == nil || !this.web.Cache.IsOn || (len(cachePolicy.CacheRefs) == 0 && len(this.web.Cache.CacheRefs) == 0) {
|
||||
return
|
||||
}
|
||||
|
||||
// 添加 X-Cache Header
|
||||
var addStatusHeader = this.web.Cache.AddStatusHeader
|
||||
var cacheBypassDescription = ""
|
||||
if addStatusHeader {
|
||||
defer func() {
|
||||
if len(cacheBypassDescription) > 0 {
|
||||
this.writer.Header().Set("X-Cache", cacheBypassDescription)
|
||||
return
|
||||
}
|
||||
var cacheStatus = this.varMapping["cache.status"]
|
||||
if cacheStatus != "HIT" {
|
||||
this.writer.Header().Set("X-Cache", cacheStatus)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// 检查服务独立的缓存条件
|
||||
var refType = ""
|
||||
for _, cacheRef := range this.web.Cache.CacheRefs {
|
||||
if !cacheRef.IsOn {
|
||||
continue
|
||||
}
|
||||
if (cacheRef.Conds != nil && cacheRef.Conds.HasRequestConds() && cacheRef.Conds.MatchRequest(this.Format)) ||
|
||||
(cacheRef.SimpleCond != nil && cacheRef.SimpleCond.Match(this.Format)) {
|
||||
if cacheRef.IsReverse {
|
||||
return
|
||||
}
|
||||
this.cacheRef = cacheRef
|
||||
refType = "server"
|
||||
break
|
||||
}
|
||||
}
|
||||
if this.cacheRef == nil && !this.web.Cache.DisablePolicyRefs {
|
||||
// 检查策略默认的缓存条件
|
||||
for _, cacheRef := range cachePolicy.CacheRefs {
|
||||
if !cacheRef.IsOn {
|
||||
continue
|
||||
}
|
||||
if (cacheRef.Conds != nil && cacheRef.Conds.HasRequestConds() && cacheRef.Conds.MatchRequest(this.Format)) ||
|
||||
(cacheRef.SimpleCond != nil && cacheRef.SimpleCond.Match(this.Format)) {
|
||||
if cacheRef.IsReverse {
|
||||
return
|
||||
}
|
||||
this.cacheRef = cacheRef
|
||||
refType = "policy"
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if this.cacheRef == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 是否强制Range回源
|
||||
if this.cacheRef.AlwaysForwardRangeRequest && len(this.RawReq.Header.Get("Range")) > 0 {
|
||||
this.cacheRef = nil
|
||||
cacheBypassDescription = "BYPASS, forward range"
|
||||
return
|
||||
}
|
||||
|
||||
// 是否正在Purge
|
||||
var isPurging = this.web.Cache.PurgeIsOn && strings.ToUpper(this.RawReq.Method) == "PURGE" && this.RawReq.Header.Get("X-Edge-Purge-Key") == this.web.Cache.PurgeKey
|
||||
if isPurging {
|
||||
this.RawReq.Method = http.MethodGet
|
||||
}
|
||||
|
||||
// 校验请求
|
||||
if !this.cacheRef.MatchRequest(this.RawReq) {
|
||||
this.cacheRef = nil
|
||||
cacheBypassDescription = "BYPASS, not match"
|
||||
return
|
||||
}
|
||||
|
||||
// 相关变量
|
||||
this.varMapping["cache.policy.name"] = cachePolicy.Name
|
||||
this.varMapping["cache.policy.id"] = strconv.FormatInt(cachePolicy.Id, 10)
|
||||
this.varMapping["cache.policy.type"] = cachePolicy.Type
|
||||
|
||||
// Cache-Pragma
|
||||
if this.cacheRef.EnableRequestCachePragma {
|
||||
if this.RawReq.Header.Get("Cache-Control") == "no-cache" || this.RawReq.Header.Get("Pragma") == "no-cache" {
|
||||
this.cacheRef = nil
|
||||
cacheBypassDescription = "BYPASS, Cache-Control or Pragma"
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// TODO 支持Vary Header
|
||||
|
||||
// 缓存标签
|
||||
var tags = []string{}
|
||||
|
||||
// 检查是否有缓存
|
||||
var key string
|
||||
if this.web.Cache.Key != nil && this.web.Cache.Key.IsOn && len(this.web.Cache.Key.Host) > 0 {
|
||||
key = configutils.ParseVariables(this.cacheRef.Key, func(varName string) (value string) {
|
||||
switch varName {
|
||||
case "scheme":
|
||||
return this.web.Cache.Key.Scheme
|
||||
case "host":
|
||||
return this.web.Cache.Key.Host
|
||||
default:
|
||||
return this.Format("${" + varName + "}")
|
||||
}
|
||||
})
|
||||
} else {
|
||||
key = this.Format(this.cacheRef.Key)
|
||||
}
|
||||
|
||||
if len(key) == 0 {
|
||||
this.cacheRef = nil
|
||||
cacheBypassDescription = "BYPASS, empty key"
|
||||
return
|
||||
}
|
||||
var method = this.Method()
|
||||
if method != http.MethodGet {
|
||||
key += caches.SuffixMethod + method
|
||||
tags = append(tags, strings.ToLower(method))
|
||||
}
|
||||
|
||||
this.cacheKey = key
|
||||
this.varMapping["cache.key"] = key
|
||||
|
||||
// 读取缓存
|
||||
var storage = caches.SharedManager.FindStorageWithPolicy(cachePolicy.Id)
|
||||
if storage == nil {
|
||||
this.cacheRef = nil
|
||||
cacheBypassDescription = "BYPASS, no policy found"
|
||||
return
|
||||
}
|
||||
this.writer.cacheStorage = storage
|
||||
|
||||
// 如果正在预热,则不读取缓存,等待下一个步骤重新生成
|
||||
if (strings.HasPrefix(this.RawReq.RemoteAddr, "127.") || strings.HasPrefix(this.RawReq.RemoteAddr, "[::1]")) && this.RawReq.Header.Get("X-Edge-Cache-Action") == "fetch" {
|
||||
return
|
||||
}
|
||||
|
||||
// 判断是否在Purge
|
||||
if isPurging {
|
||||
this.varMapping["cache.status"] = "PURGE"
|
||||
|
||||
var subKeys = []string{
|
||||
key,
|
||||
key + caches.SuffixMethod + "HEAD",
|
||||
key + caches.SuffixWebP,
|
||||
key + caches.SuffixPartial,
|
||||
}
|
||||
// TODO 根据实际缓存的内容进行组合
|
||||
for _, encoding := range compressions.AllEncodings() {
|
||||
subKeys = append(subKeys, key+caches.SuffixCompression+encoding)
|
||||
subKeys = append(subKeys, key+caches.SuffixWebP+caches.SuffixCompression+encoding)
|
||||
}
|
||||
for _, subKey := range subKeys {
|
||||
err := storage.Delete(subKey)
|
||||
if err != nil {
|
||||
remotelogs.ErrorServer("HTTP_REQUEST_CACHE", "purge failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// 通过API节点清除别节点上的的Key
|
||||
SharedHTTPCacheTaskManager.PushTaskKeys([]string{key})
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// 调用回调
|
||||
this.onRequest()
|
||||
if this.writer.isFinished {
|
||||
return
|
||||
}
|
||||
|
||||
var reader caches.Reader
|
||||
var err error
|
||||
|
||||
var rangeHeader = this.RawReq.Header.Get("Range")
|
||||
var isPartialRequest = len(rangeHeader) > 0
|
||||
|
||||
// 检查是否支持WebP
|
||||
var webPIsEnabled = false
|
||||
var isHeadMethod = method == http.MethodHead
|
||||
if !isPartialRequest &&
|
||||
!isHeadMethod &&
|
||||
this.web.WebP != nil &&
|
||||
this.web.WebP.IsOn &&
|
||||
this.web.WebP.MatchRequest(filepath.Ext(this.Path()), this.Format) &&
|
||||
this.web.WebP.MatchAccept(this.RawReq.Header.Get("Accept")) {
|
||||
webPIsEnabled = true
|
||||
}
|
||||
|
||||
// 检查WebP压缩缓存
|
||||
if webPIsEnabled && !isPartialRequest && !isHeadMethod && reader == nil {
|
||||
if this.web.Compression != nil && this.web.Compression.IsOn {
|
||||
_, encoding, ok := this.web.Compression.MatchAcceptEncoding(this.RawReq.Header.Get("Accept-Encoding"))
|
||||
if ok {
|
||||
reader, err = storage.OpenReader(key+caches.SuffixWebP+caches.SuffixCompression+encoding, useStale, false)
|
||||
if err != nil && caches.IsBusyError(err) {
|
||||
this.varMapping["cache.status"] = "BUSY"
|
||||
this.cacheRef = nil
|
||||
return
|
||||
}
|
||||
if reader != nil {
|
||||
tags = append(tags, "webp", encoding)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查WebP
|
||||
if webPIsEnabled && !isPartialRequest &&
|
||||
!isHeadMethod &&
|
||||
reader == nil {
|
||||
reader, err = storage.OpenReader(key+caches.SuffixWebP, useStale, false)
|
||||
if err != nil && caches.IsBusyError(err) {
|
||||
this.varMapping["cache.status"] = "BUSY"
|
||||
this.cacheRef = nil
|
||||
return
|
||||
}
|
||||
if reader != nil {
|
||||
this.writer.cacheReaderSuffix = caches.SuffixWebP
|
||||
tags = append(tags, "webp")
|
||||
}
|
||||
}
|
||||
|
||||
// 检查普通压缩缓存
|
||||
if !isPartialRequest && !isHeadMethod && reader == nil {
|
||||
if this.web.Compression != nil && this.web.Compression.IsOn {
|
||||
_, encoding, ok := this.web.Compression.MatchAcceptEncoding(this.RawReq.Header.Get("Accept-Encoding"))
|
||||
if ok {
|
||||
reader, err = storage.OpenReader(key+caches.SuffixCompression+encoding, useStale, false)
|
||||
if err != nil && caches.IsBusyError(err) {
|
||||
this.varMapping["cache.status"] = "BUSY"
|
||||
this.cacheRef = nil
|
||||
return
|
||||
}
|
||||
if reader != nil {
|
||||
tags = append(tags, encoding)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查正常的文件
|
||||
var isPartialCache = false
|
||||
var partialRanges []rangeutils.Range
|
||||
var firstRangeEnd int64
|
||||
if reader == nil {
|
||||
reader, err = storage.OpenReader(key, useStale, false)
|
||||
if err != nil && caches.IsBusyError(err) {
|
||||
this.varMapping["cache.status"] = "BUSY"
|
||||
this.cacheRef = nil
|
||||
return
|
||||
}
|
||||
if err != nil && this.cacheRef.AllowPartialContent {
|
||||
// 尝试读取分片的缓存内容
|
||||
if len(rangeHeader) == 0 && this.cacheRef.ForcePartialContent {
|
||||
// 默认读取开头
|
||||
rangeHeader = "bytes=0-"
|
||||
}
|
||||
|
||||
if len(rangeHeader) > 0 {
|
||||
pReader, ranges, rangeEnd, goNext := this.tryPartialReader(storage, key, useStale, rangeHeader, this.cacheRef.ForcePartialContent)
|
||||
if !goNext {
|
||||
this.cacheRef = nil
|
||||
return
|
||||
}
|
||||
if pReader != nil {
|
||||
isPartialCache = true
|
||||
reader = pReader
|
||||
partialRanges = ranges
|
||||
firstRangeEnd = rangeEnd
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, caches.ErrNotFound) {
|
||||
// 移除请求中的 If-None-Match 和 If-Modified-Since,防止源站返回304而无法缓存
|
||||
if this.reverseProxy != nil {
|
||||
this.RawReq.Header.Del("If-None-Match")
|
||||
this.RawReq.Header.Del("If-Modified-Since")
|
||||
}
|
||||
|
||||
// cache相关变量
|
||||
this.varMapping["cache.status"] = "MISS"
|
||||
|
||||
if !useStale && this.web.Cache.Stale != nil && this.web.Cache.Stale.IsOn {
|
||||
this.cacheCanTryStale = true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if !this.canIgnore(err) {
|
||||
remotelogs.WarnServer("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: open cache failed: "+err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if !this.writer.DelayRead() {
|
||||
_ = reader.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
if useStale {
|
||||
this.varMapping["cache.status"] = "STALE"
|
||||
this.logAttrs["cache.status"] = "STALE"
|
||||
} else {
|
||||
this.varMapping["cache.status"] = "HIT"
|
||||
this.logAttrs["cache.status"] = "HIT"
|
||||
}
|
||||
|
||||
// 准备Buffer
|
||||
var fileSize = reader.BodySize()
|
||||
var totalSizeString = types.String(fileSize)
|
||||
if isPartialCache {
|
||||
fileSize = reader.(*caches.PartialFileReader).MaxLength()
|
||||
if totalSizeString == "0" {
|
||||
totalSizeString = "*"
|
||||
}
|
||||
}
|
||||
|
||||
// 读取Header
|
||||
var headerData = []byte{}
|
||||
this.writer.SetSentHeaderBytes(reader.HeaderSize())
|
||||
var headerPool = this.bytePool(reader.HeaderSize())
|
||||
var headerBuf = headerPool.Get()
|
||||
err = reader.ReadHeader(headerBuf.Bytes, func(n int) (goNext bool, readErr error) {
|
||||
headerData = append(headerData, headerBuf.Bytes[:n]...)
|
||||
for {
|
||||
var nIndex = bytes.Index(headerData, []byte{'\n'})
|
||||
if nIndex >= 0 {
|
||||
var row = headerData[:nIndex]
|
||||
var spaceIndex = bytes.Index(row, []byte{':'})
|
||||
if spaceIndex <= 0 {
|
||||
return false, errors.New("invalid header '" + string(row) + "'")
|
||||
}
|
||||
|
||||
this.writer.Header().Set(string(row[:spaceIndex]), string(row[spaceIndex+1:]))
|
||||
headerData = headerData[nIndex+1:]
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
headerPool.Put(headerBuf)
|
||||
if err != nil {
|
||||
if !this.canIgnore(err) {
|
||||
remotelogs.WarnServer("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: read header failed: "+err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 设置cache.age变量
|
||||
var age = strconv.FormatInt(fasttime.Now().Unix()-reader.LastModified(), 10)
|
||||
this.varMapping["cache.age"] = age
|
||||
|
||||
if addStatusHeader {
|
||||
if useStale {
|
||||
this.writer.Header().Set("X-Cache", "STALE, "+refType+", "+reader.TypeName())
|
||||
} else {
|
||||
this.writer.Header().Set("X-Cache", "HIT, "+refType+", "+reader.TypeName())
|
||||
}
|
||||
} else {
|
||||
this.writer.Header().Del("X-Cache")
|
||||
}
|
||||
if this.web.Cache.AddAgeHeader {
|
||||
this.writer.Header().Set("Age", age)
|
||||
}
|
||||
|
||||
// ETag
|
||||
var respHeader = this.writer.Header()
|
||||
var eTag = respHeader.Get("ETag")
|
||||
var lastModifiedAt = reader.LastModified()
|
||||
if len(eTag) == 0 {
|
||||
if lastModifiedAt > 0 {
|
||||
if len(tags) > 0 {
|
||||
eTag = "\"" + strconv.FormatInt(lastModifiedAt, 10) + "_" + strings.Join(tags, "_") + "\""
|
||||
} else {
|
||||
eTag = "\"" + strconv.FormatInt(lastModifiedAt, 10) + "\""
|
||||
}
|
||||
respHeader.Del("Etag")
|
||||
// 修复:移除 isPartialCache 限制,确保所有缓存类型都返回 ETag
|
||||
respHeader["ETag"] = []string{eTag}
|
||||
}
|
||||
}
|
||||
|
||||
// 支持 Last-Modified
|
||||
var modifiedTime = ""
|
||||
if lastModifiedAt > 0 {
|
||||
modifiedTime = time.Unix(utils.GMTUnixTime(lastModifiedAt), 0).Format("Mon, 02 Jan 2006 15:04:05") + " GMT"
|
||||
// 修复:移除 isPartialCache 限制,确保所有缓存类型都返回 Last-Modified
|
||||
respHeader.Set("Last-Modified", modifiedTime)
|
||||
}
|
||||
|
||||
// 支持 If-None-Match
|
||||
// 修复:移除 isPartialCache 限制,允许分片缓存也支持 304 响应
|
||||
if !this.isLnRequest && len(eTag) > 0 && this.requestHeader("If-None-Match") == eTag {
|
||||
// 自定义Header
|
||||
this.ProcessResponseHeaders(this.writer.Header(), http.StatusNotModified)
|
||||
this.addExpiresHeader(reader.ExpiresAt())
|
||||
this.writer.WriteHeader(http.StatusNotModified)
|
||||
this.isCached = true
|
||||
this.cacheRef = nil
|
||||
this.writer.SetOk()
|
||||
return true
|
||||
}
|
||||
|
||||
// 支持 If-Modified-Since
|
||||
// 修复:移除 isPartialCache 限制,允许分片缓存也支持 304 响应
|
||||
if !this.isLnRequest && len(modifiedTime) > 0 && this.requestHeader("If-Modified-Since") == modifiedTime {
|
||||
// 自定义Header
|
||||
this.ProcessResponseHeaders(this.writer.Header(), http.StatusNotModified)
|
||||
this.addExpiresHeader(reader.ExpiresAt())
|
||||
this.writer.WriteHeader(http.StatusNotModified)
|
||||
this.isCached = true
|
||||
this.cacheRef = nil
|
||||
this.writer.SetOk()
|
||||
return true
|
||||
}
|
||||
|
||||
this.ProcessResponseHeaders(this.writer.Header(), reader.Status())
|
||||
this.addExpiresHeader(reader.ExpiresAt())
|
||||
|
||||
// 返回上级节点过期时间
|
||||
if this.isLnRequest {
|
||||
respHeader.Set(LNExpiresHeader, types.String(reader.ExpiresAt()))
|
||||
}
|
||||
|
||||
// 输出Body
|
||||
if this.RawReq.Method == http.MethodHead {
|
||||
this.writer.WriteHeader(reader.Status())
|
||||
} else {
|
||||
ifRangeHeaders, ok := this.RawReq.Header["If-Range"]
|
||||
var supportRange = true
|
||||
if ok {
|
||||
supportRange = false
|
||||
for _, v := range ifRangeHeaders {
|
||||
if v == this.writer.Header().Get("ETag") || v == this.writer.Header().Get("Last-Modified") {
|
||||
supportRange = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 支持Range
|
||||
var ranges = partialRanges
|
||||
if supportRange {
|
||||
if len(rangeHeader) > 0 {
|
||||
if fileSize == 0 {
|
||||
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
|
||||
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
|
||||
return true
|
||||
}
|
||||
|
||||
if len(ranges) == 0 {
|
||||
ranges, ok = httpRequestParseRangeHeader(rangeHeader)
|
||||
if !ok {
|
||||
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
|
||||
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
|
||||
return true
|
||||
}
|
||||
}
|
||||
if len(ranges) > 0 {
|
||||
for k, r := range ranges {
|
||||
r2, ok := r.Convert(fileSize)
|
||||
if !ok {
|
||||
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
|
||||
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
|
||||
return true
|
||||
}
|
||||
|
||||
ranges[k] = r2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(ranges) == 1 {
|
||||
respHeader.Set("Content-Range", ranges[0].ComposeContentRangeHeader(totalSizeString))
|
||||
respHeader.Set("Content-Length", strconv.FormatInt(ranges[0].Length(), 10))
|
||||
this.writer.WriteHeader(http.StatusPartialContent)
|
||||
|
||||
var pool = this.bytePool(fileSize)
|
||||
var bodyBuf = pool.Get()
|
||||
|
||||
var rangeEnd = ranges[0].End()
|
||||
if firstRangeEnd > 0 {
|
||||
rangeEnd = firstRangeEnd
|
||||
}
|
||||
|
||||
err = reader.ReadBodyRange(bodyBuf.Bytes, ranges[0].Start(), rangeEnd, func(n int) (goNext bool, readErr error) {
|
||||
_, readErr = this.writer.Write(bodyBuf.Bytes[:n])
|
||||
if readErr != nil {
|
||||
return false, errWritingToClient
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
pool.Put(bodyBuf)
|
||||
if err != nil {
|
||||
this.varMapping["cache.status"] = "MISS"
|
||||
|
||||
if errors.Is(err, caches.ErrInvalidRange) {
|
||||
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
|
||||
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
|
||||
return true
|
||||
}
|
||||
if !this.canIgnore(err) {
|
||||
remotelogs.WarnServer("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: "+err.Error())
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
} else if len(ranges) > 1 {
|
||||
var boundary = httpRequestGenBoundary()
|
||||
respHeader.Set("Content-Type", "multipart/byteranges; boundary="+boundary)
|
||||
respHeader.Del("Content-Length")
|
||||
var contentType = respHeader.Get("Content-Type")
|
||||
|
||||
this.writer.WriteHeader(http.StatusPartialContent)
|
||||
|
||||
for index, r := range ranges {
|
||||
if index == 0 {
|
||||
_, err = this.writer.WriteString("--" + boundary + "\r\n")
|
||||
} else {
|
||||
_, err = this.writer.WriteString("\r\n--" + boundary + "\r\n")
|
||||
}
|
||||
if err != nil {
|
||||
// 不提示写入客户端错误
|
||||
return true
|
||||
}
|
||||
|
||||
_, err = this.writer.WriteString("Content-Range: " + r.ComposeContentRangeHeader(totalSizeString) + "\r\n")
|
||||
if err != nil {
|
||||
// 不提示写入客户端错误
|
||||
return true
|
||||
}
|
||||
|
||||
if len(contentType) > 0 {
|
||||
_, err = this.writer.WriteString("Content-Type: " + contentType + "\r\n\r\n")
|
||||
if err != nil {
|
||||
// 不提示写入客户端错误
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
var pool = this.bytePool(fileSize)
|
||||
var bodyBuf = pool.Get()
|
||||
err = reader.ReadBodyRange(bodyBuf.Bytes, r.Start(), r.End(), func(n int) (goNext bool, readErr error) {
|
||||
_, readErr = this.writer.Write(bodyBuf.Bytes[:n])
|
||||
if readErr != nil {
|
||||
return false, errWritingToClient
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
pool.Put(bodyBuf)
|
||||
if err != nil {
|
||||
if !this.canIgnore(err) {
|
||||
remotelogs.WarnServer("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: "+err.Error())
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
_, err = this.writer.WriteString("\r\n--" + boundary + "--\r\n")
|
||||
if err != nil {
|
||||
this.varMapping["cache.status"] = "MISS"
|
||||
|
||||
// 不提示写入客户端错误
|
||||
return true
|
||||
}
|
||||
} else { // 没有Range
|
||||
var resp = &http.Response{
|
||||
Body: reader,
|
||||
ContentLength: reader.BodySize(),
|
||||
}
|
||||
this.writer.Prepare(resp, fileSize, reader.Status(), false)
|
||||
this.writer.WriteHeader(reader.Status())
|
||||
|
||||
if storage.CanSendfile() {
|
||||
var pool = this.bytePool(fileSize)
|
||||
var bodyBuf = pool.Get()
|
||||
if fp, canSendFile := this.writer.canSendfile(); canSendFile {
|
||||
this.writer.sentBodyBytes, err = io.CopyBuffer(this.writer.rawWriter, fp, bodyBuf.Bytes)
|
||||
} else {
|
||||
_, err = io.CopyBuffer(this.writer, resp.Body, bodyBuf.Bytes)
|
||||
}
|
||||
pool.Put(bodyBuf)
|
||||
} else {
|
||||
mmapReader, isMMAPReader := reader.(*caches.MMAPFileReader)
|
||||
if isMMAPReader {
|
||||
_, err = mmapReader.CopyBodyTo(this.writer)
|
||||
} else {
|
||||
var pool = this.bytePool(fileSize)
|
||||
var bodyBuf = pool.Get()
|
||||
_, err = io.CopyBuffer(this.writer, resp.Body, bodyBuf.Bytes)
|
||||
pool.Put(bodyBuf)
|
||||
}
|
||||
}
|
||||
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
if err != nil {
|
||||
this.varMapping["cache.status"] = "MISS"
|
||||
|
||||
if !this.canIgnore(err) {
|
||||
remotelogs.WarnServer("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: read body failed: "+err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this.isCached = true
|
||||
this.cacheRef = nil
|
||||
|
||||
this.writer.SetOk()
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// 设置Expires Header
|
||||
func (this *HTTPRequest) addExpiresHeader(expiresAt int64) {
|
||||
if this.cacheRef.ExpiresTime != nil && this.cacheRef.ExpiresTime.IsPrior && this.cacheRef.ExpiresTime.IsOn {
|
||||
if this.cacheRef.ExpiresTime.Overwrite || len(this.writer.Header().Get("Expires")) == 0 {
|
||||
if this.cacheRef.ExpiresTime.AutoCalculate {
|
||||
this.writer.Header().Set("Expires", time.Unix(utils.GMTUnixTime(expiresAt), 0).Format("Mon, 2 Jan 2006 15:04:05")+" GMT")
|
||||
this.writer.Header().Del("Cache-Control")
|
||||
} else if this.cacheRef.ExpiresTime.Duration != nil {
|
||||
var duration = this.cacheRef.ExpiresTime.Duration.Duration()
|
||||
if duration > 0 {
|
||||
this.writer.Header().Set("Expires", utils.GMTTime(time.Now().Add(duration)).Format("Mon, 2 Jan 2006 15:04:05")+" GMT")
|
||||
this.writer.Header().Del("Cache-Control")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试读取区间缓存
|
||||
func (this *HTTPRequest) tryPartialReader(storage caches.StorageInterface, key string, useStale bool, rangeHeader string, forcePartialContent bool) (resultReader caches.Reader, ranges []rangeutils.Range, firstRangeEnd int64, goNext bool) {
|
||||
goNext = true
|
||||
|
||||
// 尝试读取Partial cache
|
||||
if len(rangeHeader) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
ranges, ok := httpRequestParseRangeHeader(rangeHeader)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
pReader, pErr := storage.OpenReader(key+caches.SuffixPartial, useStale, true)
|
||||
if pErr != nil {
|
||||
if caches.IsBusyError(pErr) {
|
||||
this.varMapping["cache.status"] = "BUSY"
|
||||
goNext = false
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
partialReader, ok := pReader.(*caches.PartialFileReader)
|
||||
if !ok {
|
||||
_ = pReader.Close()
|
||||
return
|
||||
}
|
||||
var isOk = false
|
||||
defer func() {
|
||||
if !isOk {
|
||||
_ = pReader.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// 检查是否已下载完整
|
||||
if !forcePartialContent &&
|
||||
len(ranges) > 0 &&
|
||||
ranges[0][1] < 0 &&
|
||||
!partialReader.IsCompleted() {
|
||||
if partialReader.BodySize() > 0 {
|
||||
var options = this.ReqServer.HTTPCachePolicy.Options
|
||||
if options != nil {
|
||||
fileStorage, isFileStorage := storage.(*caches.FileStorage)
|
||||
if isFileStorage && fileStorage.Options() != nil && fileStorage.Options().EnableIncompletePartialContent {
|
||||
var r = ranges[0]
|
||||
r2, findOk := partialReader.Ranges().FindRangeAtPosition(r.Start())
|
||||
if findOk && r2.Length() >= (256<<10) /* worth reading */ {
|
||||
isOk = true
|
||||
ranges[0] = [2]int64{r.Start(), partialReader.BodySize() - 1} // Content-Range: bytes 0-[CONTENT_LENGTH - 1]/CONTENT_LENGTH
|
||||
|
||||
pReader.SetNextReader(NewHTTPRequestPartialReader(this, r2.End(), partialReader))
|
||||
return pReader, ranges, r2.End() - 1 /* not include last byte */, true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 检查范围
|
||||
// 这里 **切记不要** 为末尾位置指定一个中间值,因为部分软件客户端不支持
|
||||
for index, r := range ranges {
|
||||
r1, ok := r.Convert(partialReader.MaxLength())
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
r2, ok := partialReader.ContainsRange(r1)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
ranges[index] = r2
|
||||
}
|
||||
|
||||
isOk = true
|
||||
return pReader, ranges, -1, true
|
||||
}
|
||||
103
EdgeNode/internal/nodes/http_request_cache_partial.go
Normal file
103
EdgeNode/internal/nodes/http_request_cache_partial.go
Normal file
@@ -0,0 +1,103 @@
|
||||
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/caches"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// HTTPRequestPartialReader 分区文件读取器
|
||||
type HTTPRequestPartialReader struct {
|
||||
req *HTTPRequest
|
||||
offset int64
|
||||
resp *http.Response
|
||||
|
||||
cacheReader caches.Reader
|
||||
cacheWriter caches.Writer
|
||||
}
|
||||
|
||||
// NewHTTPRequestPartialReader 构建新的分区文件读取器
|
||||
// req 当前请求
|
||||
// offset 读取位置
|
||||
// reader 当前缓存读取器
|
||||
func NewHTTPRequestPartialReader(req *HTTPRequest, offset int64, reader caches.Reader) *HTTPRequestPartialReader {
|
||||
return &HTTPRequestPartialReader{
|
||||
req: req,
|
||||
offset: offset,
|
||||
cacheReader: reader,
|
||||
}
|
||||
}
|
||||
|
||||
// 读取内容
|
||||
func (this *HTTPRequestPartialReader) Read(p []byte) (n int, err error) {
|
||||
if this.resp == nil {
|
||||
_ = this.cacheReader.Close()
|
||||
|
||||
this.req.RawReq.Header.Set("Range", "bytes="+types.String(this.offset)+"-")
|
||||
var resp = this.req.doReverseProxy(false)
|
||||
if resp == nil {
|
||||
err = io.ErrUnexpectedEOF
|
||||
return
|
||||
}
|
||||
|
||||
this.resp = resp
|
||||
|
||||
// 对比Content-MD5
|
||||
partialReader, ok := this.cacheReader.(*caches.PartialFileReader)
|
||||
if ok {
|
||||
if partialReader.Ranges().Version >= 2 && resp.Header.Get("Content-MD5") != partialReader.Ranges().ContentMD5 {
|
||||
err = io.ErrUnexpectedEOF
|
||||
|
||||
var storage = this.req.writer.cacheStorage
|
||||
if storage != nil {
|
||||
_ = storage.Delete(this.req.cacheKey + caches.SuffixPartial)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 准备写入
|
||||
this.prepareCacheWriter()
|
||||
}
|
||||
|
||||
n, err = this.resp.Body.Read(p)
|
||||
|
||||
// 写入到缓存
|
||||
if n > 0 && this.cacheWriter != nil {
|
||||
_ = this.cacheWriter.WriteAt(this.offset, p[:n])
|
||||
this.offset += int64(n)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Close 关闭读取器
|
||||
func (this *HTTPRequestPartialReader) Close() error {
|
||||
if this.cacheWriter != nil {
|
||||
_ = this.cacheWriter.Close()
|
||||
}
|
||||
|
||||
if this.resp != nil && this.resp.Body != nil {
|
||||
return this.resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 准备缓存写入器
|
||||
func (this *HTTPRequestPartialReader) prepareCacheWriter() {
|
||||
var storage = this.req.writer.cacheStorage
|
||||
if storage == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var cacheKey = this.req.cacheKey + caches.SuffixPartial
|
||||
writer, err := storage.OpenWriter(cacheKey, this.cacheReader.ExpiresAt(), this.cacheReader.Status(), int(this.cacheReader.HeaderSize()), this.cacheReader.BodySize(), -1, true)
|
||||
if err == nil {
|
||||
this.cacheWriter = writer
|
||||
}
|
||||
}
|
||||
8
EdgeNode/internal/nodes/http_request_cc.go
Normal file
8
EdgeNode/internal/nodes/http_request_cc.go
Normal file
@@ -0,0 +1,8 @@
|
||||
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
//go:build !plus
|
||||
|
||||
package nodes
|
||||
|
||||
func (this *HTTPRequest) doCC() (block bool) {
|
||||
return
|
||||
}
|
||||
355
EdgeNode/internal/nodes/http_request_cc_plus.go
Normal file
355
EdgeNode/internal/nodes/http_request_cc_plus.go
Normal file
@@ -0,0 +1,355 @@
|
||||
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
//go:build plus
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
iplib "github.com/TeaOSLab/EdgeCommon/pkg/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/conns"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/agents"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/counters"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||||
maputils "github.com/TeaOSLab/EdgeNode/internal/utils/maps"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/ttlcache"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
stringutil "github.com/iwind/TeaGo/utils/string"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ccBlockedCounter struct {
|
||||
count int32
|
||||
updatedAt int64
|
||||
}
|
||||
|
||||
var ccBlockedMap = maputils.NewFixedMap[string, *ccBlockedCounter](65535)
|
||||
|
||||
func (this *HTTPRequest) doCC() (block bool) {
|
||||
if this.nodeConfig == nil || this.ReqServer == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var validatePath = "/GE/CC/VALIDATOR"
|
||||
var maxConnections = 30
|
||||
|
||||
// 策略
|
||||
var httpCCPolicy = this.nodeConfig.FindHTTPCCPolicyWithClusterId(this.ReqServer.ClusterId)
|
||||
var scope = firewallconfigs.FirewallScopeGlobal
|
||||
if httpCCPolicy != nil {
|
||||
if !httpCCPolicy.IsOn {
|
||||
return
|
||||
}
|
||||
if len(httpCCPolicy.RedirectsChecking.ValidatePath) > 0 {
|
||||
validatePath = httpCCPolicy.RedirectsChecking.ValidatePath
|
||||
}
|
||||
if httpCCPolicy.MaxConnectionsPerIP > 0 {
|
||||
maxConnections = httpCCPolicy.MaxConnectionsPerIP
|
||||
}
|
||||
scope = httpCCPolicy.FirewallScope()
|
||||
}
|
||||
|
||||
var ccConfig = this.web.CC
|
||||
if ccConfig == nil || !ccConfig.IsOn || (this.RawReq.URL.Path != validatePath && !ccConfig.MatchURL(this.requestScheme()+"://"+this.ReqHost+this.Path())) {
|
||||
return
|
||||
}
|
||||
|
||||
// 忽略常用文件
|
||||
if ccConfig.IgnoreCommonFiles {
|
||||
if len(this.RawReq.Referer()) > 0 {
|
||||
var ext = filepath.Ext(this.RawReq.URL.Path)
|
||||
if len(ext) > 0 && utils.IsCommonFileExtension(ext) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查白名单
|
||||
var remoteAddr = this.requestRemoteAddr(true)
|
||||
|
||||
// 检查是否为白名单直连
|
||||
if !Tea.IsTesting() && this.nodeConfig.IPIsAutoAllowed(remoteAddr) {
|
||||
return
|
||||
}
|
||||
|
||||
// 是否在全局名单中
|
||||
canGoNext, isInAllowedList, _ := iplibrary.AllowIP(remoteAddr, this.ReqServer.Id)
|
||||
if !canGoNext {
|
||||
this.disableLog = true
|
||||
this.Close()
|
||||
return true
|
||||
}
|
||||
if isInAllowedList {
|
||||
return false
|
||||
}
|
||||
|
||||
// WAF黑名单
|
||||
if waf.SharedIPBlackList.Contains(waf.IPTypeAll, scope, this.ReqServer.Id, remoteAddr) {
|
||||
this.disableLog = true
|
||||
this.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查是否启用QPS
|
||||
if ccConfig.MinQPSPerIP > 0 && this.RawReq.URL.Path != validatePath {
|
||||
var avgQPS = counters.SharedCounter.IncreaseKey("QPS:"+remoteAddr, 60) / 60
|
||||
if avgQPS <= 0 {
|
||||
avgQPS = 1
|
||||
}
|
||||
if avgQPS < types.Uint32(ccConfig.MinQPSPerIP) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 检查连接数
|
||||
if conns.SharedMap.CountIPConns(remoteAddr) >= maxConnections {
|
||||
this.ccForbid(5)
|
||||
|
||||
var forbiddenTimes = this.increaseCCCounter(remoteAddr)
|
||||
waf.SharedIPBlackList.RecordIP(waf.IPTypeAll, scope, this.ReqServer.Id, remoteAddr, fasttime.Now().Unix()+int64(forbiddenTimes*1800), 0, scope == firewallconfigs.FirewallScopeGlobal, 0, 0, "CC防护拦截:并发连接数超出"+types.String(maxConnections)+"个")
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// GET302验证
|
||||
if ccConfig.EnableGET302 &&
|
||||
this.RawReq.Method == http.MethodGet &&
|
||||
!agents.SharedManager.ContainsIP(remoteAddr) /** 搜索引擎亲和性 **/ &&
|
||||
!strings.HasPrefix(this.RawReq.URL.Path, "/baidu_verify_") /** 百度验证 **/ &&
|
||||
!strings.HasPrefix(this.RawReq.URL.Path, "/google") /** Google验证 **/ {
|
||||
// 忽略搜索引擎
|
||||
var canSkip302 = false
|
||||
var ipResult = iplib.LookupIP(remoteAddr)
|
||||
if ipResult != nil && ipResult.IsOk() {
|
||||
var providerName = ipResult.ProviderName()
|
||||
canSkip302 = strings.Contains(providerName, "百度") || strings.Contains(providerName, "谷歌") || strings.Contains(providerName, "baidu") || strings.Contains(providerName, "google")
|
||||
}
|
||||
|
||||
if !canSkip302 {
|
||||
// 检查参数
|
||||
var ccWhiteListKey = "HTTP-CC-GET302-" + remoteAddr
|
||||
var currentTime = fasttime.Now().Unix()
|
||||
if this.RawReq.URL.Path == validatePath {
|
||||
this.DisableAccessLog()
|
||||
this.DisableStat()
|
||||
|
||||
// TODO 根据浏览器信息决定是否校验referer
|
||||
|
||||
if !this.checkCCRedirects(httpCCPolicy, remoteAddr) {
|
||||
return true
|
||||
}
|
||||
|
||||
var key = this.RawReq.URL.Query().Get("key")
|
||||
var pieces = strings.Split(key, ".") // key1.key2.timestamp
|
||||
if len(pieces) != 3 {
|
||||
this.ccForbid(1)
|
||||
return true
|
||||
}
|
||||
var urlKey = pieces[0]
|
||||
var timestampKey = pieces[1]
|
||||
var timestamp = pieces[2]
|
||||
|
||||
var targetURL = this.RawReq.URL.Query().Get("url")
|
||||
|
||||
var realURLKey = stringutil.Md5(sharedNodeConfig.Secret + "@" + targetURL + "@" + remoteAddr)
|
||||
if urlKey != realURLKey {
|
||||
this.ccForbid(2)
|
||||
return true
|
||||
}
|
||||
|
||||
// 校验时间
|
||||
if timestampKey != stringutil.Md5(sharedNodeConfig.Secret+"@"+timestamp) {
|
||||
this.ccForbid(3)
|
||||
return true
|
||||
}
|
||||
|
||||
var elapsedSeconds = currentTime - types.Int64(timestamp)
|
||||
if elapsedSeconds > 10 /** 10秒钟 **/ { // 如果校验时间过长,则可能阻止当前访问
|
||||
if elapsedSeconds > 300 /** 5分钟 **/ { // 如果超时时间过长就跳回原URL
|
||||
httpRedirect(this.writer, this.RawReq, targetURL, http.StatusFound)
|
||||
} else {
|
||||
this.ccForbid(4)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// 加入到临时白名单
|
||||
ttlcache.SharedInt64Cache.Write(ccWhiteListKey, 1, currentTime+600 /** 10分钟 **/)
|
||||
|
||||
// 跳转回原来URL
|
||||
httpRedirect(this.writer, this.RawReq, targetURL, http.StatusFound)
|
||||
|
||||
return true
|
||||
} else {
|
||||
// 检查临时白名单
|
||||
if ttlcache.SharedInt64Cache.Read(ccWhiteListKey) == nil {
|
||||
if !this.checkCCRedirects(httpCCPolicy, remoteAddr) {
|
||||
return true
|
||||
}
|
||||
|
||||
var urlKey = stringutil.Md5(sharedNodeConfig.Secret + "@" + this.URL() + "@" + remoteAddr)
|
||||
var timestampKey = stringutil.Md5(sharedNodeConfig.Secret + "@" + types.String(currentTime))
|
||||
|
||||
// 跳转到验证URL
|
||||
this.DisableStat()
|
||||
httpRedirect(this.writer, this.RawReq, validatePath+"?key="+urlKey+"."+timestampKey+"."+types.String(currentTime)+"&url="+url.QueryEscape(this.URL()), http.StatusFound)
|
||||
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if this.RawReq.URL.Path == validatePath {
|
||||
// 直接跳回
|
||||
var targetURL = this.RawReq.URL.Query().Get("url")
|
||||
httpRedirect(this.writer, this.RawReq, targetURL, http.StatusFound)
|
||||
return true
|
||||
}
|
||||
|
||||
// Key
|
||||
var ccKeys = []string{}
|
||||
if ccConfig.WithRequestPath {
|
||||
ccKeys = append(ccKeys, "HTTP-CC-"+remoteAddr+"-"+this.Path()) // 这里可以忽略域名,因为一个正常用户同时访问多个域名的可能性不大
|
||||
} else {
|
||||
ccKeys = append(ccKeys, "HTTP-CC-"+remoteAddr)
|
||||
}
|
||||
|
||||
// 指纹
|
||||
if this.IsHTTPS && ccConfig.EnableFingerprint {
|
||||
var requestConn = this.RawReq.Context().Value(HTTPConnContextKey)
|
||||
if requestConn != nil {
|
||||
clientConn, ok := requestConn.(ClientConnInterface)
|
||||
if ok {
|
||||
var fingerprint = clientConn.Fingerprint()
|
||||
if len(fingerprint) > 0 {
|
||||
var fingerprintString = fmt.Sprintf("%x", fingerprint)
|
||||
if ccConfig.WithRequestPath {
|
||||
ccKeys = append(ccKeys, "HTTP-CC-"+fingerprintString+"-"+this.Path()) // 这里可以忽略域名,因为一个正常用户同时访问多个域名的可能性不大
|
||||
} else {
|
||||
ccKeys = append(ccKeys, "HTTP-CC-"+fingerprintString)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查阈值
|
||||
var thresholds = ccConfig.Thresholds
|
||||
if len(thresholds) == 0 || ccConfig.UseDefaultThresholds {
|
||||
if httpCCPolicy != nil && len(httpCCPolicy.Thresholds) > 0 {
|
||||
thresholds = httpCCPolicy.Thresholds
|
||||
} else {
|
||||
thresholds = serverconfigs.DefaultHTTPCCThresholds
|
||||
}
|
||||
}
|
||||
if len(thresholds) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var currentTime = fasttime.Now().Unix()
|
||||
|
||||
for _, threshold := range thresholds {
|
||||
if threshold.PeriodSeconds <= 0 || threshold.MaxRequests <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, ccKey := range ccKeys {
|
||||
var count = counters.SharedCounter.IncreaseKey(ccKey+"-T"+types.String(threshold.PeriodSeconds), int(threshold.PeriodSeconds))
|
||||
if count >= types.Uint32(threshold.MaxRequests) {
|
||||
this.writeCode(http.StatusTooManyRequests, "Too many requests, please wait for a few minutes.", "访问过于频繁,请稍等片刻后再访问。")
|
||||
|
||||
// 记录到黑名单
|
||||
if threshold.BlockSeconds > 0 {
|
||||
// 如果被重复拦截,则加大惩罚力度
|
||||
var forbiddenTimes = this.increaseCCCounter(remoteAddr)
|
||||
waf.SharedIPBlackList.RecordIP(waf.IPTypeAll, scope, this.ReqServer.Id, remoteAddr, currentTime+int64(threshold.BlockSeconds*forbiddenTimes), 0, scope == firewallconfigs.FirewallScopeGlobal, 0, 0, "CC防护拦截:在"+types.String(threshold.PeriodSeconds)+"秒内请求超过"+types.String(threshold.MaxRequests)+"次")
|
||||
}
|
||||
|
||||
this.tags = append(this.tags, "CCProtection")
|
||||
this.isAttack = true
|
||||
|
||||
// 关闭连接
|
||||
this.writer.Flush()
|
||||
this.Close()
|
||||
|
||||
// 关闭同一个IP其他连接
|
||||
conns.SharedMap.CloseIPConns(remoteAddr)
|
||||
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// TODO 对forbidden比较多的IP进行惩罚
|
||||
func (this *HTTPRequest) ccForbid(code int) {
|
||||
this.writeCode(http.StatusForbidden, "The request has been blocked by cc policy.", "当前请求已被CC策略拦截。")
|
||||
}
|
||||
|
||||
// 检查跳转次数
|
||||
func (this *HTTPRequest) checkCCRedirects(httpCCPolicy *nodeconfigs.HTTPCCPolicy, remoteAddr string) bool {
|
||||
// 如果无效跳转次数太多,则拦截
|
||||
var ccRedirectKey = "HTTP-CC-GET302-" + remoteAddr + "-REDIRECTS"
|
||||
|
||||
var maxRedirectDurationSeconds = 120
|
||||
var maxRedirects uint32 = 30
|
||||
var blockSeconds int64 = 3600
|
||||
|
||||
if httpCCPolicy != nil && httpCCPolicy.IsOn {
|
||||
if httpCCPolicy.RedirectsChecking.DurationSeconds > 0 {
|
||||
maxRedirectDurationSeconds = httpCCPolicy.RedirectsChecking.DurationSeconds
|
||||
}
|
||||
if httpCCPolicy.RedirectsChecking.MaxRedirects > 0 {
|
||||
maxRedirects = types.Uint32(httpCCPolicy.RedirectsChecking.MaxRedirects)
|
||||
}
|
||||
if httpCCPolicy.RedirectsChecking.BlockSeconds > 0 {
|
||||
blockSeconds = int64(httpCCPolicy.RedirectsChecking.BlockSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
var countRedirects = counters.SharedCounter.IncreaseKey(ccRedirectKey, maxRedirectDurationSeconds)
|
||||
if countRedirects >= maxRedirects {
|
||||
// 加入黑名单
|
||||
var scope = firewallconfigs.FirewallScopeGlobal
|
||||
if httpCCPolicy != nil {
|
||||
scope = httpCCPolicy.FirewallScope()
|
||||
}
|
||||
|
||||
waf.SharedIPBlackList.RecordIP(waf.IPTypeAll, scope, this.ReqServer.Id, remoteAddr, fasttime.Now().Unix()+blockSeconds, 0, scope == firewallconfigs.FirewallScopeGlobal, 0, 0, "CC防护拦截:在"+types.String(maxRedirectDurationSeconds)+"秒内无效跳转"+types.String(maxRedirects)+"次")
|
||||
this.Close()
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// 对CC禁用次数进行计数
|
||||
func (this *HTTPRequest) increaseCCCounter(remoteAddr string) int32 {
|
||||
counter, ok := ccBlockedMap.Get(remoteAddr)
|
||||
if !ok {
|
||||
counter = &ccBlockedCounter{
|
||||
count: 1,
|
||||
updatedAt: fasttime.Now().Unix(),
|
||||
}
|
||||
ccBlockedMap.Put(remoteAddr, counter)
|
||||
} else {
|
||||
if counter.updatedAt < fasttime.Now().Unix()-86400 /** 有效期间不要超过1天,防止无限期封禁 **/ {
|
||||
counter.count = 0
|
||||
}
|
||||
counter.updatedAt = fasttime.Now().Unix()
|
||||
if counter.count < 32 {
|
||||
counter.count++
|
||||
}
|
||||
}
|
||||
return counter.count
|
||||
}
|
||||
531
EdgeNode/internal/nodes/http_request_encryption.go
Normal file
531
EdgeNode/internal/nodes/http_request_encryption.go
Normal file
@@ -0,0 +1,531 @@
|
||||
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/TeaOSLab/EdgeNode/internal/encryption"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/andybalholm/brotli"
|
||||
)
|
||||
|
||||
const encryptionCacheVersion = "xor-v1"
|
||||
|
||||
// processPageEncryption 处理页面加密
|
||||
func (this *HTTPRequest) processPageEncryption(resp *http.Response) error {
|
||||
// 首先检查是否是 /waf/loader.js,如果是则直接跳过(不应该被加密)
|
||||
// 这个检查必须在所有其他检查之前,确保 loader.js 永远不会被加密
|
||||
if strings.Contains(this.URL(), "/waf/loader.js") {
|
||||
remotelogs.Debug("HTTP_REQUEST_ENCRYPTION", "skipping /waf/loader.js, should not be encrypted, URL: "+this.URL())
|
||||
return nil
|
||||
}
|
||||
|
||||
if this.web.Encryption == nil {
|
||||
remotelogs.Debug("HTTP_REQUEST_ENCRYPTION", "encryption config is nil for URL: "+this.URL())
|
||||
return nil
|
||||
}
|
||||
|
||||
if !this.web.Encryption.IsOn {
|
||||
remotelogs.Debug("HTTP_REQUEST_ENCRYPTION", "encryption switch is off for URL: "+this.URL())
|
||||
return nil
|
||||
}
|
||||
|
||||
if !this.web.Encryption.IsEnabled() {
|
||||
remotelogs.Debug("HTTP_REQUEST_ENCRYPTION", "encryption is not enabled for URL: "+this.URL())
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查 URL 白名单
|
||||
if this.web.Encryption.MatchExcludeURL(this.URL()) {
|
||||
remotelogs.Debug("HTTP_REQUEST_ENCRYPTION", "URL is in exclude list: "+this.URL())
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查 Content-Type 和 URL
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
contentTypeLower := strings.ToLower(contentType)
|
||||
urlLower := strings.ToLower(this.URL())
|
||||
|
||||
var isHTML = strings.Contains(contentTypeLower, "text/html")
|
||||
// 判断是否为 JavaScript 文件:通过 Content-Type 或 URL 后缀
|
||||
var isJavaScript = strings.Contains(contentTypeLower, "text/javascript") ||
|
||||
strings.Contains(contentTypeLower, "application/javascript") ||
|
||||
strings.Contains(contentTypeLower, "application/x-javascript") ||
|
||||
strings.Contains(contentTypeLower, "text/ecmascript") ||
|
||||
strings.HasSuffix(urlLower, ".js") ||
|
||||
strings.Contains(urlLower, ".js?") ||
|
||||
strings.Contains(urlLower, ".js&")
|
||||
|
||||
if !isHTML && !isJavaScript {
|
||||
remotelogs.Debug("HTTP_REQUEST_ENCRYPTION", "content type not match, URL: "+this.URL()+", Content-Type: "+contentType)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查内容大小(仅处理小于 10MB 的内容)
|
||||
if resp.ContentLength > 0 && resp.ContentLength > 10*1024*1024 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 读取响应体
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
// 如果源站返回了压缩内容,先解压再处理
|
||||
decodedBody, decoded, err := decodeResponseBody(bodyBytes, resp.Header.Get("Content-Encoding"))
|
||||
if err == nil && decoded {
|
||||
bodyBytes = decodedBody
|
||||
// 已经解压,移除 Content-Encoding
|
||||
resp.Header.Del("Content-Encoding")
|
||||
}
|
||||
|
||||
// 检查实际大小
|
||||
if len(bodyBytes) > 10*1024*1024 {
|
||||
// 内容太大,恢复原始响应体
|
||||
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
return nil
|
||||
}
|
||||
|
||||
var encryptedBytes []byte
|
||||
|
||||
// 处理 JavaScript 文件
|
||||
if isJavaScript {
|
||||
remotelogs.Debug("HTTP_REQUEST_ENCRYPTION", "processing JavaScript file, URL: "+this.URL())
|
||||
|
||||
// 检查是否需要加密独立的 JavaScript 文件
|
||||
if this.web.Encryption.Javascript == nil {
|
||||
remotelogs.Debug("HTTP_REQUEST_ENCRYPTION", "Javascript config is nil for URL: "+this.URL())
|
||||
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
return nil
|
||||
}
|
||||
|
||||
if !this.web.Encryption.Javascript.IsOn {
|
||||
remotelogs.Debug("HTTP_REQUEST_ENCRYPTION", "Javascript encryption is not enabled for URL: "+this.URL())
|
||||
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查 URL 匹配
|
||||
if !this.web.Encryption.Javascript.MatchURL(this.URL()) {
|
||||
remotelogs.Debug("HTTP_REQUEST_ENCRYPTION", "URL does not match patterns for URL: "+this.URL())
|
||||
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
return nil
|
||||
}
|
||||
|
||||
// 跳过 Loader 文件(必须排除,否则 loader.js 会被错误加密)
|
||||
if strings.Contains(this.URL(), "/waf/loader.js") ||
|
||||
strings.Contains(this.URL(), "waf-loader.js") ||
|
||||
strings.Contains(this.URL(), "__WAF_") {
|
||||
remotelogs.Debug("HTTP_REQUEST_ENCRYPTION", "skipping loader file, URL: "+this.URL())
|
||||
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
return nil
|
||||
}
|
||||
|
||||
// 加密 JavaScript 文件
|
||||
remotelogs.Println("HTTP_REQUEST_ENCRYPTION", "encrypting JavaScript file, URL: "+this.URL())
|
||||
encryptedBytes, err = this.encryptJavaScriptFile(bodyBytes, resp)
|
||||
if err != nil {
|
||||
remotelogs.Warn("HTTP_REQUEST_ENCRYPTION", "encrypt JavaScript file failed: "+err.Error())
|
||||
// 加密失败,恢复原始响应体
|
||||
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
return nil
|
||||
}
|
||||
remotelogs.Println("HTTP_REQUEST_ENCRYPTION", "JavaScript file encrypted successfully, URL: "+this.URL())
|
||||
} else if isHTML {
|
||||
// 加密 HTML 内容
|
||||
encryptedBytes, err = this.encryptHTMLScripts(bodyBytes, resp)
|
||||
if err != nil {
|
||||
remotelogs.Warn("HTTP_REQUEST_ENCRYPTION", "encrypt HTML failed: "+err.Error())
|
||||
// 加密失败,恢复原始响应体
|
||||
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
// 未知类型,恢复原始响应体
|
||||
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
return nil
|
||||
}
|
||||
|
||||
// 替换响应体
|
||||
resp.Body = io.NopCloser(bytes.NewReader(encryptedBytes))
|
||||
resp.ContentLength = int64(len(encryptedBytes))
|
||||
resp.Header.Set("Content-Length", fmt.Sprintf("%d", len(encryptedBytes)))
|
||||
// 避免旧缓存导致解密算法不匹配
|
||||
resp.Header.Set("Cache-Control", "no-store, no-cache, must-revalidate")
|
||||
// 删除 Content-Encoding(如果存在),因为我们修改了内容
|
||||
resp.Header.Del("Content-Encoding")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// encryptHTMLScripts 加密 HTML 中的脚本
|
||||
func (this *HTTPRequest) encryptHTMLScripts(htmlBytes []byte, resp *http.Response) ([]byte, error) {
|
||||
html := string(htmlBytes)
|
||||
|
||||
// 检查是否需要加密 HTML 脚本
|
||||
if this.web.Encryption.HTML == nil || !this.web.Encryption.HTML.IsOn {
|
||||
return htmlBytes, nil
|
||||
}
|
||||
|
||||
// 检查 URL 匹配
|
||||
if !this.web.Encryption.HTML.MatchURL(this.URL()) {
|
||||
return htmlBytes, nil
|
||||
}
|
||||
|
||||
// 生成密钥
|
||||
remoteIP := this.requestRemoteAddr(true)
|
||||
userAgent := this.RawReq.UserAgent()
|
||||
keyID, actualKey := encryption.GenerateEncryptionKey(remoteIP, userAgent, this.web.Encryption.KeyPolicy)
|
||||
|
||||
// 检查缓存
|
||||
var cacheKey string
|
||||
if this.web.Encryption.Cache != nil && this.web.Encryption.Cache.IsOn {
|
||||
// 生成缓存键:keyID + URL + contentHash
|
||||
contentHash := fmt.Sprintf("%x", bytesHash(htmlBytes))
|
||||
cacheKey = fmt.Sprintf("encrypt_%s_%s_%s_%s", encryptionCacheVersion, keyID, this.URL(), contentHash)
|
||||
|
||||
cache := encryption.SharedEncryptionCache(
|
||||
int(this.web.Encryption.Cache.MaxSize),
|
||||
time.Duration(this.web.Encryption.Cache.TTL)*time.Second,
|
||||
)
|
||||
|
||||
if cached, ok := cache.Get(cacheKey); ok {
|
||||
return cached, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 提取并加密内联脚本
|
||||
if this.web.Encryption.HTML.EncryptInlineScripts {
|
||||
html = this.encryptInlineScripts(html, actualKey, keyID)
|
||||
}
|
||||
|
||||
// 提取并加密外部脚本(通过 src 属性)
|
||||
if this.web.Encryption.HTML.EncryptExternalScripts {
|
||||
html = this.encryptExternalScripts(html, actualKey, keyID)
|
||||
}
|
||||
|
||||
// 注入 Loader
|
||||
html = this.injectLoader(html)
|
||||
|
||||
result := []byte(html)
|
||||
|
||||
// 保存到缓存
|
||||
if this.web.Encryption.Cache != nil && this.web.Encryption.Cache.IsOn && len(cacheKey) > 0 {
|
||||
cache := encryption.SharedEncryptionCache(
|
||||
int(this.web.Encryption.Cache.MaxSize),
|
||||
time.Duration(this.web.Encryption.Cache.TTL)*time.Second,
|
||||
)
|
||||
cache.Set(cacheKey, result, this.web.Encryption.Cache.TTL)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// encryptInlineScripts 加密内联脚本
|
||||
func (this *HTTPRequest) encryptInlineScripts(html string, key string, keyID string) string {
|
||||
// 匹配 <script>...</script>(不包含 src 属性)
|
||||
scriptRegex := regexp.MustCompile(`(?i)<script(?:\s+[^>]*)?>([\s\S]*?)</script>`)
|
||||
|
||||
return scriptRegex.ReplaceAllStringFunc(html, func(match string) string {
|
||||
// 检查是否有 src 属性(外部脚本)
|
||||
if strings.Contains(strings.ToLower(match), "src=") {
|
||||
return match // 跳过外部脚本
|
||||
}
|
||||
|
||||
// 提取脚本内容
|
||||
contentMatch := regexp.MustCompile(`(?i)<script(?:\s+[^>]*)?>([\s\S]*?)</script>`)
|
||||
matches := contentMatch.FindStringSubmatch(match)
|
||||
if len(matches) < 2 {
|
||||
return match
|
||||
}
|
||||
|
||||
scriptContent := matches[1]
|
||||
|
||||
// 跳过空脚本或仅包含空白字符的脚本
|
||||
if strings.TrimSpace(scriptContent) == "" {
|
||||
return match
|
||||
}
|
||||
|
||||
// 跳过已加密的脚本(包含 __WAF_P__)
|
||||
if strings.Contains(scriptContent, "__WAF_P__") {
|
||||
return match
|
||||
}
|
||||
|
||||
// 加密脚本内容
|
||||
encrypted, err := this.encryptScript(scriptContent, key)
|
||||
if err != nil {
|
||||
return match // 加密失败,返回原始内容
|
||||
}
|
||||
|
||||
// 生成元数据(k 是 keyID,用于缓存;key 是实际密钥,用于解密)
|
||||
meta := fmt.Sprintf(`{"k":"%s","key":"%s","t":%d,"alg":"xor"}`, keyID, key, time.Now().Unix())
|
||||
|
||||
// 替换为加密后的脚本(同步解密执行,保证脚本顺序)
|
||||
return fmt.Sprintf(
|
||||
`<script>(function(){
|
||||
function xorDecodeToString(b64,key){
|
||||
var bin=atob(b64);
|
||||
var out=new Uint8Array(bin.length);
|
||||
for(var i=0;i<bin.length;i++){out[i]=bin.charCodeAt(i)^key.charCodeAt(i%%key.length);}
|
||||
if (typeof TextDecoder !== 'undefined') {
|
||||
return new TextDecoder().decode(out);
|
||||
}
|
||||
var s='';for(var j=0;j<out.length;j++){s+=String.fromCharCode(out[j]);}
|
||||
return s;
|
||||
}
|
||||
try{var meta=%s;var code=xorDecodeToString("%s",meta.key);window.eval(code);}catch(e){console.error('WAF inline decrypt/execute failed',e);}
|
||||
})();</script>`,
|
||||
meta,
|
||||
encrypted,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// encryptExternalScripts 加密外部脚本(通过替换 src 为加密后的内容)
|
||||
// 注意:这里我们实际上是将外部脚本的内容内联并加密
|
||||
func (this *HTTPRequest) encryptExternalScripts(html string, key string, keyID string) string {
|
||||
// 匹配 <script src="..."></script>
|
||||
scriptRegex := regexp.MustCompile(`(?i)<script\s+([^>]*src\s*=\s*["']([^"']+)["'][^>]*)>\s*</script>`)
|
||||
|
||||
return scriptRegex.ReplaceAllStringFunc(html, func(match string) string {
|
||||
// 提取 src URL
|
||||
srcMatch := regexp.MustCompile(`(?i)src\s*=\s*["']([^"']+)["']`)
|
||||
srcMatches := srcMatch.FindStringSubmatch(match)
|
||||
if len(srcMatches) < 2 {
|
||||
return match
|
||||
}
|
||||
|
||||
srcURL := srcMatches[1]
|
||||
|
||||
// 跳过已加密的脚本或 Loader
|
||||
if strings.Contains(srcURL, "waf-loader.js") || strings.Contains(srcURL, "__WAF_") {
|
||||
return match
|
||||
}
|
||||
|
||||
// 注意:实际实现中,我们需要获取外部脚本的内容
|
||||
// 这里为了简化,我们只是标记需要加密,实际内容获取需要在响应处理时进行
|
||||
// 当前实现:将外部脚本转换为内联加密脚本的占位符
|
||||
// 实际生产环境需要:1. 获取外部脚本内容 2. 加密 3. 替换
|
||||
|
||||
return match // 暂时返回原始内容,后续可以扩展
|
||||
})
|
||||
}
|
||||
|
||||
// encryptScript 加密脚本内容
|
||||
func (this *HTTPRequest) encryptScript(scriptContent string, key string) (string, error) {
|
||||
// 1. XOR 加密(不压缩,避免浏览器解压依赖导致顺序问题)
|
||||
encrypted := xorEncrypt([]byte(scriptContent), []byte(key))
|
||||
|
||||
// 2. Base64 编码
|
||||
encoded := base64.StdEncoding.EncodeToString(encrypted)
|
||||
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
// injectLoader 注入 Loader 脚本
|
||||
func (this *HTTPRequest) injectLoader(html string) string {
|
||||
// 检查是否已经注入
|
||||
if strings.Contains(html, "waf-loader.js") {
|
||||
return html
|
||||
}
|
||||
|
||||
// 在 </head> 之前注入,如果没有 </head>,则在 </body> 之前注入
|
||||
// 不使用 async,确保 loader 在解析阶段优先加载并执行
|
||||
loaderScript := `<script src="/waf/loader.js"></script>`
|
||||
|
||||
if strings.Contains(html, "</head>") {
|
||||
return strings.Replace(html, "</head>", loaderScript+"</head>", 1)
|
||||
} else if strings.Contains(html, "</body>") {
|
||||
return strings.Replace(html, "</body>", loaderScript+"</body>", 1)
|
||||
} else {
|
||||
// 如果都没有,在开头注入
|
||||
return loaderScript + html
|
||||
}
|
||||
}
|
||||
|
||||
// compressGzip 使用 Gzip 压缩(浏览器原生支持)
|
||||
func compressGzip(data []byte) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
writer := gzip.NewWriter(&buf)
|
||||
_, err := writer.Write(data)
|
||||
if err != nil {
|
||||
writer.Close()
|
||||
return nil, err
|
||||
}
|
||||
err = writer.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// decodeResponseBody 根据 Content-Encoding 解压响应体
|
||||
func decodeResponseBody(body []byte, encoding string) ([]byte, bool, error) {
|
||||
enc := strings.ToLower(strings.TrimSpace(encoding))
|
||||
if enc == "" || enc == "identity" {
|
||||
return body, false, nil
|
||||
}
|
||||
switch enc {
|
||||
case "gzip":
|
||||
reader, err := gzip.NewReader(bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return body, false, err
|
||||
}
|
||||
defer reader.Close()
|
||||
decoded, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return body, false, err
|
||||
}
|
||||
return decoded, true, nil
|
||||
case "br":
|
||||
reader := brotli.NewReader(bytes.NewReader(body))
|
||||
decoded, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return body, false, err
|
||||
}
|
||||
return decoded, true, nil
|
||||
default:
|
||||
// 未知编码,保持原样
|
||||
return body, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
// compressBrotli 使用 Brotli 压缩(保留用于其他用途)
|
||||
func compressBrotli(data []byte, level int) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
writer := brotli.NewWriterOptions(&buf, brotli.WriterOptions{
|
||||
Quality: level,
|
||||
LGWin: 14,
|
||||
})
|
||||
_, err := writer.Write(data)
|
||||
if err != nil {
|
||||
writer.Close()
|
||||
return nil, err
|
||||
}
|
||||
err = writer.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// xorEncrypt XOR 加密
|
||||
func xorEncrypt(data []byte, key []byte) []byte {
|
||||
result := make([]byte, len(data))
|
||||
keyLen := len(key)
|
||||
if keyLen == 0 {
|
||||
return data
|
||||
}
|
||||
|
||||
for i := 0; i < len(data); i++ {
|
||||
result[i] = data[i] ^ key[i%keyLen]
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// encryptJavaScriptFile 加密独立的 JavaScript 文件
|
||||
func (this *HTTPRequest) encryptJavaScriptFile(jsBytes []byte, resp *http.Response) ([]byte, error) {
|
||||
jsContent := string(jsBytes)
|
||||
|
||||
// 跳过空文件
|
||||
if strings.TrimSpace(jsContent) == "" {
|
||||
return jsBytes, nil
|
||||
}
|
||||
|
||||
// 跳过已加密的脚本(包含 __WAF_P__)
|
||||
if strings.Contains(jsContent, "__WAF_P__") {
|
||||
return jsBytes, nil
|
||||
}
|
||||
|
||||
// 生成密钥
|
||||
remoteIP := this.requestRemoteAddr(true)
|
||||
userAgent := this.RawReq.UserAgent()
|
||||
keyID, actualKey := encryption.GenerateEncryptionKey(remoteIP, userAgent, this.web.Encryption.KeyPolicy)
|
||||
|
||||
// 检查缓存
|
||||
var cacheKey string
|
||||
if this.web.Encryption.Cache != nil && this.web.Encryption.Cache.IsOn {
|
||||
// 生成缓存键:keyID + URL + contentHash
|
||||
contentHash := fmt.Sprintf("%x", bytesHash(jsBytes))
|
||||
cacheKey = fmt.Sprintf("encrypt_js_%s_%s_%s_%s", encryptionCacheVersion, keyID, this.URL(), contentHash)
|
||||
|
||||
cache := encryption.SharedEncryptionCache(
|
||||
int(this.web.Encryption.Cache.MaxSize),
|
||||
time.Duration(this.web.Encryption.Cache.TTL)*time.Second,
|
||||
)
|
||||
|
||||
if cached, ok := cache.Get(cacheKey); ok {
|
||||
return cached, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 加密脚本内容
|
||||
encrypted, err := this.encryptScript(jsContent, actualKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 生成元数据(k 是 keyID,用于缓存;key 是实际密钥,用于解密)
|
||||
meta := fmt.Sprintf(`{"k":"%s","key":"%s","t":%d,"alg":"xor"}`, keyID, actualKey, time.Now().Unix())
|
||||
|
||||
// 生成加密后的 JavaScript 代码(同步解密执行,保证脚本顺序)
|
||||
encryptedJS := fmt.Sprintf(`(function() {
|
||||
try {
|
||||
function xorDecodeToString(b64, key) {
|
||||
var bin = atob(b64);
|
||||
var out = new Uint8Array(bin.length);
|
||||
for (var i = 0; i < bin.length; i++) {
|
||||
out[i] = bin.charCodeAt(i) ^ key.charCodeAt(i %% key.length);
|
||||
}
|
||||
if (typeof TextDecoder !== 'undefined') {
|
||||
return new TextDecoder().decode(out);
|
||||
}
|
||||
var s = '';
|
||||
for (var j = 0; j < out.length; j++) {
|
||||
s += String.fromCharCode(out[j]);
|
||||
}
|
||||
return s;
|
||||
}
|
||||
var meta = %s;
|
||||
var code = xorDecodeToString("%s", meta.key);
|
||||
// 使用全局 eval,尽量保持和 <script> 一致的作用域
|
||||
window.eval(code);
|
||||
} catch (e) {
|
||||
console.error('WAF JS decrypt/execute failed', e);
|
||||
}
|
||||
})();`, meta, encrypted)
|
||||
|
||||
result := []byte(encryptedJS)
|
||||
|
||||
// 保存到缓存
|
||||
if this.web.Encryption.Cache != nil && this.web.Encryption.Cache.IsOn && len(cacheKey) > 0 {
|
||||
cache := encryption.SharedEncryptionCache(
|
||||
int(this.web.Encryption.Cache.MaxSize),
|
||||
time.Duration(this.web.Encryption.Cache.TTL)*time.Second,
|
||||
)
|
||||
cache.Set(cacheKey, result, this.web.Encryption.Cache.TTL)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// bytesHash 计算字节数组的简单哈希(用于缓存键)
|
||||
func bytesHash(data []byte) uint64 {
|
||||
var hash uint64 = 5381
|
||||
for _, b := range data {
|
||||
hash = ((hash << 5) + hash) + uint64(b)
|
||||
}
|
||||
return hash
|
||||
}
|
||||
119
EdgeNode/internal/nodes/http_request_error.go
Normal file
119
EdgeNode/internal/nodes/http_request_error.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const httpStatusPageTemplate = `<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<title>${status} ${statusMessage}</title>
|
||||
<meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
|
||||
<style>
|
||||
address { line-height: 1.8; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
|
||||
<h1>${status} ${statusMessage}</h1>
|
||||
<p>${message}</p>
|
||||
|
||||
<address>Connection: ${remoteAddr} (Client) -> ${serverAddr} (Server)</address>
|
||||
<address>Request ID: ${requestId}.</address>
|
||||
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
func (this *HTTPRequest) write404() {
|
||||
this.writeCode(http.StatusNotFound, "", "")
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) writeCode(statusCode int, enMessage string, zhMessage string) {
|
||||
if this.doPage(statusCode) {
|
||||
return
|
||||
}
|
||||
|
||||
var pageContent = configutils.ParseVariables(httpStatusPageTemplate, func(varName string) (value string) {
|
||||
switch varName {
|
||||
case "status":
|
||||
return types.String(statusCode)
|
||||
case "statusMessage":
|
||||
return http.StatusText(statusCode)
|
||||
case "message":
|
||||
var acceptLanguages = this.RawReq.Header.Get("Accept-Language")
|
||||
if len(acceptLanguages) > 0 {
|
||||
var index = strings.Index(acceptLanguages, ",")
|
||||
if index > 0 {
|
||||
var firstLanguage = acceptLanguages[:index]
|
||||
if firstLanguage == "zh-CN" {
|
||||
return zhMessage
|
||||
}
|
||||
}
|
||||
}
|
||||
return enMessage
|
||||
}
|
||||
return this.Format("${" + varName + "}")
|
||||
})
|
||||
|
||||
this.ProcessResponseHeaders(this.writer.Header(), statusCode)
|
||||
this.writer.WriteHeader(statusCode)
|
||||
|
||||
_, _ = this.writer.Write([]byte(pageContent))
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) write50x(err error, statusCode int, enMessage string, zhMessage string, canTryStale bool) {
|
||||
if err != nil {
|
||||
this.addError(err)
|
||||
}
|
||||
|
||||
// 尝试从缓存中恢复
|
||||
if canTryStale &&
|
||||
this.cacheCanTryStale &&
|
||||
this.web.Cache.Stale != nil &&
|
||||
this.web.Cache.Stale.IsOn &&
|
||||
(len(this.web.Cache.Stale.Status) == 0 || lists.ContainsInt(this.web.Cache.Stale.Status, statusCode)) {
|
||||
var ok = this.doCacheRead(true)
|
||||
if ok {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 显示自定义页面
|
||||
if this.doPage(statusCode) {
|
||||
return
|
||||
}
|
||||
|
||||
// 内置HTML模板
|
||||
var pageContent = configutils.ParseVariables(httpStatusPageTemplate, func(varName string) (value string) {
|
||||
switch varName {
|
||||
case "status":
|
||||
return types.String(statusCode)
|
||||
case "statusMessage":
|
||||
return http.StatusText(statusCode)
|
||||
case "requestId":
|
||||
return this.requestId
|
||||
case "message":
|
||||
var acceptLanguages = this.RawReq.Header.Get("Accept-Language")
|
||||
if len(acceptLanguages) > 0 {
|
||||
var index = strings.Index(acceptLanguages, ",")
|
||||
if index > 0 {
|
||||
var firstLanguage = acceptLanguages[:index]
|
||||
if firstLanguage == "zh-CN" {
|
||||
return "网站出了一点小问题,原因:" + zhMessage + "。"
|
||||
}
|
||||
}
|
||||
}
|
||||
return "The site is unavailable now, cause: " + enMessage + "."
|
||||
}
|
||||
return this.Format("${" + varName + "}")
|
||||
})
|
||||
|
||||
this.ProcessResponseHeaders(this.writer.Header(), statusCode)
|
||||
this.writer.WriteHeader(statusCode)
|
||||
|
||||
_, _ = this.writer.Write([]byte(pageContent))
|
||||
}
|
||||
11
EdgeNode/internal/nodes/http_request_events.go
Normal file
11
EdgeNode/internal/nodes/http_request_events.go
Normal file
@@ -0,0 +1,11 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build !script
|
||||
// +build !script
|
||||
|
||||
package nodes
|
||||
|
||||
func (this *HTTPRequest) onInit() {
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) onRequest() {
|
||||
}
|
||||
124
EdgeNode/internal/nodes/http_request_events_script_plus.go
Normal file
124
EdgeNode/internal/nodes/http_request_events_script_plus.go
Normal file
@@ -0,0 +1,124 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build script
|
||||
// +build script
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"rogchap.com/v8go"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (this *HTTPRequest) onInit() {
|
||||
this.fireScriptEvent("init")
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) onRequest() {
|
||||
if this.isDone || this.writer.isFinished {
|
||||
return
|
||||
}
|
||||
this.fireScriptEvent("request")
|
||||
}
|
||||
|
||||
// 触发事件
|
||||
func (this *HTTPRequest) fireScriptEvent(eventType string) {
|
||||
if SharedJSPool == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if this.web.RequestScripts == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var group *serverconfigs.ScriptGroupConfig
|
||||
if eventType == "init" && this.web.RequestScripts.InitGroup != nil && this.web.RequestScripts.InitGroup.IsOn {
|
||||
group = this.web.RequestScripts.InitGroup
|
||||
} else if eventType == "request" && this.web.RequestScripts.RequestGroup != nil && this.web.RequestScripts.RequestGroup.IsOn {
|
||||
group = this.web.RequestScripts.RequestGroup
|
||||
}
|
||||
|
||||
if group == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, script := range group.Scripts {
|
||||
if this.isDone {
|
||||
return
|
||||
}
|
||||
|
||||
if !script.IsOn {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(script.RealCode()) > 0 {
|
||||
this.runScript(eventType, script.RealCode())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) runScript(eventType string, script string) {
|
||||
ctx, err := SharedJSPool.GetContext()
|
||||
if err != nil {
|
||||
remotelogs.Error("SCRIPT", "get context failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
ctx.SetServerId(this.ReqServer.Id)
|
||||
defer SharedJSPool.PutContext(ctx)
|
||||
|
||||
var reqObjectId = ctx.AddGoRequestObject(this)
|
||||
var respObjectId = ctx.AddGoResponseObject(this.writer)
|
||||
|
||||
script = `(function () {
|
||||
let req = new gojs.net.http.Request()
|
||||
req.setGoObject(` + types.String(reqObjectId) + `)
|
||||
|
||||
let resp = new gojs.net.http.Response()
|
||||
resp.setGoObject(` + types.String(respObjectId) + `)
|
||||
` + script + `
|
||||
})()`
|
||||
_, err = ctx.Run(script, "request."+eventType+".js")
|
||||
if err != nil {
|
||||
var errString = ""
|
||||
jsErr, ok := err.(*v8go.JSError)
|
||||
if ok {
|
||||
var pieces = strings.Split(jsErr.Location, ":")
|
||||
if len(pieces) < 3 {
|
||||
errString = err.Error()
|
||||
} else {
|
||||
if strings.HasPrefix(pieces[0], "gojs") {
|
||||
var line = types.Int(pieces[len(pieces)-2])
|
||||
var scriptLine = ctx.ReadLineFromLibrary(pieces[0], line)
|
||||
if len(scriptLine) > 0 {
|
||||
if len(scriptLine) > 256 {
|
||||
scriptLine = scriptLine[:256] + "..."
|
||||
}
|
||||
|
||||
errString = jsErr.Error() + ", location: " + strings.Join(pieces[:len(pieces)-2], ":") + ":" + types.String(line) + ":" + pieces[len(pieces)-1] + ", code: " + scriptLine
|
||||
}
|
||||
}
|
||||
if len(errString) == 0 {
|
||||
var line = types.Int(pieces[len(pieces)-2])
|
||||
var scriptLines = strings.Split(script, "\n")
|
||||
var scriptLine = ""
|
||||
if len(scriptLines) > line-1 {
|
||||
scriptLine = scriptLines[line-1]
|
||||
if len(scriptLine) > 256 {
|
||||
scriptLine = scriptLine[:256] + "..."
|
||||
}
|
||||
}
|
||||
line -= 6 /* 6是req和resp构造用的行数 */
|
||||
errString = jsErr.Error() + ", location: " + strings.Join(pieces[:len(pieces)-2], ":") + ":" + types.String(line) + ":" + pieces[len(pieces)-1] + ", code: " + scriptLine
|
||||
}
|
||||
}
|
||||
} else {
|
||||
errString = err.Error()
|
||||
}
|
||||
remotelogs.ServerError(this.ReqServer.Id, "SCRIPT", "run "+eventType+" script failed: "+errString, nodeconfigs.NodeLogTypeRunScriptFailed, maps.Map{})
|
||||
return
|
||||
}
|
||||
}
|
||||
230
EdgeNode/internal/nodes/http_request_fastcgi.go
Normal file
230
EdgeNode/internal/nodes/http_request_fastcgi.go
Normal file
@@ -0,0 +1,230 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/iwind/TeaGo/rands"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"github.com/iwind/gofcgi/pkg/fcgi"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (this *HTTPRequest) doFastcgi() (shouldStop bool) {
|
||||
fastcgiList := []*serverconfigs.HTTPFastcgiConfig{}
|
||||
for _, fastcgi := range this.web.FastcgiList {
|
||||
if !fastcgi.IsOn {
|
||||
continue
|
||||
}
|
||||
fastcgiList = append(fastcgiList, fastcgi)
|
||||
}
|
||||
if len(fastcgiList) == 0 {
|
||||
return false
|
||||
}
|
||||
shouldStop = true
|
||||
fastcgi := fastcgiList[rands.Int(0, len(fastcgiList)-1)]
|
||||
|
||||
env := fastcgi.FilterParams()
|
||||
if !env.Has("DOCUMENT_ROOT") {
|
||||
env["DOCUMENT_ROOT"] = ""
|
||||
}
|
||||
|
||||
if !env.Has("REMOTE_ADDR") {
|
||||
env["REMOTE_ADDR"] = this.requestRemoteAddr(true)
|
||||
}
|
||||
if !env.Has("QUERY_STRING") {
|
||||
u, err := url.ParseRequestURI(this.uri)
|
||||
if err == nil {
|
||||
env["QUERY_STRING"] = u.RawQuery
|
||||
} else {
|
||||
env["QUERY_STRING"] = this.RawReq.URL.RawQuery
|
||||
}
|
||||
}
|
||||
if !env.Has("SERVER_NAME") {
|
||||
env["SERVER_NAME"] = this.ReqHost
|
||||
}
|
||||
if !env.Has("REQUEST_URI") {
|
||||
env["REQUEST_URI"] = this.uri
|
||||
}
|
||||
if !env.Has("HOST") {
|
||||
env["HOST"] = this.ReqHost
|
||||
}
|
||||
|
||||
if len(this.ServerAddr) > 0 {
|
||||
if !env.Has("SERVER_ADDR") {
|
||||
env["SERVER_ADDR"] = this.ServerAddr
|
||||
}
|
||||
if !env.Has("SERVER_PORT") {
|
||||
_, port, err := net.SplitHostPort(this.ServerAddr)
|
||||
if err == nil {
|
||||
env["SERVER_PORT"] = port
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 设置为持久化连接
|
||||
var requestConn = this.RawReq.Context().Value(HTTPConnContextKey)
|
||||
if requestConn == nil {
|
||||
return
|
||||
}
|
||||
requestClientConn, ok := requestConn.(ClientConnInterface)
|
||||
if ok {
|
||||
requestClientConn.SetIsPersistent(true)
|
||||
}
|
||||
|
||||
// 连接池配置
|
||||
poolSize := fastcgi.PoolSize
|
||||
if poolSize <= 0 {
|
||||
poolSize = 32
|
||||
}
|
||||
|
||||
client, err := fcgi.SharedPool(fastcgi.Network(), fastcgi.RealAddress(), uint(poolSize)).Client()
|
||||
if err != nil {
|
||||
this.write50x(err, http.StatusInternalServerError, "Failed to create Fastcgi pool", "Fastcgi池生成失败", false)
|
||||
return
|
||||
}
|
||||
|
||||
// 请求相关
|
||||
if !env.Has("REQUEST_METHOD") {
|
||||
env["REQUEST_METHOD"] = this.RawReq.Method
|
||||
}
|
||||
if !env.Has("CONTENT_LENGTH") {
|
||||
env["CONTENT_LENGTH"] = fmt.Sprintf("%d", this.RawReq.ContentLength)
|
||||
}
|
||||
if !env.Has("CONTENT_TYPE") {
|
||||
env["CONTENT_TYPE"] = this.RawReq.Header.Get("Content-Type")
|
||||
}
|
||||
if !env.Has("SERVER_SOFTWARE") {
|
||||
env["SERVER_SOFTWARE"] = teaconst.ProductName + "/v" + teaconst.Version
|
||||
}
|
||||
|
||||
// 处理SCRIPT_FILENAME
|
||||
scriptPath := env.GetString("SCRIPT_FILENAME")
|
||||
if len(scriptPath) > 0 && !strings.Contains(scriptPath, "/") && !strings.Contains(scriptPath, "\\") {
|
||||
env["SCRIPT_FILENAME"] = env.GetString("DOCUMENT_ROOT") + Tea.DS + scriptPath
|
||||
}
|
||||
scriptFilename := filepath.Base(this.RawReq.URL.Path)
|
||||
|
||||
// PATH_INFO
|
||||
pathInfoReg := fastcgi.PathInfoRegexp()
|
||||
pathInfo := ""
|
||||
if pathInfoReg != nil {
|
||||
matches := pathInfoReg.FindStringSubmatch(this.RawReq.URL.Path)
|
||||
countMatches := len(matches)
|
||||
if countMatches == 1 {
|
||||
pathInfo = matches[0]
|
||||
} else if countMatches == 2 {
|
||||
pathInfo = matches[1]
|
||||
} else if countMatches > 2 {
|
||||
scriptFilename = matches[1]
|
||||
pathInfo = matches[2]
|
||||
}
|
||||
|
||||
if !env.Has("PATH_INFO") {
|
||||
env["PATH_INFO"] = pathInfo
|
||||
}
|
||||
}
|
||||
|
||||
this.addVarMapping(map[string]string{
|
||||
"fastcgi.documentRoot": env.GetString("DOCUMENT_ROOT"),
|
||||
"fastcgi.filename": scriptFilename,
|
||||
"fastcgi.pathInfo": pathInfo,
|
||||
})
|
||||
|
||||
params := map[string]string{}
|
||||
for key, value := range env {
|
||||
params[key] = this.Format(types.String(value))
|
||||
}
|
||||
|
||||
this.processRequestHeaders(this.RawReq.Header)
|
||||
for k, v := range this.RawReq.Header {
|
||||
if k == "Connection" {
|
||||
continue
|
||||
}
|
||||
for _, subV := range v {
|
||||
params["HTTP_"+strings.ToUpper(strings.Replace(k, "-", "_", -1))] = subV
|
||||
}
|
||||
}
|
||||
|
||||
host, found := params["HTTP_HOST"]
|
||||
if !found || len(host) == 0 {
|
||||
params["HTTP_HOST"] = this.ReqHost
|
||||
}
|
||||
|
||||
fcgiReq := fcgi.NewRequest()
|
||||
fcgiReq.SetTimeout(fastcgi.ReadTimeoutDuration())
|
||||
fcgiReq.SetParams(params)
|
||||
fcgiReq.SetBody(this.RawReq.Body, uint32(this.requestLength()))
|
||||
|
||||
resp, stderr, err := client.Call(fcgiReq)
|
||||
if err != nil {
|
||||
this.write50x(err, http.StatusInternalServerError, "Failed to read Fastcgi", "读取Fastcgi失败", false)
|
||||
return
|
||||
}
|
||||
|
||||
if len(stderr) > 0 {
|
||||
err := errors.New("Fastcgi Error: " + strings.TrimSpace(string(stderr)) + " script: " + maps.NewMap(params).GetString("SCRIPT_FILENAME"))
|
||||
this.write50x(err, http.StatusInternalServerError, "Failed to read Fastcgi", "读取Fastcgi失败", false)
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
// 设置Charset
|
||||
// TODO 这里应该可以设置文本类型的列表,以及是否强制覆盖所有文本类型的字符集
|
||||
if this.web.Charset != nil && this.web.Charset.IsOn && len(this.web.Charset.Charset) > 0 {
|
||||
contentTypes, ok := resp.Header["Content-Type"]
|
||||
if ok && len(contentTypes) > 0 {
|
||||
contentType := contentTypes[0]
|
||||
if _, found := textMimeMap[contentType]; found {
|
||||
resp.Header["Content-Type"][0] = contentType + "; charset=" + this.web.Charset.Charset
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 响应Header
|
||||
this.writer.AddHeaders(resp.Header)
|
||||
this.ProcessResponseHeaders(this.writer.Header(), resp.StatusCode)
|
||||
|
||||
// 准备
|
||||
this.writer.Prepare(resp, resp.ContentLength, resp.StatusCode, true)
|
||||
|
||||
// 设置响应代码
|
||||
this.writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
// 输出到客户端
|
||||
var pool = this.bytePool(resp.ContentLength)
|
||||
var buf = pool.Get()
|
||||
_, err = io.CopyBuffer(this.writer, resp.Body, buf.Bytes)
|
||||
pool.Put(buf)
|
||||
|
||||
closeErr := resp.Body.Close()
|
||||
if closeErr != nil {
|
||||
remotelogs.Warn("HTTP_REQUEST_FASTCGI", closeErr.Error())
|
||||
}
|
||||
|
||||
if err != nil && err != io.EOF {
|
||||
remotelogs.Warn("HTTP_REQUEST_FASTCGI", err.Error())
|
||||
this.addError(err)
|
||||
}
|
||||
|
||||
// 是否成功结束
|
||||
if err == nil && closeErr == nil {
|
||||
this.writer.SetOk()
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
35
EdgeNode/internal/nodes/http_request_health_check.go
Normal file
35
EdgeNode/internal/nodes/http_request_health_check.go
Normal file
@@ -0,0 +1,35 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeutils"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
)
|
||||
|
||||
// 健康检查
|
||||
func (this *HTTPRequest) doHealthCheck(key string, isHealthCheck *bool) (stop bool) {
|
||||
this.tags = append(this.tags, "healthCheck")
|
||||
|
||||
this.RawReq.Header.Del(serverconfigs.HealthCheckHeaderName)
|
||||
|
||||
data, err := nodeutils.Base64DecodeMap(key)
|
||||
if err != nil {
|
||||
remotelogs.Error("HTTP_REQUEST_HEALTH_CHECK", "decode key failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
*isHealthCheck = true
|
||||
|
||||
this.web.StatRef = nil
|
||||
|
||||
if !data.GetBool("accessLogIsOn") {
|
||||
this.disableLog = true
|
||||
}
|
||||
|
||||
if data.GetBool("onlyBasicRequest") {
|
||||
return true
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
16
EdgeNode/internal/nodes/http_request_hls.go
Normal file
16
EdgeNode/internal/nodes/http_request_hls.go
Normal file
@@ -0,0 +1,16 @@
|
||||
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
//go:build !plus
|
||||
|
||||
package nodes
|
||||
|
||||
import "net/http"
|
||||
|
||||
func (this *HTTPRequest) processHLSBefore() (blocked bool) {
|
||||
// stub
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) processM3u8Response(resp *http.Response) error {
|
||||
// stub
|
||||
return nil
|
||||
}
|
||||
190
EdgeNode/internal/nodes/http_request_hls_plus.go
Normal file
190
EdgeNode/internal/nodes/http_request_hls_plus.go
Normal file
@@ -0,0 +1,190 @@
|
||||
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
//go:build plus
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/compressions"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/readers"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultM3u8EncryptParam = "ge_m3u8_token"
|
||||
m3u8EncryptExtension = ".ge-m3u8-key"
|
||||
)
|
||||
|
||||
// process hls
|
||||
func (this *HTTPRequest) processHLSBefore() (blocked bool) {
|
||||
var requestPath = this.RawReq.URL.Path
|
||||
|
||||
// .ge-m3u8-key
|
||||
if !strings.HasSuffix(requestPath, m3u8EncryptExtension) {
|
||||
return
|
||||
}
|
||||
|
||||
blocked = true
|
||||
|
||||
tokenBytes, base64Err := base64.StdEncoding.DecodeString(this.RawReq.URL.Query().Get(defaultM3u8EncryptParam))
|
||||
if base64Err != nil {
|
||||
this.writeCode(http.StatusBadRequest, "invalid m3u8 token: bad format", "invalid m3u8 token: bad format")
|
||||
return
|
||||
}
|
||||
|
||||
if len(tokenBytes) != 32 {
|
||||
this.writeCode(http.StatusBadRequest, "invalid m3u8 token: bad length", "invalid m3u8 token: bad length")
|
||||
return
|
||||
}
|
||||
|
||||
_, _ = this.writer.Write(tokenBytes[:16])
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// process m3u8 file
|
||||
func (this *HTTPRequest) processM3u8Response(resp *http.Response) error {
|
||||
var requestPath = this.RawReq.URL.Path
|
||||
|
||||
// .m3u8
|
||||
if strings.HasSuffix(requestPath, ".m3u8") &&
|
||||
this.web.HLS.Encrypting.MatchURL(this.URL()) {
|
||||
return this.processM3u8File(resp)
|
||||
}
|
||||
|
||||
// .ts
|
||||
if strings.HasSuffix(requestPath, ".ts") {
|
||||
var token = this.RawReq.URL.Query().Get(defaultM3u8EncryptParam)
|
||||
if len(token) > 0 {
|
||||
return this.processTSFile(resp, token)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) processTSFile(resp *http.Response, token string) error {
|
||||
rawToken, err := base64.StdEncoding.DecodeString(token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(rawToken) != 32 {
|
||||
return errors.New("invalid token length")
|
||||
}
|
||||
|
||||
var key = rawToken[:16]
|
||||
var iv = rawToken[16:]
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create cipher failed: %w", err)
|
||||
}
|
||||
var stream = cipher.NewCBCEncrypter(block, iv)
|
||||
var reader = readers.NewFilterReaderCloser(resp.Body)
|
||||
var blockSize = stream.BlockSize()
|
||||
reader.Add(func(p []byte, readErr error) error {
|
||||
var l = len(p)
|
||||
if l == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var mod = l % blockSize
|
||||
if mod != 0 {
|
||||
p = append(p, bytes.Repeat([]byte{'0'}, blockSize-mod)...)
|
||||
}
|
||||
stream.CryptBlocks(p, p)
|
||||
return readErr
|
||||
})
|
||||
resp.Body = reader
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) processM3u8File(resp *http.Response) error {
|
||||
const maxSize = 1 << 20
|
||||
|
||||
// 检查内容长度
|
||||
if resp.Body == nil || resp.ContentLength == 0 || resp.ContentLength > maxSize {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 解压缩
|
||||
compressions.WrapHTTPResponse(resp)
|
||||
|
||||
// 读取内容
|
||||
data, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
|
||||
if err != nil {
|
||||
_ = resp.Body.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
// 超出尺寸直接返回
|
||||
if len(data) == maxSize {
|
||||
resp.Body = io.NopCloser(io.MultiReader(bytes.NewBuffer(data), resp.Body))
|
||||
return nil
|
||||
}
|
||||
|
||||
var lines = bytes.Split(data, []byte{'\n'})
|
||||
var addedKey = false
|
||||
|
||||
var ivBytes = make([]byte, 16)
|
||||
var keyBytes = make([]byte, 16)
|
||||
|
||||
_, ivErr := rand.Read(ivBytes)
|
||||
_, keyErr := rand.Read(keyBytes)
|
||||
if ivErr != nil || keyErr != nil {
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(data))
|
||||
return nil
|
||||
}
|
||||
|
||||
var ivString = fmt.Sprintf("%x", ivBytes)
|
||||
var token = url.QueryEscape(base64.StdEncoding.EncodeToString(append(keyBytes, ivBytes...)))
|
||||
for index, line := range lines {
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if bytes.HasPrefix(line, []byte("#EXT-X-KEY")) {
|
||||
goto returnData
|
||||
}
|
||||
|
||||
if !addedKey && bytes.HasPrefix(line, []byte("#EXTINF")) {
|
||||
this.URL()
|
||||
var keyPath = strings.TrimSuffix(this.RawReq.URL.Path, ".m3u8") + m3u8EncryptExtension + "?" + defaultM3u8EncryptParam + "=" + token
|
||||
lines[index] = append([]byte("#EXT-X-KEY:METHOD=AES-128,URI=\""+this.requestScheme()+"://"+this.ReqHost+keyPath+"\",IV=0x"+ivString+
|
||||
"\n"), line...)
|
||||
addedKey = true
|
||||
continue
|
||||
}
|
||||
|
||||
if line[0] != '#' && bytes.Contains(line, []byte(".ts")) {
|
||||
if bytes.ContainsRune(line, '?') {
|
||||
lines[index] = append(line, []byte("&"+defaultM3u8EncryptParam+"="+token)...)
|
||||
} else {
|
||||
lines[index] = append(line, []byte("?"+defaultM3u8EncryptParam+"="+token)...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if addedKey {
|
||||
this.tags = append(this.tags, "m3u8")
|
||||
}
|
||||
|
||||
returnData:
|
||||
data = bytes.Join(lines, []byte{'\n'})
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(data))
|
||||
resp.ContentLength = int64(len(data))
|
||||
resp.Header.Set("Content-Length", types.String(resp.ContentLength))
|
||||
|
||||
return nil
|
||||
}
|
||||
228
EdgeNode/internal/nodes/http_request_host_redirect.go
Normal file
228
EdgeNode/internal/nodes/http_request_host_redirect.go
Normal file
@@ -0,0 +1,228 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// 主机地址快速跳转
|
||||
func (this *HTTPRequest) doHostRedirect() (blocked bool) {
|
||||
var urlPath = this.RawReq.URL.Path
|
||||
if this.web.MergeSlashes {
|
||||
urlPath = utils.CleanPath(urlPath)
|
||||
}
|
||||
for _, u := range this.web.HostRedirects {
|
||||
if !u.IsOn {
|
||||
continue
|
||||
}
|
||||
if !u.MatchRequest(this.Format) {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(u.ExceptDomains) > 0 && configutils.MatchDomains(u.ExceptDomains, this.ReqHost) {
|
||||
continue
|
||||
}
|
||||
if len(u.OnlyDomains) > 0 && !configutils.MatchDomains(u.OnlyDomains, this.ReqHost) {
|
||||
continue
|
||||
}
|
||||
|
||||
var status = u.Status
|
||||
if status <= 0 {
|
||||
if searchEngineRegex.MatchString(this.RawReq.UserAgent()) {
|
||||
status = http.StatusMovedPermanently
|
||||
} else {
|
||||
status = http.StatusTemporaryRedirect
|
||||
}
|
||||
}
|
||||
|
||||
var fullURL string
|
||||
if u.BeforeHasQuery() {
|
||||
fullURL = this.URL()
|
||||
} else {
|
||||
fullURL = this.requestScheme() + "://" + this.ReqHost + urlPath
|
||||
}
|
||||
|
||||
if len(u.Type) == 0 || u.Type == serverconfigs.HTTPHostRedirectTypeURL {
|
||||
if u.MatchPrefix { // 匹配前缀
|
||||
if strings.HasPrefix(fullURL, u.BeforeURL) {
|
||||
var afterURL = u.AfterURL
|
||||
if u.KeepRequestURI {
|
||||
afterURL += this.RawReq.URL.RequestURI()
|
||||
}
|
||||
|
||||
// 前后是否一致
|
||||
if fullURL == afterURL {
|
||||
return false
|
||||
}
|
||||
|
||||
this.ProcessResponseHeaders(this.writer.Header(), status)
|
||||
httpRedirect(this.writer, this.RawReq, afterURL, status)
|
||||
return true
|
||||
}
|
||||
} else if u.MatchRegexp { // 正则匹配
|
||||
var reg = u.BeforeURLRegexp()
|
||||
if reg == nil {
|
||||
continue
|
||||
}
|
||||
var matches = reg.FindStringSubmatch(fullURL)
|
||||
if len(matches) == 0 {
|
||||
continue
|
||||
}
|
||||
var afterURL = u.AfterURL
|
||||
for i, match := range matches {
|
||||
afterURL = strings.ReplaceAll(afterURL, "${"+strconv.Itoa(i)+"}", match)
|
||||
}
|
||||
|
||||
var subNames = reg.SubexpNames()
|
||||
if len(subNames) > 0 {
|
||||
for _, subName := range subNames {
|
||||
if len(subName) > 0 {
|
||||
index := reg.SubexpIndex(subName)
|
||||
if index > -1 {
|
||||
afterURL = strings.ReplaceAll(afterURL, "${"+subName+"}", matches[index])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 前后是否一致
|
||||
if fullURL == afterURL {
|
||||
return false
|
||||
}
|
||||
|
||||
if u.KeepArgs {
|
||||
var qIndex = strings.Index(this.uri, "?")
|
||||
if qIndex >= 0 {
|
||||
afterURL += this.uri[qIndex:]
|
||||
}
|
||||
}
|
||||
|
||||
this.ProcessResponseHeaders(this.writer.Header(), status)
|
||||
httpRedirect(this.writer, this.RawReq, afterURL, status)
|
||||
return true
|
||||
} else { // 精准匹配
|
||||
if fullURL == u.RealBeforeURL() {
|
||||
// 前后是否一致
|
||||
if fullURL == u.AfterURL {
|
||||
return false
|
||||
}
|
||||
|
||||
var afterURL = u.AfterURL
|
||||
if u.KeepArgs {
|
||||
var qIndex = strings.Index(this.uri, "?")
|
||||
if qIndex >= 0 {
|
||||
var afterQIndex = strings.Index(u.AfterURL, "?")
|
||||
if afterQIndex >= 0 {
|
||||
afterURL = u.AfterURL[:afterQIndex] + this.uri[qIndex:]
|
||||
} else {
|
||||
afterURL += this.uri[qIndex:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this.ProcessResponseHeaders(this.writer.Header(), status)
|
||||
httpRedirect(this.writer, this.RawReq, afterURL, status)
|
||||
return true
|
||||
}
|
||||
}
|
||||
} else if u.Type == serverconfigs.HTTPHostRedirectTypeDomain {
|
||||
if len(u.DomainAfter) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var reqHost = this.ReqHost
|
||||
|
||||
// 忽略跳转前端口
|
||||
if u.DomainBeforeIgnorePorts {
|
||||
h, _, err := net.SplitHostPort(reqHost)
|
||||
if err == nil && len(h) > 0 {
|
||||
reqHost = h
|
||||
}
|
||||
}
|
||||
|
||||
var scheme = u.DomainAfterScheme
|
||||
if len(scheme) == 0 {
|
||||
scheme = this.requestScheme()
|
||||
}
|
||||
if u.DomainsAll || configutils.MatchDomains(u.DomainsBefore, reqHost) {
|
||||
var afterURL = scheme + "://" + u.DomainAfter + urlPath
|
||||
if fullURL == afterURL {
|
||||
// 终止匹配
|
||||
return false
|
||||
}
|
||||
|
||||
// 如果跳转前后域名一致,则终止
|
||||
if u.DomainAfter == reqHost {
|
||||
return false
|
||||
}
|
||||
|
||||
this.ProcessResponseHeaders(this.writer.Header(), status)
|
||||
|
||||
// 参数
|
||||
var qIndex = strings.Index(this.uri, "?")
|
||||
if qIndex >= 0 {
|
||||
afterURL += this.uri[qIndex:]
|
||||
}
|
||||
|
||||
httpRedirect(this.writer, this.RawReq, afterURL, status)
|
||||
return true
|
||||
}
|
||||
} else if u.Type == serverconfigs.HTTPHostRedirectTypePort {
|
||||
if u.PortAfter <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var scheme = u.PortAfterScheme
|
||||
if len(scheme) == 0 {
|
||||
scheme = this.requestScheme()
|
||||
}
|
||||
|
||||
reqHost, reqPort, _ := net.SplitHostPort(this.ReqHost)
|
||||
if len(reqHost) == 0 {
|
||||
reqHost = this.ReqHost
|
||||
}
|
||||
if len(reqPort) == 0 {
|
||||
switch this.requestScheme() {
|
||||
case "http":
|
||||
reqPort = "80"
|
||||
case "https":
|
||||
reqPort = "443"
|
||||
}
|
||||
}
|
||||
|
||||
// 如果跳转前后端口一致,则终止
|
||||
if reqPort == types.String(u.PortAfter) {
|
||||
return false
|
||||
}
|
||||
|
||||
var containsPort bool
|
||||
if u.PortsAll {
|
||||
containsPort = true
|
||||
} else {
|
||||
containsPort = u.ContainsPort(types.Int(reqPort))
|
||||
}
|
||||
if containsPort {
|
||||
var newReqHost = reqHost
|
||||
if !((scheme == "http" && u.PortAfter == 80) || (scheme == "https" && u.PortAfter == 443)) {
|
||||
newReqHost += ":" + types.String(u.PortAfter)
|
||||
}
|
||||
var afterURL = scheme + "://" + newReqHost + urlPath
|
||||
if fullURL == afterURL {
|
||||
// 终止匹配
|
||||
return false
|
||||
}
|
||||
|
||||
this.ProcessResponseHeaders(this.writer.Header(), status)
|
||||
httpRedirect(this.writer, this.RawReq, afterURL, status)
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
10
EdgeNode/internal/nodes/http_request_http3.go
Normal file
10
EdgeNode/internal/nodes/http_request_http3.go
Normal file
@@ -0,0 +1,10 @@
|
||||
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
//go:build !plus
|
||||
|
||||
package nodes
|
||||
|
||||
import "net/http"
|
||||
|
||||
func (this *HTTPRequest) processHTTP3Headers(respHeader http.Header) {
|
||||
// stub
|
||||
}
|
||||
13
EdgeNode/internal/nodes/http_request_http3_plus.go
Normal file
13
EdgeNode/internal/nodes/http_request_http3_plus.go
Normal file
@@ -0,0 +1,13 @@
|
||||
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
//go:build plus
|
||||
|
||||
package nodes
|
||||
|
||||
import "net/http"
|
||||
|
||||
func (this *HTTPRequest) processHTTP3Headers(respHeader http.Header) {
|
||||
if this.ReqServer == nil || this.ReqServer.ClusterId <= 0 {
|
||||
return
|
||||
}
|
||||
sharedHTTP3Manager.ProcessHTTP3Headers(this.RawReq.UserAgent(), respHeader, this.ReqServer.ClusterId)
|
||||
}
|
||||
41
EdgeNode/internal/nodes/http_request_limit.go
Normal file
41
EdgeNode/internal/nodes/http_request_limit.go
Normal file
@@ -0,0 +1,41 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func (this *HTTPRequest) doRequestLimit() (shouldStop bool) {
|
||||
// 是否在全局名单中
|
||||
_, isInAllowedList, _ := iplibrary.AllowIP(this.RemoteAddr(), this.ReqServer.Id)
|
||||
if isInAllowedList {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查请求Body尺寸
|
||||
// TODO 处理分片提交的内容
|
||||
if this.web.RequestLimit.MaxBodyBytes() > 0 &&
|
||||
this.RawReq.ContentLength > this.web.RequestLimit.MaxBodyBytes() {
|
||||
this.writeCode(http.StatusRequestEntityTooLarge, "", "")
|
||||
return true
|
||||
}
|
||||
|
||||
// 设置连接相关参数
|
||||
if this.web.RequestLimit.MaxConns > 0 || this.web.RequestLimit.MaxConnsPerIP > 0 {
|
||||
var requestConn = this.RawReq.Context().Value(HTTPConnContextKey)
|
||||
if requestConn != nil {
|
||||
clientConn, ok := requestConn.(ClientConnInterface)
|
||||
if ok && !clientConn.IsBound() {
|
||||
if !clientConn.Bind(this.ReqServer.Id, this.requestRemoteAddr(true), this.web.RequestLimit.MaxConns, this.web.RequestLimit.MaxConnsPerIP) {
|
||||
this.writeCode(http.StatusTooManyRequests, "", "")
|
||||
this.Close()
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
25
EdgeNode/internal/nodes/http_request_ln.go
Normal file
25
EdgeNode/internal/nodes/http_request_ln.go
Normal file
@@ -0,0 +1,25 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build !plus
|
||||
// +build !plus
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
)
|
||||
|
||||
const (
|
||||
LNExpiresHeader = "X-Edge-Ln-Expires"
|
||||
)
|
||||
|
||||
func existsLnNodeIP(nodeIP string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) checkLnRequest() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) getLnOrigin(excludingNodeIds []int64, urlHash uint64) (originConfig *serverconfigs.OriginConfig, lnNodeId int64, hasMultipleNodes bool) {
|
||||
return nil, 0, false
|
||||
}
|
||||
199
EdgeNode/internal/nodes/http_request_ln_key_plus.go
Normal file
199
EdgeNode/internal/nodes/http_request_ln_key_plus.go
Normal file
@@ -0,0 +1,199 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/reflect/protoreflect"
|
||||
"google.golang.org/protobuf/runtime/protoimpl"
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
|
||||
func UnmarshalLnRequestKey(data []byte, key proto.Message) error {
|
||||
return proto.Unmarshal(data, key)
|
||||
}
|
||||
|
||||
// LnRequestKey
|
||||
// 使用数字作为json标识意在减少编码后的字节数
|
||||
|
||||
const (
|
||||
// Verify that this generated code is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
type LnRequestKey struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Timestamp int64 `protobuf:"varint,1,opt,name=Timestamp,proto3" json:"1,omitempty"`
|
||||
NodeId int64 `protobuf:"varint,2,opt,name=NodeId,proto3" json:"2,omitempty"`
|
||||
RequestId string `protobuf:"bytes,3,opt,name=RequestId,proto3" json:"3,omitempty"`
|
||||
RemoteAddr string `protobuf:"bytes,4,opt,name=RemoteAddr,proto3" json:"4,omitempty"`
|
||||
URLMd5 string `protobuf:"bytes,5,opt,name=URLMd5,proto3" json:"5,omitempty"`
|
||||
Method string `protobuf:"bytes,6,opt,name=Method,proto3" json:"6,omitempty"`
|
||||
}
|
||||
|
||||
func (x *LnRequestKey) Reset() {
|
||||
*x = LnRequestKey{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_models_model_ln_request_key_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *LnRequestKey) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*LnRequestKey) ProtoMessage() {}
|
||||
|
||||
func (x *LnRequestKey) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_models_model_ln_request_key_proto_msgTypes[0]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use LnRequestKey.ProtoReflect.Descriptor instead.
|
||||
func (*LnRequestKey) Descriptor() ([]byte, []int) {
|
||||
return file_models_model_ln_request_key_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
func (x *LnRequestKey) GetTimestamp() int64 {
|
||||
if x != nil {
|
||||
return x.Timestamp
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (x *LnRequestKey) GetNodeId() int64 {
|
||||
if x != nil {
|
||||
return x.NodeId
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (x *LnRequestKey) GetRequestId() string {
|
||||
if x != nil {
|
||||
return x.RequestId
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *LnRequestKey) GetRemoteAddr() string {
|
||||
if x != nil {
|
||||
return x.RemoteAddr
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *LnRequestKey) GetURLMd5() string {
|
||||
if x != nil {
|
||||
return x.URLMd5
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *LnRequestKey) GetMethod() string {
|
||||
if x != nil {
|
||||
return x.Method
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var File_models_model_ln_request_key_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_models_model_ln_request_key_proto_rawDesc = []byte{
|
||||
0x0a, 0x21, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x73, 0x2f, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x5f, 0x6c,
|
||||
0x6e, 0x5f, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x6b, 0x65, 0x79, 0x2e, 0x70, 0x72,
|
||||
0x6f, 0x74, 0x6f, 0x12, 0x02, 0x70, 0x62, 0x22, 0xb2, 0x01, 0x0a, 0x0c, 0x4c, 0x6e, 0x52, 0x65,
|
||||
0x71, 0x75, 0x65, 0x73, 0x74, 0x4b, 0x65, 0x79, 0x12, 0x1c, 0x0a, 0x09, 0x54, 0x69, 0x6d, 0x65,
|
||||
0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x09, 0x54, 0x69, 0x6d,
|
||||
0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x6f, 0x64, 0x65, 0x49, 0x64,
|
||||
0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x6f, 0x64, 0x65, 0x49, 0x64, 0x12, 0x1c,
|
||||
0x0a, 0x09, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x49, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28,
|
||||
0x09, 0x52, 0x09, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x49, 0x64, 0x12, 0x1e, 0x0a, 0x0a,
|
||||
0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x41, 0x64, 0x64, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09,
|
||||
0x52, 0x0a, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x41, 0x64, 0x64, 0x72, 0x12, 0x16, 0x0a, 0x06,
|
||||
0x55, 0x52, 0x4c, 0x4d, 0x64, 0x35, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x55, 0x52,
|
||||
0x4c, 0x4d, 0x64, 0x35, 0x12, 0x16, 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x18, 0x06,
|
||||
0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x42, 0x06, 0x5a, 0x04,
|
||||
0x2e, 0x2f, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
file_models_model_ln_request_key_proto_rawDescOnce sync.Once
|
||||
file_models_model_ln_request_key_proto_rawDescData = file_models_model_ln_request_key_proto_rawDesc
|
||||
)
|
||||
|
||||
func file_models_model_ln_request_key_proto_rawDescGZIP() []byte {
|
||||
file_models_model_ln_request_key_proto_rawDescOnce.Do(func() {
|
||||
file_models_model_ln_request_key_proto_rawDescData = protoimpl.X.CompressGZIP(file_models_model_ln_request_key_proto_rawDescData)
|
||||
})
|
||||
return file_models_model_ln_request_key_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_models_model_ln_request_key_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
|
||||
var file_models_model_ln_request_key_proto_goTypes = []interface{}{
|
||||
(*LnRequestKey)(nil), // 0: pb.LnRequestKey
|
||||
}
|
||||
var file_models_model_ln_request_key_proto_depIdxs = []int32{
|
||||
0, // [0:0] is the sub-list for method output_type
|
||||
0, // [0:0] is the sub-list for method input_type
|
||||
0, // [0:0] is the sub-list for extension type_name
|
||||
0, // [0:0] is the sub-list for extension extendee
|
||||
0, // [0:0] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_models_model_ln_request_key_proto_init() }
|
||||
func file_models_model_ln_request_key_proto_init() {
|
||||
if File_models_model_ln_request_key_proto != nil {
|
||||
return
|
||||
}
|
||||
if !protoimpl.UnsafeEnabled {
|
||||
file_models_model_ln_request_key_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*LnRequestKey); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_models_model_ln_request_key_proto_rawDesc,
|
||||
NumEnums: 0,
|
||||
NumMessages: 1,
|
||||
NumExtensions: 0,
|
||||
NumServices: 0,
|
||||
},
|
||||
GoTypes: file_models_model_ln_request_key_proto_goTypes,
|
||||
DependencyIndexes: file_models_model_ln_request_key_proto_depIdxs,
|
||||
MessageInfos: file_models_model_ln_request_key_proto_msgTypes,
|
||||
}.Build()
|
||||
File_models_model_ln_request_key_proto = out.File
|
||||
file_models_model_ln_request_key_proto_rawDesc = nil
|
||||
file_models_model_ln_request_key_proto_goTypes = nil
|
||||
file_models_model_ln_request_key_proto_depIdxs = nil
|
||||
}
|
||||
|
||||
func (this *LnRequestKey) AsPB() ([]byte, error) {
|
||||
return proto.Marshal(this)
|
||||
}
|
||||
72
EdgeNode/internal/nodes/http_request_ln_key_plus_test.go
Normal file
72
EdgeNode/internal/nodes/http_request_ln_key_plus_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package nodes_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/nodes"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||||
stringutil "github.com/iwind/TeaGo/utils/string"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMarshalLnRequestKey(t *testing.T) {
|
||||
var key = &nodes.LnRequestKey{
|
||||
Timestamp: fasttime.Now().Unix(),
|
||||
NodeId: 1024,
|
||||
RequestId: "abc",
|
||||
RemoteAddr: "1.2.3.4",
|
||||
URLMd5: stringutil.Md5("123456"),
|
||||
Method: http.MethodPost,
|
||||
}
|
||||
|
||||
pbData, err := key.AsPB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(len(pbData), "bytes")
|
||||
|
||||
var key2 = &nodes.LnRequestKey{}
|
||||
err = nodes.UnmarshalLnRequestKey(pbData, key2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(key2)
|
||||
}
|
||||
|
||||
func TestMarshalLnRequestKey_JSON(t *testing.T) {
|
||||
var key = &nodes.LnRequestKey{
|
||||
Timestamp: fasttime.Now().Unix(),
|
||||
NodeId: 1024,
|
||||
RequestId: "abc",
|
||||
RemoteAddr: "1.2.3.4",
|
||||
URLMd5: stringutil.Md5("123456"),
|
||||
Method: http.MethodPost,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(string(data))
|
||||
}
|
||||
|
||||
func BenchmarkMarshalLnRequestKey(b *testing.B) {
|
||||
runtime.GOMAXPROCS(4)
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
var key = &nodes.LnRequestKey{
|
||||
Timestamp: fasttime.Now().Unix(),
|
||||
NodeId: 1024,
|
||||
RequestId: "abc",
|
||||
RemoteAddr: "1.2.3.4",
|
||||
URLMd5: stringutil.Md5("123456"),
|
||||
Method: http.MethodPost,
|
||||
}
|
||||
_, _ = key.AsPB()
|
||||
}
|
||||
})
|
||||
}
|
||||
271
EdgeNode/internal/nodes/http_request_ln_plus.go
Normal file
271
EdgeNode/internal/nodes/http_request_ln_plus.go
Normal file
@@ -0,0 +1,271 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build plus
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/iputils"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/encrypt"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||||
setutils "github.com/TeaOSLab/EdgeNode/internal/utils/sets"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/rands"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
stringutil "github.com/iwind/TeaGo/utils/string"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
LNKeyHeader = "X-Edge-Ln-Key"
|
||||
LNExpiresHeader = "X-Edge-Ln-Expires"
|
||||
LnExpiresSeconds int64 = 30
|
||||
)
|
||||
|
||||
var lnOriginConfigsMap = map[string]*serverconfigs.OriginConfig{}
|
||||
var lnEncryptMethodMap = map[string]encrypt.MethodInterface{}
|
||||
var lnRequestIdSet = setutils.NewFixedSet(32 * 1024)
|
||||
|
||||
var lnNodeIPMap = map[string]bool{} // node ip => bool
|
||||
var lnNodeIPLocker = &sync.RWMutex{}
|
||||
|
||||
var lnLocker = &sync.RWMutex{}
|
||||
|
||||
func existsLnNodeIP(nodeIP string) bool {
|
||||
lnNodeIPLocker.RLock()
|
||||
defer lnNodeIPLocker.RUnlock()
|
||||
return lnNodeIPMap[nodeIP]
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) checkLnRequest() bool {
|
||||
var keyString = this.RawReq.Header.Get(LNKeyHeader)
|
||||
if len(keyString) <= 20 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 删除
|
||||
this.RawReq.Header.Del(LNKeyHeader)
|
||||
|
||||
// 当前连接IP
|
||||
connIP, _, _ := net.SplitHostPort(this.RawReq.RemoteAddr)
|
||||
var realIP = this.RawReq.Header.Get("X-Real-Ip")
|
||||
if len(realIP) > 0 && !iputils.IsValid(realIP) {
|
||||
realIP = ""
|
||||
}
|
||||
|
||||
// 如果是在已经识别的L[n-1]节点IP列表中,无需再次验证
|
||||
if existsLnNodeIP(connIP) && len(realIP) > 0 {
|
||||
this.tags = append(this.tags, "L"+types.String(this.nodeConfig.Level))
|
||||
this.lnRemoteAddr = realIP
|
||||
return true
|
||||
}
|
||||
|
||||
// 如果已经在允许的IP中,则直接允许通过,无需再次验证
|
||||
// 这个需要放在 existsLnNodeIP() 检查之后
|
||||
if this.nodeConfig != nil && this.nodeConfig.IPIsAutoAllowed(connIP) && len(realIP) > 0 {
|
||||
this.tags = append(this.tags, "L"+types.String(this.nodeConfig.Level))
|
||||
this.lnRemoteAddr = realIP
|
||||
lnNodeIPLocker.Lock()
|
||||
lnNodeIPMap[connIP] = true
|
||||
lnNodeIPLocker.Unlock()
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查Key
|
||||
keyEncodedData, err := base64.StdEncoding.DecodeString(keyString)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var secret = this.nodeConfig.SecretHash()
|
||||
lnLocker.Lock()
|
||||
method, ok := lnEncryptMethodMap[secret]
|
||||
if !ok {
|
||||
method, err = encrypt.NewMethodInstance("aes-192-cfb", secret, secret)
|
||||
if err != nil {
|
||||
lnLocker.Unlock()
|
||||
return false
|
||||
}
|
||||
}
|
||||
lnLocker.Unlock()
|
||||
|
||||
keyData, err := method.Decrypt(keyEncodedData)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var key = &LnRequestKey{}
|
||||
err = UnmarshalLnRequestKey(keyData, key)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Method和URL需要一致
|
||||
if key.URLMd5 != stringutil.Md5(this.URL()) || key.Method != this.Method() {
|
||||
return false
|
||||
}
|
||||
|
||||
// N秒钟过期,这里要求两个节点时间相差不能超过此时间
|
||||
var currentUnixTime = fasttime.Now().Unix()
|
||||
if key.Timestamp < currentUnixTime-LnExpiresSeconds {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查请求ID唯一性
|
||||
// TODO 因为FixedSet是有限的,这里仍然无法避免重放攻击
|
||||
// TODO 而RequestId是并发的,并不能简单的对比大小
|
||||
// TODO 所以为了绝对的安全,需要使用HTTPS并检查子节点的IP
|
||||
if lnRequestIdSet.Has(key.RequestId) {
|
||||
return false
|
||||
}
|
||||
lnRequestIdSet.Push(key.RequestId)
|
||||
|
||||
this.lnRemoteAddr = key.RemoteAddr
|
||||
|
||||
this.tags = append(this.tags, "L"+types.String(this.nodeConfig.Level))
|
||||
|
||||
// 当前连接IP
|
||||
if len(connIP) > 0 {
|
||||
lnNodeIPLocker.Lock()
|
||||
lnNodeIPMap[connIP] = true
|
||||
lnNodeIPLocker.Unlock()
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) getLnOrigin(excludingNodeIds []int64, urlHash uint64) (originConfig *serverconfigs.OriginConfig, lnNodeId int64, hasMultipleNodes bool) {
|
||||
var parentNodesMap = this.nodeConfig.ParentNodes // 需要复制,防止运行过程中修改
|
||||
if len(parentNodesMap) == 0 {
|
||||
return nil, 0, false
|
||||
}
|
||||
|
||||
parentNodes, ok := parentNodesMap[this.ReqServer.ClusterId]
|
||||
var countParentNodes = len(parentNodes)
|
||||
if ok && countParentNodes > 0 {
|
||||
var parentNode *nodeconfigs.ParentNodeConfig
|
||||
|
||||
// 尝试顺序读取
|
||||
if len(excludingNodeIds) > 0 {
|
||||
for _, node := range parentNodes {
|
||||
if !lists.ContainsInt64(excludingNodeIds, node.Id) {
|
||||
parentNode = node
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试随机读取
|
||||
if parentNode == nil {
|
||||
if countParentNodes == 1 {
|
||||
parentNode = parentNodes[0]
|
||||
} else {
|
||||
var method = serverconfigs.LnRequestSchedulingMethodURLMapping
|
||||
if this.nodeConfig != nil {
|
||||
var globalServerConfig = this.nodeConfig.GlobalServerConfig // copy
|
||||
if globalServerConfig != nil {
|
||||
method = globalServerConfig.HTTPAll.LnRequestSchedulingMethod
|
||||
}
|
||||
}
|
||||
|
||||
switch method {
|
||||
case serverconfigs.LnRequestSchedulingMethodRandom:
|
||||
// 随机选取一个L2节点,有利于流量均衡,但同一份缓存可能会存在多个L2节点上,占用更多的空间
|
||||
parentNode = parentNodes[rands.Int(0, countParentNodes-1)]
|
||||
default:
|
||||
// 从固定的L2节点读取内容,优点是能够提升缓存命中率,缺点是可能会导致多个L2节点流量不均衡
|
||||
parentNode = parentNodes[urlHash%uint64(countParentNodes)]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
lnNodeId = parentNode.Id
|
||||
|
||||
var countAddrs = len(parentNode.Addrs)
|
||||
if countAddrs == 0 {
|
||||
return nil, 0, false
|
||||
}
|
||||
|
||||
var addr = parentNode.Addrs[rands.Int(0, countAddrs-1)]
|
||||
|
||||
var protocol = this.requestScheme() // http|https TODO 需要可以设置是否强制HTTPS回二级节点
|
||||
var portString = types.String(this.requestServerPort())
|
||||
var originKey = protocol + "@" + addr + "@" + portString
|
||||
lnLocker.RLock()
|
||||
config, ok := lnOriginConfigsMap[originKey]
|
||||
lnLocker.RUnlock()
|
||||
|
||||
if !ok {
|
||||
config = &serverconfigs.OriginConfig{
|
||||
Id: 0,
|
||||
IsOn: true,
|
||||
Addr: &serverconfigs.NetworkAddressConfig{
|
||||
Protocol: serverconfigs.Protocol(protocol),
|
||||
Host: addr,
|
||||
PortRange: portString,
|
||||
},
|
||||
IsOk: true,
|
||||
}
|
||||
|
||||
err := config.Init(context.Background())
|
||||
if err != nil {
|
||||
remotelogs.Error("HTTP_REQUEST", "create ln origin config failed: "+err.Error())
|
||||
return nil, 0, false
|
||||
}
|
||||
|
||||
lnLocker.Lock()
|
||||
lnOriginConfigsMap[originKey] = config
|
||||
lnLocker.Unlock()
|
||||
}
|
||||
|
||||
// 添加Header
|
||||
this.RawReq.Header.Set(LNKeyHeader, this.encodeLnKey(parentNode.SecretHash))
|
||||
|
||||
return config, lnNodeId, len(parentNodes) > 0
|
||||
}
|
||||
|
||||
return nil, 0, false
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) encodeLnKey(secretHash string) string {
|
||||
var key = &LnRequestKey{
|
||||
NodeId: this.nodeConfig.Id,
|
||||
RequestId: this.requestId,
|
||||
Timestamp: time.Now().Unix(),
|
||||
RemoteAddr: this.RemoteAddr(),
|
||||
URLMd5: stringutil.Md5(this.URL()),
|
||||
Method: this.Method(),
|
||||
}
|
||||
|
||||
data, err := key.AsPB()
|
||||
if err != nil {
|
||||
remotelogs.Error("HTTP_REQUEST", "ln request: encode key failed: "+err.Error())
|
||||
return ""
|
||||
}
|
||||
|
||||
lnLocker.Lock()
|
||||
method, ok := lnEncryptMethodMap[secretHash]
|
||||
if !ok {
|
||||
method, err = encrypt.NewMethodInstance("aes-192-cfb", secretHash, secretHash)
|
||||
if err != nil {
|
||||
remotelogs.Error("HTTP_REQUEST", "ln request: create encrypt method failed: "+err.Error())
|
||||
lnLocker.Unlock()
|
||||
return ""
|
||||
}
|
||||
lnEncryptMethodMap[secretHash] = method
|
||||
}
|
||||
lnLocker.Unlock()
|
||||
|
||||
dst, err := method.Encrypt(data)
|
||||
if err != nil {
|
||||
remotelogs.Error("HTTP_REQUEST", "ln request: encode key failed: "+err.Error())
|
||||
return ""
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(dst)
|
||||
}
|
||||
211
EdgeNode/internal/nodes/http_request_loader.go
Normal file
211
EdgeNode/internal/nodes/http_request_loader.go
Normal file
@@ -0,0 +1,211 @@
|
||||
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// serveWAFLoader 提供 WAF Loader JavaScript 文件
|
||||
func (this *HTTPRequest) serveWAFLoader() {
|
||||
loaderJS := `(function() {
|
||||
'use strict';
|
||||
|
||||
// 全局队列与执行器
|
||||
window.__WAF_Q__ = window.__WAF_Q__ || [];
|
||||
if (!window.__WAF_LOADER__) {
|
||||
window.__WAF_LOADER__ = {
|
||||
executing: false,
|
||||
execute: function() {
|
||||
if (this.executing) {
|
||||
return;
|
||||
}
|
||||
this.executing = true;
|
||||
var self = this;
|
||||
var queue = window.__WAF_Q__ || [];
|
||||
var runNext = function() {
|
||||
if (!queue.length) {
|
||||
self.executing = false;
|
||||
return;
|
||||
}
|
||||
var item = queue.shift();
|
||||
executeDecryptedCode(item.p, item.m, runNext);
|
||||
};
|
||||
runNext();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// 1. XOR 解码为字符串(不压缩,避免顺序问题)
|
||||
function xorDecodeToString(b64, key) {
|
||||
try {
|
||||
var bin = atob(b64);
|
||||
var out = new Uint8Array(bin.length);
|
||||
for (var i = 0; i < bin.length; i++) {
|
||||
out[i] = bin.charCodeAt(i) ^ key.charCodeAt(i % key.length);
|
||||
}
|
||||
if (typeof TextDecoder !== 'undefined') {
|
||||
return new TextDecoder().decode(out);
|
||||
}
|
||||
var s = '';
|
||||
for (var j = 0; j < out.length; j++) {
|
||||
s += String.fromCharCode(out[j]);
|
||||
}
|
||||
return s;
|
||||
} catch (e) {
|
||||
console.error('WAF Loader: xor decode failed', e);
|
||||
return '';
|
||||
}
|
||||
}
|
||||
|
||||
// 2. XOR 解密
|
||||
function decryptXOR(payload, key) {
|
||||
try {
|
||||
var binary = atob(payload);
|
||||
var output = [];
|
||||
var keyLen = key.length;
|
||||
if (keyLen === 0) {
|
||||
return '';
|
||||
}
|
||||
for (var i = 0; i < binary.length; i++) {
|
||||
var charCode = binary.charCodeAt(i) ^ key.charCodeAt(i % keyLen);
|
||||
output.push(String.fromCharCode(charCode));
|
||||
}
|
||||
return output.join('');
|
||||
} catch (e) {
|
||||
console.error('WAF Loader: Decrypt failed', e);
|
||||
return '';
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 执行解密后的代码
|
||||
function executeDecryptedCode(cipher, meta, done) {
|
||||
var finish = function() {
|
||||
if (typeof done === 'function') {
|
||||
done();
|
||||
}
|
||||
};
|
||||
try {
|
||||
if (!cipher || !meta || !meta.key) {
|
||||
console.error('WAF Loader: Missing cipher or meta.key');
|
||||
finish();
|
||||
return;
|
||||
}
|
||||
if (meta.alg !== 'xor') {
|
||||
console.error('WAF Loader: Unsupported alg', meta.alg);
|
||||
finish();
|
||||
return;
|
||||
}
|
||||
|
||||
// 1. XOR 解码为字符串
|
||||
var plainJS = xorDecodeToString(cipher, meta.key);
|
||||
if (!plainJS) {
|
||||
console.error('WAF Loader: XOR decode failed');
|
||||
finish();
|
||||
return;
|
||||
}
|
||||
|
||||
// 2. 执行解密后的代码(同步)
|
||||
try {
|
||||
// 使用全局 eval,尽量保持和 <script> 一致的作用域
|
||||
window.eval(plainJS);
|
||||
|
||||
// 3. 计算 Token 并握手
|
||||
calculateAndHandshake(plainJS);
|
||||
} catch (e) {
|
||||
console.error('WAF Loader: Execute failed', e);
|
||||
}
|
||||
finish();
|
||||
} catch (e) {
|
||||
console.error('WAF Loader: Security check failed', e);
|
||||
finish();
|
||||
}
|
||||
}
|
||||
|
||||
// 7. Token 计算和握手
|
||||
function base64EncodeUtf8(str) {
|
||||
try {
|
||||
return btoa(unescape(encodeURIComponent(str)));
|
||||
} catch (e) {
|
||||
console.error('WAF Loader: base64 encode failed', e);
|
||||
return '';
|
||||
}
|
||||
}
|
||||
|
||||
function calculateAndHandshake(decryptedJS) {
|
||||
try {
|
||||
// 计算 Token(简化示例)
|
||||
var tokenData = decryptedJS.substring(0, Math.min(100, decryptedJS.length)) + Date.now();
|
||||
var token = base64EncodeUtf8(tokenData).substring(0, 64); // 限制长度
|
||||
|
||||
// 发送握手请求
|
||||
fetch('/waf/handshake', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ token: token })
|
||||
}).then(function(resp) {
|
||||
if (resp.ok) {
|
||||
// 握手成功,设置全局 Token
|
||||
window.__WAF_TOKEN__ = token;
|
||||
|
||||
// 为后续请求设置 Header
|
||||
if (window.XMLHttpRequest) {
|
||||
var originalOpen = XMLHttpRequest.prototype.open;
|
||||
XMLHttpRequest.prototype.open = function() {
|
||||
originalOpen.apply(this, arguments);
|
||||
if (window.__WAF_TOKEN__) {
|
||||
this.setRequestHeader('X-WAF-TOKEN', window.__WAF_TOKEN__);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// 为 fetch 设置 Header
|
||||
var originalFetch = window.fetch;
|
||||
window.fetch = function() {
|
||||
var args = Array.prototype.slice.call(arguments);
|
||||
if (args.length > 1 && typeof args[1] === 'object') {
|
||||
if (!args[1].headers) {
|
||||
args[1].headers = {};
|
||||
}
|
||||
if (window.__WAF_TOKEN__) {
|
||||
args[1].headers['X-WAF-TOKEN'] = window.__WAF_TOKEN__;
|
||||
}
|
||||
} else {
|
||||
args[1] = {
|
||||
headers: {
|
||||
'X-WAF-TOKEN': window.__WAF_TOKEN__ || ''
|
||||
}
|
||||
};
|
||||
}
|
||||
return originalFetch.apply(this, args);
|
||||
};
|
||||
}
|
||||
}).catch(function(err) {
|
||||
console.error('WAF Loader: Handshake failed', err);
|
||||
});
|
||||
} catch (e) {
|
||||
console.error('WAF Loader: Calculate token failed', e);
|
||||
}
|
||||
}
|
||||
|
||||
// 8. 主逻辑
|
||||
if (window.__WAF_Q__ && window.__WAF_Q__.length) {
|
||||
window.__WAF_LOADER__.execute();
|
||||
} else {
|
||||
// 如果没有加密内容,等待一下再检查(可能是异步加载)
|
||||
setTimeout(function() {
|
||||
if (window.__WAF_Q__ && window.__WAF_Q__.length) {
|
||||
window.__WAF_LOADER__.execute();
|
||||
}
|
||||
}, 100);
|
||||
}
|
||||
})();`
|
||||
|
||||
this.writer.Header().Set("Content-Type", "application/javascript; charset=utf-8")
|
||||
this.writer.Header().Set("Cache-Control", "public, max-age=3600")
|
||||
this.writer.WriteHeader(http.StatusOK)
|
||||
_, _ = this.writer.WriteString(loaderJS)
|
||||
this.writer.SetOk()
|
||||
}
|
||||
188
EdgeNode/internal/nodes/http_request_log.go
Normal file
188
EdgeNode/internal/nodes/http_request_log.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// AccessLogMaxRequestBodySize 访问日志存储的请求内容最大尺寸 TODO 此值应该可以在访问日志页设置
|
||||
AccessLogMaxRequestBodySize = 2 << 20
|
||||
)
|
||||
|
||||
// 日志
|
||||
func (this *HTTPRequest) log() {
|
||||
// 检查全局配置
|
||||
if this.nodeConfig != nil && this.nodeConfig.GlobalServerConfig != nil && !this.nodeConfig.GlobalServerConfig.HTTPAccessLog.IsOn {
|
||||
return
|
||||
}
|
||||
|
||||
var ref *serverconfigs.HTTPAccessLogRef
|
||||
if !this.forceLog {
|
||||
if this.disableLog {
|
||||
return
|
||||
}
|
||||
|
||||
// 计算请求时间
|
||||
this.requestCost = time.Since(this.requestFromTime).Seconds()
|
||||
|
||||
ref = this.web.AccessLogRef
|
||||
if ref == nil {
|
||||
ref = serverconfigs.DefaultHTTPAccessLogRef
|
||||
}
|
||||
if !ref.IsOn {
|
||||
return
|
||||
}
|
||||
|
||||
if !ref.Match(this.writer.StatusCode()) {
|
||||
return
|
||||
}
|
||||
|
||||
if ref.FirewallOnly && this.firewallPolicyId == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 是否记录499
|
||||
if !ref.EnableClientClosed && this.writer.StatusCode() == 499 {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var addr = this.RawReq.RemoteAddr
|
||||
var index = strings.LastIndex(addr, ":")
|
||||
if index > 0 {
|
||||
addr = addr[:index]
|
||||
}
|
||||
|
||||
var serverGlobalConfig = this.nodeConfig.GlobalServerConfig
|
||||
|
||||
// 请求Cookie
|
||||
var cookies = map[string]string{}
|
||||
var enableCookies = false
|
||||
if serverGlobalConfig == nil || serverGlobalConfig.HTTPAccessLog.EnableCookies {
|
||||
enableCookies = true
|
||||
if ref == nil || ref.ContainsField(serverconfigs.HTTPAccessLogFieldCookie) {
|
||||
for _, cookie := range this.RawReq.Cookies() {
|
||||
cookies[cookie.Name] = cookie.Value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 请求Header
|
||||
var pbReqHeader = map[string]*pb.Strings{}
|
||||
if serverGlobalConfig == nil || serverGlobalConfig.HTTPAccessLog.EnableRequestHeaders {
|
||||
if ref == nil || ref.ContainsField(serverconfigs.HTTPAccessLogFieldHeader) {
|
||||
// 是否只记录通用Header
|
||||
var commonHeadersOnly = serverGlobalConfig != nil && serverGlobalConfig.HTTPAccessLog.CommonRequestHeadersOnly
|
||||
|
||||
for k, v := range this.RawReq.Header {
|
||||
if commonHeadersOnly && !serverconfigs.IsCommonRequestHeader(k) {
|
||||
continue
|
||||
}
|
||||
if !enableCookies && k == "Cookie" {
|
||||
continue
|
||||
}
|
||||
pbReqHeader[k] = &pb.Strings{Values: v}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 响应Header
|
||||
var pbResHeader = map[string]*pb.Strings{}
|
||||
if serverGlobalConfig == nil || serverGlobalConfig.HTTPAccessLog.EnableResponseHeaders {
|
||||
if ref == nil || ref.ContainsField(serverconfigs.HTTPAccessLogFieldSentHeader) {
|
||||
for k, v := range this.writer.Header() {
|
||||
pbResHeader[k] = &pb.Strings{Values: v}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 参数列表
|
||||
var queryString = ""
|
||||
if ref == nil || ref.ContainsField(serverconfigs.HTTPAccessLogFieldArg) {
|
||||
queryString = this.requestQueryString()
|
||||
}
|
||||
|
||||
// 浏览器
|
||||
var userAgent = ""
|
||||
if ref == nil || ref.ContainsField(serverconfigs.HTTPAccessLogFieldUserAgent) || ref.ContainsField(serverconfigs.HTTPAccessLogFieldExtend) {
|
||||
userAgent = this.RawReq.UserAgent()
|
||||
}
|
||||
|
||||
// 请求来源
|
||||
var referer = ""
|
||||
if ref == nil || ref.ContainsField(serverconfigs.HTTPAccessLogFieldReferer) {
|
||||
referer = this.RawReq.Referer()
|
||||
}
|
||||
|
||||
var accessLog = &pb.HTTPAccessLog{
|
||||
RequestId: this.requestId,
|
||||
NodeId: this.nodeConfig.Id,
|
||||
ServerId: this.ReqServer.Id,
|
||||
RemoteAddr: this.requestRemoteAddr(true),
|
||||
RawRemoteAddr: addr,
|
||||
RemotePort: int32(this.requestRemotePort()),
|
||||
RemoteUser: this.requestRemoteUser(),
|
||||
RequestURI: this.rawURI,
|
||||
RequestPath: this.Path(),
|
||||
RequestLength: this.requestLength(),
|
||||
RequestTime: this.requestCost,
|
||||
RequestMethod: this.RawReq.Method,
|
||||
RequestFilename: this.requestFilename(),
|
||||
Scheme: this.requestScheme(),
|
||||
Proto: this.RawReq.Proto,
|
||||
BytesSent: this.writer.SentBodyBytes(), // TODO 加上Header Size
|
||||
BodyBytesSent: this.writer.SentBodyBytes(),
|
||||
Status: int32(this.writer.StatusCode()),
|
||||
StatusMessage: "",
|
||||
TimeISO8601: this.requestFromTime.Format("2006-01-02T15:04:05.000Z07:00"),
|
||||
TimeLocal: this.requestFromTime.Format("2/Jan/2006:15:04:05 -0700"),
|
||||
Msec: float64(this.requestFromTime.Unix()) + float64(this.requestFromTime.Nanosecond())/1000000000,
|
||||
Timestamp: this.requestFromTime.Unix(),
|
||||
Host: this.ReqHost,
|
||||
Referer: referer,
|
||||
UserAgent: userAgent,
|
||||
Request: this.requestString(),
|
||||
ContentType: this.writer.Header().Get("Content-Type"),
|
||||
Cookie: cookies,
|
||||
Args: queryString,
|
||||
QueryString: queryString,
|
||||
Header: pbReqHeader,
|
||||
ServerName: this.ServerName,
|
||||
ServerPort: int32(this.requestServerPort()),
|
||||
ServerProtocol: this.RawReq.Proto,
|
||||
SentHeader: pbResHeader,
|
||||
Errors: this.errors,
|
||||
Hostname: HOSTNAME,
|
||||
|
||||
FirewallPolicyId: this.firewallPolicyId,
|
||||
FirewallRuleGroupId: this.firewallRuleGroupId,
|
||||
FirewallRuleSetId: this.firewallRuleSetId,
|
||||
FirewallRuleId: this.firewallRuleId,
|
||||
FirewallActions: this.firewallActions,
|
||||
Tags: this.tags,
|
||||
|
||||
Attrs: this.logAttrs,
|
||||
}
|
||||
|
||||
if this.origin != nil {
|
||||
accessLog.OriginId = this.origin.Id
|
||||
accessLog.OriginAddress = this.originAddr
|
||||
accessLog.OriginStatus = this.originStatus
|
||||
}
|
||||
|
||||
// 请求Body
|
||||
if (ref != nil && ref.ContainsField(serverconfigs.HTTPAccessLogFieldRequestBody)) || this.wafHasRequestBody {
|
||||
accessLog.RequestBody = this.requestBodyData
|
||||
|
||||
if len(accessLog.RequestBody) > AccessLogMaxRequestBodySize {
|
||||
accessLog.RequestBody = accessLog.RequestBody[:AccessLogMaxRequestBodySize]
|
||||
}
|
||||
}
|
||||
|
||||
// TODO 记录匹配的 locationId和rewriteId,非必要需求
|
||||
|
||||
sharedHTTPAccessLogQueue.Push(accessLog)
|
||||
}
|
||||
48
EdgeNode/internal/nodes/http_request_metrics.go
Normal file
48
EdgeNode/internal/nodes/http_request_metrics.go
Normal file
@@ -0,0 +1,48 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/metrics"
|
||||
)
|
||||
|
||||
// 指标统计 - 响应
|
||||
// 只需要在结束时调用指标进行统计
|
||||
func (this *HTTPRequest) doMetricsResponse() {
|
||||
metrics.SharedManager.Add(this)
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) MetricKey(key string) string {
|
||||
return this.Format(key)
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) MetricValue(value string) (result int64, ok bool) {
|
||||
// TODO 需要忽略健康检查的请求,但是同时也要防止攻击者模拟健康检查
|
||||
switch value {
|
||||
case "${countRequest}":
|
||||
return 1, true
|
||||
case "${countTrafficOut}":
|
||||
// 这里不包括Header长度
|
||||
return this.writer.SentBodyBytes(), true
|
||||
case "${countTrafficIn}":
|
||||
var hl int64 = 0 // header length
|
||||
for k, values := range this.RawReq.Header {
|
||||
for _, v := range values {
|
||||
hl += int64(len(k) + len(v) + 2 /** k: v **/)
|
||||
}
|
||||
}
|
||||
return this.RawReq.ContentLength + hl, true
|
||||
case "${countConnection}":
|
||||
return 1, true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) MetricServerId() int64 {
|
||||
return this.ReqServer.Id
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) MetricCategory() string {
|
||||
return serverconfigs.MetricItemCategoryHTTP
|
||||
}
|
||||
116
EdgeNode/internal/nodes/http_request_mismatch.go
Normal file
116
EdgeNode/internal/nodes/http_request_mismatch.go
Normal file
@@ -0,0 +1,116 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeutils"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/ttlcache"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 域名无匹配情况处理
|
||||
func (this *HTTPRequest) doMismatch() {
|
||||
// 是否为健康检查
|
||||
var healthCheckKey = this.RawReq.Header.Get(serverconfigs.HealthCheckHeaderName)
|
||||
if len(healthCheckKey) > 0 {
|
||||
_, err := nodeutils.Base64DecodeMap(healthCheckKey)
|
||||
if err == nil {
|
||||
this.writer.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 是否已经在黑名单
|
||||
var remoteIP = this.RemoteAddr()
|
||||
if waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, remoteIP) {
|
||||
this.Close()
|
||||
return
|
||||
}
|
||||
|
||||
// 根据配置进行相应的处理
|
||||
var nodeConfig = sharedNodeConfig // copy
|
||||
if nodeConfig != nil {
|
||||
var globalServerConfig = nodeConfig.GlobalServerConfig
|
||||
if globalServerConfig != nil && globalServerConfig.HTTPAll.MatchDomainStrictly {
|
||||
var statusCode = 404
|
||||
var httpAllConfig = globalServerConfig.HTTPAll
|
||||
var mismatchAction = httpAllConfig.DomainMismatchAction
|
||||
|
||||
if mismatchAction != nil && mismatchAction.Options != nil {
|
||||
var mismatchStatusCode = mismatchAction.Options.GetInt("statusCode")
|
||||
if mismatchStatusCode > 0 && mismatchStatusCode >= 100 && mismatchStatusCode < 1000 {
|
||||
statusCode = mismatchStatusCode
|
||||
}
|
||||
}
|
||||
|
||||
// 是否正在访问IP
|
||||
if globalServerConfig.HTTPAll.NodeIPShowPage && utils.IsWildIP(this.ReqHost) {
|
||||
this.writer.statusCode = statusCode
|
||||
var contentHTML = this.Format(globalServerConfig.HTTPAll.NodeIPPageHTML)
|
||||
this.writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
this.writer.Header().Set("Content-Length", types.String(len(contentHTML)))
|
||||
this.writer.WriteHeader(statusCode)
|
||||
_, _ = this.writer.WriteString(contentHTML)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查cc
|
||||
// TODO 可以在管理端配置是否开启以及最多尝试次数
|
||||
// 要考虑到服务在切换集群时,域名未生效状态时,用户访问的仍然是老集群中的节点,就会产生找不到域名的情况
|
||||
if len(remoteIP) > 0 {
|
||||
const maxAttempts = 100
|
||||
if ttlcache.SharedInt64Cache.IncreaseInt64("MISMATCH_DOMAIN:"+remoteIP, int64(1), time.Now().Unix()+60, false) > maxAttempts {
|
||||
// 在加入之前再次检查黑名单
|
||||
if !waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, remoteIP) {
|
||||
waf.SharedIPBlackList.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, remoteIP, time.Now().Unix()+3600)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理当前连接
|
||||
if mismatchAction != nil {
|
||||
if mismatchAction.Code == serverconfigs.DomainMismatchActionPage {
|
||||
if mismatchAction.Options != nil {
|
||||
this.writer.statusCode = statusCode
|
||||
var contentHTML = this.Format(mismatchAction.Options.GetString("contentHTML"))
|
||||
this.writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
this.writer.Header().Set("Content-Length", types.String(len(contentHTML)))
|
||||
this.writer.WriteHeader(statusCode)
|
||||
_, _ = this.writer.Write([]byte(contentHTML))
|
||||
} else {
|
||||
http.Error(this.writer, "404 page not found: '"+this.URL()+"'", http.StatusNotFound)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if mismatchAction.Code == serverconfigs.DomainMismatchActionRedirect {
|
||||
var url = this.Format(mismatchAction.Options.GetString("url"))
|
||||
if len(url) > 0 {
|
||||
httpRedirect(this.writer, this.RawReq, url, http.StatusTemporaryRedirect)
|
||||
} else {
|
||||
http.Error(this.writer, "404 page not found: '"+this.URL()+"'", http.StatusNotFound)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if mismatchAction.Code == serverconfigs.DomainMismatchActionClose {
|
||||
http.Error(this.writer, "404 page not found: '"+this.URL()+"'", http.StatusNotFound)
|
||||
this.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
http.Error(this.writer, "404 page not found: '"+this.URL()+"'", http.StatusNotFound)
|
||||
this.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
http.Error(this.writer, "404 page not found: '"+this.URL()+"'", http.StatusNotFound)
|
||||
}
|
||||
15
EdgeNode/internal/nodes/http_request_oss.go
Normal file
15
EdgeNode/internal/nodes/http_request_oss.go
Normal file
@@ -0,0 +1,15 @@
|
||||
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
//go:build !plus
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func (this *HTTPRequest) doOSSOrigin(origin *serverconfigs.OriginConfig) (resp *http.Response, goNext bool, errorCode string, ossBucketName string, err error) {
|
||||
// stub
|
||||
return nil, false, "", "", errors.New("not implemented")
|
||||
}
|
||||
78
EdgeNode/internal/nodes/http_request_oss_plus.go
Normal file
78
EdgeNode/internal/nodes/http_request_oss_plus.go
Normal file
@@ -0,0 +1,78 @@
|
||||
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
//go:build plus
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/oss"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// 请求OSS源站
|
||||
func (this *HTTPRequest) doOSSOrigin(origin *serverconfigs.OriginConfig) (resp *http.Response, goNext bool, errorCode string, ossBucketName string, err error) {
|
||||
if origin == nil || origin.OSS == nil {
|
||||
err = errors.New("'origin' or 'origin.OSS' should not be nil")
|
||||
return
|
||||
}
|
||||
|
||||
// 只支持少数方法
|
||||
var isHeadRequest = this.RawReq.Method != http.MethodGet &&
|
||||
this.RawReq.Method != http.MethodPost &&
|
||||
this.RawReq.Method != http.MethodPut
|
||||
|
||||
var rangeBytes = this.RawReq.Header.Get("Range")
|
||||
if isHeadRequest && len(rangeBytes) == 0 {
|
||||
resp, errorCode, ossBucketName, err = oss.SharedManager.Head(this.RawReq, this.ReqHost, origin.OSS)
|
||||
} else {
|
||||
if len(rangeBytes) > 0 {
|
||||
resp, errorCode, ossBucketName, err = oss.SharedManager.GetRange(this.RawReq, this.ReqHost, rangeBytes, origin.OSS)
|
||||
} else {
|
||||
resp, errorCode, ossBucketName, err = oss.SharedManager.Get(this.RawReq, this.ReqHost, origin.OSS)
|
||||
}
|
||||
}
|
||||
|
||||
if len(ossBucketName) == 0 {
|
||||
this.originAddr = origin.OSS.Type
|
||||
} else {
|
||||
this.originAddr = origin.OSS.Type + "/" + ossBucketName
|
||||
}
|
||||
this.tags = append(this.tags, "oss")
|
||||
|
||||
if err != nil {
|
||||
if oss.IsNotFound(err) {
|
||||
this.write404()
|
||||
return nil, false, errorCode, ossBucketName, nil
|
||||
}
|
||||
|
||||
if oss.IsTimeout(err) {
|
||||
this.writeCode(http.StatusGatewayTimeout, "Read object timeout.", "读取对象超时。")
|
||||
return nil, false, errorCode, ossBucketName, nil
|
||||
}
|
||||
|
||||
return nil, false, errorCode, ossBucketName, fmt.Errorf("OSS: [%s]: %s: %w", origin.OSS.Type, errorCode, err)
|
||||
}
|
||||
|
||||
if isHeadRequest {
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(&bytes.Buffer{})
|
||||
}
|
||||
|
||||
// fix Content-Length
|
||||
if resp.Header == nil {
|
||||
resp.Header = http.Header{}
|
||||
}
|
||||
if resp.ContentLength > 0 {
|
||||
_, ok := resp.Header["Content-Length"]
|
||||
if !ok {
|
||||
resp.Header.Set("Content-Length", types.String(resp.ContentLength))
|
||||
}
|
||||
}
|
||||
|
||||
return resp, true, "", ossBucketName, nil
|
||||
}
|
||||
170
EdgeNode/internal/nodes/http_request_page.go
Normal file
170
EdgeNode/internal/nodes/http_request_page.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/bytepool"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const defaultPageContentType = "text/html; charset=utf-8"
|
||||
|
||||
// 请求特殊页面
|
||||
func (this *HTTPRequest) doPage(status int) (shouldStop bool) {
|
||||
if len(this.web.Pages) == 0 {
|
||||
// 集群自定义页面
|
||||
if this.nodeConfig != nil && this.ReqServer != nil && this.web.EnableGlobalPages {
|
||||
var httpPagesPolicy = this.nodeConfig.FindHTTPPagesPolicyWithClusterId(this.ReqServer.ClusterId)
|
||||
if httpPagesPolicy != nil && httpPagesPolicy.IsOn && len(httpPagesPolicy.Pages) > 0 {
|
||||
return this.doPageLookup(httpPagesPolicy.Pages, status)
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// 查找当前网站自定义页面
|
||||
shouldStop = this.doPageLookup(this.web.Pages, status)
|
||||
if shouldStop {
|
||||
return
|
||||
}
|
||||
|
||||
// 集群自定义页面
|
||||
if this.nodeConfig != nil && this.ReqServer != nil && this.web.EnableGlobalPages {
|
||||
var httpPagesPolicy = this.nodeConfig.FindHTTPPagesPolicyWithClusterId(this.ReqServer.ClusterId)
|
||||
if httpPagesPolicy != nil && httpPagesPolicy.IsOn && len(httpPagesPolicy.Pages) > 0 {
|
||||
return this.doPageLookup(httpPagesPolicy.Pages, status)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) doPageLookup(pages []*serverconfigs.HTTPPageConfig, status int) (shouldStop bool) {
|
||||
var url = this.URL()
|
||||
for _, page := range pages {
|
||||
if !page.MatchURL(url) {
|
||||
continue
|
||||
}
|
||||
if page.Match(status) {
|
||||
if len(page.BodyType) == 0 || page.BodyType == serverconfigs.HTTPPageBodyTypeURL {
|
||||
if urlSchemeRegexp.MatchString(page.URL) {
|
||||
var newStatus = page.NewStatus
|
||||
if newStatus <= 0 {
|
||||
newStatus = status
|
||||
}
|
||||
this.doURL(http.MethodGet, page.URL, "", newStatus, true)
|
||||
return true
|
||||
} else {
|
||||
var realpath = path.Clean(page.URL)
|
||||
if !strings.HasPrefix(realpath, "/pages/") && !strings.HasPrefix(realpath, "pages/") { // only files under "/pages/" can be used
|
||||
var msg = "404 page not found: '" + page.URL + "'"
|
||||
this.writer.Header().Set("Content-Type", defaultPageContentType)
|
||||
this.writer.WriteHeader(http.StatusNotFound)
|
||||
_, _ = this.writer.Write([]byte(msg))
|
||||
return true
|
||||
}
|
||||
var file = Tea.Root + Tea.DS + realpath
|
||||
fp, err := os.Open(file)
|
||||
if err != nil {
|
||||
var msg = "404 page not found: '" + page.URL + "'"
|
||||
this.writer.Header().Set("Content-Type", defaultPageContentType)
|
||||
this.writer.WriteHeader(http.StatusNotFound)
|
||||
_, _ = this.writer.Write([]byte(msg))
|
||||
return true
|
||||
}
|
||||
defer func() {
|
||||
_ = fp.Close()
|
||||
}()
|
||||
|
||||
stat, err := fp.Stat()
|
||||
if err != nil {
|
||||
var msg = "404 could not read page content: '" + page.URL + "'"
|
||||
this.writer.Header().Set("Content-Type", defaultPageContentType)
|
||||
this.writer.WriteHeader(http.StatusNotFound)
|
||||
_, _ = this.writer.Write([]byte(msg))
|
||||
return true
|
||||
}
|
||||
|
||||
// 修改状态码
|
||||
if page.NewStatus > 0 {
|
||||
// 自定义响应Headers
|
||||
this.writer.Header().Set("Content-Type", defaultPageContentType)
|
||||
this.ProcessResponseHeaders(this.writer.Header(), page.NewStatus)
|
||||
this.writer.Prepare(nil, stat.Size(), page.NewStatus, true)
|
||||
this.writer.WriteHeader(page.NewStatus)
|
||||
} else {
|
||||
this.writer.Header().Set("Content-Type", defaultPageContentType)
|
||||
this.ProcessResponseHeaders(this.writer.Header(), status)
|
||||
this.writer.Prepare(nil, stat.Size(), status, true)
|
||||
this.writer.WriteHeader(status)
|
||||
}
|
||||
var buf = bytepool.Pool1k.Get()
|
||||
_, err = utils.CopyWithFilter(this.writer, fp, buf.Bytes, func(p []byte) []byte {
|
||||
return []byte(this.Format(string(p)))
|
||||
})
|
||||
bytepool.Pool1k.Put(buf)
|
||||
if err != nil {
|
||||
if !this.canIgnore(err) {
|
||||
remotelogs.Warn("HTTP_REQUEST_PAGE", "write to client failed: "+err.Error())
|
||||
}
|
||||
} else {
|
||||
this.writer.SetOk()
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
} else if page.BodyType == serverconfigs.HTTPPageBodyTypeHTML {
|
||||
// 这里需要实现设置Status,因为在Format()中可以获取${status}等变量
|
||||
if page.NewStatus > 0 {
|
||||
this.writer.statusCode = page.NewStatus
|
||||
} else {
|
||||
this.writer.statusCode = status
|
||||
}
|
||||
var content = this.Format(page.Body)
|
||||
|
||||
// 修改状态码
|
||||
if page.NewStatus > 0 {
|
||||
// 自定义响应Headers
|
||||
this.writer.Header().Set("Content-Type", defaultPageContentType)
|
||||
this.ProcessResponseHeaders(this.writer.Header(), page.NewStatus)
|
||||
this.writer.Prepare(nil, int64(len(content)), page.NewStatus, true)
|
||||
this.writer.WriteHeader(page.NewStatus)
|
||||
} else {
|
||||
this.writer.Header().Set("Content-Type", defaultPageContentType)
|
||||
this.ProcessResponseHeaders(this.writer.Header(), status)
|
||||
this.writer.Prepare(nil, int64(len(content)), status, true)
|
||||
this.writer.WriteHeader(status)
|
||||
}
|
||||
|
||||
_, err := this.writer.WriteString(content)
|
||||
if err != nil {
|
||||
if !this.canIgnore(err) {
|
||||
remotelogs.Warn("HTTP_REQUEST_PAGE", "write to client failed: "+err.Error())
|
||||
}
|
||||
} else {
|
||||
this.writer.SetOk()
|
||||
}
|
||||
return true
|
||||
} else if page.BodyType == serverconfigs.HTTPPageBodyTypeRedirectURL {
|
||||
var newURL = this.Format(page.URL)
|
||||
if len(newURL) == 0 {
|
||||
newURL = "/"
|
||||
}
|
||||
if page.NewStatus > 0 && httpStatusIsRedirect(page.NewStatus) {
|
||||
httpRedirect(this.writer, this.RawReq, newURL, page.NewStatus)
|
||||
} else {
|
||||
httpRedirect(this.writer, this.RawReq, newURL, http.StatusTemporaryRedirect)
|
||||
}
|
||||
this.writer.SetOk()
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
10
EdgeNode/internal/nodes/http_request_plan_before.go
Normal file
10
EdgeNode/internal/nodes/http_request_plan_before.go
Normal file
@@ -0,0 +1,10 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build !plus
|
||||
|
||||
package nodes
|
||||
|
||||
// 检查套餐
|
||||
func (this *HTTPRequest) doPlanBefore() (blocked bool) {
|
||||
// stub
|
||||
return false
|
||||
}
|
||||
38
EdgeNode/internal/nodes/http_request_plan_before_plus.go
Normal file
38
EdgeNode/internal/nodes/http_request_plan_before_plus.go
Normal file
@@ -0,0 +1,38 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build plus
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// 检查套餐
|
||||
func (this *HTTPRequest) doPlanBefore() (blocked bool) {
|
||||
// check date
|
||||
if !this.ReqServer.UserPlan.IsAvailable() {
|
||||
this.tags = append(this.tags, "plan")
|
||||
|
||||
var statusCode = http.StatusNotFound
|
||||
this.ProcessResponseHeaders(this.writer.Header(), statusCode)
|
||||
|
||||
this.writer.WriteHeader(statusCode)
|
||||
_, _ = this.writer.WriteString(this.Format(serverconfigs.DefaultPlanExpireNoticePageBody))
|
||||
return true
|
||||
}
|
||||
|
||||
// check max upload size
|
||||
if this.RawReq.ContentLength > 0 {
|
||||
var plan = sharedNodeConfig.FindPlan(this.ReqServer.UserPlan.PlanId)
|
||||
if plan != nil && plan.MaxUploadSize != nil && plan.MaxUploadSize.Count > 0 {
|
||||
if this.RawReq.ContentLength > plan.MaxUploadSize.Bytes() {
|
||||
this.writeCode(http.StatusRequestEntityTooLarge, "Reached max upload size limitation in your plan.", "触发套餐中最大文件上传尺寸限制。")
|
||||
this.tags = append(this.tags, "plan")
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
49
EdgeNode/internal/nodes/http_request_redirect_https.go
Normal file
49
EdgeNode/internal/nodes/http_request_redirect_https.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (this *HTTPRequest) doRedirectToHTTPS(redirectToHTTPSConfig *serverconfigs.HTTPRedirectToHTTPSConfig) (shouldBreak bool) {
|
||||
var host = this.RawReq.Host
|
||||
|
||||
// 检查域名是否匹配
|
||||
if !redirectToHTTPSConfig.MatchDomain(host) {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(redirectToHTTPSConfig.Host) > 0 {
|
||||
if redirectToHTTPSConfig.Port > 0 && redirectToHTTPSConfig.Port != 443 {
|
||||
host = redirectToHTTPSConfig.Host + ":" + strconv.Itoa(redirectToHTTPSConfig.Port)
|
||||
} else {
|
||||
host = redirectToHTTPSConfig.Host
|
||||
}
|
||||
} else if redirectToHTTPSConfig.Port > 0 {
|
||||
var lastIndex = strings.LastIndex(host, ":")
|
||||
if lastIndex > 0 {
|
||||
host = host[:lastIndex]
|
||||
}
|
||||
if redirectToHTTPSConfig.Port != 443 {
|
||||
host = host + ":" + strconv.Itoa(redirectToHTTPSConfig.Port)
|
||||
}
|
||||
} else {
|
||||
var lastIndex = strings.LastIndex(host, ":")
|
||||
if lastIndex > 0 {
|
||||
host = host[:lastIndex]
|
||||
}
|
||||
}
|
||||
|
||||
var statusCode = http.StatusMovedPermanently
|
||||
if redirectToHTTPSConfig.Status > 0 {
|
||||
statusCode = redirectToHTTPSConfig.Status
|
||||
}
|
||||
|
||||
var newURL = "https://" + host + this.RawReq.RequestURI
|
||||
this.ProcessResponseHeaders(this.writer.Header(), statusCode)
|
||||
httpRedirect(this.writer, this.RawReq, newURL, statusCode)
|
||||
|
||||
return true
|
||||
}
|
||||
78
EdgeNode/internal/nodes/http_request_referers.go
Normal file
78
EdgeNode/internal/nodes/http_request_referers.go
Normal file
@@ -0,0 +1,78 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
func (this *HTTPRequest) doCheckReferers() (shouldStop bool) {
|
||||
if this.web.Referers == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查URL
|
||||
if !this.web.Referers.MatchURL(this.URL()) {
|
||||
return
|
||||
}
|
||||
|
||||
var origin = this.RawReq.Header.Get("Origin")
|
||||
|
||||
const cacheSeconds = "3600" // 时间不能过长,防止修改设置后长期无法生效
|
||||
|
||||
// 处理用到Origin的特殊功能
|
||||
if this.web.Referers.CheckOrigin && len(origin) > 0 {
|
||||
// 处理Websocket
|
||||
if this.web.Websocket != nil && this.web.Websocket.IsOn && this.RawReq.Header.Get("Upgrade") == "websocket" {
|
||||
originHost, _ := httpParseHost(origin)
|
||||
if len(originHost) > 0 && this.web.Websocket.MatchOrigin(originHost) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var refererURL = this.RawReq.Header.Get("Referer")
|
||||
if len(refererURL) == 0 && this.web.Referers.CheckOrigin {
|
||||
if len(origin) > 0 && origin != "null" {
|
||||
if urlSchemeRegexp.MatchString(origin) {
|
||||
refererURL = origin
|
||||
} else {
|
||||
refererURL = "https://" + origin
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(refererURL) == 0 {
|
||||
if this.web.Referers.MatchDomain(this.ReqHost, "") {
|
||||
return
|
||||
}
|
||||
|
||||
this.tags = append(this.tags, "refererCheck")
|
||||
this.writer.Header().Set("Cache-Control", "max-age="+cacheSeconds)
|
||||
this.writeCode(http.StatusForbidden, "The referer has been blocked.", "当前访问已被防盗链系统拦截。")
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
u, err := url.Parse(refererURL)
|
||||
if err != nil {
|
||||
if this.web.Referers.MatchDomain(this.ReqHost, "") {
|
||||
return
|
||||
}
|
||||
|
||||
this.tags = append(this.tags, "refererCheck")
|
||||
this.writer.Header().Set("Cache-Control", "max-age="+cacheSeconds)
|
||||
this.writeCode(http.StatusForbidden, "The referer has been blocked.", "当前访问已被防盗链系统拦截。")
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
if !this.web.Referers.MatchDomain(this.ReqHost, u.Host) {
|
||||
this.tags = append(this.tags, "refererCheck")
|
||||
this.writer.Header().Set("Cache-Control", "max-age="+cacheSeconds)
|
||||
this.writeCode(http.StatusForbidden, "The referer has been blocked.", "当前访问已被防盗链系统拦截。")
|
||||
return true
|
||||
}
|
||||
return
|
||||
}
|
||||
684
EdgeNode/internal/nodes/http_request_reverse_proxy.go
Normal file
684
EdgeNode/internal/nodes/http_request_reverse_proxy.go
Normal file
@@ -0,0 +1,684 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/compressions"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/bytepool"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fnv"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/minifiers"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// 处理反向代理
|
||||
// writeToClient 读取响应并发送到客户端
|
||||
func (this *HTTPRequest) doReverseProxy(writeToClient bool) (resultResp *http.Response) {
|
||||
if this.reverseProxy == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var retries = 3
|
||||
|
||||
var failedOriginIds []int64
|
||||
var failedLnNodeIds []int64
|
||||
var failStatusCode int
|
||||
|
||||
for i := 0; i < retries; i++ {
|
||||
originId, lnNodeId, shouldRetry, resp := this.doOriginRequest(failedOriginIds, failedLnNodeIds, i == 0, i == retries-1, &failStatusCode, writeToClient)
|
||||
if !shouldRetry {
|
||||
resultResp = resp
|
||||
break
|
||||
}
|
||||
if originId > 0 {
|
||||
failedOriginIds = append(failedOriginIds, originId)
|
||||
}
|
||||
if lnNodeId > 0 {
|
||||
failedLnNodeIds = append(failedLnNodeIds, lnNodeId)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// 请求源站
|
||||
func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeIds []int64, isFirstTry bool, isLastRetry bool, failStatusCode *int, writeToClient bool) (originId int64, lnNodeId int64, shouldRetry bool, resultResp *http.Response) {
|
||||
// 对URL的处理
|
||||
var stripPrefix = this.reverseProxy.StripPrefix
|
||||
var requestURI = this.reverseProxy.RequestURI
|
||||
var requestURIHasVariables = this.reverseProxy.RequestURIHasVariables()
|
||||
var oldURI = this.uri
|
||||
|
||||
var requestHost = ""
|
||||
if this.reverseProxy.RequestHostType == serverconfigs.RequestHostTypeCustomized {
|
||||
requestHost = this.reverseProxy.RequestHost
|
||||
}
|
||||
var requestHostHasVariables = this.reverseProxy.RequestHostHasVariables()
|
||||
|
||||
// 源站
|
||||
var requestCall = shared.NewRequestCall()
|
||||
requestCall.Request = this.RawReq
|
||||
requestCall.Formatter = this.Format
|
||||
requestCall.Domain = this.ReqHost
|
||||
|
||||
var origin *serverconfigs.OriginConfig
|
||||
|
||||
// 二级节点
|
||||
var hasMultipleLnNodes = false
|
||||
if this.cacheRef != nil || (this.nodeConfig != nil && this.nodeConfig.GlobalServerConfig != nil && this.nodeConfig.GlobalServerConfig.HTTPAll.ForceLnRequest) {
|
||||
origin, lnNodeId, hasMultipleLnNodes = this.getLnOrigin(failedLnNodeIds, fnv.HashString(this.URL()))
|
||||
if origin != nil {
|
||||
// 强制变更原来访问的域名
|
||||
requestHost = this.ReqHost
|
||||
}
|
||||
|
||||
if this.cacheRef != nil {
|
||||
// 回源Header中去除If-None-Match和If-Modified-Since
|
||||
if !this.cacheRef.EnableIfNoneMatch {
|
||||
this.DeleteHeader("If-None-Match")
|
||||
}
|
||||
if !this.cacheRef.EnableIfModifiedSince {
|
||||
this.DeleteHeader("If-Modified-Since")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 自定义源站
|
||||
if origin == nil {
|
||||
if !isFirstTry {
|
||||
origin = this.reverseProxy.AnyOrigin(requestCall, failedOriginIds)
|
||||
}
|
||||
if origin == nil {
|
||||
origin = this.reverseProxy.NextOrigin(requestCall)
|
||||
if origin != nil && origin.Id > 0 && (*failStatusCode >= 403 && *failStatusCode <= 404) && lists.ContainsInt64(failedOriginIds, origin.Id) {
|
||||
shouldRetry = false
|
||||
isLastRetry = true
|
||||
}
|
||||
}
|
||||
requestCall.CallResponseCallbacks(this.writer)
|
||||
if origin == nil {
|
||||
err := errors.New(this.URL() + ": no available origin sites for reverse proxy")
|
||||
remotelogs.ServerError(this.ReqServer.Id, "HTTP_REQUEST_REVERSE_PROXY", err.Error(), "", nil)
|
||||
this.write50x(err, http.StatusBadGateway, "No origin site yet", "尚未配置源站", true)
|
||||
return
|
||||
}
|
||||
originId = origin.Id
|
||||
|
||||
if len(origin.StripPrefix) > 0 {
|
||||
stripPrefix = origin.StripPrefix
|
||||
}
|
||||
if len(origin.RequestURI) > 0 {
|
||||
requestURI = origin.RequestURI
|
||||
requestURIHasVariables = origin.RequestURIHasVariables()
|
||||
}
|
||||
}
|
||||
|
||||
this.origin = origin // 设置全局变量是为了日志等处理
|
||||
|
||||
if len(origin.RequestHost) > 0 {
|
||||
requestHost = origin.RequestHost
|
||||
requestHostHasVariables = origin.RequestHostHasVariables()
|
||||
}
|
||||
|
||||
// 处理OSS
|
||||
var isHTTPOrigin = origin.OSS == nil
|
||||
|
||||
// 处理Scheme
|
||||
if isHTTPOrigin && origin.Addr == nil {
|
||||
err := errors.New(this.URL() + ": Origin '" + strconv.FormatInt(origin.Id, 10) + "' does not has a address")
|
||||
remotelogs.ErrorServer("HTTP_REQUEST_REVERSE_PROXY", err.Error())
|
||||
this.write50x(err, http.StatusBadGateway, "Origin site did not has a valid address", "源站尚未配置地址", true)
|
||||
return
|
||||
}
|
||||
|
||||
if isHTTPOrigin {
|
||||
this.RawReq.URL.Scheme = origin.Addr.Protocol.Primary().Scheme()
|
||||
}
|
||||
|
||||
// StripPrefix
|
||||
if len(stripPrefix) > 0 {
|
||||
if stripPrefix[0] != '/' {
|
||||
stripPrefix = "/" + stripPrefix
|
||||
}
|
||||
this.uri = strings.TrimPrefix(this.uri, stripPrefix)
|
||||
if len(this.uri) == 0 || this.uri[0] != '/' {
|
||||
this.uri = "/" + this.uri
|
||||
}
|
||||
}
|
||||
|
||||
// RequestURI
|
||||
if len(requestURI) > 0 {
|
||||
if requestURIHasVariables {
|
||||
this.uri = this.Format(requestURI)
|
||||
} else {
|
||||
this.uri = requestURI
|
||||
}
|
||||
if len(this.uri) == 0 || this.uri[0] != '/' {
|
||||
this.uri = "/" + this.uri
|
||||
}
|
||||
|
||||
// 处理RequestURI中的问号
|
||||
var questionMark = strings.LastIndex(this.uri, "?")
|
||||
if questionMark > 0 {
|
||||
var path = this.uri[:questionMark]
|
||||
if strings.Contains(path, "?") {
|
||||
this.uri = path + "&" + this.uri[questionMark+1:]
|
||||
}
|
||||
}
|
||||
|
||||
// 去除多个/
|
||||
this.uri = utils.CleanPath(this.uri)
|
||||
}
|
||||
|
||||
var originAddr = ""
|
||||
if isHTTPOrigin {
|
||||
// 获取源站地址
|
||||
originAddr = origin.Addr.PickAddress()
|
||||
if origin.Addr.HostHasVariables() {
|
||||
originAddr = this.Format(originAddr)
|
||||
}
|
||||
|
||||
// 端口跟随
|
||||
if origin.FollowPort {
|
||||
var originHostIndex = strings.Index(originAddr, ":")
|
||||
if originHostIndex < 0 {
|
||||
var originErr = errors.New(this.URL() + ": Invalid origin address '" + originAddr + "', lacking port")
|
||||
remotelogs.ErrorServer("HTTP_REQUEST_REVERSE_PROXY", originErr.Error())
|
||||
this.write50x(originErr, http.StatusBadGateway, "No port in origin site address", "源站地址中没有配置端口", true)
|
||||
return
|
||||
}
|
||||
originAddr = originAddr[:originHostIndex+1] + types.String(this.requestServerPort())
|
||||
}
|
||||
this.originAddr = originAddr
|
||||
|
||||
// RequestHost
|
||||
if len(requestHost) > 0 {
|
||||
if requestHostHasVariables {
|
||||
this.RawReq.Host = this.Format(requestHost)
|
||||
} else {
|
||||
this.RawReq.Host = requestHost
|
||||
}
|
||||
|
||||
// 是否移除端口
|
||||
if this.reverseProxy.RequestHostExcludingPort {
|
||||
this.RawReq.Host = utils.ParseAddrHost(this.RawReq.Host)
|
||||
}
|
||||
|
||||
this.RawReq.URL.Host = this.RawReq.Host
|
||||
} else if this.reverseProxy.RequestHostType == serverconfigs.RequestHostTypeOrigin {
|
||||
// 源站主机名
|
||||
var hostname = originAddr
|
||||
if origin.Addr.Protocol.IsHTTPFamily() {
|
||||
hostname = strings.TrimSuffix(hostname, ":80")
|
||||
} else if origin.Addr.Protocol.IsHTTPSFamily() {
|
||||
hostname = strings.TrimSuffix(hostname, ":443")
|
||||
}
|
||||
|
||||
this.RawReq.Host = hostname
|
||||
|
||||
// 是否移除端口
|
||||
if this.reverseProxy.RequestHostExcludingPort {
|
||||
this.RawReq.Host = utils.ParseAddrHost(this.RawReq.Host)
|
||||
}
|
||||
|
||||
this.RawReq.URL.Host = this.RawReq.Host
|
||||
} else {
|
||||
this.RawReq.URL.Host = this.ReqHost
|
||||
|
||||
// 是否移除端口
|
||||
if this.reverseProxy.RequestHostExcludingPort {
|
||||
this.RawReq.Host = utils.ParseAddrHost(this.RawReq.Host)
|
||||
this.RawReq.URL.Host = utils.ParseAddrHost(this.RawReq.URL.Host)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 重组请求URL
|
||||
var questionMark = strings.Index(this.uri, "?")
|
||||
if questionMark > -1 {
|
||||
this.RawReq.URL.Path = this.uri[:questionMark]
|
||||
this.RawReq.URL.RawQuery = this.uri[questionMark+1:]
|
||||
} else {
|
||||
this.RawReq.URL.Path = this.uri
|
||||
this.RawReq.URL.RawQuery = ""
|
||||
}
|
||||
this.RawReq.RequestURI = ""
|
||||
|
||||
// 处理Header
|
||||
this.setForwardHeaders(this.RawReq.Header)
|
||||
this.processRequestHeaders(this.RawReq.Header)
|
||||
|
||||
// 调用回调
|
||||
this.onRequest()
|
||||
if this.writer.isFinished {
|
||||
return
|
||||
}
|
||||
|
||||
// 判断是否为Websocket请求
|
||||
if isHTTPOrigin && this.RawReq.Header.Get("Upgrade") == "websocket" {
|
||||
shouldRetry = this.doWebsocket(requestHost, isLastRetry)
|
||||
return
|
||||
}
|
||||
|
||||
var resp *http.Response
|
||||
var respBodyIsClosed bool
|
||||
var requestErr error
|
||||
var requestErrCode string
|
||||
if isHTTPOrigin { // 普通HTTP(S)源站
|
||||
// 修复空User-Agent问题
|
||||
_, existsUserAgent := this.RawReq.Header["User-Agent"]
|
||||
if !existsUserAgent {
|
||||
this.RawReq.Header["User-Agent"] = []string{""}
|
||||
}
|
||||
|
||||
// 获取请求客户端
|
||||
client, err := SharedHTTPClientPool.Client(this, origin, originAddr, this.reverseProxy.ProxyProtocol, this.reverseProxy.FollowRedirects)
|
||||
if err != nil {
|
||||
remotelogs.ErrorServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Create client failed: "+err.Error())
|
||||
this.write50x(err, http.StatusBadGateway, "Failed to create origin site client", "构造源站客户端失败", true)
|
||||
return
|
||||
}
|
||||
|
||||
// 尝试自动纠正源站地址中的scheme
|
||||
if this.RawReq.URL.Scheme == "http" && strings.HasSuffix(originAddr, ":443") {
|
||||
this.RawReq.URL.Scheme = "https"
|
||||
} else if this.RawReq.URL.Scheme == "https" && strings.HasSuffix(originAddr, ":80") {
|
||||
this.RawReq.URL.Scheme = "http"
|
||||
}
|
||||
|
||||
// request origin with Accept-Encoding: gzip, ...
|
||||
var rawAcceptEncoding string
|
||||
var acceptEncodingChanged bool
|
||||
if this.nodeConfig != nil &&
|
||||
this.nodeConfig.GlobalServerConfig != nil &&
|
||||
this.nodeConfig.GlobalServerConfig.HTTPAll.RequestOriginsWithEncodings &&
|
||||
this.RawReq.ProtoAtLeast(1, 1) &&
|
||||
this.RawReq.Header != nil {
|
||||
rawAcceptEncoding = this.RawReq.Header.Get("Accept-Encoding")
|
||||
if len(rawAcceptEncoding) == 0 {
|
||||
this.RawReq.Header.Set("Accept-Encoding", "gzip")
|
||||
acceptEncodingChanged = true
|
||||
} else if strings.Index(rawAcceptEncoding, "gzip") < 0 {
|
||||
this.RawReq.Header.Set("Accept-Encoding", rawAcceptEncoding+", gzip")
|
||||
acceptEncodingChanged = true
|
||||
}
|
||||
}
|
||||
|
||||
// 开始请求
|
||||
resp, requestErr = client.Do(this.RawReq)
|
||||
|
||||
// recover Accept-Encoding
|
||||
if acceptEncodingChanged {
|
||||
if len(rawAcceptEncoding) > 0 {
|
||||
this.RawReq.Header.Set("Accept-Encoding", rawAcceptEncoding)
|
||||
} else {
|
||||
this.RawReq.Header.Del("Accept-Encoding")
|
||||
}
|
||||
|
||||
if resp != nil && resp.Header != nil && resp.Header.Get("Content-Encoding") == "gzip" {
|
||||
bodyReader, gzipErr := compressions.NewGzipReader(resp.Body)
|
||||
if gzipErr == nil {
|
||||
resp.Body = bodyReader
|
||||
}
|
||||
resp.TransferEncoding = nil
|
||||
resp.Header.Del("Content-Encoding")
|
||||
}
|
||||
}
|
||||
} else if origin.OSS != nil { // OSS源站
|
||||
var goNext bool
|
||||
resp, goNext, requestErrCode, _, requestErr = this.doOSSOrigin(origin)
|
||||
if requestErr == nil {
|
||||
if resp == nil || !goNext {
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
this.writeCode(http.StatusBadGateway, "The type of origin site has not been supported", "设置的源站类型尚未支持")
|
||||
return
|
||||
}
|
||||
|
||||
if resp != nil && resp.Body != nil {
|
||||
defer func() {
|
||||
if !respBodyIsClosed {
|
||||
if writeToClient {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if requestErr != nil {
|
||||
// 客户端取消请求,则不提示
|
||||
var httpErr *url.Error
|
||||
var ok = errors.As(requestErr, &httpErr)
|
||||
if !ok {
|
||||
if isHTTPOrigin {
|
||||
SharedOriginStateManager.Fail(origin, requestHost, this.reverseProxy, func() {
|
||||
this.reverseProxy.ResetScheduling()
|
||||
})
|
||||
}
|
||||
|
||||
if len(requestErrCode) > 0 {
|
||||
this.write50x(requestErr, http.StatusBadGateway, "Failed to read origin site (error code: "+requestErrCode+")", "源站读取失败(错误代号:"+requestErrCode+")", true)
|
||||
} else {
|
||||
this.write50x(requestErr, http.StatusBadGateway, "Failed to read origin site", "源站读取失败", true)
|
||||
}
|
||||
remotelogs.WarnServer("HTTP_REQUEST_REVERSE_PROXY", this.RawReq.URL.String()+": Request origin server failed: "+requestErr.Error())
|
||||
} else if !errors.Is(httpErr, context.Canceled) {
|
||||
if isHTTPOrigin {
|
||||
SharedOriginStateManager.Fail(origin, requestHost, this.reverseProxy, func() {
|
||||
this.reverseProxy.ResetScheduling()
|
||||
})
|
||||
}
|
||||
|
||||
// 是否需要重试
|
||||
if (originId > 0 || (lnNodeId > 0 && hasMultipleLnNodes)) && !isLastRetry {
|
||||
shouldRetry = true
|
||||
this.uri = oldURI // 恢复备份
|
||||
|
||||
if httpErr.Err != io.EOF && !errors.Is(httpErr.Err, http.ErrBodyReadAfterClose) {
|
||||
remotelogs.WarnServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Request origin server failed: "+requestErr.Error())
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if httpErr.Timeout() {
|
||||
this.write50x(requestErr, http.StatusGatewayTimeout, "Read origin site timeout", "源站读取超时", true)
|
||||
} else if httpErr.Temporary() {
|
||||
this.write50x(requestErr, http.StatusServiceUnavailable, "Origin site unavailable now", "源站当前不可用", true)
|
||||
} else {
|
||||
this.write50x(requestErr, http.StatusBadGateway, "Failed to read origin site", "源站读取失败", true)
|
||||
}
|
||||
if httpErr.Err != io.EOF && !errors.Is(httpErr.Err, http.ErrBodyReadAfterClose) {
|
||||
remotelogs.WarnServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Request origin server failed: "+requestErr.Error())
|
||||
}
|
||||
} else {
|
||||
// 是否为客户端方面的错误
|
||||
var isClientError = false
|
||||
if errors.Is(httpErr, context.Canceled) {
|
||||
// 如果是服务器端主动关闭,则无需提示
|
||||
if this.isConnClosed() {
|
||||
this.disableLog = true
|
||||
return
|
||||
}
|
||||
|
||||
isClientError = true
|
||||
this.addError(errors.New(httpErr.Op + " " + httpErr.URL + ": client closed the connection"))
|
||||
this.writer.WriteHeader(499) // 仿照nginx
|
||||
}
|
||||
|
||||
if !isClientError {
|
||||
this.write50x(requestErr, http.StatusBadGateway, "Failed to read origin site", "源站读取失败", true)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if resp == nil {
|
||||
this.write50x(requestErr, http.StatusBadGateway, "Failed to read origin site", "源站读取失败", true)
|
||||
return
|
||||
}
|
||||
|
||||
if !writeToClient {
|
||||
resultResp = resp
|
||||
return
|
||||
}
|
||||
|
||||
// fix Content-Type
|
||||
if resp.Header["Content-Type"] == nil {
|
||||
resp.Header["Content-Type"] = []string{}
|
||||
}
|
||||
|
||||
// 40x && 50x
|
||||
*failStatusCode = resp.StatusCode
|
||||
if ((resp.StatusCode >= 500 && resp.StatusCode < 510 && this.reverseProxy.Retry50X) ||
|
||||
(resp.StatusCode >= 403 && resp.StatusCode <= 404 && this.reverseProxy.Retry40X)) &&
|
||||
(originId > 0 || (lnNodeId > 0 && hasMultipleLnNodes)) &&
|
||||
!isLastRetry {
|
||||
if resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
|
||||
shouldRetry = true
|
||||
return
|
||||
}
|
||||
|
||||
// 尝试从缓存中恢复
|
||||
if resp.StatusCode >= 500 && // support 50X only
|
||||
resp.StatusCode < 510 &&
|
||||
this.cacheCanTryStale &&
|
||||
this.web.Cache.Stale != nil &&
|
||||
this.web.Cache.Stale.IsOn &&
|
||||
(len(this.web.Cache.Stale.Status) == 0 || lists.ContainsInt(this.web.Cache.Stale.Status, resp.StatusCode)) {
|
||||
var ok = this.doCacheRead(true)
|
||||
if ok {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 记录相关数据
|
||||
this.originStatus = int32(resp.StatusCode)
|
||||
|
||||
// 恢复源站状态
|
||||
if !origin.IsOk && isHTTPOrigin {
|
||||
SharedOriginStateManager.Success(origin, func() {
|
||||
this.reverseProxy.ResetScheduling()
|
||||
})
|
||||
}
|
||||
|
||||
// WAF对出站进行检查
|
||||
if this.web.FirewallRef != nil && this.web.FirewallRef.IsOn {
|
||||
if this.doWAFResponse(resp) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 特殊页面
|
||||
if this.doPage(resp.StatusCode) {
|
||||
return
|
||||
}
|
||||
|
||||
// Page encryption (必须在页面优化之前)
|
||||
if this.web.Encryption != nil && resp.Body != nil {
|
||||
err := this.processPageEncryption(resp)
|
||||
if err != nil {
|
||||
remotelogs.Warn("HTTP_REQUEST_ENCRYPTION", "encrypt page failed: "+err.Error())
|
||||
// 加密失败不影响正常响应,继续处理
|
||||
}
|
||||
}
|
||||
|
||||
// Page optimization (如果已加密,跳过优化)
|
||||
if this.web.Optimization != nil && resp.Body != nil && this.cacheRef != nil /** must under cache **/ {
|
||||
// 如果已加密,跳过优化
|
||||
if this.web.Encryption == nil || !this.web.Encryption.IsOn || !this.web.Encryption.IsEnabled() {
|
||||
err := minifiers.MinifyResponse(this.web.Optimization, this.URL(), resp)
|
||||
if err != nil {
|
||||
this.write50x(err, http.StatusBadGateway, "Page Optimization: fail to read content from origin", "内容优化:从源站读取内容失败", false)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HLS
|
||||
if this.web.HLS != nil &&
|
||||
this.web.HLS.Encrypting != nil &&
|
||||
this.web.HLS.Encrypting.IsOn &&
|
||||
resp.StatusCode == http.StatusOK {
|
||||
m3u8Err := this.processM3u8Response(resp)
|
||||
if m3u8Err != nil {
|
||||
this.write50x(m3u8Err, http.StatusBadGateway, "m3u8 encrypt: fail to read content from origin", "m3u8加密:从源站读取内容失败", false)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 设置Charset
|
||||
// TODO 这里应该可以设置文本类型的列表
|
||||
if this.web.Charset != nil && this.web.Charset.IsOn && len(this.web.Charset.Charset) > 0 {
|
||||
contentTypes, ok := resp.Header["Content-Type"]
|
||||
if ok && len(contentTypes) > 0 {
|
||||
var contentType = contentTypes[0]
|
||||
if this.web.Charset.Force {
|
||||
var semiIndex = strings.Index(contentType, ";")
|
||||
if semiIndex > 0 {
|
||||
contentType = contentType[:semiIndex]
|
||||
}
|
||||
}
|
||||
if _, found := textMimeMap[contentType]; found {
|
||||
var newCharset = this.web.Charset.Charset
|
||||
if this.web.Charset.IsUpper {
|
||||
newCharset = strings.ToUpper(newCharset)
|
||||
}
|
||||
resp.Header["Content-Type"][0] = contentType + "; charset=" + newCharset
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 替换Location中的源站地址
|
||||
var locationHeader = resp.Header.Get("Location")
|
||||
if len(locationHeader) > 0 {
|
||||
// 空Location处理
|
||||
if locationHeader == emptyHTTPLocation {
|
||||
resp.Header.Del("Location")
|
||||
} else {
|
||||
// 自动修正Location中的源站地址
|
||||
locationURL, err := url.Parse(locationHeader)
|
||||
if err == nil && locationURL.Host != this.ReqHost && (locationURL.Host == originAddr || strings.HasPrefix(originAddr, locationURL.Host+":")) {
|
||||
locationURL.Host = this.ReqHost
|
||||
|
||||
var oldScheme = locationURL.Scheme
|
||||
|
||||
// 尝试和当前Scheme一致
|
||||
if this.IsHTTP {
|
||||
locationURL.Scheme = "http"
|
||||
} else if this.IsHTTPS {
|
||||
locationURL.Scheme = "https"
|
||||
}
|
||||
|
||||
// 如果和当前URL一样,则可能是http -> https,防止无限循环
|
||||
if locationURL.String() == this.URL() {
|
||||
locationURL.Scheme = oldScheme
|
||||
resp.Header.Set("Location", locationURL.String())
|
||||
} else {
|
||||
resp.Header.Set("Location", locationURL.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 响应Header
|
||||
this.writer.AddHeaders(resp.Header)
|
||||
this.ProcessResponseHeaders(this.writer.Header(), resp.StatusCode)
|
||||
|
||||
// 是否需要刷新
|
||||
var shouldAutoFlush = this.reverseProxy.AutoFlush || (resp.Header != nil && strings.Contains(resp.Header.Get("Content-Type"), "stream"))
|
||||
|
||||
// 设置当前连接为Persistence
|
||||
if shouldAutoFlush && this.nodeConfig != nil && this.nodeConfig.HasConnTimeoutSettings() {
|
||||
var requestConn = this.RawReq.Context().Value(HTTPConnContextKey)
|
||||
if requestConn == nil {
|
||||
return
|
||||
}
|
||||
requestClientConn, ok := requestConn.(ClientConnInterface)
|
||||
if ok {
|
||||
requestClientConn.SetIsPersistent(true)
|
||||
}
|
||||
}
|
||||
|
||||
// 准备
|
||||
var delayHeaders = this.writer.Prepare(resp, resp.ContentLength, resp.StatusCode, true)
|
||||
|
||||
// 设置响应代码
|
||||
if !delayHeaders {
|
||||
this.writer.WriteHeader(resp.StatusCode)
|
||||
}
|
||||
|
||||
// 是否有内容
|
||||
if resp.ContentLength == 0 && len(resp.TransferEncoding) == 0 {
|
||||
// 即使内容为0,也需要读取一次,以便于触发相关事件
|
||||
var buf = bytepool.Pool4k.Get()
|
||||
_, _ = io.CopyBuffer(this.writer, resp.Body, buf.Bytes)
|
||||
bytepool.Pool4k.Put(buf)
|
||||
_ = resp.Body.Close()
|
||||
respBodyIsClosed = true
|
||||
|
||||
this.writer.SetOk()
|
||||
return
|
||||
}
|
||||
|
||||
// 输出到客户端
|
||||
var pool = this.bytePool(resp.ContentLength)
|
||||
var buf = pool.Get()
|
||||
var err error
|
||||
if shouldAutoFlush {
|
||||
for {
|
||||
n, readErr := resp.Body.Read(buf.Bytes)
|
||||
if n > 0 {
|
||||
_, err = this.writer.Write(buf.Bytes[:n])
|
||||
this.writer.Flush()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if readErr != nil {
|
||||
err = readErr
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if this.cacheRef != nil &&
|
||||
this.cacheRef.EnableReadingOriginAsync &&
|
||||
resp.ContentLength > 0 &&
|
||||
resp.ContentLength < (128<<20) { // TODO configure max content-length in cache policy OR CacheRef
|
||||
var requestIsCanceled = false
|
||||
for {
|
||||
n, readErr := resp.Body.Read(buf.Bytes)
|
||||
|
||||
if n > 0 && !requestIsCanceled {
|
||||
_, err = this.writer.Write(buf.Bytes[:n])
|
||||
if err != nil {
|
||||
requestIsCanceled = true
|
||||
}
|
||||
}
|
||||
if readErr != nil {
|
||||
err = readErr
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
_, err = io.CopyBuffer(this.writer, resp.Body, buf.Bytes)
|
||||
}
|
||||
}
|
||||
pool.Put(buf)
|
||||
|
||||
var closeErr = resp.Body.Close()
|
||||
respBodyIsClosed = true
|
||||
if closeErr != nil {
|
||||
if !this.canIgnore(closeErr) {
|
||||
remotelogs.WarnServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Closing error: "+closeErr.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil && err != io.EOF {
|
||||
if !this.canIgnore(err) {
|
||||
remotelogs.WarnServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Writing error: "+err.Error())
|
||||
this.addError(err)
|
||||
}
|
||||
}
|
||||
|
||||
// 是否成功结束
|
||||
if (err == nil || err == io.EOF) && (closeErr == nil || closeErr == io.EOF) {
|
||||
this.writer.SetOk()
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
43
EdgeNode/internal/nodes/http_request_rewrite.go
Normal file
43
EdgeNode/internal/nodes/http_request_rewrite.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// 调用Rewrite
|
||||
func (this *HTTPRequest) doRewrite() (shouldShop bool) {
|
||||
if this.rewriteRule == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 代理
|
||||
if this.rewriteRule.Mode == serverconfigs.HTTPRewriteModeProxy {
|
||||
// 外部URL
|
||||
if this.rewriteIsExternalURL {
|
||||
host := this.ReqHost
|
||||
if len(this.rewriteRule.ProxyHost) > 0 {
|
||||
host = this.rewriteRule.ProxyHost
|
||||
}
|
||||
this.doURL(this.RawReq.Method, this.rewriteReplace, host, 0, false)
|
||||
return true
|
||||
}
|
||||
|
||||
// 内部URL继续
|
||||
return false
|
||||
}
|
||||
|
||||
// 跳转
|
||||
if this.rewriteRule.Mode == serverconfigs.HTTPRewriteModeRedirect {
|
||||
if this.rewriteRule.RedirectStatus > 0 {
|
||||
this.ProcessResponseHeaders(this.writer.Header(), this.rewriteRule.RedirectStatus)
|
||||
httpRedirect(this.writer, this.RawReq, this.rewriteReplace, this.rewriteRule.RedirectStatus)
|
||||
} else {
|
||||
this.ProcessResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect)
|
||||
httpRedirect(this.writer, this.RawReq, this.rewriteReplace, http.StatusTemporaryRedirect)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
460
EdgeNode/internal/nodes/http_request_root.go
Normal file
460
EdgeNode/internal/nodes/http_request_root.go
Normal file
@@ -0,0 +1,460 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
rangeutils "github.com/TeaOSLab/EdgeNode/internal/utils/ranges"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/zero"
|
||||
"github.com/cespare/xxhash/v2"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"io"
|
||||
"io/fs"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// 文本mime-type列表
|
||||
var textMimeMap = map[string]zero.Zero{
|
||||
"application/atom+xml": {},
|
||||
"application/javascript": {},
|
||||
"application/x-javascript": {},
|
||||
"application/json": {},
|
||||
"application/rss+xml": {},
|
||||
"application/x-web-app-manifest+json": {},
|
||||
"application/xhtml+xml": {},
|
||||
"application/xml": {},
|
||||
"image/svg+xml": {},
|
||||
"text/css": {},
|
||||
"text/plain": {},
|
||||
"text/javascript": {},
|
||||
"text/xml": {},
|
||||
"text/html": {},
|
||||
"text/xhtml": {},
|
||||
"text/sgml": {},
|
||||
}
|
||||
|
||||
// 调用本地静态资源
|
||||
// 如果返回true,则终止请求
|
||||
func (this *HTTPRequest) doRoot() (isBreak bool) {
|
||||
if this.web.Root == nil || !this.web.Root.IsOn {
|
||||
return
|
||||
}
|
||||
|
||||
if len(this.uri) == 0 {
|
||||
this.write404()
|
||||
return true
|
||||
}
|
||||
|
||||
var rootDir = this.web.Root.Dir
|
||||
if this.web.Root.HasVariables() {
|
||||
rootDir = this.Format(rootDir)
|
||||
}
|
||||
if !filepath.IsAbs(rootDir) {
|
||||
rootDir = Tea.Root + Tea.DS + rootDir
|
||||
}
|
||||
|
||||
var requestPath = this.uri
|
||||
|
||||
var questionMarkIndex = strings.Index(this.uri, "?")
|
||||
if questionMarkIndex > -1 {
|
||||
requestPath = this.uri[:questionMarkIndex]
|
||||
}
|
||||
|
||||
// except hidden files
|
||||
if this.web.Root.ExceptHiddenFiles &&
|
||||
(strings.Contains(requestPath, "/.") || strings.Contains(requestPath, "\\.")) {
|
||||
this.write404()
|
||||
return true
|
||||
}
|
||||
|
||||
// except and only files
|
||||
if !this.web.Root.MatchURL(this.URL()) {
|
||||
this.write404()
|
||||
return true
|
||||
}
|
||||
|
||||
// 去掉其中的奇怪的路径
|
||||
requestPath = strings.Replace(requestPath, "..\\", "", -1)
|
||||
|
||||
// 进行URL Decode
|
||||
if this.web.Root.DecodePath {
|
||||
p, err := url.QueryUnescape(requestPath)
|
||||
if err == nil {
|
||||
requestPath = p
|
||||
} else {
|
||||
if !this.canIgnore(err) {
|
||||
logs.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 去掉前缀
|
||||
stripPrefix := this.web.Root.StripPrefix
|
||||
if len(stripPrefix) > 0 {
|
||||
if stripPrefix[0] != '/' {
|
||||
stripPrefix = "/" + stripPrefix
|
||||
}
|
||||
|
||||
requestPath = strings.TrimPrefix(requestPath, stripPrefix)
|
||||
if len(requestPath) == 0 || requestPath[0] != '/' {
|
||||
requestPath = "/" + requestPath
|
||||
}
|
||||
}
|
||||
|
||||
var filename = strings.Replace(requestPath, "/", Tea.DS, -1)
|
||||
var filePath string
|
||||
if len(filename) > 0 && filename[0:1] == Tea.DS {
|
||||
filePath = rootDir + filename
|
||||
} else {
|
||||
filePath = rootDir + Tea.DS + filename
|
||||
}
|
||||
|
||||
this.filePath = filePath // 用来记录日志
|
||||
|
||||
stat, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
_, isPathError := err.(*fs.PathError)
|
||||
if os.IsNotExist(err) || isPathError {
|
||||
if this.web.Root.IsBreak {
|
||||
this.write404()
|
||||
return true
|
||||
}
|
||||
return
|
||||
} else {
|
||||
this.write50x(err, http.StatusInternalServerError, "Failed to stat the file", "查看文件统计信息失败", true)
|
||||
if !this.canIgnore(err) {
|
||||
logs.Error(err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
if stat.IsDir() {
|
||||
indexFile, indexStat := this.findIndexFile(filePath)
|
||||
if len(indexFile) > 0 {
|
||||
filePath += Tea.DS + indexFile
|
||||
} else {
|
||||
if this.web.Root.IsBreak {
|
||||
this.write404()
|
||||
return true
|
||||
}
|
||||
return
|
||||
}
|
||||
this.filePath = filePath
|
||||
|
||||
// stat again
|
||||
if indexStat == nil {
|
||||
stat, err = os.Stat(filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
if this.web.Root.IsBreak {
|
||||
this.write404()
|
||||
return true
|
||||
}
|
||||
return
|
||||
} else {
|
||||
this.write50x(err, http.StatusInternalServerError, "Failed to stat the file", "查看文件统计信息失败", true)
|
||||
if !this.canIgnore(err) {
|
||||
logs.Error(err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
stat = indexStat
|
||||
}
|
||||
}
|
||||
|
||||
// 响应header
|
||||
var respHeader = this.writer.Header()
|
||||
|
||||
// mime type
|
||||
var contentType = ""
|
||||
if this.web.ResponseHeaderPolicy == nil || !this.web.ResponseHeaderPolicy.IsOn || !this.web.ResponseHeaderPolicy.ContainsHeader("CONTENT-TYPE") {
|
||||
var ext = filepath.Ext(filePath)
|
||||
if len(ext) > 0 {
|
||||
mimeType := mime.TypeByExtension(ext)
|
||||
if len(mimeType) > 0 {
|
||||
var semicolonIndex = strings.Index(mimeType, ";")
|
||||
var mimeTypeKey = mimeType
|
||||
if semicolonIndex > 0 {
|
||||
mimeTypeKey = mimeType[:semicolonIndex]
|
||||
}
|
||||
|
||||
if _, found := textMimeMap[mimeTypeKey]; found {
|
||||
if this.web.Charset != nil && this.web.Charset.IsOn && len(this.web.Charset.Charset) > 0 {
|
||||
var charset = this.web.Charset.Charset
|
||||
if this.web.Charset.IsUpper {
|
||||
charset = strings.ToUpper(charset)
|
||||
}
|
||||
contentType = mimeTypeKey + "; charset=" + charset
|
||||
respHeader.Set("Content-Type", mimeTypeKey+"; charset="+charset)
|
||||
} else {
|
||||
contentType = mimeType
|
||||
respHeader.Set("Content-Type", mimeType)
|
||||
}
|
||||
} else {
|
||||
contentType = mimeType
|
||||
respHeader.Set("Content-Type", mimeType)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// length
|
||||
var fileSize = stat.Size()
|
||||
|
||||
// 支持 Last-Modified
|
||||
modifiedTime := stat.ModTime().Format("Mon, 02 Jan 2006 15:04:05 GMT")
|
||||
if len(respHeader.Get("Last-Modified")) == 0 {
|
||||
respHeader.Set("Last-Modified", modifiedTime)
|
||||
}
|
||||
|
||||
// 支持 ETag
|
||||
var eTag = "\"e" + fmt.Sprintf("%0x", xxhash.Sum64String(filename+strconv.FormatInt(stat.ModTime().UnixNano(), 10)+strconv.FormatInt(stat.Size(), 10))) + "\""
|
||||
if len(respHeader.Get("ETag")) == 0 {
|
||||
respHeader.Set("ETag", eTag)
|
||||
}
|
||||
|
||||
// 调用回调
|
||||
this.onRequest()
|
||||
if this.writer.isFinished {
|
||||
return
|
||||
}
|
||||
|
||||
// 支持 If-None-Match
|
||||
if this.requestHeader("If-None-Match") == eTag {
|
||||
// 自定义Header
|
||||
this.ProcessResponseHeaders(this.writer.Header(), http.StatusNotModified)
|
||||
this.writer.WriteHeader(http.StatusNotModified)
|
||||
return true
|
||||
}
|
||||
|
||||
// 支持 If-Modified-Since
|
||||
if this.requestHeader("If-Modified-Since") == modifiedTime {
|
||||
// 自定义Header
|
||||
this.ProcessResponseHeaders(this.writer.Header(), http.StatusNotModified)
|
||||
this.writer.WriteHeader(http.StatusNotModified)
|
||||
return true
|
||||
}
|
||||
|
||||
// 支持Range
|
||||
respHeader.Set("Accept-Ranges", "bytes")
|
||||
ifRangeHeaders, ok := this.RawReq.Header["If-Range"]
|
||||
var supportRange = true
|
||||
if ok {
|
||||
supportRange = false
|
||||
for _, v := range ifRangeHeaders {
|
||||
if v == eTag || v == modifiedTime {
|
||||
supportRange = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !supportRange {
|
||||
respHeader.Del("Accept-Ranges")
|
||||
}
|
||||
}
|
||||
|
||||
// 支持Range
|
||||
var ranges = []rangeutils.Range{}
|
||||
if supportRange {
|
||||
var contentRange = this.RawReq.Header.Get("Range")
|
||||
if len(contentRange) > 0 {
|
||||
if fileSize == 0 {
|
||||
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
|
||||
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
|
||||
return true
|
||||
}
|
||||
|
||||
set, ok := httpRequestParseRangeHeader(contentRange)
|
||||
if !ok {
|
||||
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
|
||||
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
|
||||
return true
|
||||
}
|
||||
if len(set) > 0 {
|
||||
ranges = set
|
||||
for k, r := range ranges {
|
||||
r2, ok := r.Convert(fileSize)
|
||||
if !ok {
|
||||
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
|
||||
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
|
||||
return true
|
||||
}
|
||||
ranges[k] = r2
|
||||
}
|
||||
}
|
||||
} else {
|
||||
respHeader.Set("Content-Length", strconv.FormatInt(fileSize, 10))
|
||||
}
|
||||
} else {
|
||||
respHeader.Set("Content-Length", strconv.FormatInt(fileSize, 10))
|
||||
}
|
||||
|
||||
fileReader, err := os.OpenFile(filePath, os.O_RDONLY, 0444)
|
||||
if err != nil {
|
||||
this.write50x(err, http.StatusInternalServerError, "Failed to open the file", "试图打开文件失败", true)
|
||||
return true
|
||||
}
|
||||
|
||||
// 自定义Header
|
||||
this.ProcessResponseHeaders(this.writer.Header(), http.StatusOK)
|
||||
|
||||
// 在Range请求中不能缓存
|
||||
if len(ranges) > 0 {
|
||||
this.cacheRef = nil // 不支持缓存
|
||||
}
|
||||
|
||||
var resp = &http.Response{
|
||||
ContentLength: fileSize,
|
||||
Body: fileReader,
|
||||
StatusCode: http.StatusOK,
|
||||
}
|
||||
this.writer.Prepare(resp, fileSize, http.StatusOK, true)
|
||||
|
||||
var pool = this.bytePool(fileSize)
|
||||
var buf = pool.Get()
|
||||
defer func() {
|
||||
pool.Put(buf)
|
||||
}()
|
||||
|
||||
if len(ranges) == 1 {
|
||||
respHeader.Set("Content-Range", ranges[0].ComposeContentRangeHeader(types.String(fileSize)))
|
||||
this.writer.WriteHeader(http.StatusPartialContent)
|
||||
|
||||
ok, err := httpRequestReadRange(resp.Body, buf.Bytes, ranges[0].Start(), ranges[0].End(), func(buf []byte, n int) error {
|
||||
_, err := this.writer.Write(buf[:n])
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
if !this.canIgnore(err) {
|
||||
logs.Error(err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
if !ok {
|
||||
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
|
||||
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
|
||||
return true
|
||||
}
|
||||
} else if len(ranges) > 1 {
|
||||
var boundary = httpRequestGenBoundary()
|
||||
respHeader.Set("Content-Type", "multipart/byteranges; boundary="+boundary)
|
||||
|
||||
this.writer.WriteHeader(http.StatusPartialContent)
|
||||
|
||||
for index, r := range ranges {
|
||||
if index == 0 {
|
||||
_, err = this.writer.WriteString("--" + boundary + "\r\n")
|
||||
} else {
|
||||
_, err = this.writer.WriteString("\r\n--" + boundary + "\r\n")
|
||||
}
|
||||
if err != nil {
|
||||
if !this.canIgnore(err) {
|
||||
logs.Error(err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
_, err = this.writer.WriteString("Content-Range: " + r.ComposeContentRangeHeader(types.String(fileSize)) + "\r\n")
|
||||
if err != nil {
|
||||
if !this.canIgnore(err) {
|
||||
logs.Error(err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
if len(contentType) > 0 {
|
||||
_, err = this.writer.WriteString("Content-Type: " + contentType + "\r\n\r\n")
|
||||
if err != nil {
|
||||
if !this.canIgnore(err) {
|
||||
logs.Error(err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
ok, err := httpRequestReadRange(resp.Body, buf.Bytes, r.Start(), r.End(), func(buf []byte, n int) error {
|
||||
_, err := this.writer.Write(buf[:n])
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
if !this.canIgnore(err) {
|
||||
logs.Error(err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
if !ok {
|
||||
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
|
||||
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
_, err = this.writer.WriteString("\r\n--" + boundary + "--\r\n")
|
||||
if err != nil {
|
||||
if !this.canIgnore(err) {
|
||||
logs.Error(err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
_, err = io.CopyBuffer(this.writer, resp.Body, buf.Bytes)
|
||||
if err != nil {
|
||||
if !this.canIgnore(err) {
|
||||
logs.Error(err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// 设置成功
|
||||
this.writer.SetOk()
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// 查找首页文件
|
||||
func (this *HTTPRequest) findIndexFile(dir string) (filename string, stat os.FileInfo) {
|
||||
if this.web.Root == nil || !this.web.Root.IsOn {
|
||||
return "", nil
|
||||
}
|
||||
if len(this.web.Root.Indexes) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
for _, index := range this.web.Root.Indexes {
|
||||
if len(index) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// 模糊查找
|
||||
if strings.Contains(index, "*") {
|
||||
indexFiles, err := filepath.Glob(dir + Tea.DS + index)
|
||||
if err != nil {
|
||||
if !this.canIgnore(err) {
|
||||
logs.Error(err)
|
||||
}
|
||||
this.addError(err)
|
||||
continue
|
||||
}
|
||||
if len(indexFiles) > 0 {
|
||||
return filepath.Base(indexFiles[0]), nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 精确查找
|
||||
filePath := dir + Tea.DS + index
|
||||
stat, err := os.Stat(filePath)
|
||||
if err != nil || !stat.Mode().IsRegular() {
|
||||
continue
|
||||
}
|
||||
return index, stat
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
115
EdgeNode/internal/nodes/http_request_shutdown.go
Normal file
115
EdgeNode/internal/nodes/http_request_shutdown.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/bytepool"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// 调用临时关闭页面
|
||||
func (this *HTTPRequest) doShutdown() {
|
||||
var shutdown = this.web.Shutdown
|
||||
if shutdown == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if len(shutdown.BodyType) == 0 || shutdown.BodyType == serverconfigs.HTTPPageBodyTypeURL {
|
||||
// URL
|
||||
if urlSchemeRegexp.MatchString(shutdown.URL) {
|
||||
this.doURL(http.MethodGet, shutdown.URL, "", shutdown.Status, true)
|
||||
return
|
||||
}
|
||||
|
||||
// URL为空,则显示文本
|
||||
if len(shutdown.URL) == 0 {
|
||||
// 自定义响应Headers
|
||||
if shutdown.Status > 0 {
|
||||
this.ProcessResponseHeaders(this.writer.Header(), shutdown.Status)
|
||||
this.writer.WriteHeader(shutdown.Status)
|
||||
} else {
|
||||
this.ProcessResponseHeaders(this.writer.Header(), http.StatusOK)
|
||||
this.writer.WriteHeader(http.StatusOK)
|
||||
}
|
||||
_, _ = this.writer.WriteString("The site have been shutdown.")
|
||||
return
|
||||
}
|
||||
|
||||
// 从本地文件中读取
|
||||
var realpath = path.Clean(shutdown.URL)
|
||||
if !strings.HasPrefix(realpath, "/pages/") && !strings.HasPrefix(realpath, "pages/") { // only files under "/pages/" can be used
|
||||
var msg = "404 page not found: '" + shutdown.URL + "'"
|
||||
this.writer.WriteHeader(http.StatusNotFound)
|
||||
_, _ = this.writer.Write([]byte(msg))
|
||||
return
|
||||
}
|
||||
var file = Tea.Root + Tea.DS + shutdown.URL
|
||||
fp, err := os.Open(file)
|
||||
if err != nil {
|
||||
var msg = "404 page not found: '" + shutdown.URL + "'"
|
||||
this.writer.WriteHeader(http.StatusNotFound)
|
||||
_, _ = this.writer.Write([]byte(msg))
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = fp.Close()
|
||||
}()
|
||||
|
||||
// 自定义响应Headers
|
||||
if shutdown.Status > 0 {
|
||||
this.ProcessResponseHeaders(this.writer.Header(), shutdown.Status)
|
||||
this.writer.WriteHeader(shutdown.Status)
|
||||
} else {
|
||||
this.ProcessResponseHeaders(this.writer.Header(), http.StatusOK)
|
||||
this.writer.WriteHeader(http.StatusOK)
|
||||
}
|
||||
var buf = bytepool.Pool1k.Get()
|
||||
_, err = utils.CopyWithFilter(this.writer, fp, buf.Bytes, func(p []byte) []byte {
|
||||
return []byte(this.Format(string(p)))
|
||||
})
|
||||
bytepool.Pool1k.Put(buf)
|
||||
if err != nil {
|
||||
if !this.canIgnore(err) {
|
||||
remotelogs.Warn("HTTP_REQUEST_SHUTDOWN", "write to client failed: "+err.Error())
|
||||
}
|
||||
} else {
|
||||
this.writer.SetOk()
|
||||
}
|
||||
} else if shutdown.BodyType == serverconfigs.HTTPPageBodyTypeHTML {
|
||||
// 自定义响应Headers
|
||||
if shutdown.Status > 0 {
|
||||
this.ProcessResponseHeaders(this.writer.Header(), shutdown.Status)
|
||||
this.writer.WriteHeader(shutdown.Status)
|
||||
} else {
|
||||
this.ProcessResponseHeaders(this.writer.Header(), http.StatusOK)
|
||||
this.writer.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
_, err := this.writer.WriteString(this.Format(shutdown.Body))
|
||||
if err != nil {
|
||||
if !this.canIgnore(err) {
|
||||
remotelogs.Warn("HTTP_REQUEST_SHUTDOWN", "write to client failed: "+err.Error())
|
||||
}
|
||||
} else {
|
||||
this.writer.SetOk()
|
||||
}
|
||||
} else if shutdown.BodyType == serverconfigs.HTTPPageBodyTypeRedirectURL {
|
||||
var newURL = shutdown.URL
|
||||
if len(newURL) == 0 {
|
||||
newURL = "/"
|
||||
}
|
||||
|
||||
if shutdown.Status > 0 && httpStatusIsRedirect(shutdown.Status) {
|
||||
httpRedirect(this.writer, this.RawReq, newURL, shutdown.Status)
|
||||
} else {
|
||||
httpRedirect(this.writer, this.RawReq, newURL, http.StatusTemporaryRedirect)
|
||||
}
|
||||
this.writer.SetOk()
|
||||
}
|
||||
}
|
||||
16
EdgeNode/internal/nodes/http_request_stat.go
Normal file
16
EdgeNode/internal/nodes/http_request_stat.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/stats"
|
||||
)
|
||||
|
||||
// 统计
|
||||
func (this *HTTPRequest) doStat() {
|
||||
if this.ReqServer == nil || this.web == nil || this.web.StatRef == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 内置的统计
|
||||
stats.SharedHTTPRequestStatManager.AddRemoteAddr(this.ReqServer.Id, this.requestRemoteAddr(true), this.writer.SentBodyBytes(), this.isAttack)
|
||||
stats.SharedHTTPRequestStatManager.AddUserAgent(this.ReqServer.Id, this.requestHeader("User-Agent"), this.remoteAddr)
|
||||
}
|
||||
22
EdgeNode/internal/nodes/http_request_sub.go
Normal file
22
EdgeNode/internal/nodes/http_request_sub.go
Normal file
@@ -0,0 +1,22 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import "net/http"
|
||||
|
||||
// 执行子请求
|
||||
func (this *HTTPRequest) doSubRequest(writer http.ResponseWriter, rawReq *http.Request) {
|
||||
// 包装新请求对象
|
||||
req := &HTTPRequest{
|
||||
RawReq: rawReq,
|
||||
RawWriter: writer,
|
||||
ReqServer: this.ReqServer,
|
||||
ReqHost: this.ReqHost,
|
||||
ServerName: this.ServerName,
|
||||
ServerAddr: this.ServerAddr,
|
||||
IsHTTP: this.IsHTTP,
|
||||
IsHTTPS: this.IsHTTPS,
|
||||
}
|
||||
req.isSubRequest = true
|
||||
req.Do()
|
||||
}
|
||||
71
EdgeNode/internal/nodes/http_request_test.go
Normal file
71
EdgeNode/internal/nodes/http_request_test.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHTTPRequest_RedirectToHTTPS(t *testing.T) {
|
||||
var a = assert.NewAssertion(t)
|
||||
{
|
||||
rawReq, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var req = &HTTPRequest{
|
||||
RawReq: rawReq,
|
||||
RawWriter: NewEmptyResponseWriter(nil),
|
||||
ReqServer: &serverconfigs.ServerConfig{
|
||||
IsOn: true,
|
||||
Web: &serverconfigs.HTTPWebConfig{
|
||||
IsOn: true,
|
||||
RedirectToHttps: &serverconfigs.HTTPRedirectToHTTPSConfig{},
|
||||
},
|
||||
},
|
||||
}
|
||||
req.init()
|
||||
req.Do()
|
||||
|
||||
a.IsBool(req.web.RedirectToHttps.IsOn == false)
|
||||
}
|
||||
{
|
||||
rawReq, err := http.NewRequest(http.MethodGet, "/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var req = &HTTPRequest{
|
||||
RawReq: rawReq,
|
||||
RawWriter: NewEmptyResponseWriter(nil),
|
||||
ReqServer: &serverconfigs.ServerConfig{
|
||||
IsOn: true,
|
||||
Web: &serverconfigs.HTTPWebConfig{
|
||||
IsOn: true,
|
||||
RedirectToHttps: &serverconfigs.HTTPRedirectToHTTPSConfig{
|
||||
IsOn: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
req.init()
|
||||
req.Do()
|
||||
a.IsBool(req.web.RedirectToHttps.IsOn == true)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPRequest_Memory(t *testing.T) {
|
||||
var stat1 = &runtime.MemStats{}
|
||||
runtime.ReadMemStats(stat1)
|
||||
|
||||
var requests = []*HTTPRequest{}
|
||||
for i := 0; i < 1_000_000; i++ {
|
||||
requests = append(requests, &HTTPRequest{})
|
||||
}
|
||||
|
||||
var stat2 = &runtime.MemStats{}
|
||||
runtime.ReadMemStats(stat2)
|
||||
t.Log((stat2.HeapInuse-stat1.HeapInuse)/1024/1024, "MB,")
|
||||
t.Log(len(requests))
|
||||
}
|
||||
120
EdgeNode/internal/nodes/http_request_token.go
Normal file
120
EdgeNode/internal/nodes/http_request_token.go
Normal file
@@ -0,0 +1,120 @@
|
||||
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/ttlcache"
|
||||
)
|
||||
|
||||
// TokenCache Token 缓存
|
||||
type TokenCache struct {
|
||||
cache *ttlcache.Cache[string] // IP => Token
|
||||
}
|
||||
|
||||
var sharedTokenCache *TokenCache
|
||||
var sharedTokenCacheOnce sync.Once
|
||||
|
||||
// SharedTokenCache 获取共享 Token 缓存实例
|
||||
func SharedTokenCache() *TokenCache {
|
||||
sharedTokenCacheOnce.Do(func() {
|
||||
sharedTokenCache = NewTokenCache()
|
||||
})
|
||||
return sharedTokenCache
|
||||
}
|
||||
|
||||
// NewTokenCache 创建新 Token 缓存
|
||||
func NewTokenCache() *TokenCache {
|
||||
cache := ttlcache.NewCache[string](
|
||||
ttlcache.NewMaxItemsOption(10000),
|
||||
ttlcache.NewGCConfigOption().
|
||||
WithBaseInterval(5*time.Minute).
|
||||
WithMinInterval(2*time.Minute).
|
||||
WithMaxInterval(10*time.Minute).
|
||||
WithAdaptive(true),
|
||||
)
|
||||
|
||||
return &TokenCache{
|
||||
cache: cache,
|
||||
}
|
||||
}
|
||||
|
||||
// Set 设置 Token
|
||||
func (this *TokenCache) Set(ip string, token string) {
|
||||
expiresAt := fasttime.Now().Unix() + 300 // 5 分钟过期
|
||||
this.cache.Write(ip, token, expiresAt)
|
||||
}
|
||||
|
||||
// Get 获取 Token
|
||||
func (this *TokenCache) Get(ip string) (string, bool) {
|
||||
item := this.cache.Read(ip)
|
||||
if item == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
if item.ExpiresAt() < fasttime.Now().Unix() {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return item.Value, true
|
||||
}
|
||||
|
||||
// ValidateToken 验证 Token
|
||||
func (this *HTTPRequest) validateEncryptionToken(token string) bool {
|
||||
if this.web.Encryption == nil || !this.web.Encryption.IsOn || !this.web.Encryption.IsEnabled() {
|
||||
return true // 未启用加密,直接通过
|
||||
}
|
||||
|
||||
if len(token) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
remoteIP := this.requestRemoteAddr(true)
|
||||
cache := SharedTokenCache()
|
||||
storedToken, ok := cache.Get(remoteIP)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return storedToken == token
|
||||
}
|
||||
|
||||
// handleTokenHandshake 处理 Token 握手请求
|
||||
func (this *HTTPRequest) handleTokenHandshake() {
|
||||
if this.RawReq.Method != http.MethodPost {
|
||||
this.writeCode(http.StatusMethodNotAllowed, "Method Not Allowed", "方法不允许")
|
||||
return
|
||||
}
|
||||
|
||||
// 解析请求体
|
||||
var requestBody struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
err := json.NewDecoder(this.RawReq.Body).Decode(&requestBody)
|
||||
if err != nil {
|
||||
this.writeCode(http.StatusBadRequest, "Invalid Request", "请求格式错误")
|
||||
return
|
||||
}
|
||||
|
||||
if len(requestBody.Token) == 0 {
|
||||
this.writeCode(http.StatusBadRequest, "Token Required", "Token 不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
// 保存 Token
|
||||
remoteIP := this.requestRemoteAddr(true)
|
||||
cache := SharedTokenCache()
|
||||
cache.Set(remoteIP, requestBody.Token)
|
||||
|
||||
// 返回成功响应
|
||||
this.writer.Header().Set("Content-Type", "application/json")
|
||||
this.writer.WriteHeader(http.StatusOK)
|
||||
_, _ = this.writer.WriteString(`{"success":true}`)
|
||||
this.writer.SetOk()
|
||||
}
|
||||
47
EdgeNode/internal/nodes/http_request_traffic_limit.go
Normal file
47
EdgeNode/internal/nodes/http_request_traffic_limit.go
Normal file
@@ -0,0 +1,47 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
)
|
||||
|
||||
// 流量限制
|
||||
func (this *HTTPRequest) doTrafficLimit(status *serverconfigs.TrafficLimitStatus) (blocked bool) {
|
||||
if status == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 如果是网站单独设置的流量限制,则检查是否已关闭
|
||||
var config = this.ReqServer.TrafficLimit
|
||||
if (config == nil || !config.IsOn) && status.PlanId == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 如果是套餐设置的流量限制,即使套餐变更了(变更套餐或者变更套餐的限制),仍然会提示流量超限
|
||||
|
||||
this.tags = append(this.tags, "trafficLimit")
|
||||
|
||||
var statusCode = 509
|
||||
this.writer.statusCode = statusCode
|
||||
this.ProcessResponseHeaders(this.writer.Header(), statusCode)
|
||||
|
||||
this.writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
this.writer.WriteHeader(statusCode)
|
||||
|
||||
// check plan traffic limit
|
||||
if (config == nil || !config.IsOn) && this.ReqServer.PlanId() > 0 && this.nodeConfig != nil {
|
||||
var planConfig = this.nodeConfig.FindPlan(this.ReqServer.PlanId())
|
||||
if planConfig != nil && planConfig.TrafficLimit != nil && planConfig.TrafficLimit.IsOn {
|
||||
config = planConfig.TrafficLimit
|
||||
}
|
||||
}
|
||||
|
||||
if config != nil && len(config.NoticePageBody) != 0 {
|
||||
_, _ = this.writer.WriteString(this.Format(config.NoticePageBody))
|
||||
} else {
|
||||
_, _ = this.writer.WriteString(this.Format(serverconfigs.DefaultTrafficLimitNoticePageBody))
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
16
EdgeNode/internal/nodes/http_request_uam.go
Normal file
16
EdgeNode/internal/nodes/http_request_uam.go
Normal file
@@ -0,0 +1,16 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build !plus
|
||||
// +build !plus
|
||||
|
||||
package nodes
|
||||
|
||||
func (this *HTTPRequest) isUAMRequest() bool {
|
||||
// stub
|
||||
return false
|
||||
}
|
||||
|
||||
// UAM
|
||||
func (this *HTTPRequest) doUAM() (block bool) {
|
||||
// stub
|
||||
return false
|
||||
}
|
||||
210
EdgeNode/internal/nodes/http_request_uam_plus.go
Normal file
210
EdgeNode/internal/nodes/http_request_uam_plus.go
Normal file
@@ -0,0 +1,210 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build plus
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/uam"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/agents"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/counters"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/ttlcache"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var sharedUAMManager *uam.Manager
|
||||
|
||||
func init() {
|
||||
if !teaconst.IsMain {
|
||||
return
|
||||
}
|
||||
|
||||
events.On(events.EventLoaded, func() {
|
||||
if sharedUAMManager != nil {
|
||||
return
|
||||
}
|
||||
manager, _ := uam.NewManager(sharedNodeConfig.NodeId, sharedNodeConfig.Secret)
|
||||
if manager != nil {
|
||||
sharedUAMManager = manager
|
||||
}
|
||||
})
|
||||
|
||||
events.On(events.EventReload, func() {
|
||||
if sharedUAMManager != nil {
|
||||
return
|
||||
}
|
||||
manager, _ := uam.NewManager(sharedNodeConfig.NodeId, sharedNodeConfig.Secret)
|
||||
if manager != nil {
|
||||
sharedUAMManager = manager
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) isUAMRequest() bool {
|
||||
if this.web.UAM == nil || !this.web.UAM.IsOn || this.RawReq.Method != http.MethodPost {
|
||||
return false
|
||||
}
|
||||
|
||||
var cookiesString = this.RawReq.Header.Get("Cookie")
|
||||
if len(cookiesString) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
return strings.HasPrefix(cookiesString, uam.CookiePrevKey+"=") ||
|
||||
strings.Contains(cookiesString, " "+uam.CookiePrevKey+"=")
|
||||
}
|
||||
|
||||
// UAM
|
||||
// TODO 需要检查是否为plus
|
||||
func (this *HTTPRequest) doUAM() (block bool) {
|
||||
var serverId int64
|
||||
if this.ReqServer != nil {
|
||||
serverId = this.ReqServer.Id
|
||||
}
|
||||
|
||||
var uamConfig = this.web.UAM
|
||||
if uamConfig == nil ||
|
||||
!uamConfig.IsOn ||
|
||||
!uamConfig.MatchURL(this.requestScheme()+"://"+this.ReqHost+this.Path()) ||
|
||||
!uamConfig.MatchRequest(this.Format) {
|
||||
return
|
||||
}
|
||||
|
||||
var policy = this.nodeConfig.FindUAMPolicyWithClusterId(this.ReqServer.ClusterId)
|
||||
if policy == nil {
|
||||
policy = nodeconfigs.NewUAMPolicy()
|
||||
}
|
||||
if policy == nil || !policy.IsOn {
|
||||
return
|
||||
}
|
||||
|
||||
// 获取UAM管理器
|
||||
var manager = sharedUAMManager
|
||||
if manager == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 忽略URL白名单
|
||||
if this.RawReq.URL.Path == "/favicon.ico" || this.RawReq.URL.Path == "/favicon.png" {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查白名单
|
||||
var remoteAddr = this.requestRemoteAddr(true)
|
||||
|
||||
// 检查UAM白名单
|
||||
if uamConfig.AddToWhiteList && ttlcache.SharedInt64Cache.Read("UAM:WHITE:"+remoteAddr) != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查是否为白名单直连
|
||||
if !Tea.IsTesting() && this.nodeConfig.IPIsAutoAllowed(remoteAddr) {
|
||||
return
|
||||
}
|
||||
|
||||
// 是否在全局名单中
|
||||
canGoNext, isInAllowedList, _ := iplibrary.AllowIP(remoteAddr, serverId)
|
||||
if !canGoNext {
|
||||
this.disableLog = true
|
||||
this.Close()
|
||||
return true
|
||||
}
|
||||
if isInAllowedList {
|
||||
return false
|
||||
}
|
||||
|
||||
// 如果是搜索引擎直接通过
|
||||
var userAgent = this.RawReq.UserAgent()
|
||||
if len(userAgent) == 0 {
|
||||
// 如果User-Agent为空,则直接阻止
|
||||
this.writer.WriteHeader(http.StatusForbidden)
|
||||
|
||||
// 增加失败次数
|
||||
if manager.IncreaseFails(policy, remoteAddr, serverId) {
|
||||
this.isAttack = true
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// 不管是否开启允许搜索引擎,这里都放行,避免收录了拦截的代码
|
||||
if agents.SharedManager.ContainsIP(remoteAddr) {
|
||||
return false
|
||||
}
|
||||
if policy.AllowSearchEngines {
|
||||
if searchEngineRegex.MatchString(userAgent) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 如果是python之类的直接拦截
|
||||
if policy.DenySpiders && spiderRegexp.MatchString(userAgent) {
|
||||
this.writer.WriteHeader(http.StatusForbidden)
|
||||
|
||||
// 增加失败次数
|
||||
if manager.IncreaseFails(policy, remoteAddr, serverId) {
|
||||
this.isAttack = true
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查预生成Key
|
||||
var step = this.Header().Get("X-GE-UA-Step")
|
||||
if step == uam.StepPrev {
|
||||
if this.Method() != http.MethodPost {
|
||||
this.writer.WriteHeader(http.StatusForbidden)
|
||||
return true
|
||||
}
|
||||
if manager.CheckPrevKey(policy, uamConfig, this.RawReq, remoteAddr, this.writer) {
|
||||
_, _ = this.writer.Write([]byte(`{"ok": true}`))
|
||||
} else {
|
||||
_, _ = this.writer.Write([]byte(`{"ok": false}`))
|
||||
}
|
||||
|
||||
// 增加失败次数
|
||||
if manager.IncreaseFails(policy, remoteAddr, serverId) {
|
||||
this.isAttack = true
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查是否启用QPS
|
||||
if uamConfig.MinQPSPerIP > 0 && len(step) == 0 {
|
||||
var avgQPS = counters.SharedCounter.IncreaseKey("UAM:"+remoteAddr, 60) / 60
|
||||
if avgQPS <= 0 {
|
||||
avgQPS = 1
|
||||
}
|
||||
if avgQPS < types.Uint32(uamConfig.MinQPSPerIP) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 检查Cookie中的Key
|
||||
isVerified, isAttack, _ := manager.CheckKey(policy, this.RawReq, this.writer, remoteAddr, serverId, uamConfig.KeyLife)
|
||||
if isVerified {
|
||||
return false
|
||||
}
|
||||
this.isAttack = isAttack
|
||||
|
||||
// 检查是否已生成有效的prev key,如果已经生成,则表示当前请求是附带请求(比如favicon.ico),不再重新生成新的
|
||||
// TODO 考虑这里的必要性
|
||||
if manager.ExistsActivePreKey(this.RawReq) {
|
||||
return true
|
||||
}
|
||||
|
||||
// 显示加载页面
|
||||
err := manager.LoadPage(policy, this.RawReq, this.Format, remoteAddr, this.writer)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
88
EdgeNode/internal/nodes/http_request_url.go
Normal file
88
EdgeNode/internal/nodes/http_request_url.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 请求某个URL
|
||||
func (this *HTTPRequest) doURL(method string, url string, host string, statusCode int, supportVariables bool) {
|
||||
req, err := http.NewRequest(method, url, this.RawReq.Body)
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
// 修改Host
|
||||
if len(host) > 0 {
|
||||
req.Host = this.Format(host)
|
||||
}
|
||||
|
||||
// 添加当前Header
|
||||
req.Header = this.RawReq.Header
|
||||
|
||||
// 代理头部
|
||||
this.setForwardHeaders(req.Header)
|
||||
|
||||
// 自定义请求Header
|
||||
this.processRequestHeaders(req.Header)
|
||||
|
||||
var client = utils.SharedHttpClient(60 * time.Second)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
remotelogs.Error("HTTP_REQUEST_URL", req.URL.String()+": "+err.Error())
|
||||
this.write50x(err, http.StatusInternalServerError, "Failed to read url", "读取URL失败", false)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
// Header
|
||||
if statusCode <= 0 {
|
||||
this.ProcessResponseHeaders(this.writer.Header(), resp.StatusCode)
|
||||
} else {
|
||||
this.ProcessResponseHeaders(this.writer.Header(), statusCode)
|
||||
}
|
||||
|
||||
if supportVariables {
|
||||
resp.Header.Del("Content-Length")
|
||||
}
|
||||
this.writer.AddHeaders(resp.Header)
|
||||
if statusCode <= 0 {
|
||||
this.writer.Prepare(resp, resp.ContentLength, resp.StatusCode, true)
|
||||
} else {
|
||||
this.writer.Prepare(resp, resp.ContentLength, statusCode, true)
|
||||
}
|
||||
|
||||
// 设置响应代码
|
||||
if statusCode <= 0 {
|
||||
this.writer.WriteHeader(resp.StatusCode)
|
||||
} else {
|
||||
this.writer.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
// 输出内容
|
||||
var pool = this.bytePool(resp.ContentLength)
|
||||
var buf = pool.Get()
|
||||
if supportVariables {
|
||||
_, err = utils.CopyWithFilter(this.writer, resp.Body, buf.Bytes, func(p []byte) []byte {
|
||||
return []byte(this.Format(string(p)))
|
||||
})
|
||||
} else {
|
||||
_, err = io.CopyBuffer(this.writer, resp.Body, buf.Bytes)
|
||||
}
|
||||
pool.Put(buf)
|
||||
|
||||
if err != nil {
|
||||
if !this.canIgnore(err) {
|
||||
remotelogs.Warn("HTTP_REQUEST_URL", "write to client failed: "+err.Error())
|
||||
}
|
||||
} else {
|
||||
this.writer.SetOk()
|
||||
}
|
||||
}
|
||||
24
EdgeNode/internal/nodes/http_request_user_agent.go
Normal file
24
EdgeNode/internal/nodes/http_request_user_agent.go
Normal file
@@ -0,0 +1,24 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func (this *HTTPRequest) doCheckUserAgent() (shouldStop bool) {
|
||||
if this.web.UserAgent == nil || !this.web.UserAgent.IsOn {
|
||||
return
|
||||
}
|
||||
|
||||
const cacheSeconds = "3600" // 时间不能过长,防止修改设置后长期无法生效
|
||||
|
||||
if this.web.UserAgent.MatchURL(this.URL()) && !this.web.UserAgent.AllowRequest(this.RawReq) {
|
||||
this.tags = append(this.tags, "userAgentCheck")
|
||||
this.writer.Header().Set("Cache-Control", "max-age="+cacheSeconds)
|
||||
this.writeCode(http.StatusForbidden, "The User-Agent has been blocked.", "当前访问已被UA名单拦截。")
|
||||
return true
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
241
EdgeNode/internal/nodes/http_request_utils.go
Normal file
241
EdgeNode/internal/nodes/http_request_utils.go
Normal file
@@ -0,0 +1,241 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/ranges"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// 搜索引擎和爬虫正则
|
||||
var searchEngineRegex = regexp.MustCompile(`(?i)(60spider|adldxbot|adsbot-google|applebot|admantx|alexa|baidu|bingbot|bingpreview|facebookexternalhit|googlebot|proximic|slurp|sogou|twitterbot|yandex)`)
|
||||
var spiderRegexp = regexp.MustCompile(`(?i)(python|pycurl|http-client|httpclient|apachebench|nethttp|http_request|java|perl|ruby|scrapy|php|rust)`)
|
||||
|
||||
// 内容范围正则,其中的每个括号里的内容都在被引用,不能轻易修改
|
||||
var contentRangeRegexp = regexp.MustCompile(`^bytes (\d+)-(\d+)/(\d+|\*)`)
|
||||
|
||||
// URL协议前缀
|
||||
var urlSchemeRegexp = regexp.MustCompile("^(?i)(http|https|ftp)://")
|
||||
|
||||
// 分解Range
|
||||
func httpRequestParseRangeHeader(rangeValue string) (result []rangeutils.Range, ok bool) {
|
||||
// 参考RFC:https://tools.ietf.org/html/rfc7233
|
||||
index := strings.Index(rangeValue, "=")
|
||||
if index == -1 {
|
||||
return
|
||||
}
|
||||
unit := rangeValue[:index]
|
||||
if unit != "bytes" {
|
||||
return
|
||||
}
|
||||
|
||||
var rangeSetString = rangeValue[index+1:]
|
||||
if len(rangeSetString) == 0 {
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
|
||||
var pieces = strings.Split(rangeSetString, ", ")
|
||||
for _, piece := range pieces {
|
||||
index = strings.Index(piece, "-")
|
||||
if index == -1 {
|
||||
return
|
||||
}
|
||||
first := piece[:index]
|
||||
firstInt := int64(-1)
|
||||
|
||||
var err error
|
||||
last := piece[index+1:]
|
||||
var lastInt = int64(-1)
|
||||
|
||||
if len(first) > 0 {
|
||||
firstInt, err = strconv.ParseInt(first, 10, 64)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if len(last) > 0 {
|
||||
lastInt, err = strconv.ParseInt(last, 10, 64)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if lastInt < firstInt {
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if len(last) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
lastInt, err = strconv.ParseInt(last, 10, 64)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
lastInt = -lastInt
|
||||
}
|
||||
|
||||
result = append(result, [2]int64{firstInt, lastInt})
|
||||
}
|
||||
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
|
||||
// 读取内容Range
|
||||
func httpRequestReadRange(reader io.Reader, buf []byte, start int64, end int64, callback func(buf []byte, n int) error) (ok bool, err error) {
|
||||
if start < 0 || end < 0 {
|
||||
return
|
||||
}
|
||||
seeker, ok := reader.(io.Seeker)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
_, err = seeker.Seek(start, io.SeekStart)
|
||||
if err != nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
offset := start
|
||||
for {
|
||||
n, err := reader.Read(buf)
|
||||
if n > 0 {
|
||||
offset += int64(n)
|
||||
if end < offset {
|
||||
err = callback(buf, n-int(offset-end-1))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
} else {
|
||||
err = callback(buf, n)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return true, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 分解Content-Range
|
||||
func httpRequestParseContentRangeHeader(contentRange string) (start int64, total int64) {
|
||||
var matches = contentRangeRegexp.FindStringSubmatch(contentRange)
|
||||
if len(matches) < 4 {
|
||||
return -1, -1
|
||||
}
|
||||
|
||||
start = types.Int64(matches[1])
|
||||
var sizeString = matches[3]
|
||||
if sizeString != "*" {
|
||||
total = types.Int64(sizeString)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 生成boundary
|
||||
// 仿照Golang自带的函数(multipart包)
|
||||
func httpRequestGenBoundary() string {
|
||||
var buf [8]byte
|
||||
_, err := io.ReadFull(rand.Reader, buf[:])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return fmt.Sprintf("%x", buf[:])
|
||||
}
|
||||
|
||||
// 从content-type中读取boundary
|
||||
func httpRequestParseBoundary(contentType string) string {
|
||||
var delim = "boundary="
|
||||
var boundaryIndex = strings.Index(contentType, delim)
|
||||
if boundaryIndex < 0 {
|
||||
return ""
|
||||
}
|
||||
var boundary = contentType[boundaryIndex+len(delim):]
|
||||
semicolonIndex := strings.Index(boundary, ";")
|
||||
if semicolonIndex >= 0 {
|
||||
return boundary[:semicolonIndex]
|
||||
}
|
||||
return boundary
|
||||
}
|
||||
|
||||
// 判断状态是否为跳转
|
||||
func httpStatusIsRedirect(statusCode int) bool {
|
||||
return statusCode == http.StatusPermanentRedirect ||
|
||||
statusCode == http.StatusTemporaryRedirect ||
|
||||
statusCode == http.StatusMovedPermanently ||
|
||||
statusCode == http.StatusSeeOther ||
|
||||
statusCode == http.StatusFound
|
||||
}
|
||||
|
||||
// 生成请求ID
|
||||
var httpRequestTimestamp int64
|
||||
var httpRequestId int32 = 1_000_000
|
||||
|
||||
func httpRequestNextId() string {
|
||||
unixTime, unixTimeString := fasttime.Now().UnixMilliString()
|
||||
if unixTime > httpRequestTimestamp {
|
||||
atomic.StoreInt32(&httpRequestId, 1_000_000)
|
||||
httpRequestTimestamp = unixTime
|
||||
}
|
||||
|
||||
// timestamp + nodeId + requestId
|
||||
return unixTimeString + teaconst.NodeIdString + strconv.Itoa(int(atomic.AddInt32(&httpRequestId, 1)))
|
||||
}
|
||||
|
||||
// 检查是否可以接受某个编码
|
||||
func httpAcceptEncoding(acceptEncodings string, encoding string) bool {
|
||||
if len(acceptEncodings) == 0 {
|
||||
return false
|
||||
}
|
||||
var pieces = strings.Split(acceptEncodings, ",")
|
||||
for _, piece := range pieces {
|
||||
var qualityIndex = strings.Index(piece, ";")
|
||||
if qualityIndex >= 0 {
|
||||
piece = piece[:qualityIndex]
|
||||
}
|
||||
|
||||
if strings.TrimSpace(piece) == encoding {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// 跳转到某个URL
|
||||
func httpRedirect(writer http.ResponseWriter, req *http.Request, url string, code int) {
|
||||
if len(writer.Header().Get("Content-Type")) == 0 {
|
||||
// 设置Content-Type,是为了让页面不输出链接
|
||||
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
}
|
||||
|
||||
http.Redirect(writer, req, url, code)
|
||||
}
|
||||
|
||||
// 分析URL中的Host部分
|
||||
func httpParseHost(urlString string) (host string, err error) {
|
||||
if !urlSchemeRegexp.MatchString(urlString) {
|
||||
urlString = "https://" + urlString
|
||||
}
|
||||
|
||||
u, err := url.Parse(urlString)
|
||||
if err != nil && u != nil {
|
||||
return "", err
|
||||
}
|
||||
return u.Host, nil
|
||||
}
|
||||
173
EdgeNode/internal/nodes/http_request_utils_test.go
Normal file
173
EdgeNode/internal/nodes/http_request_utils_test.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/zero"
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestHTTPRequest_httpRequestGenBoundary(t *testing.T) {
|
||||
for i := 0; i < 10; i++ {
|
||||
var boundary = httpRequestGenBoundary()
|
||||
t.Log(boundary, "[", len(boundary), "bytes", "]")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPRequest_httpRequestParseBoundary(t *testing.T) {
|
||||
var a = assert.NewAssertion(t)
|
||||
a.IsTrue(httpRequestParseBoundary("multipart/byteranges") == "")
|
||||
a.IsTrue(httpRequestParseBoundary("multipart/byteranges; boundary=123") == "123")
|
||||
a.IsTrue(httpRequestParseBoundary("multipart/byteranges; boundary=123; 456") == "123")
|
||||
}
|
||||
|
||||
func TestHTTPRequest_httpRequestParseRangeHeader(t *testing.T) {
|
||||
var a = assert.NewAssertion(t)
|
||||
{
|
||||
_, ok := httpRequestParseRangeHeader("")
|
||||
a.IsFalse(ok)
|
||||
}
|
||||
{
|
||||
_, ok := httpRequestParseRangeHeader("byte=")
|
||||
a.IsFalse(ok)
|
||||
}
|
||||
{
|
||||
_, ok := httpRequestParseRangeHeader("byte=")
|
||||
a.IsFalse(ok)
|
||||
}
|
||||
{
|
||||
set, ok := httpRequestParseRangeHeader("bytes=")
|
||||
a.IsTrue(ok)
|
||||
a.IsTrue(len(set) == 0)
|
||||
}
|
||||
{
|
||||
_, ok := httpRequestParseRangeHeader("bytes=60-50")
|
||||
a.IsFalse(ok)
|
||||
}
|
||||
{
|
||||
set, ok := httpRequestParseRangeHeader("bytes=0-50")
|
||||
a.IsTrue(ok)
|
||||
a.IsTrue(len(set) > 0)
|
||||
t.Log(set)
|
||||
}
|
||||
{
|
||||
set, ok := httpRequestParseRangeHeader("bytes=0-")
|
||||
a.IsTrue(ok)
|
||||
a.IsTrue(len(set) > 0)
|
||||
if len(set) > 0 {
|
||||
a.IsTrue(set[0][0] == 0)
|
||||
}
|
||||
t.Log(set)
|
||||
}
|
||||
{
|
||||
set, ok := httpRequestParseRangeHeader("bytes=-50")
|
||||
a.IsTrue(ok)
|
||||
a.IsTrue(len(set) > 0)
|
||||
t.Log(set)
|
||||
}
|
||||
{
|
||||
set, ok := httpRequestParseRangeHeader("bytes=0-50, 60-100")
|
||||
a.IsTrue(ok)
|
||||
a.IsTrue(len(set) > 0)
|
||||
t.Log(set)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPRequest_httpRequestParseContentRangeHeader(t *testing.T) {
|
||||
{
|
||||
var c1 = "bytes 0-100/*"
|
||||
t.Log(httpRequestParseContentRangeHeader(c1))
|
||||
}
|
||||
{
|
||||
var c1 = "bytes 30-100/*"
|
||||
t.Log(httpRequestParseContentRangeHeader(c1))
|
||||
}
|
||||
{
|
||||
var c1 = "bytes1 0-100/*"
|
||||
t.Log(httpRequestParseContentRangeHeader(c1))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHTTPRequest_httpRequestParseContentRangeHeader(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
var c1 = "bytes 0-100/*"
|
||||
httpRequestParseContentRangeHeader(c1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPRequest_httpRequestNextId(t *testing.T) {
|
||||
teaconst.NodeId = 123
|
||||
teaconst.NodeIdString = "123"
|
||||
t.Log(httpRequestNextId())
|
||||
t.Log(httpRequestNextId())
|
||||
t.Log(httpRequestNextId())
|
||||
time.Sleep(1 * time.Second)
|
||||
t.Log(httpRequestNextId())
|
||||
t.Log(httpRequestNextId())
|
||||
time.Sleep(1 * time.Second)
|
||||
t.Log(httpRequestNextId())
|
||||
}
|
||||
|
||||
func TestHTTPRequest_httpRequestNextId_Concurrent(t *testing.T) {
|
||||
var m = map[string]zero.Zero{}
|
||||
var locker = sync.Mutex{}
|
||||
|
||||
var count = 4000
|
||||
var wg = &sync.WaitGroup{}
|
||||
wg.Add(count)
|
||||
|
||||
var countDuplicated = 0
|
||||
for i := 0; i < count; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
var requestId = httpRequestNextId()
|
||||
|
||||
locker.Lock()
|
||||
|
||||
_, ok := m[requestId]
|
||||
if ok {
|
||||
t.Log("duplicated:", requestId)
|
||||
countDuplicated++
|
||||
}
|
||||
|
||||
m[requestId] = zero.New()
|
||||
locker.Unlock()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
t.Log("ok", countDuplicated, "duplicated")
|
||||
|
||||
var a = assert.NewAssertion(t)
|
||||
a.IsTrue(countDuplicated == 0)
|
||||
}
|
||||
|
||||
func TestHTTPParseURL(t *testing.T) {
|
||||
for _, s := range []string{
|
||||
"",
|
||||
"null",
|
||||
"example.com",
|
||||
"https://example.com",
|
||||
"https://example.com/hello",
|
||||
} {
|
||||
host, err := httpParseHost(s)
|
||||
if err == nil {
|
||||
t.Log(s, "=>", host)
|
||||
} else {
|
||||
t.Log(s, "=>")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHTTPRequest_httpRequestNextId(b *testing.B) {
|
||||
runtime.GOMAXPROCS(1)
|
||||
|
||||
teaconst.NodeIdString = "123"
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = httpRequestNextId()
|
||||
}
|
||||
}
|
||||
553
EdgeNode/internal/nodes/http_request_waf.go
Normal file
553
EdgeNode/internal/nodes/http_request_waf.go
Normal file
@@ -0,0 +1,553 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
iplib "github.com/TeaOSLab/EdgeCommon/pkg/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/stats"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||
wafutils "github.com/TeaOSLab/EdgeNode/internal/waf/utils"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
)
|
||||
|
||||
// 调用WAF
|
||||
func (this *HTTPRequest) doWAFRequest() (blocked bool) {
|
||||
if this.web.FirewallRef == nil || !this.web.FirewallRef.IsOn {
|
||||
return
|
||||
}
|
||||
|
||||
var remoteAddr = this.requestRemoteAddr(true)
|
||||
|
||||
// 检查是否为白名单直连
|
||||
if !Tea.IsTesting() && this.nodeConfig.IPIsAutoAllowed(remoteAddr) {
|
||||
return
|
||||
}
|
||||
|
||||
// 当前连接是否已关闭
|
||||
if this.isConnClosed() {
|
||||
this.disableLog = true
|
||||
return true
|
||||
}
|
||||
|
||||
// 是否在全局名单中
|
||||
canGoNext, isInAllowedList, _ := iplibrary.AllowIP(remoteAddr, this.ReqServer.Id)
|
||||
if !canGoNext {
|
||||
this.disableLog = true
|
||||
this.Close()
|
||||
return true
|
||||
}
|
||||
if isInAllowedList {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查是否在临时黑名单中
|
||||
if waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeServer, this.ReqServer.Id, remoteAddr) || waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, remoteAddr) {
|
||||
this.disableLog = true
|
||||
this.Close()
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
var forceLog = false
|
||||
var forceLogRequestBody = false
|
||||
var forceLogRegionDenying = false
|
||||
if this.ReqServer.HTTPFirewallPolicy != nil &&
|
||||
this.ReqServer.HTTPFirewallPolicy.IsOn &&
|
||||
this.ReqServer.HTTPFirewallPolicy.Log != nil &&
|
||||
this.ReqServer.HTTPFirewallPolicy.Log.IsOn {
|
||||
forceLog = true
|
||||
forceLogRequestBody = this.ReqServer.HTTPFirewallPolicy.Log.RequestBody
|
||||
forceLogRegionDenying = this.ReqServer.HTTPFirewallPolicy.Log.RegionDenying
|
||||
}
|
||||
|
||||
// 检查IP名单
|
||||
{
|
||||
// 当前服务的独立设置
|
||||
if this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn {
|
||||
blockedRequest, breakChecking := this.checkWAFRemoteAddr(this.web.FirewallPolicy)
|
||||
if blockedRequest {
|
||||
return true
|
||||
}
|
||||
if breakChecking {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 公用的防火墙设置
|
||||
if this.ReqServer.HTTPFirewallPolicy != nil && this.ReqServer.HTTPFirewallPolicy.IsOn {
|
||||
blockedRequest, breakChecking := this.checkWAFRemoteAddr(this.ReqServer.HTTPFirewallPolicy)
|
||||
if blockedRequest {
|
||||
return true
|
||||
}
|
||||
if breakChecking {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查WAF规则
|
||||
{
|
||||
// 当前服务的独立设置
|
||||
if this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn {
|
||||
blockedRequest, breakChecking := this.checkWAFRequest(this.web.FirewallPolicy, forceLog, forceLogRequestBody, forceLogRegionDenying, false)
|
||||
if blockedRequest {
|
||||
return true
|
||||
}
|
||||
if breakChecking {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 公用的防火墙设置
|
||||
if this.ReqServer.HTTPFirewallPolicy != nil && this.ReqServer.HTTPFirewallPolicy.IsOn {
|
||||
blockedRequest, breakChecking := this.checkWAFRequest(this.ReqServer.HTTPFirewallPolicy, forceLog, forceLogRequestBody, forceLogRegionDenying, this.web.FirewallRef.IgnoreGlobalRules)
|
||||
if blockedRequest {
|
||||
return true
|
||||
}
|
||||
if breakChecking {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// check client remote address
|
||||
func (this *HTTPRequest) checkWAFRemoteAddr(firewallPolicy *firewallconfigs.HTTPFirewallPolicy) (blocked bool, breakChecking bool) {
|
||||
if firewallPolicy == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var isDefendMode = firewallPolicy.Mode == firewallconfigs.FirewallModeDefend
|
||||
|
||||
// 检查IP白名单
|
||||
var remoteAddrs []string
|
||||
if len(this.remoteAddr) > 0 {
|
||||
remoteAddrs = []string{this.remoteAddr}
|
||||
} else {
|
||||
remoteAddrs = this.requestRemoteAddrs()
|
||||
}
|
||||
|
||||
var inbound = firewallPolicy.Inbound
|
||||
if inbound == nil {
|
||||
return
|
||||
}
|
||||
for _, ref := range inbound.AllAllowListRefs() {
|
||||
if ref.IsOn && ref.ListId > 0 {
|
||||
var list = iplibrary.SharedIPListManager.FindList(ref.ListId)
|
||||
if list != nil {
|
||||
_, found := list.ContainsIPStrings(remoteAddrs)
|
||||
if found {
|
||||
breakChecking = true
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查IP黑名单
|
||||
if isDefendMode {
|
||||
for _, ref := range inbound.AllDenyListRefs() {
|
||||
if ref.IsOn && ref.ListId > 0 {
|
||||
var list = iplibrary.SharedIPListManager.FindList(ref.ListId)
|
||||
if list != nil {
|
||||
item, found := list.ContainsIPStrings(remoteAddrs)
|
||||
if found {
|
||||
// 触发事件
|
||||
if item != nil && len(item.EventLevel) > 0 {
|
||||
actions := iplibrary.SharedActionManager.FindEventActions(item.EventLevel)
|
||||
for _, action := range actions {
|
||||
goNext, err := action.DoHTTP(this.RawReq, this.RawWriter)
|
||||
if err != nil {
|
||||
remotelogs.Error("HTTP_REQUEST_WAF", "do action '"+err.Error()+"' failed: "+err.Error())
|
||||
return true, false
|
||||
}
|
||||
if !goNext {
|
||||
this.disableLog = true
|
||||
return true, false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO 考虑是否需要记录日志信息吗,可能数据量非常庞大,所以暂时不记录
|
||||
|
||||
this.writer.WriteHeader(http.StatusForbidden)
|
||||
this.writer.Close()
|
||||
|
||||
// 停止日志
|
||||
this.disableLog = true
|
||||
|
||||
return true, false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// check waf inbound rules
|
||||
func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFirewallPolicy, forceLog bool, logRequestBody bool, logDenying bool, ignoreRules bool) (blocked bool, breakChecking bool) {
|
||||
// 检查配置是否为空
|
||||
if firewallPolicy == nil || !firewallPolicy.IsOn || firewallPolicy.Inbound == nil || !firewallPolicy.Inbound.IsOn || firewallPolicy.Mode == firewallconfigs.FirewallModeBypass {
|
||||
return
|
||||
}
|
||||
|
||||
var isDefendMode = firewallPolicy.Mode == firewallconfigs.FirewallModeDefend
|
||||
|
||||
// 检查IP白名单
|
||||
var remoteAddrs []string
|
||||
if len(this.remoteAddr) > 0 {
|
||||
remoteAddrs = []string{this.remoteAddr}
|
||||
} else {
|
||||
remoteAddrs = this.requestRemoteAddrs()
|
||||
}
|
||||
|
||||
var inbound = firewallPolicy.Inbound
|
||||
if inbound == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查地区封禁
|
||||
if firewallPolicy.Inbound.Region != nil && firewallPolicy.Inbound.Region.IsOn {
|
||||
var regionConfig = firewallPolicy.Inbound.Region
|
||||
if regionConfig.IsNotEmpty() {
|
||||
for _, remoteAddr := range remoteAddrs {
|
||||
var result = iplib.LookupIP(remoteAddr)
|
||||
if result != nil && result.IsOk() {
|
||||
var currentURL = this.URL()
|
||||
if regionConfig.MatchCountryURL(currentURL) {
|
||||
// 检查国家/地区级别封禁
|
||||
if !regionConfig.IsAllowedCountry(result.CountryId(), result.ProvinceId()) && (!regionConfig.AllowSearchEngine || wafutils.CheckSearchEngine(remoteAddr)) {
|
||||
this.firewallPolicyId = firewallPolicy.Id
|
||||
|
||||
if isDefendMode {
|
||||
var promptHTML string
|
||||
if len(regionConfig.CountryHTML) > 0 {
|
||||
promptHTML = regionConfig.CountryHTML
|
||||
} else if this.ReqServer != nil && this.ReqServer.HTTPFirewallPolicy != nil && len(this.ReqServer.HTTPFirewallPolicy.DenyCountryHTML) > 0 {
|
||||
promptHTML = this.ReqServer.HTTPFirewallPolicy.DenyCountryHTML
|
||||
}
|
||||
|
||||
if len(promptHTML) > 0 {
|
||||
var formattedHTML = this.Format(promptHTML)
|
||||
this.writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
this.writer.Header().Set("Content-Length", types.String(len(formattedHTML)))
|
||||
this.writer.WriteHeader(http.StatusForbidden)
|
||||
_, _ = this.writer.Write([]byte(formattedHTML))
|
||||
} else {
|
||||
this.writeCode(http.StatusForbidden, "The region has been denied.", "当前区域禁止访问")
|
||||
}
|
||||
|
||||
// 延时返回,避免攻击
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
// 停止日志
|
||||
if !logDenying {
|
||||
this.disableLog = true
|
||||
} else {
|
||||
this.tags = append(this.tags, "denyCountry")
|
||||
}
|
||||
|
||||
if isDefendMode {
|
||||
return true, false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if regionConfig.MatchProvinceURL(currentURL) {
|
||||
// 检查省份封禁
|
||||
if !regionConfig.IsAllowedProvince(result.CountryId(), result.ProvinceId()) {
|
||||
this.firewallPolicyId = firewallPolicy.Id
|
||||
|
||||
if isDefendMode {
|
||||
var promptHTML string
|
||||
if len(regionConfig.ProvinceHTML) > 0 {
|
||||
promptHTML = regionConfig.ProvinceHTML
|
||||
} else if this.ReqServer != nil && this.ReqServer.HTTPFirewallPolicy != nil && len(this.ReqServer.HTTPFirewallPolicy.DenyProvinceHTML) > 0 {
|
||||
promptHTML = this.ReqServer.HTTPFirewallPolicy.DenyProvinceHTML
|
||||
}
|
||||
|
||||
if len(promptHTML) > 0 {
|
||||
var formattedHTML = this.Format(promptHTML)
|
||||
this.writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
this.writer.Header().Set("Content-Length", types.String(len(formattedHTML)))
|
||||
this.writer.WriteHeader(http.StatusForbidden)
|
||||
_, _ = this.writer.Write([]byte(formattedHTML))
|
||||
} else {
|
||||
this.writeCode(http.StatusForbidden, "The region has been denied.", "当前区域禁止访问")
|
||||
}
|
||||
|
||||
// 延时返回,避免攻击
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
// 停止日志
|
||||
if !logDenying {
|
||||
this.disableLog = true
|
||||
} else {
|
||||
this.tags = append(this.tags, "denyProvince")
|
||||
}
|
||||
|
||||
if isDefendMode {
|
||||
return true, false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 是否执行规则
|
||||
if ignoreRules {
|
||||
return
|
||||
}
|
||||
|
||||
// 规则测试
|
||||
var w = waf.SharedWAFManager.FindWAF(firewallPolicy.Id)
|
||||
if w == nil {
|
||||
return
|
||||
}
|
||||
|
||||
result, err := w.MatchRequest(this, this.writer, this.web.FirewallRef.DefaultCaptchaType)
|
||||
if err != nil {
|
||||
if !this.canIgnore(err) {
|
||||
remotelogs.Warn("HTTP_REQUEST_WAF", this.rawURI+": "+err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
if result.IsAllowed && (len(result.AllowScope) == 0 || result.AllowScope == waf.AllowScopeGlobal) {
|
||||
breakChecking = true
|
||||
}
|
||||
if forceLog && logRequestBody && result.HasRequestBody && result.Set != nil && result.Set.HasAttackActions() {
|
||||
this.wafHasRequestBody = true
|
||||
}
|
||||
|
||||
if result.Set != nil {
|
||||
if forceLog {
|
||||
this.forceLog = true
|
||||
}
|
||||
|
||||
if result.Set.HasSpecialActions() {
|
||||
this.firewallPolicyId = firewallPolicy.Id
|
||||
this.firewallRuleGroupId = types.Int64(result.Group.Id)
|
||||
this.firewallRuleSetId = types.Int64(result.Set.Id)
|
||||
|
||||
if result.Set.HasAttackActions() {
|
||||
this.isAttack = true
|
||||
}
|
||||
|
||||
// 添加统计
|
||||
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.ReqServer.Id, this.firewallRuleGroupId, result.Set.Actions)
|
||||
}
|
||||
|
||||
this.firewallActions = append(result.Set.ActionCodes(), firewallPolicy.Mode)
|
||||
}
|
||||
|
||||
return !result.GoNext, breakChecking
|
||||
}
|
||||
|
||||
// call response waf
|
||||
func (this *HTTPRequest) doWAFResponse(resp *http.Response) (blocked bool) {
|
||||
if this.web.FirewallRef == nil || !this.web.FirewallRef.IsOn {
|
||||
return
|
||||
}
|
||||
|
||||
// 当前服务的独立设置
|
||||
var forceLog = false
|
||||
var forceLogRequestBody = false
|
||||
if this.ReqServer.HTTPFirewallPolicy != nil && this.ReqServer.HTTPFirewallPolicy.IsOn && this.ReqServer.HTTPFirewallPolicy.Log != nil && this.ReqServer.HTTPFirewallPolicy.Log.IsOn {
|
||||
forceLog = true
|
||||
forceLogRequestBody = this.ReqServer.HTTPFirewallPolicy.Log.RequestBody
|
||||
}
|
||||
|
||||
if this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn {
|
||||
blockedRequest, breakChecking := this.checkWAFResponse(this.web.FirewallPolicy, resp, forceLog, forceLogRequestBody, false)
|
||||
if blockedRequest {
|
||||
return true
|
||||
}
|
||||
if breakChecking {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 公用的防火墙设置
|
||||
if this.ReqServer.HTTPFirewallPolicy != nil && this.ReqServer.HTTPFirewallPolicy.IsOn {
|
||||
blockedRequest, _ := this.checkWAFResponse(this.ReqServer.HTTPFirewallPolicy, resp, forceLog, forceLogRequestBody, this.web.FirewallRef.IgnoreGlobalRules)
|
||||
if blockedRequest {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// check waf outbound rules
|
||||
func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFirewallPolicy, resp *http.Response, forceLog bool, logRequestBody bool, ignoreRules bool) (blocked bool, breakChecking bool) {
|
||||
if firewallPolicy == nil || !firewallPolicy.IsOn || !firewallPolicy.Outbound.IsOn || firewallPolicy.Mode == firewallconfigs.FirewallModeBypass {
|
||||
return
|
||||
}
|
||||
|
||||
// 是否执行规则
|
||||
if ignoreRules {
|
||||
return
|
||||
}
|
||||
|
||||
var w = waf.SharedWAFManager.FindWAF(firewallPolicy.Id)
|
||||
if w == nil {
|
||||
return
|
||||
}
|
||||
|
||||
result, err := w.MatchResponse(this, resp, this.writer)
|
||||
if err != nil {
|
||||
if !this.canIgnore(err) {
|
||||
remotelogs.Warn("HTTP_REQUEST_WAF", this.rawURI+": "+err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
if result.IsAllowed && (len(result.AllowScope) == 0 || result.AllowScope == waf.AllowScopeGlobal) {
|
||||
breakChecking = true
|
||||
}
|
||||
if forceLog && logRequestBody && result.HasRequestBody && result.Set != nil && result.Set.HasAttackActions() {
|
||||
this.wafHasRequestBody = true
|
||||
}
|
||||
|
||||
if result.Set != nil {
|
||||
if forceLog {
|
||||
this.forceLog = true
|
||||
}
|
||||
|
||||
if result.Set.HasSpecialActions() {
|
||||
this.firewallPolicyId = firewallPolicy.Id
|
||||
this.firewallRuleGroupId = types.Int64(result.Group.Id)
|
||||
this.firewallRuleSetId = types.Int64(result.Set.Id)
|
||||
|
||||
if result.Set.HasAttackActions() {
|
||||
this.isAttack = true
|
||||
}
|
||||
|
||||
// 添加统计
|
||||
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.ReqServer.Id, this.firewallRuleGroupId, result.Set.Actions)
|
||||
}
|
||||
|
||||
this.firewallActions = append(result.Set.ActionCodes(), firewallPolicy.Mode)
|
||||
}
|
||||
|
||||
return !result.GoNext, breakChecking
|
||||
}
|
||||
|
||||
// WAFRaw 原始请求
|
||||
func (this *HTTPRequest) WAFRaw() *http.Request {
|
||||
return this.RawReq
|
||||
}
|
||||
|
||||
// WAFRemoteIP 客户端IP
|
||||
func (this *HTTPRequest) WAFRemoteIP() string {
|
||||
return this.requestRemoteAddr(true)
|
||||
}
|
||||
|
||||
// WAFGetCacheBody 获取缓存中的Body
|
||||
func (this *HTTPRequest) WAFGetCacheBody() []byte {
|
||||
return this.requestBodyData
|
||||
}
|
||||
|
||||
// WAFSetCacheBody 设置Body
|
||||
func (this *HTTPRequest) WAFSetCacheBody(body []byte) {
|
||||
this.requestBodyData = body
|
||||
}
|
||||
|
||||
// WAFReadBody 读取Body
|
||||
func (this *HTTPRequest) WAFReadBody(max int64) (data []byte, err error) {
|
||||
if this.RawReq.ContentLength > 0 {
|
||||
data, err = io.ReadAll(io.LimitReader(this.RawReq.Body, max))
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// WAFRestoreBody 恢复Body
|
||||
func (this *HTTPRequest) WAFRestoreBody(data []byte) {
|
||||
if len(data) > 0 {
|
||||
this.RawReq.Body = io.NopCloser(io.MultiReader(bytes.NewBuffer(data), this.RawReq.Body))
|
||||
}
|
||||
}
|
||||
|
||||
// WAFServerId 服务ID
|
||||
func (this *HTTPRequest) WAFServerId() int64 {
|
||||
return this.ReqServer.Id
|
||||
}
|
||||
|
||||
// WAFClose 关闭连接
|
||||
func (this *HTTPRequest) WAFClose() {
|
||||
this.Close()
|
||||
|
||||
// 这里不要强关IP所有连接,避免因为单个服务而影响所有
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) WAFOnAction(action interface{}) (goNext bool) {
|
||||
if action == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
instance, ok := action.(waf.ActionInterface)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
|
||||
switch instance.Code() {
|
||||
case waf.ActionTag:
|
||||
this.tags = append(this.tags, action.(*waf.TagAction).Tags...)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) WAFFingerprint() []byte {
|
||||
// 目前只有HTTPS请求才有指纹
|
||||
if !this.IsHTTPS {
|
||||
return nil
|
||||
}
|
||||
|
||||
var requestConn = this.RawReq.Context().Value(HTTPConnContextKey)
|
||||
if requestConn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
clientConn, ok := requestConn.(ClientConnInterface)
|
||||
if ok {
|
||||
return clientConn.Fingerprint()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) WAFMaxRequestSize() int64 {
|
||||
var maxRequestSize = firewallconfigs.DefaultMaxRequestBodySize
|
||||
if this.ReqServer.HTTPFirewallPolicy != nil && this.ReqServer.HTTPFirewallPolicy.MaxRequestBodySize > 0 {
|
||||
maxRequestSize = this.ReqServer.HTTPFirewallPolicy.MaxRequestBodySize
|
||||
}
|
||||
return maxRequestSize
|
||||
}
|
||||
|
||||
// DisableAccessLog 在当前请求中不使用访问日志
|
||||
func (this *HTTPRequest) DisableAccessLog() {
|
||||
this.disableLog = true
|
||||
}
|
||||
|
||||
// DisableStat 停用统计
|
||||
func (this *HTTPRequest) DisableStat() {
|
||||
if this.web != nil {
|
||||
this.web.StatRef = nil
|
||||
}
|
||||
|
||||
this.disableMetrics = true
|
||||
}
|
||||
205
EdgeNode/internal/nodes/http_request_websocket.go
Normal file
205
EdgeNode/internal/nodes/http_request_websocket.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/bytepool"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// WebsocketResponseReader Websocket响应Reader
|
||||
type WebsocketResponseReader struct {
|
||||
rawReader io.Reader
|
||||
buf []byte
|
||||
}
|
||||
|
||||
func NewWebsocketResponseReader(rawReader io.Reader) *WebsocketResponseReader {
|
||||
return &WebsocketResponseReader{
|
||||
rawReader: rawReader,
|
||||
}
|
||||
}
|
||||
|
||||
func (this *WebsocketResponseReader) Read(p []byte) (n int, err error) {
|
||||
n, err = this.rawReader.Read(p)
|
||||
if n > 0 {
|
||||
if len(this.buf) == 0 {
|
||||
this.buf = make([]byte, n)
|
||||
copy(this.buf, p[:n])
|
||||
} else {
|
||||
this.buf = append(this.buf, p[:n]...)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 处理Websocket请求
|
||||
func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shouldRetry bool) {
|
||||
// 设置不缓存
|
||||
this.web.Cache = nil
|
||||
|
||||
if this.web.WebsocketRef == nil || !this.web.WebsocketRef.IsOn || this.web.Websocket == nil || !this.web.Websocket.IsOn {
|
||||
this.writer.WriteHeader(http.StatusForbidden)
|
||||
this.addError(errors.New("websocket have not been enabled yet"))
|
||||
return
|
||||
}
|
||||
|
||||
// TODO 实现handshakeTimeout
|
||||
|
||||
// 校验来源
|
||||
var requestOrigin = this.RawReq.Header.Get("Origin")
|
||||
if len(requestOrigin) > 0 {
|
||||
u, err := url.Parse(requestOrigin)
|
||||
if err == nil {
|
||||
if !this.web.Websocket.MatchOrigin(u.Host) {
|
||||
this.writer.WriteHeader(http.StatusForbidden)
|
||||
this.addError(errors.New("websocket origin '" + requestOrigin + "' not been allowed"))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 标记
|
||||
this.isWebsocketResponse = true
|
||||
|
||||
// 设置指定的来源域
|
||||
if !this.web.Websocket.RequestSameOrigin && len(this.web.Websocket.RequestOrigin) > 0 {
|
||||
var newRequestOrigin = this.web.Websocket.RequestOrigin
|
||||
if this.web.Websocket.RequestOriginHasVariables() {
|
||||
newRequestOrigin = this.Format(newRequestOrigin)
|
||||
}
|
||||
this.RawReq.Header.Set("Origin", newRequestOrigin)
|
||||
}
|
||||
|
||||
// 获取当前连接
|
||||
var requestConn = this.RawReq.Context().Value(HTTPConnContextKey)
|
||||
if requestConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 连接源站
|
||||
originConn, _, err := OriginConnect(this.origin, this.requestServerPort(), this.RawReq.RemoteAddr, requestHost)
|
||||
if err != nil {
|
||||
if isLastRetry {
|
||||
this.write50x(err, http.StatusBadGateway, "Failed to connect origin site", "源站连接失败", false)
|
||||
}
|
||||
|
||||
// 增加失败次数
|
||||
SharedOriginStateManager.Fail(this.origin, requestHost, this.reverseProxy, func() {
|
||||
this.reverseProxy.ResetScheduling()
|
||||
})
|
||||
|
||||
shouldRetry = true
|
||||
return
|
||||
}
|
||||
|
||||
if !this.origin.IsOk {
|
||||
SharedOriginStateManager.Success(this.origin, func() {
|
||||
this.reverseProxy.ResetScheduling()
|
||||
})
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = originConn.Close()
|
||||
}()
|
||||
|
||||
err = this.RawReq.Write(originConn)
|
||||
if err != nil {
|
||||
this.write50x(err, http.StatusBadGateway, "Failed to write request to origin site", "源站请求初始化失败", false)
|
||||
return
|
||||
}
|
||||
|
||||
requestClientConn, ok := requestConn.(ClientConnInterface)
|
||||
if ok {
|
||||
requestClientConn.SetIsPersistent(true)
|
||||
}
|
||||
|
||||
clientConn, _, err := this.writer.Hijack()
|
||||
if err != nil || clientConn == nil {
|
||||
this.write50x(err, http.StatusInternalServerError, "Failed to get origin site connection", "获取源站连接失败", false)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = clientConn.Close()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
// 读取第一个响应
|
||||
var respReader = NewWebsocketResponseReader(originConn)
|
||||
resp, respErr := http.ReadResponse(bufio.NewReader(respReader), this.RawReq)
|
||||
if respErr != nil || resp == nil {
|
||||
if resp != nil && resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
|
||||
_ = clientConn.Close()
|
||||
_ = originConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
this.ProcessResponseHeaders(resp.Header, resp.StatusCode)
|
||||
this.writer.statusCode = resp.StatusCode
|
||||
|
||||
// 将响应写回客户端
|
||||
err = resp.Write(clientConn)
|
||||
if err != nil {
|
||||
if resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
|
||||
_ = clientConn.Close()
|
||||
_ = originConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
// 剩余已经从源站读取的内容
|
||||
var headerBytes = respReader.buf
|
||||
var headerIndex = bytes.Index(headerBytes, []byte{'\r', '\n', '\r', '\n'}) // CRLF
|
||||
if headerIndex > 0 {
|
||||
var leftBytes = headerBytes[headerIndex+4:]
|
||||
if len(leftBytes) > 0 {
|
||||
_, writeErr := clientConn.Write(leftBytes)
|
||||
if writeErr != nil {
|
||||
if resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
|
||||
_ = clientConn.Close()
|
||||
_ = originConn.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
|
||||
// 复制剩余的数据
|
||||
var buf = bytepool.Pool4k.Get()
|
||||
defer bytepool.Pool4k.Put(buf)
|
||||
for {
|
||||
n, readErr := originConn.Read(buf.Bytes)
|
||||
if n > 0 {
|
||||
this.writer.sentBodyBytes += int64(n)
|
||||
_, writeErr := clientConn.Write(buf.Bytes[:n])
|
||||
if writeErr != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if readErr != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
_ = clientConn.Close()
|
||||
_ = originConn.Close()
|
||||
}()
|
||||
|
||||
var buf = bytepool.Pool4k.Get()
|
||||
_, _ = io.CopyBuffer(originConn, clientConn, buf.Bytes)
|
||||
bytepool.Pool4k.Put(buf)
|
||||
|
||||
return
|
||||
}
|
||||
1328
EdgeNode/internal/nodes/http_writer.go
Normal file
1328
EdgeNode/internal/nodes/http_writer.go
Normal file
File diff suppressed because it is too large
Load Diff
63
EdgeNode/internal/nodes/http_writer_empty.go
Normal file
63
EdgeNode/internal/nodes/http_writer_empty.go
Normal file
@@ -0,0 +1,63 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// EmptyResponseWriter 空的响应Writer
|
||||
type EmptyResponseWriter struct {
|
||||
header http.Header
|
||||
parentWriter http.ResponseWriter
|
||||
|
||||
statusCode int
|
||||
}
|
||||
|
||||
func NewEmptyResponseWriter(parentWriter http.ResponseWriter) *EmptyResponseWriter {
|
||||
return &EmptyResponseWriter{
|
||||
header: http.Header{},
|
||||
parentWriter: parentWriter,
|
||||
}
|
||||
}
|
||||
|
||||
func (this *EmptyResponseWriter) Header() http.Header {
|
||||
return this.header
|
||||
}
|
||||
|
||||
func (this *EmptyResponseWriter) Write(data []byte) (int, error) {
|
||||
if this.statusCode > 300 && this.parentWriter != nil {
|
||||
return this.parentWriter.Write(data)
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (this *EmptyResponseWriter) WriteHeader(statusCode int) {
|
||||
this.statusCode = statusCode
|
||||
|
||||
if this.statusCode > 300 && this.parentWriter != nil {
|
||||
var parentHeader = this.parentWriter.Header()
|
||||
for k, v := range this.header {
|
||||
parentHeader[k] = v
|
||||
}
|
||||
this.parentWriter.WriteHeader(this.statusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func (this *EmptyResponseWriter) StatusCode() int {
|
||||
return this.statusCode
|
||||
}
|
||||
|
||||
// Hijack Hijack
|
||||
func (this *EmptyResponseWriter) Hijack() (conn net.Conn, buf *bufio.ReadWriter, err error) {
|
||||
if this.parentWriter == nil {
|
||||
return
|
||||
}
|
||||
hijack, ok := this.parentWriter.(http.Hijacker)
|
||||
if ok {
|
||||
return hijack.Hijack()
|
||||
}
|
||||
return
|
||||
}
|
||||
17
EdgeNode/internal/nodes/http_writer_ext.go
Normal file
17
EdgeNode/internal/nodes/http_writer_ext.go
Normal file
@@ -0,0 +1,17 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build !plus
|
||||
// +build !plus
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"os"
|
||||
)
|
||||
|
||||
func (this *HTTPWriter) canSendfile() (*os.File, bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (this *HTTPWriter) checkPlanBandwidth(n int) {
|
||||
// stub
|
||||
}
|
||||
50
EdgeNode/internal/nodes/http_writer_ext_plus.go
Normal file
50
EdgeNode/internal/nodes/http_writer_ext_plus.go
Normal file
@@ -0,0 +1,50 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build plus
|
||||
// +build plus
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/caches"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/writers"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
func (this *HTTPWriter) canSendfile() (*os.File, bool) {
|
||||
if this.cacheWriter != nil || this.webpIsEncoding || this.isPartial || this.delayRead || this.compressionCacheWriter != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if this.rawReader == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
fileReader, ok := this.rawReader.(*caches.FileReader)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
_, ok = this.rawWriter.(io.ReaderFrom)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
counterWriter, ok := this.writer.(*writers.BytesCounterWriter)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if counterWriter.RawWriter() != this.rawWriter {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return fileReader.FP(), true
|
||||
}
|
||||
|
||||
// 检查套餐带宽限速
|
||||
func (this *HTTPWriter) checkPlanBandwidth(n int) {
|
||||
if this.req.ReqServer != nil && this.req.ReqServer.PlanId() > 0 {
|
||||
sharedPlanBandwidthLimiter.Ack(this.req.RawReq.Context(), this.req.ReqServer.PlanId(), n)
|
||||
}
|
||||
}
|
||||
230
EdgeNode/internal/nodes/listener.go
Normal file
230
EdgeNode/internal/nodes/listener.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/goman"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Listener struct {
|
||||
group *serverconfigs.ServerAddressGroup
|
||||
listener ListenerInterface // 监听器
|
||||
|
||||
locker sync.RWMutex
|
||||
}
|
||||
|
||||
func NewListener() *Listener {
|
||||
return &Listener{}
|
||||
}
|
||||
|
||||
func (this *Listener) Reload(group *serverconfigs.ServerAddressGroup) {
|
||||
this.locker.Lock()
|
||||
this.group = group
|
||||
if this.listener != nil {
|
||||
this.listener.Reload(group)
|
||||
}
|
||||
this.locker.Unlock()
|
||||
}
|
||||
|
||||
func (this *Listener) FullAddr() string {
|
||||
if this.group != nil {
|
||||
return this.group.FullAddr()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (this *Listener) Listen() error {
|
||||
if this.group == nil {
|
||||
return nil
|
||||
}
|
||||
var protocol = this.group.Protocol()
|
||||
if protocol.IsUDPFamily() {
|
||||
return this.listenUDP()
|
||||
}
|
||||
return this.listenTCP()
|
||||
}
|
||||
|
||||
func (this *Listener) listenTCP() error {
|
||||
if this.group == nil {
|
||||
return nil
|
||||
}
|
||||
var protocol = this.group.Protocol()
|
||||
|
||||
tcpListener, err := this.createTCPListener()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var netListener = NewClientListener(tcpListener, protocol.IsHTTPFamily() || protocol.IsHTTPSFamily())
|
||||
events.OnKey(events.EventQuit, this, func() {
|
||||
remotelogs.Println("LISTENER", "quit "+this.group.FullAddr())
|
||||
_ = netListener.Close()
|
||||
})
|
||||
|
||||
switch protocol {
|
||||
case serverconfigs.ProtocolHTTP, serverconfigs.ProtocolHTTP4, serverconfigs.ProtocolHTTP6:
|
||||
this.listener = &HTTPListener{
|
||||
BaseListener: BaseListener{Group: this.group},
|
||||
Listener: netListener,
|
||||
}
|
||||
case serverconfigs.ProtocolHTTPS, serverconfigs.ProtocolHTTPS4, serverconfigs.ProtocolHTTPS6:
|
||||
netListener.SetIsTLS(true)
|
||||
this.listener = &HTTPListener{
|
||||
BaseListener: BaseListener{Group: this.group},
|
||||
Listener: netListener,
|
||||
}
|
||||
case serverconfigs.ProtocolTCP, serverconfigs.ProtocolTCP4, serverconfigs.ProtocolTCP6:
|
||||
this.listener = &TCPListener{
|
||||
BaseListener: BaseListener{Group: this.group},
|
||||
Listener: netListener,
|
||||
}
|
||||
case serverconfigs.ProtocolTLS, serverconfigs.ProtocolTLS4, serverconfigs.ProtocolTLS6:
|
||||
netListener.SetIsTLS(true)
|
||||
this.listener = &TCPListener{
|
||||
BaseListener: BaseListener{Group: this.group},
|
||||
Listener: netListener,
|
||||
}
|
||||
default:
|
||||
return errors.New("unknown protocol '" + protocol.String() + "'")
|
||||
}
|
||||
|
||||
this.listener.Init()
|
||||
|
||||
goman.New(func() {
|
||||
err := this.listener.Serve()
|
||||
if err != nil {
|
||||
// 在这里屏蔽accept错误,防止在优雅关闭的时候有多余的提示
|
||||
opErr, ok := err.(*net.OpError)
|
||||
if ok && opErr.Op == "accept" {
|
||||
return
|
||||
}
|
||||
|
||||
// 打印其他错误
|
||||
remotelogs.Error("LISTENER", err.Error())
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Listener) listenUDP() error {
|
||||
var addr = this.group.Addr()
|
||||
|
||||
var ipv4PacketListener *ipv4.PacketConn
|
||||
var ipv6PacketListener *ipv6.PacketConn
|
||||
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(host) == 0 {
|
||||
// ipv4
|
||||
ipv4Listener, err := this.createUDPIPv4Listener()
|
||||
if err == nil {
|
||||
ipv4PacketListener = ipv4.NewPacketConn(ipv4Listener)
|
||||
} else {
|
||||
remotelogs.Error("LISTENER", "create udp ipv4 listener '"+addr+"': "+err.Error())
|
||||
}
|
||||
|
||||
// ipv6
|
||||
ipv6Listener, err := this.createUDPIPv6Listener()
|
||||
if err == nil {
|
||||
ipv6PacketListener = ipv6.NewPacketConn(ipv6Listener)
|
||||
} else {
|
||||
remotelogs.Error("LISTENER", "create udp ipv6 listener '"+addr+"': "+err.Error())
|
||||
}
|
||||
} else if strings.Contains(host, ":") { // ipv6
|
||||
ipv6Listener, err := this.createUDPIPv6Listener()
|
||||
if err == nil {
|
||||
ipv6PacketListener = ipv6.NewPacketConn(ipv6Listener)
|
||||
} else {
|
||||
remotelogs.Error("LISTENER", "create udp ipv6 listener '"+addr+"': "+err.Error())
|
||||
}
|
||||
} else { // ipv4
|
||||
ipv4Listener, err := this.createUDPIPv4Listener()
|
||||
if err == nil {
|
||||
ipv4PacketListener = ipv4.NewPacketConn(ipv4Listener)
|
||||
} else {
|
||||
remotelogs.Error("LISTENER", "create udp ipv4 listener '"+addr+"': "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
events.OnKey(events.EventQuit, this, func() {
|
||||
remotelogs.Println("LISTENER", "quit "+this.group.FullAddr())
|
||||
|
||||
if ipv4PacketListener != nil {
|
||||
_ = ipv4PacketListener.Close()
|
||||
}
|
||||
|
||||
if ipv6PacketListener != nil {
|
||||
_ = ipv6PacketListener.Close()
|
||||
}
|
||||
})
|
||||
|
||||
this.listener = &UDPListener{
|
||||
BaseListener: BaseListener{Group: this.group},
|
||||
IPv4Listener: ipv4PacketListener,
|
||||
IPv6Listener: ipv6PacketListener,
|
||||
}
|
||||
|
||||
goman.New(func() {
|
||||
err := this.listener.Serve()
|
||||
if err != nil {
|
||||
remotelogs.Error("LISTENER", err.Error())
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Listener) Close() error {
|
||||
events.Remove(this)
|
||||
|
||||
if this.listener == nil {
|
||||
return nil
|
||||
}
|
||||
return this.listener.Close()
|
||||
}
|
||||
|
||||
// 创建TCP监听器
|
||||
func (this *Listener) createTCPListener() (net.Listener, error) {
|
||||
var listenConfig = net.ListenConfig{
|
||||
Control: nil,
|
||||
KeepAlive: 0,
|
||||
}
|
||||
|
||||
switch this.group.Protocol() {
|
||||
case serverconfigs.ProtocolHTTP4, serverconfigs.ProtocolHTTPS4, serverconfigs.ProtocolTLS4:
|
||||
return listenConfig.Listen(context.Background(), "tcp4", this.group.Addr())
|
||||
case serverconfigs.ProtocolHTTP6, serverconfigs.ProtocolHTTPS6, serverconfigs.ProtocolTLS6:
|
||||
return listenConfig.Listen(context.Background(), "tcp6", this.group.Addr())
|
||||
}
|
||||
|
||||
return listenConfig.Listen(context.Background(), "tcp", this.group.Addr())
|
||||
}
|
||||
|
||||
// 创建UDP IPv4监听器
|
||||
func (this *Listener) createUDPIPv4Listener() (*net.UDPConn, error) {
|
||||
addr, err := net.ResolveUDPAddr("udp", this.group.Addr())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return net.ListenUDP("udp4", addr)
|
||||
}
|
||||
|
||||
// 创建UDP监听器
|
||||
func (this *Listener) createUDPIPv6Listener() (*net.UDPConn, error) {
|
||||
addr, err := net.ResolveUDPAddr("udp", this.group.Addr())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return net.ListenUDP("udp6", addr)
|
||||
}
|
||||
276
EdgeNode/internal/nodes/listener_base.go
Normal file
276
EdgeNode/internal/nodes/listener_base.go
Normal file
@@ -0,0 +1,276 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net"
|
||||
)
|
||||
|
||||
type BaseListener struct {
|
||||
Group *serverconfigs.ServerAddressGroup
|
||||
|
||||
countActiveConnections int64 // 当前活跃的连接数
|
||||
}
|
||||
|
||||
// Init 初始化
|
||||
func (this *BaseListener) Init() {
|
||||
}
|
||||
|
||||
// Reset 清除既有配置
|
||||
func (this *BaseListener) Reset() {
|
||||
|
||||
}
|
||||
|
||||
// CountActiveConnections 获取当前活跃连接数
|
||||
func (this *BaseListener) CountActiveConnections() int {
|
||||
return types.Int(this.countActiveConnections)
|
||||
}
|
||||
|
||||
// 构造TLS配置
|
||||
func (this *BaseListener) buildTLSConfig() *tls.Config {
|
||||
return &tls.Config{
|
||||
Certificates: nil,
|
||||
GetConfigForClient: func(clientInfo *tls.ClientHelloInfo) (config *tls.Config, e error) {
|
||||
// 指纹信息
|
||||
var fingerprint = this.calculateFingerprint(clientInfo)
|
||||
if len(fingerprint) > 0 && clientInfo.Conn != nil {
|
||||
clientConn, ok := clientInfo.Conn.(ClientConnInterface)
|
||||
if ok {
|
||||
clientConn.SetFingerprint(fingerprint)
|
||||
}
|
||||
}
|
||||
|
||||
tlsPolicy, _, err := this.matchSSL(this.helloServerNames(clientInfo))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if tlsPolicy == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
tlsPolicy.CheckOCSP()
|
||||
|
||||
return tlsPolicy.TLSConfig(), nil
|
||||
},
|
||||
GetCertificate: func(clientInfo *tls.ClientHelloInfo) (certificate *tls.Certificate, e error) {
|
||||
// 指纹信息
|
||||
var fingerprint = this.calculateFingerprint(clientInfo)
|
||||
if len(fingerprint) > 0 && clientInfo.Conn != nil {
|
||||
clientConn, ok := clientInfo.Conn.(ClientConnInterface)
|
||||
if ok {
|
||||
clientConn.SetFingerprint(fingerprint)
|
||||
}
|
||||
}
|
||||
|
||||
tlsPolicy, cert, err := this.matchSSL(this.helloServerNames(clientInfo))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cert == nil {
|
||||
return nil, errors.New("no ssl certs found for '" + clientInfo.ServerName + "'")
|
||||
}
|
||||
|
||||
tlsPolicy.CheckOCSP()
|
||||
|
||||
return cert, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// 根据域名匹配证书
|
||||
func (this *BaseListener) matchSSL(domains []string) (*sslconfigs.SSLPolicy, *tls.Certificate, error) {
|
||||
var group = this.Group
|
||||
|
||||
if group == nil {
|
||||
return nil, nil, errors.New("no configure found")
|
||||
}
|
||||
|
||||
var globalServerConfig *serverconfigs.GlobalServerConfig
|
||||
if sharedNodeConfig != nil {
|
||||
globalServerConfig = sharedNodeConfig.GlobalServerConfig
|
||||
}
|
||||
|
||||
// 如果域名为空,则取第一个
|
||||
// 通常域名为空是因为是直接通过IP访问的
|
||||
if len(domains) == 0 {
|
||||
if group.IsHTTPS() && globalServerConfig != nil && globalServerConfig.HTTPAll.MatchDomainStrictly {
|
||||
return nil, nil, errors.New("no tls server name matched")
|
||||
}
|
||||
|
||||
firstServer := group.FirstTLSServer()
|
||||
if firstServer == nil {
|
||||
return nil, nil, errors.New("no tls server available")
|
||||
}
|
||||
sslConfig := firstServer.SSLPolicy()
|
||||
|
||||
if sslConfig != nil {
|
||||
return sslConfig, sslConfig.FirstCert(), nil
|
||||
|
||||
}
|
||||
return nil, nil, errors.New("no tls server name found")
|
||||
}
|
||||
var firstDomain = domains[0]
|
||||
|
||||
// 通过网站域名配置匹配
|
||||
var server *serverconfigs.ServerConfig
|
||||
var matchedDomain string
|
||||
for _, domain := range domains {
|
||||
server, _ = this.findNamedServer(domain, true)
|
||||
if server != nil {
|
||||
matchedDomain = domain
|
||||
break
|
||||
}
|
||||
}
|
||||
if server == nil {
|
||||
server, _ = this.findNamedServer(firstDomain, false)
|
||||
if server != nil {
|
||||
matchedDomain = firstDomain
|
||||
}
|
||||
}
|
||||
|
||||
if server == nil {
|
||||
// 找不到或者此时的服务没有配置证书,需要搜索所有的Server,通过SSL证书内容中的DNSName匹配
|
||||
// 此功能仅为了兼容以往版本(v1.0.4),不应该作为常态启用
|
||||
if globalServerConfig != nil && globalServerConfig.HTTPAll.MatchCertFromAllServers {
|
||||
for _, searchingServer := range group.Servers() {
|
||||
if searchingServer.SSLPolicy() == nil || !searchingServer.SSLPolicy().IsOn {
|
||||
continue
|
||||
}
|
||||
cert, ok := searchingServer.SSLPolicy().MatchDomain(firstDomain)
|
||||
if ok {
|
||||
return searchingServer.SSLPolicy(), cert, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil, errors.New("no server found for '" + firstDomain + "'")
|
||||
}
|
||||
if server.SSLPolicy() == nil || !server.SSLPolicy().IsOn {
|
||||
// 找不到或者此时的服务没有配置证书,需要搜索所有的Server,通过SSL证书内容中的DNSName匹配
|
||||
// 此功能仅为了兼容以往版本(v1.0.4),不应该作为常态启用
|
||||
if globalServerConfig != nil && globalServerConfig.HTTPAll.MatchCertFromAllServers {
|
||||
for _, searchingServer := range group.Servers() {
|
||||
if searchingServer.SSLPolicy() == nil || !searchingServer.SSLPolicy().IsOn {
|
||||
continue
|
||||
}
|
||||
cert, ok := searchingServer.SSLPolicy().MatchDomain(matchedDomain)
|
||||
if ok {
|
||||
return searchingServer.SSLPolicy(), cert, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil, errors.New("no cert found for '" + matchedDomain + "'")
|
||||
}
|
||||
|
||||
// 证书是否匹配
|
||||
var sslConfig = server.SSLPolicy()
|
||||
cert, ok := sslConfig.MatchDomain(matchedDomain)
|
||||
if ok {
|
||||
return sslConfig, cert, nil
|
||||
}
|
||||
|
||||
if len(sslConfig.Certs) == 0 {
|
||||
remotelogs.ServerError(server.Id, "BASE_LISTENER", "no ssl certs found for '"+matchedDomain+"', server id: "+types.String(server.Id), "", nil)
|
||||
}
|
||||
|
||||
return sslConfig, sslConfig.FirstCert(), nil
|
||||
}
|
||||
|
||||
// 根据域名来查找匹配的域名
|
||||
func (this *BaseListener) findNamedServer(name string, exactly bool) (serverConfig *serverconfigs.ServerConfig, serverName string) {
|
||||
serverConfig, serverName = this.findNamedServerMatched(name)
|
||||
if serverConfig != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var globalServerConfig = sharedNodeConfig.GlobalServerConfig
|
||||
var matchDomainStrictly = globalServerConfig != nil && globalServerConfig.HTTPAll.MatchDomainStrictly
|
||||
|
||||
if globalServerConfig != nil &&
|
||||
len(globalServerConfig.HTTPAll.DefaultDomain) > 0 &&
|
||||
(!matchDomainStrictly || configutils.MatchDomains(globalServerConfig.HTTPAll.AllowMismatchDomains, name) || (globalServerConfig.HTTPAll.AllowNodeIP && utils.IsWildIP(name))) {
|
||||
if globalServerConfig.HTTPAll.AllowNodeIP &&
|
||||
globalServerConfig.HTTPAll.NodeIPShowPage &&
|
||||
utils.IsWildIP(name) {
|
||||
return
|
||||
} else {
|
||||
var defaultDomain = globalServerConfig.HTTPAll.DefaultDomain
|
||||
serverConfig, serverName = this.findNamedServerMatched(defaultDomain)
|
||||
if serverConfig != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if matchDomainStrictly && !configutils.MatchDomains(globalServerConfig.HTTPAll.AllowMismatchDomains, name) && (!globalServerConfig.HTTPAll.AllowNodeIP || (!utils.IsWildIP(name) || globalServerConfig.HTTPAll.NodeIPShowPage)) {
|
||||
return
|
||||
}
|
||||
|
||||
if !exactly {
|
||||
// 如果没有找到,则匹配到第一个
|
||||
var group = this.Group
|
||||
var currentServers = group.Servers()
|
||||
var countServers = len(currentServers)
|
||||
if countServers == 0 {
|
||||
return nil, ""
|
||||
}
|
||||
return currentServers[0], name
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// 严格查找域名
|
||||
func (this *BaseListener) findNamedServerMatched(name string) (serverConfig *serverconfigs.ServerConfig, serverName string) {
|
||||
var group = this.Group
|
||||
if group == nil {
|
||||
return nil, ""
|
||||
}
|
||||
|
||||
server := group.MatchServerName(name)
|
||||
if server != nil {
|
||||
return server, name
|
||||
}
|
||||
|
||||
// 是否严格匹配域名
|
||||
var matchDomainStrictly = sharedNodeConfig.GlobalServerConfig != nil && sharedNodeConfig.GlobalServerConfig.HTTPAll.MatchDomainStrictly
|
||||
|
||||
// 如果只有一个server,则默认为这个
|
||||
var currentServers = group.Servers()
|
||||
var countServers = len(currentServers)
|
||||
if countServers == 1 && !matchDomainStrictly {
|
||||
return currentServers[0], name
|
||||
}
|
||||
|
||||
return nil, name
|
||||
}
|
||||
|
||||
// 从Hello信息中获取服务名称
|
||||
func (this *BaseListener) helloServerNames(clientInfo *tls.ClientHelloInfo) (serverNames []string) {
|
||||
if len(clientInfo.ServerName) != 0 {
|
||||
serverNames = append(serverNames, clientInfo.ServerName)
|
||||
return
|
||||
}
|
||||
|
||||
if clientInfo.Conn != nil {
|
||||
var localAddr = clientInfo.Conn.LocalAddr()
|
||||
if localAddr != nil {
|
||||
tcpAddr, ok := localAddr.(*net.TCPAddr)
|
||||
if ok {
|
||||
serverNames = append(serverNames, tcpAddr.IP.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
serverNames = append(serverNames, sharedNodeConfig.IPAddresses...)
|
||||
|
||||
return
|
||||
}
|
||||
10
EdgeNode/internal/nodes/listener_base_ext.go
Normal file
10
EdgeNode/internal/nodes/listener_base_ext.go
Normal file
@@ -0,0 +1,10 @@
|
||||
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
//go:build !plus
|
||||
|
||||
package nodes
|
||||
|
||||
import "crypto/tls"
|
||||
|
||||
func (this *BaseListener) calculateFingerprint(clientInfo *tls.ClientHelloInfo) []byte {
|
||||
return nil
|
||||
}
|
||||
50
EdgeNode/internal/nodes/listener_base_ext_plus.go
Normal file
50
EdgeNode/internal/nodes/listener_base_ext_plus.go
Normal file
@@ -0,0 +1,50 @@
|
||||
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
//go:build plus
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"hash"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var md5Pool = &sync.Pool{
|
||||
New: func() any {
|
||||
return md5.New()
|
||||
},
|
||||
}
|
||||
|
||||
func md5Bytes(b []byte) []byte {
|
||||
var h = md5Pool.Get().(hash.Hash)
|
||||
h.Write(b)
|
||||
var sum = h.Sum(nil)
|
||||
h.Reset()
|
||||
md5Pool.Put(h)
|
||||
return sum
|
||||
}
|
||||
|
||||
func (this *BaseListener) calculateFingerprint(clientInfo *tls.ClientHelloInfo) []byte {
|
||||
var b = []byte{}
|
||||
for _, c := range clientInfo.CipherSuites {
|
||||
b = binary.BigEndian.AppendUint16(b, c)
|
||||
}
|
||||
b = append(b, 0)
|
||||
|
||||
for _, c := range clientInfo.SupportedCurves {
|
||||
b = binary.BigEndian.AppendUint16(b, uint16(c))
|
||||
}
|
||||
b = append(b, 0)
|
||||
|
||||
b = append(b, clientInfo.SupportedPoints...)
|
||||
b = append(b, 0)
|
||||
|
||||
for _, s := range clientInfo.SignatureSchemes {
|
||||
b = binary.BigEndian.AppendUint16(b, uint16(s))
|
||||
}
|
||||
b = append(b, 0)
|
||||
|
||||
return md5Bytes(b)
|
||||
}
|
||||
37
EdgeNode/internal/nodes/listener_base_test.go
Normal file
37
EdgeNode/internal/nodes/listener_base_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestBaseListener_FindServer(t *testing.T) {
|
||||
sharedNodeConfig = &nodeconfigs.NodeConfig{}
|
||||
|
||||
var listener = &BaseListener{}
|
||||
listener.Group = serverconfigs.NewServerAddressGroup("https://*:443")
|
||||
for i := 0; i < 1_000_000; i++ {
|
||||
var server = &serverconfigs.ServerConfig{
|
||||
IsOn: true,
|
||||
Name: types.String(i) + ".hello.com",
|
||||
ServerNames: []*serverconfigs.ServerNameConfig{
|
||||
{Name: types.String(i) + ".hello.com"},
|
||||
},
|
||||
}
|
||||
_ = server.Init(context.Background())
|
||||
listener.Group.Add(server)
|
||||
}
|
||||
|
||||
var before = time.Now()
|
||||
defer func() {
|
||||
t.Log(time.Since(before).Seconds()*1000, "ms")
|
||||
}()
|
||||
|
||||
t.Log(listener.findNamedServerMatched("855555.hello.com"))
|
||||
}
|
||||
278
EdgeNode/internal/nodes/listener_http.go
Normal file
278
EdgeNode/internal/nodes/listener_http.go
Normal file
@@ -0,0 +1,278 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
var httpErrorLogger = log.New(io.Discard, "", 0)
|
||||
|
||||
const HTTPIdleTimeout = 75 * time.Second
|
||||
|
||||
type contextKey struct {
|
||||
key string
|
||||
}
|
||||
|
||||
var HTTPConnContextKey = &contextKey{key: "http-conn"}
|
||||
|
||||
type HTTPListener struct {
|
||||
BaseListener
|
||||
|
||||
Listener net.Listener
|
||||
|
||||
addr string
|
||||
isHTTP bool
|
||||
isHTTPS bool
|
||||
isHTTP3 bool
|
||||
httpServer *http.Server
|
||||
}
|
||||
|
||||
func (this *HTTPListener) Serve() error {
|
||||
this.addr = this.Group.Addr()
|
||||
this.isHTTP = this.Group.IsHTTP()
|
||||
this.isHTTPS = this.Group.IsHTTPS()
|
||||
|
||||
this.httpServer = &http.Server{
|
||||
Addr: this.addr,
|
||||
Handler: this,
|
||||
ReadHeaderTimeout: 3 * time.Second, // TODO 改成可以配置
|
||||
IdleTimeout: HTTPIdleTimeout, // TODO 改成可以配置
|
||||
ConnState: func(conn net.Conn, state http.ConnState) {
|
||||
switch state {
|
||||
case http.StateNew:
|
||||
atomic.AddInt64(&this.countActiveConnections, 1)
|
||||
case http.StateClosed:
|
||||
atomic.AddInt64(&this.countActiveConnections, -1)
|
||||
default:
|
||||
// do nothing
|
||||
}
|
||||
},
|
||||
ConnContext: func(ctx context.Context, conn net.Conn) context.Context {
|
||||
tlsConn, ok := conn.(*tls.Conn)
|
||||
|
||||
if ok {
|
||||
conn = NewClientTLSConn(tlsConn)
|
||||
}
|
||||
|
||||
return context.WithValue(ctx, HTTPConnContextKey, conn)
|
||||
},
|
||||
}
|
||||
|
||||
if !Tea.IsTesting() {
|
||||
this.httpServer.ErrorLog = httpErrorLogger
|
||||
}
|
||||
|
||||
this.httpServer.SetKeepAlivesEnabled(true)
|
||||
|
||||
// HTTP协议
|
||||
if this.isHTTP {
|
||||
err := this.httpServer.Serve(this.Listener)
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// HTTPS协议
|
||||
if this.isHTTPS {
|
||||
this.httpServer.TLSConfig = this.buildTLSConfig()
|
||||
|
||||
err := this.httpServer.ServeTLS(this.Listener, "", "")
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *HTTPListener) Close() error {
|
||||
if this.httpServer != nil {
|
||||
_ = this.httpServer.Close()
|
||||
}
|
||||
return this.Listener.Close()
|
||||
}
|
||||
|
||||
func (this *HTTPListener) Reload(group *serverconfigs.ServerAddressGroup) {
|
||||
this.Group = group
|
||||
|
||||
this.Reset()
|
||||
}
|
||||
|
||||
// ServeHTTPWithAddr 处理HTTP请求
|
||||
func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.Request) {
|
||||
this.ServeHTTPWithAddr(rawWriter, rawReq, this.addr)
|
||||
}
|
||||
|
||||
// ServeHTTPWithAddr 处理HTTP请求并指定服务地址
|
||||
func (this *HTTPListener) ServeHTTPWithAddr(rawWriter http.ResponseWriter, rawReq *http.Request, serverAddr string) {
|
||||
if len(rawReq.Host) > 253 {
|
||||
http.Error(rawWriter, "Host too long.", http.StatusBadRequest)
|
||||
time.Sleep(1 * time.Second) // make connection slow down
|
||||
return
|
||||
}
|
||||
|
||||
var globalServerConfig = sharedNodeConfig.GlobalServerConfig
|
||||
if globalServerConfig != nil && !globalServerConfig.HTTPAll.SupportsLowVersionHTTP && (rawReq.ProtoMajor < 1 /** 0.x **/ || (rawReq.ProtoMajor == 1 && rawReq.ProtoMinor == 0 /** 1.0 **/)) {
|
||||
http.Error(rawWriter, rawReq.Proto+" request is not supported.", http.StatusHTTPVersionNotSupported)
|
||||
time.Sleep(1 * time.Second) // make connection slow down
|
||||
return
|
||||
}
|
||||
|
||||
// 不支持Connect
|
||||
if rawReq.Method == http.MethodConnect {
|
||||
http.Error(rawWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
|
||||
time.Sleep(1 * time.Second) // make connection slow down
|
||||
return
|
||||
}
|
||||
|
||||
// 域名
|
||||
var reqHost = strings.ToLower(strings.TrimRight(rawReq.Host, "."))
|
||||
|
||||
// TLS域名
|
||||
if this.isIP(reqHost) {
|
||||
if rawReq.TLS != nil {
|
||||
var serverName = rawReq.TLS.ServerName
|
||||
if len(serverName) > 0 {
|
||||
// 端口
|
||||
var index = strings.LastIndex(reqHost, ":")
|
||||
if index >= 0 {
|
||||
reqHost = serverName + reqHost[index:]
|
||||
} else {
|
||||
reqHost = serverName
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 防止空Host
|
||||
if len(reqHost) == 0 {
|
||||
var ctx = rawReq.Context()
|
||||
if ctx != nil {
|
||||
addr := ctx.Value(http.LocalAddrContextKey)
|
||||
if addr != nil {
|
||||
reqHost = addr.(net.Addr).String()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
domain, _, err := net.SplitHostPort(reqHost)
|
||||
if err != nil {
|
||||
domain = reqHost
|
||||
}
|
||||
|
||||
server, serverName := this.findNamedServer(domain, false)
|
||||
if server == nil {
|
||||
server = this.emptyServer()
|
||||
} else if !server.CNameAsDomain && server.CNameDomain == domain {
|
||||
server = this.emptyServer()
|
||||
} else {
|
||||
serverName = domain
|
||||
}
|
||||
|
||||
// 绑定连接
|
||||
var clientConn ClientConnInterface
|
||||
if server != nil && server.Id > 0 {
|
||||
var requestConn = rawReq.Context().Value(HTTPConnContextKey)
|
||||
if requestConn != nil {
|
||||
var ok bool
|
||||
clientConn, ok = requestConn.(ClientConnInterface)
|
||||
if ok {
|
||||
var goNext = clientConn.SetServerId(server.Id)
|
||||
if !goNext {
|
||||
return
|
||||
}
|
||||
clientConn.SetUserId(server.UserId)
|
||||
|
||||
var userPlanId int64
|
||||
if server.UserPlan != nil && server.UserPlan.Id > 0 {
|
||||
userPlanId = server.UserPlan.Id
|
||||
}
|
||||
clientConn.SetUserPlanId(userPlanId)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查用户
|
||||
if server != nil && server.UserId > 0 {
|
||||
if !SharedUserManager.CheckUserServersIsEnabled(server.UserId) {
|
||||
rawWriter.WriteHeader(http.StatusNotFound)
|
||||
_, _ = rawWriter.Write([]byte("The site owner is unavailable."))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 包装新请求对象
|
||||
var req = &HTTPRequest{
|
||||
RawReq: rawReq,
|
||||
RawWriter: rawWriter,
|
||||
ReqServer: server,
|
||||
ReqHost: reqHost,
|
||||
ServerName: serverName,
|
||||
ServerAddr: serverAddr,
|
||||
IsHTTP: this.isHTTP,
|
||||
IsHTTPS: this.isHTTPS,
|
||||
IsHTTP3: this.isHTTP3,
|
||||
|
||||
nodeConfig: sharedNodeConfig,
|
||||
}
|
||||
req.Do()
|
||||
|
||||
// fix hijacked connection state
|
||||
if req.isHijacked && clientConn != nil && this.httpServer.ConnState != nil {
|
||||
netConn, ok := clientConn.(net.Conn)
|
||||
if ok {
|
||||
this.httpServer.ConnState(netConn, http.StateClosed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查host是否为IP
|
||||
func (this *HTTPListener) isIP(host string) bool {
|
||||
// IPv6
|
||||
if strings.Contains(host, "[") {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, b := range host {
|
||||
if b >= 'a' && b <= 'z' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// 默认的访问日志
|
||||
func (this *HTTPListener) emptyServer() *serverconfigs.ServerConfig {
|
||||
var server = &serverconfigs.ServerConfig{
|
||||
Type: serverconfigs.ServerTypeHTTPProxy,
|
||||
}
|
||||
|
||||
// 检查是否开启访问日志
|
||||
if sharedNodeConfig != nil {
|
||||
var globalServerConfig = sharedNodeConfig.GlobalServerConfig
|
||||
if globalServerConfig != nil && globalServerConfig.HTTPAccessLog.EnableServerNotFound {
|
||||
var accessLogRef = serverconfigs.NewHTTPAccessLogRef()
|
||||
accessLogRef.IsOn = true
|
||||
accessLogRef.Fields = append([]int{}, serverconfigs.HTTPAccessLogDefaultFieldsCodes...)
|
||||
server.Web = &serverconfigs.HTTPWebConfig{
|
||||
IsOn: true,
|
||||
AccessLogRef: accessLogRef,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO 需要对访问频率过多的IP进行惩罚
|
||||
|
||||
return server
|
||||
}
|
||||
21
EdgeNode/internal/nodes/listener_interface.go
Normal file
21
EdgeNode/internal/nodes/listener_interface.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package nodes
|
||||
|
||||
import "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
|
||||
// ListenerInterface 各协议监听器的接口
|
||||
type ListenerInterface interface {
|
||||
// Init 初始化
|
||||
Init()
|
||||
|
||||
// Serve 监听
|
||||
Serve() error
|
||||
|
||||
// Close 关闭
|
||||
Close() error
|
||||
|
||||
// Reload 重载配置
|
||||
Reload(serverGroup *serverconfigs.ServerAddressGroup)
|
||||
|
||||
// CountActiveConnections 获取当前活跃的连接数
|
||||
CountActiveConnections() int
|
||||
}
|
||||
369
EdgeNode/internal/nodes/listener_manager.go
Normal file
369
EdgeNode/internal/nodes/listener_manager.go
Normal file
@@ -0,0 +1,369 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/goman"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var sharedListenerManager *ListenerManager
|
||||
|
||||
func init() {
|
||||
if !teaconst.IsMain {
|
||||
return
|
||||
}
|
||||
|
||||
sharedListenerManager = NewListenerManager()
|
||||
}
|
||||
|
||||
// ListenerManager 端口监听管理器
|
||||
type ListenerManager struct {
|
||||
listenersMap map[string]*Listener // addr => *Listener
|
||||
http3Listener *HTTPListener
|
||||
|
||||
locker sync.Mutex
|
||||
lastConfig *nodeconfigs.NodeConfig
|
||||
|
||||
retryListenerMap map[string]*Listener // 需要重试的监听器 addr => Listener
|
||||
ticker *time.Ticker
|
||||
|
||||
firewalld *firewalls.Firewalld
|
||||
lastPortStrings string
|
||||
lastTCPPortRanges [][2]int
|
||||
lastUDPPortRanges [][2]int
|
||||
}
|
||||
|
||||
// NewListenerManager 获取新对象
|
||||
func NewListenerManager() *ListenerManager {
|
||||
var manager = &ListenerManager{
|
||||
listenersMap: map[string]*Listener{},
|
||||
retryListenerMap: map[string]*Listener{},
|
||||
ticker: time.NewTicker(1 * time.Minute),
|
||||
firewalld: firewalls.NewFirewalld(),
|
||||
}
|
||||
|
||||
// 提升测试效率
|
||||
if Tea.IsTesting() {
|
||||
manager.ticker = time.NewTicker(5 * time.Second)
|
||||
}
|
||||
|
||||
goman.New(func() {
|
||||
for range manager.ticker.C {
|
||||
manager.retryListeners()
|
||||
}
|
||||
})
|
||||
|
||||
return manager
|
||||
}
|
||||
|
||||
// Start 启动监听
|
||||
func (this *ListenerManager) Start(nodeConfig *nodeconfigs.NodeConfig) error {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
// 重置数据
|
||||
this.retryListenerMap = map[string]*Listener{}
|
||||
|
||||
// 检查是否有变化
|
||||
/**if this.lastConfig != nil && this.lastConfig.Version == node.Version {
|
||||
return nil
|
||||
}**/
|
||||
this.lastConfig = nodeConfig
|
||||
|
||||
// 所有的新地址
|
||||
var groupAddrs = []string{}
|
||||
var availableServerGroups = nodeConfig.AvailableGroups()
|
||||
if !nodeConfig.IsOn {
|
||||
availableServerGroups = []*serverconfigs.ServerAddressGroup{}
|
||||
}
|
||||
|
||||
if len(availableServerGroups) == 0 {
|
||||
remotelogs.Println("LISTENER_MANAGER", "no available servers to startup")
|
||||
}
|
||||
|
||||
for _, group := range availableServerGroups {
|
||||
var addr = group.FullAddr()
|
||||
groupAddrs = append(groupAddrs, addr)
|
||||
}
|
||||
|
||||
// 停掉老的
|
||||
for listenerKey, listener := range this.listenersMap {
|
||||
var addr = listener.FullAddr()
|
||||
if !lists.ContainsString(groupAddrs, addr) {
|
||||
remotelogs.Println("LISTENER_MANAGER", "close '"+addr+"'")
|
||||
_ = listener.Close()
|
||||
|
||||
delete(this.listenersMap, listenerKey)
|
||||
}
|
||||
}
|
||||
|
||||
// 启动新的或修改老的
|
||||
for _, group := range availableServerGroups {
|
||||
var addr = group.FullAddr()
|
||||
listener, ok := this.listenersMap[addr]
|
||||
if ok {
|
||||
// 不需要打印reload信息,防止日志数量过多
|
||||
listener.Reload(group)
|
||||
} else {
|
||||
remotelogs.Println("LISTENER_MANAGER", "listen '"+this.prettyAddress(addr)+"'")
|
||||
listener = NewListener()
|
||||
listener.Reload(group)
|
||||
err := listener.Listen()
|
||||
if err != nil {
|
||||
// 放入到重试队列中
|
||||
this.retryListenerMap[addr] = listener
|
||||
|
||||
var firstServer = group.FirstServer()
|
||||
if firstServer == nil {
|
||||
remotelogs.Error("LISTENER_MANAGER", err.Error())
|
||||
} else {
|
||||
// 当前占用的进程名
|
||||
if strings.Contains(err.Error(), "in use") {
|
||||
portIndex := strings.LastIndex(addr, ":")
|
||||
if portIndex > 0 {
|
||||
var port = addr[portIndex+1:]
|
||||
var processName = this.findProcessNameWithPort(group.IsUDP(), port)
|
||||
if len(processName) > 0 {
|
||||
err = fmt.Errorf("%w (the process using port: '%s')", err, processName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
remotelogs.ServerError(firstServer.Id, "LISTENER_MANAGER", "listen '"+addr+"' failed: "+err.Error(), nodeconfigs.NodeLogTypeListenAddressFailed, maps.Map{"address": addr})
|
||||
}
|
||||
|
||||
continue
|
||||
} else {
|
||||
// TODO 是否是从错误中恢复
|
||||
}
|
||||
this.listenersMap[addr] = listener
|
||||
}
|
||||
}
|
||||
|
||||
// 加入到firewalld
|
||||
go this.addToFirewalld(groupAddrs)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TotalActiveConnections 获取总的活跃连接数
|
||||
func (this *ListenerManager) TotalActiveConnections() int {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
var total = 0
|
||||
for _, listener := range this.listenersMap {
|
||||
total += listener.listener.CountActiveConnections()
|
||||
}
|
||||
|
||||
if this.http3Listener != nil {
|
||||
total += this.http3Listener.CountActiveConnections()
|
||||
}
|
||||
|
||||
return total
|
||||
}
|
||||
|
||||
// 返回更加友好格式的地址
|
||||
func (this *ListenerManager) prettyAddress(addr string) string {
|
||||
u, err := url.Parse(addr)
|
||||
if err != nil {
|
||||
return addr
|
||||
}
|
||||
if regexp.MustCompile(`^:\d+$`).MatchString(u.Host) {
|
||||
u.Host = "*" + u.Host
|
||||
}
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// 重试失败的Listener
|
||||
func (this *ListenerManager) retryListeners() {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
for addr, listener := range this.retryListenerMap {
|
||||
err := listener.Listen()
|
||||
if err == nil {
|
||||
delete(this.retryListenerMap, addr)
|
||||
this.listenersMap[addr] = listener
|
||||
remotelogs.ServerSuccess(listener.group.FirstServer().Id, "LISTENER_MANAGER", "retry to listen '"+addr+"' successfully", nodeconfigs.NodeLogTypeListenAddressFailed, maps.Map{"address": addr})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (this *ListenerManager) findProcessNameWithPort(isUdp bool, port string) string {
|
||||
if runtime.GOOS != "linux" {
|
||||
return ""
|
||||
}
|
||||
|
||||
path, err := executils.LookPath("ss")
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var option = "t"
|
||||
if isUdp {
|
||||
option = "u"
|
||||
}
|
||||
|
||||
var cmd = executils.NewTimeoutCmd(10*time.Second, path, "-"+option+"lpn", "sport = :"+port)
|
||||
cmd.WithStdout()
|
||||
err = cmd.Run()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var matches = regexp.MustCompile(`(?U)\(\("(.+)",pid=\d+,fd=\d+\)\)`).FindStringSubmatch(cmd.Stdout())
|
||||
if len(matches) > 1 {
|
||||
return matches[1]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (this *ListenerManager) addToFirewalld(groupAddrs []string) {
|
||||
if !sharedNodeConfig.AutoOpenPorts {
|
||||
return
|
||||
}
|
||||
|
||||
if this.firewalld == nil || !this.firewalld.IsReady() {
|
||||
return
|
||||
}
|
||||
|
||||
// HTTP/3相关端口
|
||||
var http3Ports = sharedNodeConfig.FindHTTP3Ports()
|
||||
if len(http3Ports) > 0 {
|
||||
for _, port := range http3Ports {
|
||||
var groupAddr = "udp://:" + types.String(port)
|
||||
if !lists.ContainsString(groupAddrs, groupAddr) {
|
||||
groupAddrs = append(groupAddrs, groupAddr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 组合端口号
|
||||
var portStrings = []string{}
|
||||
var udpPorts = []int{}
|
||||
var tcpPorts = []int{}
|
||||
for _, addr := range groupAddrs {
|
||||
var protocol = "tcp"
|
||||
if strings.HasPrefix(addr, "udp") {
|
||||
protocol = "udp"
|
||||
}
|
||||
|
||||
var lastIndex = strings.LastIndex(addr, ":")
|
||||
if lastIndex > 0 {
|
||||
var portString = addr[lastIndex+1:]
|
||||
portStrings = append(portStrings, portString+"/"+protocol)
|
||||
|
||||
switch protocol {
|
||||
case "tcp":
|
||||
tcpPorts = append(tcpPorts, types.Int(portString))
|
||||
case "udp":
|
||||
udpPorts = append(udpPorts, types.Int(portString))
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(portStrings) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否有变化
|
||||
sort.Strings(portStrings)
|
||||
var newPortStrings = strings.Join(portStrings, ",")
|
||||
if newPortStrings == this.lastPortStrings {
|
||||
return
|
||||
}
|
||||
this.locker.Lock()
|
||||
this.lastPortStrings = newPortStrings
|
||||
this.locker.Unlock()
|
||||
|
||||
remotelogs.Println("FIREWALLD", "opening ports automatically ...")
|
||||
defer func() {
|
||||
remotelogs.Println("FIREWALLD", "open ports successfully")
|
||||
}()
|
||||
|
||||
// 合并端口
|
||||
var tcpPortRanges = utils.MergePorts(tcpPorts)
|
||||
var udpPortRanges = utils.MergePorts(udpPorts)
|
||||
|
||||
defer func() {
|
||||
this.locker.Lock()
|
||||
this.lastTCPPortRanges = tcpPortRanges
|
||||
this.lastUDPPortRanges = udpPortRanges
|
||||
this.locker.Unlock()
|
||||
}()
|
||||
|
||||
// 删除老的不存在的端口
|
||||
var tcpPortRangesMap = map[string]bool{}
|
||||
var udpPortRangesMap = map[string]bool{}
|
||||
for _, portRange := range tcpPortRanges {
|
||||
tcpPortRangesMap[this.firewalld.PortRangeString(portRange, "tcp")] = true
|
||||
}
|
||||
for _, portRange := range udpPortRanges {
|
||||
udpPortRangesMap[this.firewalld.PortRangeString(portRange, "udp")] = true
|
||||
}
|
||||
|
||||
for _, portRange := range this.lastTCPPortRanges {
|
||||
var s = this.firewalld.PortRangeString(portRange, "tcp")
|
||||
_, ok := tcpPortRangesMap[s]
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
remotelogs.Println("FIREWALLD", "remove port '"+s+"'")
|
||||
_ = this.firewalld.RemovePortRangePermanently(portRange, "tcp")
|
||||
}
|
||||
for _, portRange := range this.lastUDPPortRanges {
|
||||
var s = this.firewalld.PortRangeString(portRange, "udp")
|
||||
_, ok := udpPortRangesMap[s]
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
remotelogs.Println("FIREWALLD", "remove port '"+s+"'")
|
||||
_ = this.firewalld.RemovePortRangePermanently(portRange, "udp")
|
||||
}
|
||||
|
||||
// 添加新的
|
||||
_ = this.firewalld.AllowPortRangesPermanently(tcpPortRanges, "tcp")
|
||||
_ = this.firewalld.AllowPortRangesPermanently(udpPortRanges, "udp")
|
||||
}
|
||||
|
||||
func (this *ListenerManager) reloadFirewalld() {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
var nodeConfig = sharedNodeConfig
|
||||
|
||||
// 所有的新地址
|
||||
var groupAddrs = []string{}
|
||||
var availableServerGroups = nodeConfig.AvailableGroups()
|
||||
if !nodeConfig.IsOn {
|
||||
availableServerGroups = []*serverconfigs.ServerAddressGroup{}
|
||||
}
|
||||
|
||||
if len(availableServerGroups) == 0 {
|
||||
remotelogs.Println("LISTENER_MANAGER", "no available servers to startup")
|
||||
}
|
||||
|
||||
for _, group := range availableServerGroups {
|
||||
var addr = group.FullAddr()
|
||||
groupAddrs = append(groupAddrs, addr)
|
||||
}
|
||||
|
||||
go this.addToFirewalld(groupAddrs)
|
||||
}
|
||||
84
EdgeNode/internal/nodes/listener_manager_test.go
Normal file
84
EdgeNode/internal/nodes/listener_manager_test.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestListenerManager_Listen(t *testing.T) {
|
||||
manager := NewListenerManager()
|
||||
err := manager.Start(&nodeconfigs.NodeConfig{
|
||||
Servers: []*serverconfigs.ServerConfig{
|
||||
{
|
||||
IsOn: true,
|
||||
HTTP: &serverconfigs.HTTPProtocolConfig{
|
||||
BaseProtocol: serverconfigs.BaseProtocol{
|
||||
IsOn: true,
|
||||
Listen: []*serverconfigs.NetworkAddressConfig{
|
||||
{
|
||||
Protocol: serverconfigs.ProtocolHTTP,
|
||||
PortRange: "1234",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
IsOn: true,
|
||||
HTTP: &serverconfigs.HTTPProtocolConfig{
|
||||
BaseProtocol: serverconfigs.BaseProtocol{
|
||||
IsOn: true,
|
||||
Listen: []*serverconfigs.NetworkAddressConfig{
|
||||
{
|
||||
Protocol: serverconfigs.ProtocolHTTP,
|
||||
PortRange: "1235",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = manager.Start(&nodeconfigs.NodeConfig{
|
||||
Servers: []*serverconfigs.ServerConfig{
|
||||
{
|
||||
IsOn: true,
|
||||
HTTP: &serverconfigs.HTTPProtocolConfig{
|
||||
BaseProtocol: serverconfigs.BaseProtocol{
|
||||
IsOn: true,
|
||||
Listen: []*serverconfigs.NetworkAddressConfig{
|
||||
{
|
||||
Protocol: serverconfigs.ProtocolHTTP,
|
||||
PortRange: "1234",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
IsOn: true,
|
||||
HTTP: &serverconfigs.HTTPProtocolConfig{
|
||||
BaseProtocol: serverconfigs.BaseProtocol{
|
||||
IsOn: true,
|
||||
Listen: []*serverconfigs.NetworkAddressConfig{
|
||||
{
|
||||
Protocol: serverconfigs.ProtocolHTTP,
|
||||
PortRange: "1236",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Log("all ok")
|
||||
}
|
||||
318
EdgeNode/internal/nodes/listener_tcp.go
Normal file
318
EdgeNode/internal/nodes/listener_tcp.go
Normal file
@@ -0,0 +1,318 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/stats"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/bytepool"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/goman"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"github.com/pires/go-proxyproto"
|
||||
"net"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type TCPListener struct {
|
||||
BaseListener
|
||||
|
||||
Listener net.Listener
|
||||
|
||||
port int
|
||||
}
|
||||
|
||||
func (this *TCPListener) Serve() error {
|
||||
var listener = this.Listener
|
||||
if this.Group.IsTLS() {
|
||||
listener = tls.NewListener(listener, this.buildTLSConfig())
|
||||
}
|
||||
|
||||
// 获取分组端口
|
||||
var groupAddr = this.Group.Addr()
|
||||
var portIndex = strings.LastIndex(groupAddr, ":")
|
||||
if portIndex >= 0 {
|
||||
var port = groupAddr[portIndex+1:]
|
||||
this.port = types.Int(port)
|
||||
}
|
||||
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
atomic.AddInt64(&this.countActiveConnections, 1)
|
||||
|
||||
go func(conn net.Conn) {
|
||||
var server = this.Group.FirstServer()
|
||||
if server == nil {
|
||||
return
|
||||
}
|
||||
err = this.handleConn(server, conn)
|
||||
if err != nil {
|
||||
remotelogs.ServerError(server.Id, "TCP_LISTENER", err.Error(), "", nil)
|
||||
}
|
||||
atomic.AddInt64(&this.countActiveConnections, -1)
|
||||
}(conn)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *TCPListener) Reload(group *serverconfigs.ServerAddressGroup) {
|
||||
this.Group = group
|
||||
this.Reset()
|
||||
}
|
||||
|
||||
func (this *TCPListener) handleConn(server *serverconfigs.ServerConfig, conn net.Conn) error {
|
||||
if server == nil {
|
||||
return errors.New("no server available")
|
||||
}
|
||||
if server.ReverseProxy == nil {
|
||||
return errors.New("no ReverseProxy configured for the server")
|
||||
}
|
||||
|
||||
// 绑定连接和服务
|
||||
clientConn, ok := conn.(ClientConnInterface)
|
||||
if ok {
|
||||
var goNext = clientConn.SetServerId(server.Id)
|
||||
if !goNext {
|
||||
return nil
|
||||
}
|
||||
clientConn.SetUserId(server.UserId)
|
||||
|
||||
var userPlanId int64
|
||||
if server.UserPlan != nil && server.UserPlan.Id > 0 {
|
||||
userPlanId = server.UserPlan.Id
|
||||
}
|
||||
clientConn.SetUserPlanId(userPlanId)
|
||||
} else {
|
||||
tlsConn, ok := conn.(*tls.Conn)
|
||||
if ok {
|
||||
var internalConn = tlsConn.NetConn()
|
||||
if internalConn != nil {
|
||||
clientConn, ok = internalConn.(ClientConnInterface)
|
||||
if ok {
|
||||
var goNext = clientConn.SetServerId(server.Id)
|
||||
if !goNext {
|
||||
return nil
|
||||
}
|
||||
clientConn.SetUserId(server.UserId)
|
||||
|
||||
var userPlanId int64
|
||||
if server.UserPlan != nil && server.UserPlan.Id > 0 {
|
||||
userPlanId = server.UserPlan.Id
|
||||
}
|
||||
clientConn.SetUserPlanId(userPlanId)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 是否已达到流量限制
|
||||
if this.reachedTrafficLimit() || (server.UserId > 0 && !SharedUserManager.CheckUserServersIsEnabled(server.UserId)) {
|
||||
// 关闭连接
|
||||
tcpConn, ok := conn.(LingerConn)
|
||||
if ok {
|
||||
_ = tcpConn.SetLinger(0)
|
||||
}
|
||||
_ = conn.Close()
|
||||
|
||||
// TODO 使用系统防火墙drop当前端口的数据包一段时间(1分钟)
|
||||
// 不能使用阻止IP的方法,因为边缘节点只上有可能还有别的代理服务
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 记录域名排行
|
||||
tlsConn, ok := conn.(*tls.Conn)
|
||||
var recordStat = false
|
||||
var serverName = ""
|
||||
if ok {
|
||||
serverName = tlsConn.ConnectionState().ServerName
|
||||
if len(serverName) > 0 {
|
||||
// 统计
|
||||
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, serverName, 0, 0, 1, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
|
||||
recordStat = true
|
||||
}
|
||||
}
|
||||
|
||||
// 统计
|
||||
if !recordStat {
|
||||
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", 0, 0, 1, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
|
||||
}
|
||||
|
||||
// DAU统计
|
||||
clientIP, _, parseErr := net.SplitHostPort(conn.RemoteAddr().String())
|
||||
if parseErr == nil {
|
||||
stats.SharedDAUManager.AddIP(server.Id, clientIP)
|
||||
}
|
||||
|
||||
originConn, err := this.connectOrigin(server.Id, serverName, server.ReverseProxy, conn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
var closer = func() {
|
||||
_ = conn.Close()
|
||||
_ = originConn.Close()
|
||||
}
|
||||
|
||||
// PROXY Protocol
|
||||
if server.ReverseProxy != nil &&
|
||||
server.ReverseProxy.ProxyProtocol != nil &&
|
||||
server.ReverseProxy.ProxyProtocol.IsOn &&
|
||||
(server.ReverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion1 || server.ReverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion2) {
|
||||
var remoteAddr = conn.RemoteAddr()
|
||||
var transportProtocol = proxyproto.TCPv4
|
||||
if strings.Contains(remoteAddr.String(), "[") {
|
||||
transportProtocol = proxyproto.TCPv6
|
||||
}
|
||||
var header = proxyproto.Header{
|
||||
Version: byte(server.ReverseProxy.ProxyProtocol.Version),
|
||||
Command: proxyproto.PROXY,
|
||||
TransportProtocol: transportProtocol,
|
||||
SourceAddr: remoteAddr,
|
||||
DestinationAddr: conn.LocalAddr(),
|
||||
}
|
||||
_, err = header.WriteTo(originConn)
|
||||
if err != nil {
|
||||
closer()
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 从源站读取
|
||||
goman.New(func() {
|
||||
var originBuf = bytepool.Pool16k.Get()
|
||||
defer func() {
|
||||
bytepool.Pool16k.Put(originBuf)
|
||||
}()
|
||||
for {
|
||||
n, err := originConn.Read(originBuf.Bytes)
|
||||
if n > 0 {
|
||||
_, err = conn.Write(originBuf.Bytes[:n])
|
||||
if err != nil {
|
||||
closer()
|
||||
break
|
||||
}
|
||||
|
||||
// 记录流量
|
||||
if server != nil {
|
||||
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", int64(n), 0, 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
closer()
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// 从客户端读取
|
||||
var clientBuf = bytepool.Pool16k.Get()
|
||||
defer func() {
|
||||
bytepool.Pool16k.Put(clientBuf)
|
||||
}()
|
||||
for {
|
||||
// 是否已达到流量限制
|
||||
if this.reachedTrafficLimit() {
|
||||
closer()
|
||||
return nil
|
||||
}
|
||||
|
||||
n, err := conn.Read(clientBuf.Bytes)
|
||||
if n > 0 {
|
||||
_, err = originConn.Write(clientBuf.Bytes[:n])
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 关闭连接
|
||||
closer()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *TCPListener) Close() error {
|
||||
return this.Listener.Close()
|
||||
}
|
||||
|
||||
// 连接源站
|
||||
func (this *TCPListener) connectOrigin(serverId int64, requestHost string, reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) {
|
||||
if reverseProxy == nil {
|
||||
return nil, errors.New("no reverse proxy config")
|
||||
}
|
||||
|
||||
var requestCall = shared.NewRequestCall()
|
||||
requestCall.Domain = requestHost
|
||||
|
||||
var retries = 3
|
||||
var addr string
|
||||
|
||||
var failedOriginIds []int64
|
||||
|
||||
for i := 0; i < retries; i++ {
|
||||
var origin *serverconfigs.OriginConfig
|
||||
if len(failedOriginIds) > 0 {
|
||||
origin = reverseProxy.AnyOrigin(requestCall, failedOriginIds)
|
||||
}
|
||||
if origin == nil {
|
||||
origin = reverseProxy.NextOrigin(requestCall)
|
||||
}
|
||||
if origin == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 回源主机名
|
||||
if len(origin.RequestHost) > 0 {
|
||||
requestHost = origin.RequestHost
|
||||
} else if len(reverseProxy.RequestHost) > 0 {
|
||||
requestHost = reverseProxy.RequestHost
|
||||
}
|
||||
|
||||
conn, addr, err = OriginConnect(origin, this.port, remoteAddr, requestHost)
|
||||
if err != nil {
|
||||
failedOriginIds = append(failedOriginIds, origin.Id)
|
||||
|
||||
remotelogs.ServerError(serverId, "TCP_LISTENER", "unable to connect origin server: "+addr+": "+err.Error(), "", nil)
|
||||
|
||||
SharedOriginStateManager.Fail(origin, requestHost, reverseProxy, func() {
|
||||
reverseProxy.ResetScheduling()
|
||||
})
|
||||
|
||||
continue
|
||||
} else {
|
||||
if !origin.IsOk {
|
||||
SharedOriginStateManager.Success(origin, func() {
|
||||
reverseProxy.ResetScheduling()
|
||||
})
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
err = errors.New("server '" + types.String(serverId) + "': no available origin server can be used")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否已经达到流量限制
|
||||
func (this *TCPListener) reachedTrafficLimit() bool {
|
||||
var server = this.Group.FirstServer()
|
||||
if server == nil {
|
||||
return true
|
||||
}
|
||||
return server.TrafficLimitStatus != nil && server.TrafficLimitStatus.IsValid()
|
||||
}
|
||||
18
EdgeNode/internal/nodes/listener_test.go
Normal file
18
EdgeNode/internal/nodes/listener_test.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestListener_Listen(t *testing.T) {
|
||||
listener := NewListener()
|
||||
|
||||
group := serverconfigs.NewServerAddressGroup("https://:1234")
|
||||
|
||||
listener.Reload(group)
|
||||
err := listener.Listen()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
465
EdgeNode/internal/nodes/listener_udp.go
Normal file
465
EdgeNode/internal/nodes/listener_udp.go
Normal file
@@ -0,0 +1,465 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/stats"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/bytepool"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/goman"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"github.com/pires/go-proxyproto"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
UDPConnLifeSeconds = 30
|
||||
)
|
||||
|
||||
type UDPPacketListener interface {
|
||||
ReadFrom(b []byte) (n int, cm any, src net.Addr, err error)
|
||||
WriteTo(b []byte, cm any, dst net.Addr) (n int, err error)
|
||||
LocalAddr() net.Addr
|
||||
}
|
||||
|
||||
type UDPIPv4Listener struct {
|
||||
rawListener *ipv4.PacketConn
|
||||
}
|
||||
|
||||
func NewUDPIPv4Listener(rawListener *ipv4.PacketConn) *UDPIPv4Listener {
|
||||
return &UDPIPv4Listener{rawListener: rawListener}
|
||||
}
|
||||
|
||||
func (this *UDPIPv4Listener) ReadFrom(b []byte) (n int, cm any, src net.Addr, err error) {
|
||||
return this.rawListener.ReadFrom(b)
|
||||
}
|
||||
|
||||
func (this *UDPIPv4Listener) WriteTo(b []byte, cm any, dst net.Addr) (n int, err error) {
|
||||
return this.rawListener.WriteTo(b, cm.(*ipv4.ControlMessage), dst)
|
||||
}
|
||||
|
||||
func (this *UDPIPv4Listener) LocalAddr() net.Addr {
|
||||
return this.rawListener.LocalAddr()
|
||||
}
|
||||
|
||||
type UDPIPv6Listener struct {
|
||||
rawListener *ipv6.PacketConn
|
||||
}
|
||||
|
||||
func NewUDPIPv6Listener(rawListener *ipv6.PacketConn) *UDPIPv6Listener {
|
||||
return &UDPIPv6Listener{rawListener: rawListener}
|
||||
}
|
||||
|
||||
func (this *UDPIPv6Listener) ReadFrom(b []byte) (n int, cm any, src net.Addr, err error) {
|
||||
return this.rawListener.ReadFrom(b)
|
||||
}
|
||||
|
||||
func (this *UDPIPv6Listener) WriteTo(b []byte, cm any, dst net.Addr) (n int, err error) {
|
||||
return this.rawListener.WriteTo(b, cm.(*ipv6.ControlMessage), dst)
|
||||
}
|
||||
|
||||
func (this *UDPIPv6Listener) LocalAddr() net.Addr {
|
||||
return this.rawListener.LocalAddr()
|
||||
}
|
||||
|
||||
type UDPListener struct {
|
||||
BaseListener
|
||||
|
||||
IPv4Listener *ipv4.PacketConn
|
||||
IPv6Listener *ipv6.PacketConn
|
||||
|
||||
connMap map[string]*UDPConn
|
||||
connLocker sync.Mutex
|
||||
connTicker *utils.Ticker
|
||||
|
||||
reverseProxy *serverconfigs.ReverseProxyConfig
|
||||
|
||||
port int
|
||||
|
||||
isClosed bool
|
||||
}
|
||||
|
||||
func (this *UDPListener) Serve() error {
|
||||
if this.Group == nil {
|
||||
return nil
|
||||
}
|
||||
var server = this.Group.FirstServer()
|
||||
if server == nil {
|
||||
return nil
|
||||
}
|
||||
var serverId = server.Id
|
||||
|
||||
var wg = &sync.WaitGroup{}
|
||||
wg.Add(2) // 2 = ipv4 + ipv6
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
if this.IPv4Listener != nil {
|
||||
err := this.IPv4Listener.SetControlMessage(ipv4.FlagDst, true)
|
||||
if err != nil {
|
||||
remotelogs.ServerError(serverId, "UDP_LISTENER", "can not serve ipv4 listener: "+err.Error(), "", nil)
|
||||
return
|
||||
}
|
||||
|
||||
err = this.servePacketListener(NewUDPIPv4Listener(this.IPv4Listener))
|
||||
if err != nil {
|
||||
remotelogs.ServerError(serverId, "UDP_LISTENER", "can not serve ipv4 listener: "+err.Error(), "", nil)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
if this.IPv6Listener != nil {
|
||||
err := this.IPv6Listener.SetControlMessage(ipv6.FlagDst, true)
|
||||
if err != nil {
|
||||
remotelogs.ServerError(serverId, "UDP_LISTENER", "can not serve ipv6 listener: "+err.Error(), "", nil)
|
||||
return
|
||||
}
|
||||
|
||||
err = this.servePacketListener(NewUDPIPv6Listener(this.IPv6Listener))
|
||||
if err != nil {
|
||||
remotelogs.ServerError(serverId, "UDP_LISTENER", "can not serve ipv6 listener: "+err.Error(), "", nil)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *UDPListener) servePacketListener(listener UDPPacketListener) error {
|
||||
// 获取分组端口
|
||||
var groupAddr = this.Group.Addr()
|
||||
var portIndex = strings.LastIndex(groupAddr, ":")
|
||||
if portIndex >= 0 {
|
||||
var port = groupAddr[portIndex+1:]
|
||||
this.port = types.Int(port)
|
||||
}
|
||||
|
||||
var firstServer = this.Group.FirstServer()
|
||||
if firstServer == nil {
|
||||
return errors.New("no server available")
|
||||
}
|
||||
this.reverseProxy = firstServer.ReverseProxy
|
||||
if this.reverseProxy == nil {
|
||||
return errors.New("no ReverseProxy configured for the server '" + firstServer.Name + "'")
|
||||
}
|
||||
|
||||
this.connMap = map[string]*UDPConn{}
|
||||
this.connTicker = utils.NewTicker(1 * time.Minute)
|
||||
goman.New(func() {
|
||||
for this.connTicker.Next() {
|
||||
this.gcConns()
|
||||
}
|
||||
})
|
||||
|
||||
var buffer = make([]byte, 4<<10)
|
||||
for {
|
||||
if this.isClosed {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if firstServer.UserId > 0 && !SharedUserManager.CheckUserServersIsEnabled(firstServer.UserId) {
|
||||
return nil
|
||||
}
|
||||
|
||||
n, cm, clientAddr, err := listener.ReadFrom(buffer)
|
||||
if err != nil {
|
||||
if this.isClosed {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查IP名单
|
||||
clientIP, _, parseHostErr := net.SplitHostPort(clientAddr.String())
|
||||
if parseHostErr == nil {
|
||||
ok, _, expiresAt := iplibrary.AllowIP(clientIP, firstServer.Id)
|
||||
if !ok {
|
||||
firewalls.DropTemporaryTo(clientIP, expiresAt)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
this.connLocker.Lock()
|
||||
conn, ok := this.connMap[clientAddr.String()]
|
||||
this.connLocker.Unlock()
|
||||
if ok && !conn.IsOk() {
|
||||
_ = conn.Close()
|
||||
ok = false
|
||||
}
|
||||
if !ok {
|
||||
originConn, err := this.connectOrigin(firstServer.Id, this.reverseProxy, listener.LocalAddr(), clientAddr)
|
||||
if err != nil {
|
||||
remotelogs.Error("UDP_LISTENER", "unable to connect to origin server: "+err.Error())
|
||||
continue
|
||||
}
|
||||
if originConn == nil {
|
||||
remotelogs.Error("UDP_LISTENER", "unable to find a origin server")
|
||||
continue
|
||||
}
|
||||
conn = NewUDPConn(firstServer, clientAddr, listener, cm, originConn.(*net.UDPConn))
|
||||
this.connLocker.Lock()
|
||||
this.connMap[clientAddr.String()] = conn
|
||||
this.connLocker.Unlock()
|
||||
}
|
||||
_, _ = conn.Write(buffer[:n])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (this *UDPListener) Close() error {
|
||||
this.isClosed = true
|
||||
|
||||
if this.connTicker != nil {
|
||||
this.connTicker.Stop()
|
||||
}
|
||||
|
||||
// 关闭所有连接
|
||||
this.connLocker.Lock()
|
||||
for _, conn := range this.connMap {
|
||||
_ = conn.Close()
|
||||
}
|
||||
this.connLocker.Unlock()
|
||||
|
||||
var errorStrings = []string{}
|
||||
if this.IPv4Listener != nil {
|
||||
err := this.IPv4Listener.Close()
|
||||
if err != nil {
|
||||
errorStrings = append(errorStrings, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if this.IPv6Listener != nil {
|
||||
err := this.IPv6Listener.Close()
|
||||
if err != nil {
|
||||
errorStrings = append(errorStrings, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if len(errorStrings) > 0 {
|
||||
return errors.New(errorStrings[0])
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *UDPListener) Reload(group *serverconfigs.ServerAddressGroup) {
|
||||
this.Group = group
|
||||
this.Reset()
|
||||
|
||||
// 重置配置
|
||||
var firstServer = this.Group.FirstServer()
|
||||
if firstServer == nil {
|
||||
return
|
||||
}
|
||||
this.reverseProxy = firstServer.ReverseProxy
|
||||
}
|
||||
|
||||
func (this *UDPListener) connectOrigin(serverId int64, reverseProxy *serverconfigs.ReverseProxyConfig, localAddr net.Addr, remoteAddr net.Addr) (conn net.Conn, err error) {
|
||||
if reverseProxy == nil {
|
||||
return nil, errors.New("no reverse proxy config")
|
||||
}
|
||||
|
||||
var retries = 3
|
||||
var addr string
|
||||
|
||||
var failedOriginIds []int64
|
||||
|
||||
for i := 0; i < retries; i++ {
|
||||
var origin *serverconfigs.OriginConfig
|
||||
if len(failedOriginIds) > 0 {
|
||||
origin = reverseProxy.AnyOrigin(nil, failedOriginIds)
|
||||
}
|
||||
if origin == nil {
|
||||
origin = reverseProxy.NextOrigin(nil)
|
||||
}
|
||||
if origin == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
conn, addr, err = OriginConnect(origin, this.port, remoteAddr.String(), "")
|
||||
if err != nil {
|
||||
failedOriginIds = append(failedOriginIds, origin.Id)
|
||||
|
||||
remotelogs.ServerError(serverId, "UDP_LISTENER", "unable to connect origin server: "+addr+": "+err.Error(), "", nil)
|
||||
|
||||
SharedOriginStateManager.Fail(origin, "", reverseProxy, func() {
|
||||
reverseProxy.ResetScheduling()
|
||||
})
|
||||
|
||||
continue
|
||||
} else {
|
||||
if !origin.IsOk {
|
||||
SharedOriginStateManager.Success(origin, func() {
|
||||
reverseProxy.ResetScheduling()
|
||||
})
|
||||
}
|
||||
|
||||
// PROXY Protocol
|
||||
if reverseProxy != nil &&
|
||||
reverseProxy.ProxyProtocol != nil &&
|
||||
reverseProxy.ProxyProtocol.IsOn &&
|
||||
(reverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion1 || reverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion2) {
|
||||
var transportProtocol = proxyproto.UDPv4
|
||||
if strings.Contains(remoteAddr.String(), "[") {
|
||||
transportProtocol = proxyproto.UDPv6
|
||||
}
|
||||
var header = proxyproto.Header{
|
||||
Version: byte(reverseProxy.ProxyProtocol.Version),
|
||||
Command: proxyproto.PROXY,
|
||||
TransportProtocol: transportProtocol,
|
||||
SourceAddr: remoteAddr,
|
||||
DestinationAddr: localAddr,
|
||||
}
|
||||
_, err = header.WriteTo(conn)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
err = errors.New("server '" + types.String(serverId) + "': no available origin server can be used")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 回收连接
|
||||
func (this *UDPListener) gcConns() {
|
||||
this.connLocker.Lock()
|
||||
var closingConns = []*UDPConn{}
|
||||
for addr, conn := range this.connMap {
|
||||
if !conn.IsOk() {
|
||||
closingConns = append(closingConns, conn)
|
||||
delete(this.connMap, addr)
|
||||
}
|
||||
}
|
||||
this.connLocker.Unlock()
|
||||
|
||||
for _, conn := range closingConns {
|
||||
_ = conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// UDPConn 自定义的UDP连接管理
|
||||
type UDPConn struct {
|
||||
addr net.Addr
|
||||
proxyListener UDPPacketListener
|
||||
serverConn net.Conn
|
||||
activatedAt int64
|
||||
isOk bool
|
||||
isClosed bool
|
||||
}
|
||||
|
||||
func NewUDPConn(server *serverconfigs.ServerConfig, clientAddr net.Addr, proxyListener UDPPacketListener, cm any, serverConn *net.UDPConn) *UDPConn {
|
||||
var conn = &UDPConn{
|
||||
addr: clientAddr,
|
||||
proxyListener: proxyListener,
|
||||
serverConn: serverConn,
|
||||
activatedAt: time.Now().Unix(),
|
||||
isOk: true,
|
||||
}
|
||||
|
||||
// 统计
|
||||
if server != nil {
|
||||
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", 0, 0, 1, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
|
||||
|
||||
// DAU统计
|
||||
clientIP, _, parseErr := net.SplitHostPort(clientAddr.String())
|
||||
if parseErr == nil {
|
||||
stats.SharedDAUManager.AddIP(server.Id, clientIP)
|
||||
}
|
||||
}
|
||||
|
||||
// 处理ControlMessage
|
||||
switch controlMessage := cm.(type) {
|
||||
case *ipv4.ControlMessage:
|
||||
controlMessage.Src = controlMessage.Dst
|
||||
case *ipv6.ControlMessage:
|
||||
controlMessage.Src = controlMessage.Dst
|
||||
}
|
||||
|
||||
goman.New(func() {
|
||||
var buf = bytepool.Pool4k.Get()
|
||||
defer func() {
|
||||
bytepool.Pool4k.Put(buf)
|
||||
}()
|
||||
|
||||
for {
|
||||
n, err := serverConn.Read(buf.Bytes)
|
||||
if n > 0 {
|
||||
conn.activatedAt = time.Now().Unix()
|
||||
|
||||
_, writingErr := proxyListener.WriteTo(buf.Bytes[:n], cm, clientAddr)
|
||||
if writingErr != nil {
|
||||
conn.isOk = false
|
||||
break
|
||||
}
|
||||
|
||||
// 记录流量和带宽
|
||||
if server != nil {
|
||||
// 流量
|
||||
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", int64(n), 0, 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
|
||||
|
||||
// 带宽
|
||||
var userPlanId int64
|
||||
if server.UserPlan != nil && server.UserPlan.Id > 0 {
|
||||
userPlanId = server.UserPlan.Id
|
||||
}
|
||||
stats.SharedBandwidthStatManager.AddBandwidth(server.UserId, userPlanId, server.Id, int64(n), int64(n))
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
conn.isOk = false
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
return conn
|
||||
}
|
||||
|
||||
func (this *UDPConn) Write(b []byte) (n int, err error) {
|
||||
this.activatedAt = time.Now().Unix()
|
||||
n, err = this.serverConn.Write(b)
|
||||
if err != nil {
|
||||
this.isOk = false
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (this *UDPConn) Close() error {
|
||||
this.isOk = false
|
||||
if this.isClosed {
|
||||
return nil
|
||||
}
|
||||
this.isClosed = true
|
||||
return this.serverConn.Close()
|
||||
}
|
||||
|
||||
func (this *UDPConn) IsOk() bool {
|
||||
if !this.isOk {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix()-this.activatedAt < UDPConnLifeSeconds // 如果超过 N 秒没有活动我们认为是超时
|
||||
}
|
||||
1241
EdgeNode/internal/nodes/node.go
Normal file
1241
EdgeNode/internal/nodes/node.go
Normal file
File diff suppressed because it is too large
Load Diff
21
EdgeNode/internal/nodes/node_ext.go
Normal file
21
EdgeNode/internal/nodes/node_ext.go
Normal file
@@ -0,0 +1,21 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build !script
|
||||
// +build !script
|
||||
|
||||
package nodes
|
||||
|
||||
func (this *Node) reloadCommonScripts() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Node) reloadIPLibrary() {
|
||||
|
||||
}
|
||||
|
||||
func (this *Node) notifyPlusChange() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Node) execTOAChangedTask() error {
|
||||
return nil
|
||||
}
|
||||
142
EdgeNode/internal/nodes/node_ext_plus.go
Normal file
142
EdgeNode/internal/nodes/node_ext_plus.go
Normal file
@@ -0,0 +1,142 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build plus && script
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
iplib "github.com/TeaOSLab/EdgeCommon/pkg/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/goman"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"time"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if !teaconst.IsMain {
|
||||
return
|
||||
}
|
||||
|
||||
events.On(events.EventLoaded, func() {
|
||||
var plusTicker = time.NewTicker(1 * time.Hour)
|
||||
if Tea.IsTesting() {
|
||||
plusTicker = time.NewTicker(1 * time.Minute)
|
||||
}
|
||||
goman.New(func() {
|
||||
for range plusTicker.C {
|
||||
_ = nodeInstance.notifyPlusChange()
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
var lastEdition string
|
||||
|
||||
func (this *Node) reloadCommonScripts() error {
|
||||
// 下载配置
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
configsResp, err := rpcClient.ScriptRPC.ComposeScriptConfigs(rpcClient.Context(), &pb.ComposeScriptConfigsRequest{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(configsResp.ScriptConfigsJSON) == 0 {
|
||||
sharedNodeConfig.CommonScripts = nil
|
||||
} else {
|
||||
var configs = []*serverconfigs.CommonScript{}
|
||||
err = json.Unmarshal(configsResp.ScriptConfigsJSON, &configs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decode script configs failed: %w", err)
|
||||
}
|
||||
sharedNodeConfig.CommonScripts = configs
|
||||
}
|
||||
|
||||
// 通知更新
|
||||
select {
|
||||
case commonScriptsChangesChan <- true:
|
||||
default:
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Node) reloadIPLibrary() {
|
||||
if sharedNodeConfig.Edition == lastEdition {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
var err error
|
||||
lastEdition = sharedNodeConfig.Edition
|
||||
if len(lastEdition) > 0 && (lists.ContainsString([]string{"pro", "ent", "max", "ultra"}, lastEdition)) {
|
||||
err = iplib.InitPlus()
|
||||
} else {
|
||||
err = iplib.InitDefault()
|
||||
}
|
||||
if err != nil {
|
||||
remotelogs.Error("IP_LIBRARY", "load ip library failed: "+err.Error())
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (this *Node) notifyPlusChange() error {
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := rpcClient.AuthorityKeyRPC.CheckAuthority(rpcClient.Context(), &pb.CheckAuthorityRequest{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var isChanged = resp.Edition != sharedNodeConfig.Edition
|
||||
if resp.IsPlus {
|
||||
sharedNodeConfig.Edition = resp.Edition
|
||||
} else {
|
||||
sharedNodeConfig.Edition = ""
|
||||
}
|
||||
|
||||
if isChanged {
|
||||
this.reloadIPLibrary()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Node) execTOAChangedTask() error {
|
||||
if sharedTOAManager == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := rpcClient.NodeRPC.FindNodeTOAConfig(rpcClient.Context(), &pb.FindNodeTOAConfigRequest{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(resp.ToaJSON) == 0 {
|
||||
return sharedTOAManager.Apply(&nodeconfigs.TOAConfig{IsOn: false})
|
||||
}
|
||||
|
||||
var config = nodeconfigs.NewTOAConfig()
|
||||
err = json.Unmarshal(resp.ToaJSON, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return sharedTOAManager.Apply(config)
|
||||
}
|
||||
44
EdgeNode/internal/nodes/node_panic.go
Normal file
44
EdgeNode/internal/nodes/node_panic.go
Normal file
@@ -0,0 +1,44 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build !arm64
|
||||
// +build !arm64
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// 处理异常
|
||||
func (this *Node) handlePanic() {
|
||||
// 如果是在前台运行就直接返回
|
||||
backgroundEnv, _ := os.LookupEnv("EdgeBackground")
|
||||
if backgroundEnv != "on" {
|
||||
return
|
||||
}
|
||||
|
||||
var panicFile = Tea.Root + "/logs/panic.log"
|
||||
|
||||
// 分析panic
|
||||
data, err := os.ReadFile(panicFile)
|
||||
if err == nil {
|
||||
var index = bytes.Index(data, []byte("panic:"))
|
||||
if index >= 0 {
|
||||
remotelogs.Error("NODE", "系统错误,请上报给开发者: "+string(data[index:]))
|
||||
}
|
||||
}
|
||||
|
||||
fp, err := os.OpenFile(panicFile, os.O_CREATE|os.O_TRUNC|os.O_WRONLY|os.O_APPEND, 0777)
|
||||
if err != nil {
|
||||
logs.Println("NODE", "open 'panic.log' failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
err = syscall.Dup2(int(fp.Fd()), int(os.Stderr.Fd()))
|
||||
if err != nil {
|
||||
logs.Println("NODE", "write to 'panic.log' failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
9
EdgeNode/internal/nodes/node_panic_arm64.go
Normal file
9
EdgeNode/internal/nodes/node_panic_arm64.go
Normal file
@@ -0,0 +1,9 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build arm64
|
||||
|
||||
package nodes
|
||||
|
||||
// 处理异常
|
||||
func (this *Node) handlePanic() {
|
||||
|
||||
}
|
||||
379
EdgeNode/internal/nodes/node_status_executor.go
Normal file
379
EdgeNode/internal/nodes/node_status_executor.go
Normal file
@@ -0,0 +1,379 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/caches"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/monitor"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
fsutils "github.com/TeaOSLab/EdgeNode/internal/utils/fs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/trackers"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/shirou/gopsutil/v3/cpu"
|
||||
"github.com/shirou/gopsutil/v3/disk"
|
||||
"github.com/shirou/gopsutil/v3/net"
|
||||
"math"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type NodeStatusExecutor struct {
|
||||
isFirstTime bool
|
||||
lastUpdatedTime time.Time
|
||||
|
||||
cpuLogicalCount int
|
||||
cpuPhysicalCount int
|
||||
|
||||
// 流量统计
|
||||
lastIOCounterStat net.IOCountersStat
|
||||
lastUDPInDatagrams int64
|
||||
lastUDPOutDatagrams int64
|
||||
|
||||
apiCallStat *rpc.CallStat
|
||||
|
||||
ticker *time.Ticker
|
||||
}
|
||||
|
||||
func NewNodeStatusExecutor() *NodeStatusExecutor {
|
||||
return &NodeStatusExecutor{
|
||||
ticker: time.NewTicker(30 * time.Second),
|
||||
apiCallStat: rpc.NewCallStat(10),
|
||||
|
||||
lastUDPInDatagrams: -1,
|
||||
lastUDPOutDatagrams: -1,
|
||||
}
|
||||
}
|
||||
|
||||
func (this *NodeStatusExecutor) Listen() {
|
||||
this.isFirstTime = true
|
||||
this.lastUpdatedTime = time.Now()
|
||||
this.update()
|
||||
|
||||
events.OnKey(events.EventQuit, this, func() {
|
||||
remotelogs.Println("NODE_STATUS", "quit executor")
|
||||
this.ticker.Stop()
|
||||
})
|
||||
|
||||
for range this.ticker.C {
|
||||
this.isFirstTime = false
|
||||
this.update()
|
||||
}
|
||||
}
|
||||
|
||||
func (this *NodeStatusExecutor) update() {
|
||||
if sharedNodeConfig == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var tr = trackers.Begin("UPLOAD_NODE_STATUS")
|
||||
defer tr.End()
|
||||
|
||||
var status = &nodeconfigs.NodeStatus{}
|
||||
status.BuildVersion = teaconst.Version
|
||||
status.BuildVersionCode = utils.VersionToLong(teaconst.Version)
|
||||
status.OS = runtime.GOOS
|
||||
status.Arch = runtime.GOARCH
|
||||
status.ExePath, _ = os.Executable()
|
||||
status.ConfigVersion = sharedNodeConfig.Version
|
||||
status.IsActive = true
|
||||
status.ConnectionCount = sharedListenerManager.TotalActiveConnections()
|
||||
status.CacheTotalDiskSize = caches.SharedManager.TotalDiskSize()
|
||||
status.CacheTotalMemorySize = caches.SharedManager.TotalMemorySize()
|
||||
status.TrafficInBytes = teaconst.InTrafficBytes
|
||||
status.TrafficOutBytes = teaconst.OutTrafficBytes
|
||||
|
||||
apiSuccessPercent, apiAvgCostSeconds := this.apiCallStat.Sum()
|
||||
status.APISuccessPercent = apiSuccessPercent
|
||||
status.APIAvgCostSeconds = apiAvgCostSeconds
|
||||
|
||||
var localFirewall = firewalls.Firewall()
|
||||
if localFirewall != nil && !localFirewall.IsMock() {
|
||||
status.LocalFirewallName = localFirewall.Name()
|
||||
}
|
||||
|
||||
// 记录监控数据
|
||||
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemConnections, maps.Map{
|
||||
"total": status.ConnectionCount,
|
||||
})
|
||||
|
||||
hostname, _ := os.Hostname()
|
||||
status.Hostname = hostname
|
||||
|
||||
var cpuTR = tr.Begin("cpu")
|
||||
this.updateCPU(status)
|
||||
cpuTR.End()
|
||||
|
||||
var memTR = tr.Begin("memory")
|
||||
this.updateMem(status)
|
||||
memTR.End()
|
||||
|
||||
var loadTR = tr.Begin("load")
|
||||
this.updateLoad(status)
|
||||
loadTR.End()
|
||||
|
||||
var diskTR = tr.Begin("disk")
|
||||
this.updateDisk(status)
|
||||
diskTR.End()
|
||||
|
||||
var cacheSpaceTR = tr.Begin("cache space")
|
||||
this.updateCacheSpace(status)
|
||||
cacheSpaceTR.End()
|
||||
|
||||
this.updateAllTraffic(status)
|
||||
|
||||
// 修改更新时间
|
||||
this.lastUpdatedTime = time.Now()
|
||||
|
||||
status.UpdatedAt = time.Now().Unix()
|
||||
status.Timestamp = status.UpdatedAt
|
||||
|
||||
// 发送数据
|
||||
jsonData, err := json.Marshal(status)
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE_STATUS", "serial NodeStatus fail: "+err.Error())
|
||||
return
|
||||
}
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE_STATUS", "failed to open rpc: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var before = time.Now()
|
||||
_, err = rpcClient.NodeRPC.UpdateNodeStatus(rpcClient.Context(), &pb.UpdateNodeStatusRequest{
|
||||
StatusJSON: jsonData,
|
||||
})
|
||||
var costSeconds = time.Since(before).Seconds()
|
||||
this.apiCallStat.Add(err == nil, costSeconds)
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Warn("NODE_STATUS", "rpc UpdateNodeStatus() failed: "+err.Error())
|
||||
} else {
|
||||
remotelogs.Error("NODE_STATUS", "rpc UpdateNodeStatus() failed: "+err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 更新CPU
|
||||
func (this *NodeStatusExecutor) updateCPU(status *nodeconfigs.NodeStatus) {
|
||||
var duration = time.Duration(0)
|
||||
if this.isFirstTime {
|
||||
duration = 100 * time.Millisecond
|
||||
}
|
||||
percents, err := cpu.Percent(duration, false)
|
||||
if err != nil {
|
||||
status.Error = "cpu.Percent(): " + err.Error()
|
||||
return
|
||||
}
|
||||
if len(percents) == 0 {
|
||||
return
|
||||
}
|
||||
status.CPUUsage = percents[0] / 100
|
||||
|
||||
// 记录监控数据
|
||||
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemCPU, maps.Map{
|
||||
"usage": status.CPUUsage,
|
||||
"cores": runtime.NumCPU(),
|
||||
})
|
||||
|
||||
if this.cpuLogicalCount == 0 && this.cpuPhysicalCount == 0 {
|
||||
status.CPULogicalCount, err = cpu.Counts(true)
|
||||
if err != nil {
|
||||
status.Error = "cpu.Counts(): " + err.Error()
|
||||
return
|
||||
}
|
||||
status.CPUPhysicalCount, err = cpu.Counts(false)
|
||||
if err != nil {
|
||||
status.Error = "cpu.Counts(): " + err.Error()
|
||||
return
|
||||
}
|
||||
this.cpuLogicalCount = status.CPULogicalCount
|
||||
this.cpuPhysicalCount = status.CPUPhysicalCount
|
||||
} else {
|
||||
status.CPULogicalCount = this.cpuLogicalCount
|
||||
status.CPUPhysicalCount = this.cpuPhysicalCount
|
||||
}
|
||||
}
|
||||
|
||||
// 更新硬盘
|
||||
func (this *NodeStatusExecutor) updateDisk(status *nodeconfigs.NodeStatus) {
|
||||
status.DiskWritingSpeedMB = int(fsutils.DiskSpeedMB)
|
||||
|
||||
partitions, err := disk.Partitions(false)
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE_STATUS", err.Error())
|
||||
return
|
||||
}
|
||||
lists.Sort(partitions, func(i int, j int) bool {
|
||||
p1 := partitions[i]
|
||||
p2 := partitions[j]
|
||||
return p1.Mountpoint > p2.Mountpoint
|
||||
})
|
||||
|
||||
// 当前TeaWeb所在的fs
|
||||
var rootFS = ""
|
||||
var rootTotal = uint64(0)
|
||||
var totalUsed = uint64(0)
|
||||
if lists.ContainsString([]string{"darwin", "linux", "freebsd"}, runtime.GOOS) {
|
||||
for _, p := range partitions {
|
||||
if p.Mountpoint == "/" {
|
||||
rootFS = p.Fstype
|
||||
usage, _ := disk.Usage(p.Mountpoint)
|
||||
if usage != nil {
|
||||
rootTotal = usage.Total
|
||||
totalUsed = usage.Used
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var total = rootTotal
|
||||
var maxUsage = float64(0)
|
||||
for _, partition := range partitions {
|
||||
if runtime.GOOS != "windows" && !strings.Contains(partition.Device, "/") && !strings.Contains(partition.Device, "\\") {
|
||||
continue
|
||||
}
|
||||
|
||||
// 跳过不同fs的
|
||||
if len(rootFS) > 0 && rootFS != partition.Fstype {
|
||||
continue
|
||||
}
|
||||
|
||||
usage, err := disk.Usage(partition.Mountpoint)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if partition.Mountpoint != "/" && (usage.Total != rootTotal || total == 0) {
|
||||
total += usage.Total
|
||||
totalUsed += usage.Used
|
||||
if usage.UsedPercent >= maxUsage {
|
||||
maxUsage = usage.UsedPercent
|
||||
status.DiskMaxUsagePartition = partition.Mountpoint
|
||||
}
|
||||
}
|
||||
}
|
||||
status.DiskTotal = total
|
||||
if total > 0 {
|
||||
status.DiskUsage = float64(totalUsed) / float64(total)
|
||||
}
|
||||
status.DiskMaxUsage = maxUsage / 100
|
||||
|
||||
// 记录监控数据
|
||||
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemDisk, maps.Map{
|
||||
"total": status.DiskTotal,
|
||||
"usage": status.DiskUsage,
|
||||
"maxUsage": status.DiskMaxUsage,
|
||||
})
|
||||
}
|
||||
|
||||
// 缓存空间
|
||||
func (this *NodeStatusExecutor) updateCacheSpace(status *nodeconfigs.NodeStatus) {
|
||||
var result = []maps.Map{}
|
||||
var cachePaths = caches.SharedManager.FindAllCachePaths()
|
||||
for _, path := range cachePaths {
|
||||
stat, err := fsutils.StatDevice(path)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
result = append(result, maps.Map{
|
||||
"path": path,
|
||||
"total": stat.TotalSize(),
|
||||
"avail": stat.FreeSize(),
|
||||
"used": stat.UsedSize(),
|
||||
})
|
||||
}
|
||||
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemCacheDir, maps.Map{
|
||||
"dirs": result,
|
||||
})
|
||||
}
|
||||
|
||||
// 流量
|
||||
func (this *NodeStatusExecutor) updateAllTraffic(status *nodeconfigs.NodeStatus) {
|
||||
trafficCounters, err := net.IOCounters(true)
|
||||
if err != nil {
|
||||
remotelogs.Warn("NODE_STATUS_EXECUTOR", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var allCounter = net.IOCountersStat{}
|
||||
for _, counter := range trafficCounters {
|
||||
// 跳过lo
|
||||
if counter.Name == "lo" {
|
||||
continue
|
||||
}
|
||||
allCounter.BytesRecv += counter.BytesRecv
|
||||
allCounter.BytesSent += counter.BytesSent
|
||||
}
|
||||
if allCounter.BytesSent == 0 && allCounter.BytesRecv == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if this.lastIOCounterStat.BytesSent > 0 {
|
||||
// 记录监控数据
|
||||
if allCounter.BytesSent >= this.lastIOCounterStat.BytesSent && allCounter.BytesRecv >= this.lastIOCounterStat.BytesRecv {
|
||||
var costSeconds = int(math.Ceil(time.Since(this.lastUpdatedTime).Seconds()))
|
||||
if costSeconds > 0 {
|
||||
var bytesSent = allCounter.BytesSent - this.lastIOCounterStat.BytesSent
|
||||
var bytesRecv = allCounter.BytesRecv - this.lastIOCounterStat.BytesRecv
|
||||
|
||||
// UDP
|
||||
var udpInDatagrams int64 = 0
|
||||
var udpOutDatagrams int64 = 0
|
||||
protoStats, protoErr := net.ProtoCounters([]string{"udp"})
|
||||
if protoErr == nil {
|
||||
for _, protoStat := range protoStats {
|
||||
if protoStat.Protocol == "udp" {
|
||||
udpInDatagrams = protoStat.Stats["InDatagrams"]
|
||||
udpOutDatagrams = protoStat.Stats["OutDatagrams"]
|
||||
if udpInDatagrams < 0 {
|
||||
udpInDatagrams = 0
|
||||
}
|
||||
if udpOutDatagrams < 0 {
|
||||
udpOutDatagrams = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var avgUDPInDatagrams int64 = 0
|
||||
var avgUDPOutDatagrams int64 = 0
|
||||
if this.lastUDPInDatagrams >= 0 && this.lastUDPOutDatagrams >= 0 {
|
||||
avgUDPInDatagrams = (udpInDatagrams - this.lastUDPInDatagrams) / int64(costSeconds)
|
||||
avgUDPOutDatagrams = (udpOutDatagrams - this.lastUDPOutDatagrams) / int64(costSeconds)
|
||||
if avgUDPInDatagrams < 0 {
|
||||
avgUDPInDatagrams = 0
|
||||
}
|
||||
if avgUDPOutDatagrams < 0 {
|
||||
avgUDPOutDatagrams = 0
|
||||
}
|
||||
}
|
||||
|
||||
this.lastUDPInDatagrams = udpInDatagrams
|
||||
this.lastUDPOutDatagrams = udpOutDatagrams
|
||||
|
||||
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemAllTraffic, maps.Map{
|
||||
"inBytes": bytesRecv,
|
||||
"outBytes": bytesSent,
|
||||
"avgInBytes": bytesRecv / uint64(costSeconds),
|
||||
"avgOutBytes": bytesSent / uint64(costSeconds),
|
||||
|
||||
"avgUDPInDatagrams": avgUDPInDatagrams,
|
||||
"avgUDPOutDatagrams": avgUDPOutDatagrams,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
this.lastIOCounterStat = allCounter
|
||||
}
|
||||
27
EdgeNode/internal/nodes/node_status_executor_test.go
Normal file
27
EdgeNode/internal/nodes/node_status_executor_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/shirou/gopsutil/v3/cpu"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNodeStatusExecutor_CPU(t *testing.T) {
|
||||
countLogicCPU, err := cpu.Counts(true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("logic count:", countLogicCPU)
|
||||
|
||||
countPhysicalCPU, err := cpu.Counts(false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("physical count:", countPhysicalCPU)
|
||||
|
||||
percents, err := cpu.Percent(100*time.Millisecond, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(percents)
|
||||
}
|
||||
71
EdgeNode/internal/nodes/node_status_executor_unix.go
Normal file
71
EdgeNode/internal/nodes/node_status_executor_unix.go
Normal file
@@ -0,0 +1,71 @@
|
||||
//go:build !windows
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/monitor"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/shirou/gopsutil/v3/load"
|
||||
"github.com/shirou/gopsutil/v3/mem"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
)
|
||||
|
||||
// 更新内存
|
||||
func (this *NodeStatusExecutor) updateMem(status *nodeconfigs.NodeStatus) {
|
||||
stat, err := mem.VirtualMemory()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 重新计算内存
|
||||
if stat.Total > 0 {
|
||||
stat.Used = stat.Total - stat.Free - stat.Buffers - stat.Cached
|
||||
status.MemoryUsage = float64(stat.Used) / float64(stat.Total)
|
||||
}
|
||||
|
||||
status.MemoryTotal = stat.Total
|
||||
|
||||
// 记录监控数据
|
||||
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemMemory, maps.Map{
|
||||
"usage": status.MemoryUsage,
|
||||
"total": status.MemoryTotal,
|
||||
"used": stat.Used,
|
||||
})
|
||||
|
||||
// 内存严重不足时自动释放内存
|
||||
if stat.Total > 0 {
|
||||
var minFreeMemory = stat.Total / 8
|
||||
if minFreeMemory > 1<<30 {
|
||||
minFreeMemory = 1 << 30
|
||||
}
|
||||
if stat.Available > 0 && stat.Available < minFreeMemory {
|
||||
runtime.GC()
|
||||
debug.FreeOSMemory()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 更新负载
|
||||
func (this *NodeStatusExecutor) updateLoad(status *nodeconfigs.NodeStatus) {
|
||||
stat, err := load.Avg()
|
||||
if err != nil {
|
||||
status.Error = err.Error()
|
||||
return
|
||||
}
|
||||
if stat == nil {
|
||||
status.Error = "load is nil"
|
||||
return
|
||||
}
|
||||
status.Load1m = stat.Load1
|
||||
status.Load5m = stat.Load5
|
||||
status.Load15m = stat.Load15
|
||||
|
||||
// 记录监控数据
|
||||
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemLoad, maps.Map{
|
||||
"load1m": status.Load1m,
|
||||
"load5m": status.Load5m,
|
||||
"load15m": status.Load15m,
|
||||
})
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user