This commit is contained in:
unknown
2026-02-04 20:27:13 +08:00
commit 3b042d1dad
9410 changed files with 1488147 additions and 0 deletions

View File

@@ -0,0 +1,529 @@
package nodes
import (
"context"
"encoding/json"
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/messageconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/caches"
"github.com/TeaOSLab/EdgeNode/internal/configs"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/errors"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
"github.com/TeaOSLab/EdgeNode/internal/utils/goman"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/maps"
"net/url"
"regexp"
"runtime"
"strconv"
"time"
)
type APIStream struct {
stream pb.NodeService_NodeStreamClient
isQuiting bool
cancelFunc context.CancelFunc
}
func NewAPIStream() *APIStream {
return &APIStream{}
}
func (this *APIStream) Start() {
events.OnKey(events.EventQuit, this, func() {
this.isQuiting = true
if this.cancelFunc != nil {
this.cancelFunc()
}
})
for {
if this.isQuiting {
return
}
err := this.loop()
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("API_STREAM", err.Error())
} else {
remotelogs.Warn("API_STREAM", err.Error())
}
time.Sleep(10 * time.Second)
continue
}
time.Sleep(1 * time.Second)
}
}
func (this *APIStream) loop() error {
rpcClient, err := rpc.SharedRPC()
if err != nil {
return errors.Wrap(err)
}
ctx, cancelFunc := context.WithCancel(rpcClient.Context())
this.cancelFunc = cancelFunc
defer func() {
cancelFunc()
}()
nodeStream, err := rpcClient.NodeRPC.NodeStream(ctx)
if err != nil {
if this.isQuiting {
return nil
}
return err
}
this.stream = nodeStream
for {
if this.isQuiting {
remotelogs.Println("API_STREAM", "quit")
break
}
message, streamErr := nodeStream.Recv()
if streamErr != nil {
if this.isQuiting {
remotelogs.Println("API_STREAM", "quit")
return nil
}
return streamErr
}
// 处理消息
switch message.Code {
case messageconfigs.MessageCodeConnectedAPINode: // 连接API节点成功
err = this.handleConnectedAPINode(message)
case messageconfigs.MessageCodeWriteCache: // 写入缓存
err = this.handleWriteCache(message)
case messageconfigs.MessageCodeReadCache: // 读取缓存
err = this.handleReadCache(message)
case messageconfigs.MessageCodeStatCache: // 统计缓存
err = this.handleStatCache(message)
case messageconfigs.MessageCodeCleanCache: // 清理缓存
err = this.handleCleanCache(message)
case messageconfigs.MessageCodeNewNodeTask: // 有新的任务
err = this.handleNewNodeTask(message)
case messageconfigs.MessageCodeCheckSystemdService: // 检查Systemd服务
err = this.handleCheckSystemdService(message)
case messageconfigs.MessageCodeCheckLocalFirewall: // 检查本地防火墙
err = this.handleCheckLocalFirewall(message)
case messageconfigs.MessageCodeChangeAPINode: // 修改API节点地址
err = this.handleChangeAPINode(message)
default:
err = this.handleUnknownMessage(message)
}
if err != nil {
remotelogs.Error("API_STREAM", "handle message failed: "+err.Error())
}
}
return nil
}
// 连接API节点成功
func (this *APIStream) handleConnectedAPINode(message *pb.NodeStreamMessage) error {
// 更改连接的APINode信息
if len(message.DataJSON) == 0 {
return nil
}
msg := &messageconfigs.ConnectedAPINodeMessage{}
err := json.Unmarshal(message.DataJSON, msg)
if err != nil {
return errors.Wrap(err)
}
_, err = rpc.SharedRPC()
if err != nil {
return errors.Wrap(err)
}
remotelogs.Println("API_STREAM", "connected to api node '"+strconv.FormatInt(msg.APINodeId, 10)+"'")
// 重新读取配置
if nodeConfigUpdatedAt == 0 {
select {
case nodeConfigChangedNotify <- true:
default:
}
}
return nil
}
// 写入缓存
func (this *APIStream) handleWriteCache(message *pb.NodeStreamMessage) error {
msg := &messageconfigs.WriteCacheMessage{}
err := json.Unmarshal(message.DataJSON, msg)
if err != nil {
this.replyFail(message.RequestId, "decode message data failed: "+err.Error())
return err
}
storage, shouldStop, err := this.cacheStorage(message, msg.CachePolicyJSON)
if err != nil {
return err
}
if shouldStop {
defer func() {
storage.Stop()
}()
}
expiredAt := time.Now().Unix() + msg.LifeSeconds
writer, err := storage.OpenWriter(msg.Key, expiredAt, 200, -1, int64(len(msg.Value)), -1, false)
if err != nil {
this.replyFail(message.RequestId, "prepare writing failed: "+err.Error())
return err
}
// 写入一个空的Header
_, err = writer.WriteHeader([]byte(":"))
if err != nil {
this.replyFail(message.RequestId, "write failed: "+err.Error())
return err
}
// 写入数据
_, err = writer.Write(msg.Value)
if err != nil {
this.replyFail(message.RequestId, "write failed: "+err.Error())
return err
}
err = writer.Close()
if err != nil {
this.replyFail(message.RequestId, "write failed: "+err.Error())
return err
}
storage.AddToList(&caches.Item{
Type: writer.ItemType(),
Key: msg.Key,
ExpiresAt: expiredAt,
HeaderSize: writer.HeaderSize(),
BodySize: writer.BodySize(),
})
this.replyOk(message.RequestId, "write ok")
return nil
}
// 读取缓存
func (this *APIStream) handleReadCache(message *pb.NodeStreamMessage) error {
msg := &messageconfigs.ReadCacheMessage{}
err := json.Unmarshal(message.DataJSON, msg)
if err != nil {
this.replyFail(message.RequestId, "decode message data failed: "+err.Error())
return err
}
storage, shouldStop, err := this.cacheStorage(message, msg.CachePolicyJSON)
if err != nil {
return err
}
if shouldStop {
defer func() {
storage.Stop()
}()
}
reader, err := storage.OpenReader(msg.Key, false, false)
if err != nil {
if err == caches.ErrNotFound {
this.replyFail(message.RequestId, "key not found")
return nil
}
this.replyFail(message.RequestId, "read key failed: "+err.Error())
return nil
}
defer func() {
_ = reader.Close()
}()
this.replyOk(message.RequestId, "value "+strconv.FormatInt(reader.BodySize(), 10)+" bytes")
return nil
}
// 统计缓存
func (this *APIStream) handleStatCache(message *pb.NodeStreamMessage) error {
msg := &messageconfigs.ReadCacheMessage{}
err := json.Unmarshal(message.DataJSON, msg)
if err != nil {
this.replyFail(message.RequestId, "decode message data failed: "+err.Error())
return err
}
storage, shouldStop, err := this.cacheStorage(message, msg.CachePolicyJSON)
if err != nil {
return err
}
if shouldStop {
defer func() {
storage.Stop()
}()
}
stat, err := storage.Stat()
if err != nil {
this.replyFail(message.RequestId, "stat failed: "+err.Error())
return err
}
sizeFormat := ""
if stat.Size < (1 << 10) {
sizeFormat = strconv.FormatInt(stat.Size, 10) + " Bytes"
} else if stat.Size < (1 << 20) {
sizeFormat = fmt.Sprintf("%.2f KiB", float64(stat.Size)/(1<<10))
} else if stat.Size < (1 << 30) {
sizeFormat = fmt.Sprintf("%.2f MiB", float64(stat.Size)/(1<<20))
} else {
sizeFormat = fmt.Sprintf("%.2f GiB", float64(stat.Size)/(1<<30))
}
this.replyOk(message.RequestId, "size:"+sizeFormat+", count:"+strconv.Itoa(stat.Count))
return nil
}
// 清理缓存
func (this *APIStream) handleCleanCache(message *pb.NodeStreamMessage) error {
msg := &messageconfigs.ReadCacheMessage{}
err := json.Unmarshal(message.DataJSON, msg)
if err != nil {
this.replyFail(message.RequestId, "decode message data failed: "+err.Error())
return err
}
storage, shouldStop, err := this.cacheStorage(message, msg.CachePolicyJSON)
if err != nil {
return err
}
if shouldStop {
defer func() {
storage.Stop()
}()
}
err = storage.CleanAll()
if err != nil {
this.replyFail(message.RequestId, "clean cache failed: "+err.Error())
return err
}
this.replyOk(message.RequestId, "ok")
return nil
}
// 处理配置变化
func (this *APIStream) handleNewNodeTask(message *pb.NodeStreamMessage) error {
select {
case nodeTaskNotify <- true:
default:
}
this.replyOk(message.RequestId, "ok")
return nil
}
// 检查Systemd服务
func (this *APIStream) handleCheckSystemdService(message *pb.NodeStreamMessage) error {
systemctl, err := executils.LookPath("systemctl")
if err != nil {
this.replyFail(message.RequestId, "'systemctl' not found")
return nil
}
if len(systemctl) == 0 {
this.replyFail(message.RequestId, "'systemctl' not found")
return nil
}
var shortName = teaconst.SystemdServiceName
var cmd = executils.NewTimeoutCmd(10*time.Second, systemctl, "is-enabled", shortName)
cmd.WithStdout()
err = cmd.Run()
if err != nil {
this.replyFail(message.RequestId, "'systemctl' command error: "+err.Error())
return nil
}
if cmd.Stdout() == "enabled" {
this.replyOk(message.RequestId, "ok")
} else {
this.replyFail(message.RequestId, "not installed")
}
return nil
}
// 检查本地防火墙
func (this *APIStream) handleCheckLocalFirewall(message *pb.NodeStreamMessage) error {
var dataMessage = &messageconfigs.CheckLocalFirewallMessage{}
err := json.Unmarshal(message.DataJSON, dataMessage)
if err != nil {
this.replyFail(message.RequestId, "decode message data failed: "+err.Error())
return nil
}
// nft
if dataMessage.Name == "nftables" {
if runtime.GOOS != "linux" {
this.replyFail(message.RequestId, "not Linux system")
return nil
}
nft, err := executils.LookPath("nft")
if err != nil {
this.replyFail(message.RequestId, "'nft' not found: "+err.Error())
return nil
}
var cmd = executils.NewTimeoutCmd(10*time.Second, nft, "--version")
cmd.WithStdout()
err = cmd.Run()
if err != nil {
this.replyFail(message.RequestId, "get version failed: "+err.Error())
return nil
}
var outputString = cmd.Stdout()
var versionMatches = regexp.MustCompile(`nftables v([\d.]+)`).FindStringSubmatch(outputString)
if len(versionMatches) <= 1 {
this.replyFail(message.RequestId, "can not get nft version")
return nil
}
var version = versionMatches[1]
var result = maps.Map{
"version": version,
}
var protectionConfig = sharedNodeConfig.DDoSProtection
err = firewalls.SharedDDoSProtectionManager.Apply(protectionConfig)
if err != nil {
this.replyFail(message.RequestId, dataMessage.Name+" was installed, but apply DDoS protection config failed: "+err.Error())
} else {
this.replyOk(message.RequestId, string(result.AsJSON()))
}
} else {
this.replyFail(message.RequestId, "invalid firewall name '"+dataMessage.Name+"'")
}
return nil
}
// 修改API地址
func (this *APIStream) handleChangeAPINode(message *pb.NodeStreamMessage) error {
config, err := configs.LoadAPIConfig()
if err != nil {
this.replyFail(message.RequestId, "read config error: "+err.Error())
return nil
}
var messageData = &messageconfigs.ChangeAPINodeMessage{}
err = json.Unmarshal(message.DataJSON, messageData)
if err != nil {
this.replyFail(message.RequestId, "unmarshal message failed: "+err.Error())
return nil
}
_, err = url.Parse(messageData.Addr)
if err != nil {
this.replyFail(message.RequestId, "invalid new api node address: '"+messageData.Addr+"'")
return nil
}
config.RPCEndpoints = []string{messageData.Addr}
// 保存到文件
err = config.WriteFile(Tea.ConfigFile(configs.ConfigFileName))
if err != nil {
this.replyFail(message.RequestId, "save config file failed: "+err.Error())
return nil
}
this.replyOk(message.RequestId, "")
goman.New(func() {
// 延后生效防止变更前的API无法读取到状态
time.Sleep(1 * time.Second)
rpcClient, err := rpc.SharedRPC()
if err != nil {
remotelogs.Error("API_STREAM", "change rpc endpoint to '"+
messageData.Addr+"' failed: "+err.Error())
return
}
rpcClient.Close()
err = rpcClient.UpdateConfig(config)
if err != nil {
remotelogs.Error("API_STREAM", "change rpc endpoint to '"+
messageData.Addr+"' failed: "+err.Error())
return
}
remotelogs.Println("API_STREAM", "change rpc endpoint to '"+
messageData.Addr+"' successfully")
})
return nil
}
// 处理未知消息
func (this *APIStream) handleUnknownMessage(message *pb.NodeStreamMessage) error {
this.replyFail(message.RequestId, "unknown message code '"+message.Code+"'")
return nil
}
// 回复失败
func (this *APIStream) replyFail(requestId int64, message string) {
_ = this.stream.Send(&pb.NodeStreamMessage{RequestId: requestId, IsOk: false, Message: message})
}
// 回复成功
func (this *APIStream) replyOk(requestId int64, message string) {
_ = this.stream.Send(&pb.NodeStreamMessage{RequestId: requestId, IsOk: true, Message: message})
}
// 回复成功并包含数据
func (this *APIStream) replyOkData(requestId int64, message string, dataJSON []byte) {
_ = this.stream.Send(&pb.NodeStreamMessage{RequestId: requestId, IsOk: true, Message: message, DataJSON: dataJSON})
}
// 获取缓存存取对象
func (this *APIStream) cacheStorage(message *pb.NodeStreamMessage, cachePolicyJSON []byte) (storage caches.StorageInterface, shouldStop bool, err error) {
cachePolicy := &serverconfigs.HTTPCachePolicy{}
err = json.Unmarshal(cachePolicyJSON, cachePolicy)
if err != nil {
this.replyFail(message.RequestId, "decode cache policy config failed: "+err.Error())
return nil, false, err
}
storage = caches.SharedManager.FindStorageWithPolicy(cachePolicy.Id)
if storage == nil {
storage = caches.SharedManager.NewStorageWithPolicy(cachePolicy)
if storage == nil {
this.replyFail(message.RequestId, "invalid storage type '"+cachePolicy.Type+"'")
return nil, false, errors.New("invalid storage type '" + cachePolicy.Type + "'")
}
err = storage.Init()
if err != nil {
this.replyFail(message.RequestId, "storage init failed: "+err.Error())
return nil, false, err
}
shouldStop = true
}
return
}

View File

@@ -0,0 +1,15 @@
package nodes
import (
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"testing"
)
func TestAPIStream_Start(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
apiStream := NewAPIStream()
apiStream.Start()
}

View File

@@ -0,0 +1,326 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/conns"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"github.com/TeaOSLab/EdgeNode/internal/stats"
"github.com/TeaOSLab/EdgeNode/internal/utils"
connutils "github.com/TeaOSLab/EdgeNode/internal/utils/conns"
"github.com/TeaOSLab/EdgeNode/internal/utils/counters"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/types"
"net"
"os"
"strings"
"sync/atomic"
"time"
)
// ClientConn 客户端连接
type ClientConn struct {
BaseClientConn
createdAt int64
isTLS bool
isHTTP bool
hasRead bool
isLO bool // 是否为环路
isNoStat bool // 是否不统计带宽
isInAllowList bool
hasResetSYNFlood bool
lastReadAt int64
lastWriteAt int64
lastErr error
readDeadlineTime int64
isShortReading bool // reading header or tls handshake
isDebugging bool
autoReadTimeout bool
autoWriteTimeout bool
}
func NewClientConn(rawConn net.Conn, isHTTP bool, isTLS bool, isInAllowList bool) net.Conn {
// 是否为环路
var remoteAddr = rawConn.RemoteAddr().String()
var conn = &ClientConn{
BaseClientConn: BaseClientConn{rawConn: rawConn},
isTLS: isTLS,
isHTTP: isHTTP,
isLO: strings.HasPrefix(remoteAddr, "127.0.0.1:") || strings.HasPrefix(remoteAddr, "[::1]:"),
isNoStat: connutils.IsNoStatConn(remoteAddr),
isInAllowList: isInAllowList,
createdAt: fasttime.Now().Unix(),
}
if existsLnNodeIP(conn.RawIP()) {
conn.SetIsPersistent(true)
}
// 超时等设置
var globalServerConfig = sharedNodeConfig.GlobalServerConfig
if globalServerConfig != nil {
var performanceConfig = globalServerConfig.Performance
conn.isDebugging = performanceConfig.Debug
conn.autoReadTimeout = performanceConfig.AutoReadTimeout
conn.autoWriteTimeout = performanceConfig.AutoWriteTimeout
}
if isHTTP {
// TODO 可以在配置中设置此值
_ = conn.SetLinger(nodeconfigs.DefaultTCPLinger)
}
// 加入到Map
conns.SharedMap.Add(conn)
return conn
}
func (this *ClientConn) Read(b []byte) (n int, err error) {
if this.isDebugging {
this.lastReadAt = fasttime.Now().Unix()
defer func() {
if err != nil {
this.lastErr = fmt.Errorf("read error: %w", err)
} else {
this.lastErr = nil
}
}()
}
// 环路直接读取
if this.isLO {
n, err = this.rawConn.Read(b)
if n > 0 {
atomic.AddUint64(&teaconst.InTrafficBytes, uint64(n))
}
return
}
// 设置读超时时间
if this.isHTTP && !this.isPersistent && !this.isShortReading && this.autoReadTimeout {
this.setHTTPReadTimeout()
}
// 开始读取
n, err = this.rawConn.Read(b)
if n > 0 {
atomic.AddUint64(&teaconst.InTrafficBytes, uint64(n))
this.hasRead = true
}
// 检测是否为超时错误
var isTimeout = err != nil && os.IsTimeout(err)
var isHandshakeError = isTimeout && !this.hasRead
if err != nil {
_ = this.SetLinger(nodeconfigs.DefaultTCPLinger)
}
// 忽略白名单和局域网
if !this.isPersistent && this.isHTTP && !this.isInAllowList && !utils.IsLocalIP(this.RawIP()) {
// SYN Flood检测
if this.serverId == 0 || !this.hasResetSYNFlood {
var synFloodConfig = sharedNodeConfig.SYNFloodConfig()
if synFloodConfig != nil && synFloodConfig.IsOn {
if isHandshakeError {
this.increaseSYNFlood(synFloodConfig)
} else if err == nil && !this.hasResetSYNFlood {
this.hasResetSYNFlood = true
this.resetSYNFlood()
}
}
}
}
return
}
func (this *ClientConn) Write(b []byte) (n int, err error) {
if len(b) == 0 {
return 0, nil
}
if this.isDebugging {
this.lastWriteAt = fasttime.Now().Unix()
defer func() {
if err != nil {
this.lastErr = fmt.Errorf("write error: %w", err)
} else {
this.lastErr = nil
}
}()
}
// 设置写超时时间
if !this.isPersistent && this.autoWriteTimeout {
var timeoutSeconds = len(b) / 1024
if timeoutSeconds < 3 {
timeoutSeconds = 3
}
_ = this.rawConn.SetWriteDeadline(time.Now().Add(time.Duration(timeoutSeconds) * time.Second)) // TODO 时间可以设置
}
// 延长读超时时间
if this.isHTTP && !this.isPersistent && this.autoReadTimeout {
this.setHTTPReadTimeout()
}
// 开始写入
var before = time.Now()
n, err = this.rawConn.Write(b)
if n > 0 {
atomic.AddInt64(&this.totalSentBytes, int64(n))
// 统计当前服务带宽
if this.serverId > 0 {
// TODO 需要加入在serverId绑定之前的带宽
if !this.isNoStat || Tea.IsTesting() { // 环路不统计带宽,避免缓存预热等行为产生带宽
atomic.AddUint64(&teaconst.OutTrafficBytes, uint64(n))
var cost = time.Since(before).Seconds()
if cost > 1 {
stats.SharedBandwidthStatManager.AddBandwidth(this.userId, this.userPlanId, this.serverId, int64(float64(n)/cost), int64(n))
} else {
stats.SharedBandwidthStatManager.AddBandwidth(this.userId, this.userPlanId, this.serverId, int64(n), int64(n))
}
}
}
}
// 如果是写入超时,则立即关闭连接
if err != nil && os.IsTimeout(err) {
// TODO 考虑对多次慢连接的IP做出惩罚
conn, ok := this.rawConn.(LingerConn)
if ok {
_ = conn.SetLinger(0)
}
_ = this.Close()
}
return
}
func (this *ClientConn) Close() error {
this.isClosed = true
err := this.rawConn.Close()
// 单个服务并发数限制
// 不能加条件限制,因为服务配置随时有变化
sharedClientConnLimiter.Remove(this.rawConn.RemoteAddr().String())
// 从conn map中移除
conns.SharedMap.Remove(this)
return err
}
func (this *ClientConn) LocalAddr() net.Addr {
return this.rawConn.LocalAddr()
}
func (this *ClientConn) RemoteAddr() net.Addr {
return this.rawConn.RemoteAddr()
}
func (this *ClientConn) SetDeadline(t time.Time) error {
return this.rawConn.SetDeadline(t)
}
func (this *ClientConn) SetReadDeadline(t time.Time) error {
// 如果开启了HTTP自动读超时选项则自动控制超时时间
if this.isHTTP && !this.isPersistent && this.autoReadTimeout {
this.isShortReading = false
var unixTime = t.Unix()
if unixTime < 10 {
return nil
}
if unixTime == this.readDeadlineTime {
return nil
}
this.readDeadlineTime = unixTime
var seconds = -time.Since(t)
if seconds <= 0 || seconds > HTTPIdleTimeout {
return nil
}
if seconds < HTTPIdleTimeout-1*time.Second {
this.isShortReading = true
}
}
return this.rawConn.SetReadDeadline(t)
}
func (this *ClientConn) SetWriteDeadline(t time.Time) error {
return this.rawConn.SetWriteDeadline(t)
}
func (this *ClientConn) CreatedAt() int64 {
return this.createdAt
}
func (this *ClientConn) LastReadAt() int64 {
return this.lastReadAt
}
func (this *ClientConn) LastWriteAt() int64 {
return this.lastWriteAt
}
func (this *ClientConn) LastErr() error {
return this.lastErr
}
func (this *ClientConn) resetSYNFlood() {
counters.SharedCounter.ResetKey("SYN_FLOOD:" + this.RawIP())
}
func (this *ClientConn) increaseSYNFlood(synFloodConfig *firewallconfigs.SYNFloodConfig) {
var ip = this.RawIP()
if len(ip) > 0 && !iplibrary.IsInWhiteList(ip) && (!synFloodConfig.IgnoreLocal || !utils.IsLocalIP(ip)) {
var result = counters.SharedCounter.IncreaseKey("SYN_FLOOD:"+ip, 60)
var minAttempts = synFloodConfig.MinAttempts
if minAttempts < 5 {
minAttempts = 5
}
if !this.isTLS {
// 非TLS设置为两倍防止误封
minAttempts = 2 * minAttempts
}
if result >= types.Uint32(minAttempts) {
var timeout = synFloodConfig.TimeoutSeconds
if timeout <= 0 {
timeout = 600
}
// 关闭当前连接
_ = this.SetLinger(0)
_ = this.Close()
waf.SharedIPBlackList.RecordIP(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip, fasttime.Now().Unix()+int64(timeout), 0, true, 0, 0, "疑似SYN Flood攻击当前1分钟"+types.String(result)+"次空连接")
}
}
}
// 设置读超时时间
func (this *ClientConn) setHTTPReadTimeout() {
_ = this.SetReadDeadline(time.Now().Add(HTTPIdleTimeout))
}

View File

@@ -0,0 +1,193 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"crypto/tls"
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"net"
"sync/atomic"
"time"
)
type BaseClientConn struct {
rawConn net.Conn
isBound bool
userId int64
userPlanId int64
serverId int64
remoteAddr string
hasLimit bool
isPersistent bool // 是否为持久化连接
fingerprint []byte
isClosed bool
rawIP string
totalSentBytes int64
}
func (this *BaseClientConn) IsClosed() bool {
return this.isClosed
}
// IsBound 是否已绑定服务
func (this *BaseClientConn) IsBound() bool {
return this.isBound
}
// Bind 绑定服务
func (this *BaseClientConn) Bind(serverId int64, remoteAddr string, maxConnsPerServer int, maxConnsPerIP int) bool {
if this.isBound {
return true
}
this.isBound = true
this.serverId = serverId
this.remoteAddr = remoteAddr
this.hasLimit = true
// 检查是否可以连接
return sharedClientConnLimiter.Add(this.rawConn.RemoteAddr().String(), serverId, remoteAddr, maxConnsPerServer, maxConnsPerIP)
}
// SetServerId 设置服务ID
func (this *BaseClientConn) SetServerId(serverId int64) (goNext bool) {
goNext = true
// 检查服务相关IP黑名单
var rawIP = this.RawIP()
if serverId > 0 && len(rawIP) > 0 {
// 是否在白名单中
ok, _, expiresAt := iplibrary.AllowIP(rawIP, serverId)
if !ok {
_ = this.rawConn.Close()
firewalls.DropTemporaryTo(rawIP, expiresAt)
return false
}
}
this.serverId = serverId
// 设置包装前连接
switch conn := this.rawConn.(type) {
case *tls.Conn:
nativeConn, ok := conn.NetConn().(ClientConnInterface)
if ok {
nativeConn.SetServerId(serverId)
}
case *ClientConn:
conn.SetServerId(serverId)
}
return true
}
// ServerId 读取当前连接绑定的服务ID
func (this *BaseClientConn) ServerId() int64 {
return this.serverId
}
// SetUserId 设置所属服务的用户ID
func (this *BaseClientConn) SetUserId(userId int64) {
this.userId = userId
// 设置包装前连接
switch conn := this.rawConn.(type) {
case *tls.Conn:
nativeConn, ok := conn.NetConn().(ClientConnInterface)
if ok {
nativeConn.SetUserId(userId)
}
case *ClientConn:
conn.SetUserId(userId)
}
}
func (this *BaseClientConn) SetUserPlanId(userPlanId int64) {
this.userPlanId = userPlanId
// 设置包装前连接
switch conn := this.rawConn.(type) {
case *tls.Conn:
nativeConn, ok := conn.NetConn().(ClientConnInterface)
if ok {
nativeConn.SetUserPlanId(userPlanId)
}
case *ClientConn:
conn.SetUserPlanId(userPlanId)
}
}
// UserId 获取当前连接所属服务的用户ID
func (this *BaseClientConn) UserId() int64 {
return this.userId
}
// UserPlanId 用户套餐ID
func (this *BaseClientConn) UserPlanId() int64 {
return this.userPlanId
}
// RawIP 原本IP
func (this *BaseClientConn) RawIP() string {
if len(this.rawIP) > 0 {
return this.rawIP
}
ip, _, _ := net.SplitHostPort(this.rawConn.RemoteAddr().String())
this.rawIP = ip
return ip
}
// TCPConn 转换为TCPConn
func (this *BaseClientConn) TCPConn() (tcpConn *net.TCPConn, ok bool) {
// 设置包装前连接
switch conn := this.rawConn.(type) {
case *tls.Conn:
var internalConn = conn.NetConn()
clientConn, isClientConn := internalConn.(*ClientConn)
if isClientConn {
return clientConn.TCPConn()
}
tcpConn, ok = internalConn.(*net.TCPConn)
default:
tcpConn, ok = this.rawConn.(*net.TCPConn)
}
return
}
// SetLinger 设置Linger
func (this *BaseClientConn) SetLinger(seconds int) error {
tcpConn, ok := this.TCPConn()
if ok {
return tcpConn.SetLinger(seconds)
}
return nil
}
func (this *BaseClientConn) SetIsPersistent(isPersistent bool) {
this.isPersistent = isPersistent
_ = this.rawConn.SetDeadline(time.Time{})
}
// SetFingerprint 设置指纹信息
func (this *BaseClientConn) SetFingerprint(fingerprint []byte) {
this.fingerprint = fingerprint
}
// Fingerprint 读取指纹信息
func (this *BaseClientConn) Fingerprint() []byte {
return this.fingerprint
}
// LastRequestBytes 读取上一次请求发送的字节数
func (this *BaseClientConn) LastRequestBytes() int64 {
var result = atomic.LoadInt64(&this.totalSentBytes)
atomic.StoreInt64(&this.totalSentBytes, 0)
return result
}

View File

@@ -0,0 +1,41 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
type ClientConnInterface interface {
// IsClosed 是否已关闭
IsClosed() bool
// IsBound 是否已绑定服务
IsBound() bool
// Bind 绑定服务
Bind(serverId int64, remoteAddr string, maxConnsPerServer int, maxConnsPerIP int) bool
// ServerId 获取服务ID
ServerId() int64
// SetServerId 设置服务ID
SetServerId(serverId int64) (goNext bool)
// SetUserId 设置所属网站的用户ID
SetUserId(userId int64)
// SetUserPlanId 设置
SetUserPlanId(userPlanId int64)
// UserId 获取当前连接所属服务的用户ID
UserId() int64
// SetIsPersistent 设置是否为持久化
SetIsPersistent(isPersistent bool)
// SetFingerprint 设置指纹信息
SetFingerprint(fingerprint []byte)
// Fingerprint 读取指纹信息
Fingerprint() []byte
// LastRequestBytes 读取上一次请求发送的字节数
LastRequestBytes() int64
}

View File

@@ -0,0 +1,130 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"github.com/TeaOSLab/EdgeNode/internal/utils/zero"
"sync"
)
var sharedClientConnLimiter = NewClientConnLimiter()
// ClientConnRemoteAddr 客户端地址定义
type ClientConnRemoteAddr struct {
remoteAddr string
serverId int64
}
// ClientConnLimiter 客户端连接数限制
type ClientConnLimiter struct {
remoteAddrMap map[string]*ClientConnRemoteAddr // raw remote addr => remoteAddr
ipConns map[string]map[string]zero.Zero // remoteAddr => { raw remote addr => Zero }
serverConns map[int64]map[string]zero.Zero // serverId => { remoteAddr => Zero }
locker sync.Mutex
}
func NewClientConnLimiter() *ClientConnLimiter {
return &ClientConnLimiter{
remoteAddrMap: map[string]*ClientConnRemoteAddr{},
ipConns: map[string]map[string]zero.Zero{},
serverConns: map[int64]map[string]zero.Zero{},
}
}
// Add 添加新连接
// 返回值为true的时候表示允许添加否则表示不允许添加
func (this *ClientConnLimiter) Add(rawRemoteAddr string, serverId int64, remoteAddr string, maxConnsPerServer int, maxConnsPerIP int) bool {
if (maxConnsPerServer <= 0 && maxConnsPerIP <= 0) || len(remoteAddr) == 0 || serverId <= 0 {
return true
}
this.locker.Lock()
defer this.locker.Unlock()
// 检查服务连接数
var serverMap = this.serverConns[serverId]
if maxConnsPerServer > 0 {
if serverMap == nil {
serverMap = map[string]zero.Zero{}
this.serverConns[serverId] = serverMap
}
if maxConnsPerServer <= len(serverMap) {
return false
}
}
// 检查IP连接数
var ipMap = this.ipConns[remoteAddr]
if maxConnsPerIP > 0 {
if ipMap == nil {
ipMap = map[string]zero.Zero{}
this.ipConns[remoteAddr] = ipMap
}
if maxConnsPerIP > 0 && maxConnsPerIP <= len(ipMap) {
return false
}
}
this.remoteAddrMap[rawRemoteAddr] = &ClientConnRemoteAddr{
remoteAddr: remoteAddr,
serverId: serverId,
}
if maxConnsPerServer > 0 {
serverMap[rawRemoteAddr] = zero.New()
}
if maxConnsPerIP > 0 {
ipMap[rawRemoteAddr] = zero.New()
}
return true
}
// Remove 删除连接
func (this *ClientConnLimiter) Remove(rawRemoteAddr string) {
this.locker.Lock()
defer this.locker.Unlock()
addr, ok := this.remoteAddrMap[rawRemoteAddr]
if !ok {
return
}
delete(this.remoteAddrMap, rawRemoteAddr)
delete(this.ipConns[addr.remoteAddr], rawRemoteAddr)
delete(this.serverConns[addr.serverId], rawRemoteAddr)
if len(this.ipConns[addr.remoteAddr]) == 0 {
delete(this.ipConns, addr.remoteAddr)
}
if len(this.serverConns[addr.serverId]) == 0 {
delete(this.serverConns, addr.serverId)
}
}
// Conns 获取连接信息
// 用于调试
func (this *ClientConnLimiter) Conns() (ipConns map[string][]string, serverConns map[int64][]string) {
this.locker.Lock()
defer this.locker.Unlock()
ipConns = map[string][]string{} // ip => [addr1, addr2, ...]
serverConns = map[int64][]string{} // serverId => [addr1, addr2, ...]
for ip, m := range this.ipConns {
for addr := range m {
ipConns[ip] = append(ipConns[ip], addr)
}
}
for serverId, m := range this.serverConns {
for addr := range m {
serverConns[serverId] = append(serverConns[serverId], addr)
}
}
return
}

View File

@@ -0,0 +1,38 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"github.com/iwind/TeaGo/logs"
"testing"
)
func TestClientConnLimiter_Add(t *testing.T) {
var limiter = NewClientConnLimiter()
{
b := limiter.Add("127.0.0.1:1234", 1, "192.168.1.100", 10, 5)
t.Log(b)
}
{
b := limiter.Add("127.0.0.1:1235", 1, "192.168.1.100", 10, 5)
t.Log(b)
}
{
b := limiter.Add("127.0.0.1:1236", 1, "192.168.1.100", 10, 5)
t.Log(b)
}
{
b := limiter.Add("127.0.0.1:1237", 1, "192.168.1.101", 10, 5)
t.Log(b)
}
{
b := limiter.Add("127.0.0.1:1238", 1, "192.168.1.100", 5, 5)
t.Log(b)
}
limiter.Remove("127.0.0.1:1238")
limiter.Remove("127.0.0.1:1239")
limiter.Remove("127.0.0.1:1237")
logs.PrintAsJSON(limiter.remoteAddrMap, t)
logs.PrintAsJSON(limiter.ipConns, t)
logs.PrintAsJSON(limiter.serverConns, t)
}

View File

@@ -0,0 +1,45 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/monitor"
"github.com/TeaOSLab/EdgeNode/internal/utils/goman"
"github.com/iwind/TeaGo/maps"
"sync/atomic"
"time"
)
// 发送监控流量
func init() {
if !teaconst.IsMain {
return
}
events.On(events.EventStart, func() {
var ticker = time.NewTicker(1 * time.Minute)
goman.New(func() {
for range ticker.C {
// 加入到数据队列中
var inBytes = atomic.LoadUint64(&teaconst.InTrafficBytes)
atomic.StoreUint64(&teaconst.InTrafficBytes, 0) // 重置数据
if inBytes > 0 {
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemTrafficIn, maps.Map{
"total": inBytes,
})
}
var outBytes = atomic.LoadUint64(&teaconst.OutTrafficBytes)
atomic.StoreUint64(&teaconst.OutTrafficBytes, 0) // 重置数据
if outBytes > 0 {
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemTrafficOut, maps.Map{
"total": outBytes,
})
}
}
})
})
}

View File

@@ -0,0 +1,20 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"net"
)
// 判断客户端连接是否已关闭
func isClientConnClosed(conn net.Conn) bool {
if conn == nil {
return true
}
clientConn, ok := conn.(ClientConnInterface)
if ok {
return clientConn.IsClosed()
}
return true
}

View File

@@ -0,0 +1,81 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"github.com/TeaOSLab/EdgeNode/internal/waf"
"net"
)
// ClientListener 客户端网络监听
type ClientListener struct {
rawListener net.Listener
isHTTP bool
isTLS bool
}
func NewClientListener(listener net.Listener, isHTTP bool) *ClientListener {
return &ClientListener{
rawListener: listener,
isHTTP: isHTTP,
}
}
func (this *ClientListener) SetIsTLS(isTLS bool) {
this.isTLS = isTLS
}
func (this *ClientListener) IsTLS() bool {
return this.isTLS
}
func (this *ClientListener) Accept() (net.Conn, error) {
conn, err := this.rawListener.Accept()
if err != nil {
return nil, err
}
// 是否在WAF名单中
ip, _, err := net.SplitHostPort(conn.RemoteAddr().String())
var isInAllowList = false
if err == nil {
canGoNext, inAllowList, expiresAt := iplibrary.AllowIP(ip, 0)
isInAllowList = inAllowList
if !canGoNext {
firewalls.DropTemporaryTo(ip, expiresAt)
} else {
if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) {
var ok bool
expiresAt, ok = waf.SharedIPBlackList.ContainsExpires(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip)
if ok {
canGoNext = false
firewalls.DropTemporaryTo(ip, expiresAt)
}
}
}
if !canGoNext {
tcpConn, ok := conn.(*net.TCPConn)
if ok {
_ = tcpConn.SetLinger(0)
}
_ = conn.Close()
return this.Accept()
}
}
return NewClientConn(conn, this.isHTTP, this.isTLS, isInAllowList), nil
}
func (this *ClientListener) Close() error {
return this.rawListener.Close()
}
func (this *ClientListener) Addr() net.Addr {
return this.rawListener.Addr()
}

View File

@@ -0,0 +1,99 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"crypto/tls"
"net"
"time"
)
// ClientTLSConn TLS连接封装
type ClientTLSConn struct {
BaseClientConn
}
func NewClientTLSConn(conn *tls.Conn) net.Conn {
return &ClientTLSConn{BaseClientConn{rawConn: conn}}
}
func (this *ClientTLSConn) Read(b []byte) (n int, err error) {
n, err = this.rawConn.Read(b)
return
}
func (this *ClientTLSConn) Write(b []byte) (n int, err error) {
n, err = this.rawConn.Write(b)
return
}
func (this *ClientTLSConn) Close() error {
this.isClosed = true
// 单个服务并发数限制
sharedClientConnLimiter.Remove(this.rawConn.RemoteAddr().String())
return this.rawConn.Close()
}
func (this *ClientTLSConn) LocalAddr() net.Addr {
return this.rawConn.LocalAddr()
}
func (this *ClientTLSConn) RemoteAddr() net.Addr {
return this.rawConn.RemoteAddr()
}
func (this *ClientTLSConn) SetDeadline(t time.Time) error {
return this.rawConn.SetDeadline(t)
}
func (this *ClientTLSConn) SetReadDeadline(t time.Time) error {
return this.rawConn.SetReadDeadline(t)
}
func (this *ClientTLSConn) SetWriteDeadline(t time.Time) error {
return this.rawConn.SetWriteDeadline(t)
}
func (this *ClientTLSConn) SetIsPersistent(isPersistent bool) {
tlsConn, ok := this.rawConn.(*tls.Conn)
if ok {
var rawConn = tlsConn.NetConn()
if rawConn != nil {
clientConn, ok := rawConn.(*ClientConn)
if ok {
clientConn.SetIsPersistent(isPersistent)
}
}
}
}
func (this *ClientTLSConn) Fingerprint() []byte {
tlsConn, ok := this.rawConn.(*tls.Conn)
if ok {
var rawConn = tlsConn.NetConn()
if rawConn != nil {
clientConn, ok := rawConn.(*ClientConn)
if ok {
return clientConn.fingerprint
}
}
}
return nil
}
// LastRequestBytes 读取上一次请求发送的字节数
func (this *ClientTLSConn) LastRequestBytes() int64 {
tlsConn, ok := this.rawConn.(*tls.Conn)
if ok {
var rawConn = tlsConn.NetConn()
if rawConn != nil {
clientConn, ok := rawConn.(*ClientConn)
if ok {
return clientConn.LastRequestBytes()
}
}
}
return 0
}

View File

@@ -0,0 +1,7 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes
type LingerConn interface {
SetLinger(sec int) error
}

View File

@@ -0,0 +1,37 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package nodes
import (
"github.com/TeaOSLab/EdgeNode/internal/stats"
"sync/atomic"
)
type HTTP3ConnNotifier struct {
conn *HTTP3Conn
}
func NewHTTP3ConnNotifier(conn *HTTP3Conn) *HTTP3ConnNotifier {
return &HTTP3ConnNotifier{
conn: conn,
}
}
func (this *HTTP3ConnNotifier) ReadBytes(bytes int) {
}
func (this *HTTP3ConnNotifier) WriteBytes(bytes int) {
if this.conn != nil && this.conn.ServerId() > 0 {
stats.SharedBandwidthStatManager.AddBandwidth(this.conn.UserId(), this.conn.UserPlanId(), this.conn.ServerId(), int64(bytes), int64(bytes))
if this.conn.BaseClientConn != nil {
atomic.AddInt64(&this.conn.BaseClientConn.totalSentBytes, int64(bytes))
}
}
}
func (this *HTTP3ConnNotifier) Close() {
this.conn.NotifyClose()
}

View File

@@ -0,0 +1,34 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package nodes
import (
"github.com/TeaOSLab/EdgeNode/internal/conns"
"github.com/TeaOSLab/EdgeNode/internal/http3"
)
type HTTP3Conn struct {
*BaseClientConn
http3.Conn
}
func NewHTTP3Conn(http3Conn http3.Conn) *HTTP3Conn {
// 添加新的
var conn = &HTTP3Conn{
BaseClientConn: &BaseClientConn{rawConn: http3Conn},
Conn: http3Conn,
}
http3Conn.SetParentConn(conn)
http3Conn.SetNotifier(NewHTTP3ConnNotifier(conn))
// 添加到统计Map
conns.SharedMap.Add(conn)
return conn
}
func (this *HTTP3Conn) NotifyClose() {
conns.SharedMap.Remove(this)
}

View File

@@ -0,0 +1,89 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package nodes
import (
"context"
"crypto/tls"
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
"github.com/TeaOSLab/EdgeNode/internal/http3"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/quic-go/quic-go"
"net"
)
type HTTP3Listener struct {
rawListener http3.Listener
}
func ListenHTTP3(addr string, tlsConfig *tls.Config) (http3.Listener, error) {
rawListener, err := http3.Listen(addr, tlsConfig)
if err != nil {
return nil, err
}
var listener = &HTTP3Listener{rawListener: rawListener}
listener.init()
return listener, nil
}
func (this *HTTP3Listener) init() {
events.OnKey(events.EventQuit, fmt.Sprintf("http_listener_%p", this), func() {
_ = this.Close()
})
events.OnKey(events.EventTerminated, fmt.Sprintf("http_listener_%p", this), func() {
_ = this.Close()
})
}
func (this *HTTP3Listener) Addr() net.Addr {
return this.rawListener.Addr()
}
func (this *HTTP3Listener) Accept(ctx context.Context, ctxFunc http3.ContextFunc) (quic.EarlyConnection, error) {
conn, err := this.rawListener.Accept(ctx, ctxFunc)
if err != nil {
return conn, err
}
// 是否在WAF名单中
ip, _, err := net.SplitHostPort(conn.RemoteAddr().String())
var isInAllowList = false
if err == nil {
canGoNext, inAllowList, expiresAt := iplibrary.AllowIP(ip, 0)
isInAllowList = inAllowList
if !canGoNext {
firewalls.DropTemporaryTo(ip, expiresAt)
} else {
if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) {
var ok bool
expiresAt, ok = waf.SharedIPBlackList.ContainsExpires(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip)
if ok {
canGoNext = false
firewalls.DropTemporaryTo(ip, expiresAt)
}
}
}
if !canGoNext {
_ = conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "")
return this.Accept(ctx, ctxFunc)
}
}
// TODO 将isInAllowList传递到HTTP3Conn
_ = isInAllowList
return NewHTTP3Conn(conn.(*http3.BasicConn)), nil
}
func (this *HTTP3Listener) Close() error {
events.Remove(fmt.Sprintf("http_listener_%p", this))
return this.rawListener.Close()
}

View File

@@ -0,0 +1,231 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package nodes
import (
"context"
"crypto/tls"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/http3"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/iwind/TeaGo/types"
"net"
"net/http"
"regexp"
"sync"
"sync/atomic"
)
var sharedHTTP3Manager = NewHTTP3Manager()
func init() {
if !teaconst.IsMain {
return
}
var listener = &HTTPListener{
isHTTPS: true,
isHTTP3: true,
}
events.On(events.EventLoaded, func() {
sharedListenerManager.http3Listener = listener // 注册到ListenerManager以便统计用
})
var eventLocker = sync.Mutex{}
events.OnEvents([]events.Event{events.EventReload, events.EventReloadSomeServers}, func() {
go func() {
eventLocker.Lock()
defer eventLocker.Unlock()
if sharedNodeConfig == nil {
return
}
_ = sharedHTTP3Manager.Update(sharedNodeConfig.HTTP3Policies)
sharedHTTP3Manager.UpdateHTTPListener(listener)
listener.Reload(sharedNodeConfig.HTTP3Group())
}()
})
}
// HTTP3Manager HTTP3管理器
type HTTP3Manager struct {
locker sync.RWMutex
hasHTTP3 bool
policies map[int64]*nodeconfigs.HTTP3Policy // clusterId => *HTTP3Policy
serverMap map[int]*http3.Server // port => *Server
mobileUserAgentReg *regexp.Regexp
httpListener *HTTPListener
tlsConfig *tls.Config
}
func NewHTTP3Manager() *HTTP3Manager {
return &HTTP3Manager{
policies: map[int64]*nodeconfigs.HTTP3Policy{},
serverMap: map[int]*http3.Server{},
mobileUserAgentReg: regexp.MustCompile(`(?i)(iPhone|Android)`),
}
}
// Update 更新配置
// m: clusterId => *HTTP3Policy
func (this *HTTP3Manager) Update(m map[int64]*nodeconfigs.HTTP3Policy) error {
this.locker.Lock()
defer this.locker.Unlock()
// 启动新的
var newPolicyMap = map[int64]*nodeconfigs.HTTP3Policy{} // clusterId => *HTTP3Policy
var newPorts = map[int]bool{} // port => bool
for clusterId, policy := range m {
if policy.IsOn && policy.Port > 0 {
this.policies[clusterId] = policy
newPolicyMap[clusterId] = policy
var port = policy.Port
newPorts[port] = true
_, existPort := this.serverMap[port]
if !existPort {
server, err := this.createServer(port)
if err != nil {
remotelogs.Error("HTTP3_MANAGER", "start port '"+types.String(port)+"' failed: "+err.Error())
continue
}
this.serverMap[port] = server
remotelogs.Debug("HTTP3_MANAGER", "start port '"+types.String(port)+"'")
}
}
}
this.policies = newPolicyMap
// 关闭老的
for port, server := range this.serverMap {
if !newPorts[port] {
_ = server.Close()
delete(this.serverMap, port)
remotelogs.Debug("HTTP3_MANAGER", "close port '"+types.String(port)+"'")
}
}
this.hasHTTP3 = len(this.serverMap) > 0
return nil
}
// UpdateHTTPListener 更新Listener
// 这里的Listener只是为了方便复用HTTPListener的相关方法
func (this *HTTP3Manager) UpdateHTTPListener(listener *HTTPListener) {
this.locker.Lock()
this.httpListener = listener
if listener != nil {
this.tlsConfig = listener.buildTLSConfig()
}
this.locker.Unlock()
}
// ProcessHTTP3Headers 处理HTTP3相关Headers
func (this *HTTP3Manager) ProcessHTTP3Headers(userAgent string, headers http.Header, clusterId int64) {
// 这里不要加锁,以便于提升性能
if !this.hasHTTP3 {
return
}
this.locker.RLock()
defer this.locker.RUnlock()
// 再次准确检查
if !this.hasHTTP3 {
return
}
policy, ok := this.policies[clusterId]
if !ok {
return
}
if policy.IsOn && policy.Port > 0 && (policy.SupportMobileBrowsers || !this.mobileUserAgentReg.MatchString(userAgent)) {
// TODO 版本好和有效期可以在策略里设置
headers.Set("Alt-Svc", `h3=":`+types.String(policy.Port)+`"; ma=2592000,h3-29=":`+types.String(policy.Port)+`"; ma=2592000`)
}
}
// 创建server
func (this *HTTP3Manager) createServer(port int) (*http3.Server, error) {
var addr = ":" + types.String(port)
listener, err := ListenHTTP3(addr, &tls.Config{
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
this.locker.RLock()
var tlsConfig = this.tlsConfig
this.locker.RUnlock()
if tlsConfig != nil && tlsConfig.GetConfigForClient != nil {
return tlsConfig.GetConfigForClient(info)
}
return nil, errors.New("http3: no tls config")
},
GetCertificate: func(clientInfo *tls.ClientHelloInfo) (certificate *tls.Certificate, e error) {
this.locker.RLock()
var tlsConfig = this.tlsConfig
this.locker.RUnlock()
if tlsConfig != nil && tlsConfig.GetCertificate != nil {
return tlsConfig.GetCertificate(clientInfo)
}
return nil, errors.New("http3: no tls config")
},
})
if err != nil {
return nil, err
}
var server = &http3.Server{
Addr: ":" + types.String(port),
Handler: http.HandlerFunc(func(writer http.ResponseWriter, req *http.Request) {
if this.httpListener != nil {
var servePortString = "443"
if len(req.Host) > 0 {
_, hostPortString, hostErr := net.SplitHostPort(req.Host)
if hostErr == nil && len(hostPortString) > 0 {
servePortString = hostPortString
}
}
this.httpListener.ServeHTTPWithAddr(writer, req, ":"+servePortString)
}
}),
ConnState: func(conn net.Conn, state http.ConnState) {
if this.httpListener == nil {
return
}
switch state {
case http.StateNew:
atomic.AddInt64(&this.httpListener.countActiveConnections, 1)
case http.StateClosed:
atomic.AddInt64(&this.httpListener.countActiveConnections, -1)
default:
// do nothing
}
},
ConnContext: func(ctx context.Context, conn net.Conn) context.Context {
return context.WithValue(ctx, HTTPConnContextKey, conn)
},
}
go func() {
err = server.Serve(listener)
if err != nil {
remotelogs.Error("HTTP3_MANAGER", "serve '"+addr+"' failed: "+err.Error())
this.locker.Lock()
delete(this.serverMap, port)
this.locker.Unlock()
}
}()
return server, nil
}

View File

@@ -0,0 +1,194 @@
package nodes
import (
"bytes"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/utils/goman"
memutils "github.com/TeaOSLab/EdgeNode/internal/utils/mem"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"strings"
"time"
"unicode/utf8"
)
var sharedHTTPAccessLogQueue = NewHTTPAccessLogQueue()
// HTTPAccessLogQueue HTTP访问日志队列
type HTTPAccessLogQueue struct {
queue chan *pb.HTTPAccessLog
rpcClient *rpc.RPCClient
}
// NewHTTPAccessLogQueue 获取新对象
func NewHTTPAccessLogQueue() *HTTPAccessLogQueue {
// 队列中最大的值,超出此数量的访问日志会被丢弃
var maxSize = 2_000 * (1 + memutils.SystemMemoryGB()/2)
if maxSize > 20_000 {
maxSize = 20_000
}
var queue = &HTTPAccessLogQueue{
queue: make(chan *pb.HTTPAccessLog, maxSize),
}
goman.New(func() {
queue.Start()
})
return queue
}
// Start 开始处理访问日志
func (this *HTTPAccessLogQueue) Start() {
ticker := time.NewTicker(1 * time.Second)
for range ticker.C {
err := this.loop()
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("ACCESS_LOG_QUEUE", err.Error())
} else {
remotelogs.Error("ACCESS_LOG_QUEUE", err.Error())
}
}
}
}
// Push 加入新访问日志
func (this *HTTPAccessLogQueue) Push(accessLog *pb.HTTPAccessLog) {
select {
case this.queue <- accessLog:
default:
}
}
// 上传访问日志
func (this *HTTPAccessLogQueue) loop() error {
const maxLen = 2000
var accessLogs = make([]*pb.HTTPAccessLog, 0, maxLen)
var count = 0
Loop:
for {
select {
case accessLog := <-this.queue:
accessLogs = append(accessLogs, accessLog)
count++
// 每次只提交 N 条访问日志,防止网络拥堵
if count >= maxLen {
break Loop
}
default:
break Loop
}
}
if len(accessLogs) == 0 {
return nil
}
// 发送到本地
if sharedHTTPAccessLogViewer.HasConns() {
for _, accessLog := range accessLogs {
sharedHTTPAccessLogViewer.Send(accessLog)
}
}
// 发送到API
if this.rpcClient == nil {
client, err := rpc.SharedRPC()
if err != nil {
return err
}
this.rpcClient = client
}
_, err := this.rpcClient.HTTPAccessLogRPC.CreateHTTPAccessLogs(this.rpcClient.Context(), &pb.CreateHTTPAccessLogsRequest{HttpAccessLogs: accessLogs})
if err != nil {
// 是否包含了invalid UTF-8
if strings.Contains(err.Error(), "string field contains invalid UTF-8") {
for _, accessLog := range accessLogs {
this.ToValidUTF8(accessLog)
}
// 重新提交
_, err = this.rpcClient.HTTPAccessLogRPC.CreateHTTPAccessLogs(this.rpcClient.Context(), &pb.CreateHTTPAccessLogsRequest{HttpAccessLogs: accessLogs})
return err
}
// 是否请求内容过大
statusCode, ok := status.FromError(err)
if ok && statusCode.Code() == codes.ResourceExhausted {
// 去除Body
for _, accessLog := range accessLogs {
accessLog.RequestBody = nil
}
// 重新提交
_, err = this.rpcClient.HTTPAccessLogRPC.CreateHTTPAccessLogs(this.rpcClient.Context(), &pb.CreateHTTPAccessLogsRequest{HttpAccessLogs: accessLogs})
return err
}
return err
}
return nil
}
// ToValidUTF8 处理访问日志中的非UTF-8字节
func (this *HTTPAccessLogQueue) ToValidUTF8(accessLog *pb.HTTPAccessLog) {
accessLog.RemoteAddr = utils.ToValidUTF8string(accessLog.RemoteAddr)
accessLog.RemoteUser = utils.ToValidUTF8string(accessLog.RemoteUser)
accessLog.RequestURI = utils.ToValidUTF8string(accessLog.RequestURI)
accessLog.RequestPath = utils.ToValidUTF8string(accessLog.RequestPath)
accessLog.RequestFilename = utils.ToValidUTF8string(accessLog.RequestFilename)
accessLog.RequestBody = bytes.ToValidUTF8(accessLog.RequestBody, []byte{})
accessLog.Host = utils.ToValidUTF8string(accessLog.Host)
accessLog.Hostname = utils.ToValidUTF8string(accessLog.Hostname)
for k, v := range accessLog.SentHeader {
if !utf8.ValidString(k) {
delete(accessLog.SentHeader, k)
continue
}
for index, s := range v.Values {
v.Values[index] = utils.ToValidUTF8string(s)
}
}
accessLog.Referer = utils.ToValidUTF8string(accessLog.Referer)
accessLog.UserAgent = utils.ToValidUTF8string(accessLog.UserAgent)
accessLog.Request = utils.ToValidUTF8string(accessLog.Request)
accessLog.ContentType = utils.ToValidUTF8string(accessLog.ContentType)
for k, c := range accessLog.Cookie {
if !utf8.ValidString(k) {
delete(accessLog.Cookie, k)
continue
}
accessLog.Cookie[k] = utils.ToValidUTF8string(c)
}
accessLog.Args = utils.ToValidUTF8string(accessLog.Args)
accessLog.QueryString = utils.ToValidUTF8string(accessLog.QueryString)
for k, v := range accessLog.Header {
if !utf8.ValidString(k) {
delete(accessLog.Header, k)
continue
}
for index, s := range v.Values {
v.Values[index] = utils.ToValidUTF8string(s)
}
}
for k, v := range accessLog.Errors {
accessLog.Errors[k] = utils.ToValidUTF8string(v)
}
}

View File

@@ -0,0 +1,232 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes_test
import (
"bytes"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/nodes"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
_ "github.com/iwind/TeaGo/bootstrap"
"google.golang.org/grpc/status"
"reflect"
"runtime"
"runtime/debug"
"strconv"
"strings"
"testing"
"time"
"unicode/utf8"
)
func TestHTTPAccessLogQueue_Push(t *testing.T) {
// 发送到API
client, err := rpc.SharedRPC()
if err != nil {
t.Fatal(err)
}
var requestId = 1_000_000
var utf8Bytes = []byte{}
for i := 0; i < 254; i++ {
utf8Bytes = append(utf8Bytes, uint8(i))
}
//bytes = []byte("真不错")
var accessLog = &pb.HTTPAccessLog{
ServerId: 23,
RequestId: strconv.FormatInt(time.Now().Unix(), 10) + strconv.Itoa(requestId) + strconv.FormatInt(1, 10),
NodeId: 48,
Host: "www.hello.com",
RequestURI: string(utf8Bytes),
RequestPath: string(utf8Bytes),
Timestamp: time.Now().Unix(),
Cookie: map[string]string{"test": string(utf8Bytes)},
Header: map[string]*pb.Strings{
"test": {Values: []string{string(utf8Bytes)}},
},
}
new(nodes.HTTPAccessLogQueue).ToValidUTF8(accessLog)
// logs.PrintAsJSON(accessLog)
//t.Log(strings.ToValidUTF8(string(utf8Bytes), ""))
_, err = client.HTTPAccessLogRPC.CreateHTTPAccessLogs(client.Context(), &pb.CreateHTTPAccessLogsRequest{HttpAccessLogs: []*pb.HTTPAccessLog{
accessLog,
}})
if err != nil {
// 这里只是为了重现错误
t.Logf("%#v, %s", err, err.Error())
statusErr, ok := status.FromError(err)
if ok {
t.Logf("%#v", statusErr)
}
return
}
t.Log("ok")
}
func TestHTTPAccessLogQueue_Push2(t *testing.T) {
var utf8Bytes = []byte{}
for i := 0; i < 254; i++ {
utf8Bytes = append(utf8Bytes, uint8(i))
}
var accessLog = &pb.HTTPAccessLog{
ServerId: 23,
RequestId: strconv.FormatInt(time.Now().Unix(), 10) + strconv.Itoa(1) + strconv.FormatInt(1, 10),
NodeId: 48,
Host: "www.hello.com",
RequestURI: string(utf8Bytes),
RequestPath: string(utf8Bytes),
Timestamp: time.Now().Unix(),
}
var v = reflect.Indirect(reflect.ValueOf(accessLog))
var countFields = v.NumField()
for i := 0; i < countFields; i++ {
var field = v.Field(i)
if field.Kind() == reflect.String {
field.SetString(strings.ToValidUTF8(field.String(), ""))
}
}
client, err := rpc.SharedRPC()
if err != nil {
t.Fatal(err)
}
_, err = client.HTTPAccessLogRPC.CreateHTTPAccessLogs(client.Context(), &pb.CreateHTTPAccessLogsRequest{HttpAccessLogs: []*pb.HTTPAccessLog{
accessLog,
}})
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestHTTPAccessLogQueue_Memory(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
testutils.StartMemoryStats(t)
debug.SetGCPercent(10)
var accessLogs = []*pb.HTTPAccessLog{}
for i := 0; i < 20_000; i++ {
accessLogs = append(accessLogs, &pb.HTTPAccessLog{
RequestPath: "https://goedge.cn/hello/world",
})
}
runtime.GC()
_ = accessLogs
// will not release automatically
func() {
var accessLogs1 = []*pb.HTTPAccessLog{}
for i := 0; i < 2_000_000; i++ {
accessLogs1 = append(accessLogs1, &pb.HTTPAccessLog{
RequestPath: "https://goedge.cn/hello/world",
})
}
_ = accessLogs1
}()
time.Sleep(5 * time.Second)
}
func TestUTF8_IsValid(t *testing.T) {
t.Log(utf8.ValidString("abc"))
var noneUTF8Bytes = []byte{}
for i := 0; i < 254; i++ {
noneUTF8Bytes = append(noneUTF8Bytes, uint8(i))
}
t.Log(utf8.ValidString(string(noneUTF8Bytes)))
}
func BenchmarkHTTPAccessLogQueue_ToValidUTF8(b *testing.B) {
runtime.GOMAXPROCS(1)
var utf8Bytes = []byte{}
for i := 0; i < 254; i++ {
utf8Bytes = append(utf8Bytes, uint8(i))
}
for i := 0; i < b.N; i++ {
_ = bytes.ToValidUTF8(utf8Bytes, nil)
}
}
func BenchmarkHTTPAccessLogQueue_ToValidUTF8String(b *testing.B) {
runtime.GOMAXPROCS(1)
var utf8Bytes = []byte{}
for i := 0; i < 254; i++ {
utf8Bytes = append(utf8Bytes, uint8(i))
}
var s = string(utf8Bytes)
for i := 0; i < b.N; i++ {
_ = strings.ToValidUTF8(s, "")
}
}
func BenchmarkAppendAccessLogs(b *testing.B) {
b.ReportAllocs()
var stat1 = &runtime.MemStats{}
runtime.ReadMemStats(stat1)
const count = 20000
var a = make([]*pb.HTTPAccessLog, 0, count)
for i := 0; i < b.N; i++ {
a = append(a, &pb.HTTPAccessLog{
RequestPath: "/hello/world",
Host: "example.com",
RequestBody: bytes.Repeat([]byte{'A'}, 1024),
})
if len(a) == count {
a = make([]*pb.HTTPAccessLog, 0, count)
}
}
_ = len(a)
var stat2 = &runtime.MemStats{}
runtime.ReadMemStats(stat2)
b.Log((stat2.TotalAlloc-stat1.TotalAlloc)>>20, "MB")
}
func BenchmarkAppendAccessLogs2(b *testing.B) {
b.ReportAllocs()
var stat1 = &runtime.MemStats{}
runtime.ReadMemStats(stat1)
const count = 20000
var a = []*pb.HTTPAccessLog{}
for i := 0; i < b.N; i++ {
a = append(a, &pb.HTTPAccessLog{
RequestPath: "/hello/world",
Host: "example.com",
RequestBody: bytes.Repeat([]byte{'A'}, 1024),
})
if len(a) == count {
a = []*pb.HTTPAccessLog{}
}
}
_ = len(a)
var stat2 = &runtime.MemStats{}
runtime.ReadMemStats(stat2)
b.Log((stat2.TotalAlloc-stat1.TotalAlloc)>>20, "MB")
}

View File

@@ -0,0 +1,117 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes
import (
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/iwind/TeaGo/types"
"net"
"os"
"strconv"
"sync"
"sync/atomic"
)
var sharedHTTPAccessLogViewer = NewHTTPAccessLogViewer()
// HTTPAccessLogViewer 本地访问日志浏览器
type HTTPAccessLogViewer struct {
sockFile string
listener net.Listener
connMap map[int64]net.Conn // connId => net.Conn
connId int64
locker sync.Mutex
}
// NewHTTPAccessLogViewer 获取新对象
func NewHTTPAccessLogViewer() *HTTPAccessLogViewer {
return &HTTPAccessLogViewer{
sockFile: os.TempDir() + "/" + teaconst.AccessLogSockName,
connMap: map[int64]net.Conn{},
}
}
// Start 启动
func (this *HTTPAccessLogViewer) Start() error {
this.locker.Lock()
defer this.locker.Unlock()
if this.listener == nil {
// remove if exists
_ = os.Remove(this.sockFile)
// start listening
listener, err := net.Listen("unix", this.sockFile)
if err != nil {
return err
}
this.listener = listener
go func() {
for {
conn, err := this.listener.Accept()
if err != nil {
remotelogs.Error("ACCESS_LOG", "start local reading failed: "+err.Error())
break
}
this.locker.Lock()
var connId = this.nextConnId()
this.connMap[connId] = conn
go func() {
this.startReading(conn, connId)
}()
this.locker.Unlock()
}
}()
}
return nil
}
// HasConns 检查是否有连接
func (this *HTTPAccessLogViewer) HasConns() bool {
this.locker.Lock()
defer this.locker.Unlock()
return len(this.connMap) > 0
}
// Send 发送日志
func (this *HTTPAccessLogViewer) Send(accessLog *pb.HTTPAccessLog) {
var conns = []net.Conn{}
this.locker.Lock()
for _, conn := range this.connMap {
conns = append(conns, conn)
}
this.locker.Unlock()
if len(conns) == 0 {
return
}
for _, conn := range conns {
// ignore error
_, _ = conn.Write([]byte(accessLog.RemoteAddr + " [" + accessLog.TimeLocal + "] \"" + accessLog.RequestMethod + " " + accessLog.Scheme + "://" + accessLog.Host + accessLog.RequestURI + " " + accessLog.Proto + "\" " + types.String(accessLog.Status) + " " + types.String(accessLog.BytesSent) + " " + strconv.Quote(accessLog.Referer) + " " + strconv.Quote(accessLog.UserAgent) + " - " + fmt.Sprintf("%.2fms", accessLog.RequestTime*1000) + "\n"))
}
}
func (this *HTTPAccessLogViewer) nextConnId() int64 {
return atomic.AddInt64(&this.connId, 1)
}
func (this *HTTPAccessLogViewer) startReading(conn net.Conn, connId int64) {
var buf = make([]byte, 1024)
for {
_, err := conn.Read(buf)
if err != nil {
this.locker.Lock()
delete(this.connMap, connId)
this.locker.Unlock()
break
}
}
}

View File

@@ -0,0 +1,327 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes
import (
"context"
"crypto/tls"
"errors"
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/caches"
"github.com/TeaOSLab/EdgeNode/internal/compressions"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/utils/bytepool"
connutils "github.com/TeaOSLab/EdgeNode/internal/utils/conns"
"github.com/TeaOSLab/EdgeNode/internal/utils/goman"
"github.com/iwind/TeaGo/Tea"
"io"
"net"
"net/http"
"os"
"regexp"
"strings"
"sync"
"time"
)
func init() {
if !teaconst.IsMain {
return
}
events.On(events.EventStart, func() {
goman.New(func() {
SharedHTTPCacheTaskManager.Start()
})
})
}
var SharedHTTPCacheTaskManager = NewHTTPCacheTaskManager()
// HTTPCacheTaskManager 缓存任务管理
type HTTPCacheTaskManager struct {
ticker *time.Ticker
protocolReg *regexp.Regexp
timeoutClientMap map[time.Duration]*http.Client // timeout seconds=> *http.Client
locker sync.Mutex
taskQueue chan *pb.PurgeServerCacheRequest
}
func NewHTTPCacheTaskManager() *HTTPCacheTaskManager {
var duration = 30 * time.Second
if Tea.IsTesting() {
duration = 10 * time.Second
}
return &HTTPCacheTaskManager{
ticker: time.NewTicker(duration),
protocolReg: regexp.MustCompile(`^(?i)(http|https)://`),
taskQueue: make(chan *pb.PurgeServerCacheRequest, 1024),
timeoutClientMap: make(map[time.Duration]*http.Client),
}
}
func (this *HTTPCacheTaskManager) Start() {
// task queue
goman.New(func() {
rpcClient, _ := rpc.SharedRPC()
if rpcClient != nil {
for taskReq := range this.taskQueue {
_, err := rpcClient.ServerRPC.PurgeServerCache(rpcClient.Context(), taskReq)
if err != nil {
remotelogs.Error("HTTP_CACHE_TASK_MANAGER", "create purge task failed: "+err.Error())
}
}
}
})
// Loop
for range this.ticker.C {
err := this.Loop()
if err != nil {
remotelogs.Error("HTTP_CACHE_TASK_MANAGER", "execute task failed: "+err.Error())
}
}
}
func (this *HTTPCacheTaskManager) Loop() error {
rpcClient, err := rpc.SharedRPC()
if err != nil {
return err
}
resp, err := rpcClient.HTTPCacheTaskKeyRPC.FindDoingHTTPCacheTaskKeys(rpcClient.Context(), &pb.FindDoingHTTPCacheTaskKeysRequest{})
if err != nil {
// 忽略连接错误
if rpc.IsConnError(err) {
return nil
}
return err
}
var keys = resp.HttpCacheTaskKeys
if len(keys) == 0 {
return nil
}
var pbResults = []*pb.UpdateHTTPCacheTaskKeysStatusRequest_KeyResult{}
var taskGroup = goman.NewTaskGroup()
for _, key := range keys {
var taskKey = key
taskGroup.Run(func() {
processErr := this.processKey(taskKey)
var pbResult = &pb.UpdateHTTPCacheTaskKeysStatusRequest_KeyResult{
Id: taskKey.Id,
NodeClusterId: taskKey.NodeClusterId,
Error: "",
}
if processErr != nil {
pbResult.Error = processErr.Error()
}
taskGroup.Lock()
pbResults = append(pbResults, pbResult)
taskGroup.Unlock()
})
}
taskGroup.Wait()
_, err = rpcClient.HTTPCacheTaskKeyRPC.UpdateHTTPCacheTaskKeysStatus(rpcClient.Context(), &pb.UpdateHTTPCacheTaskKeysStatusRequest{KeyResults: pbResults})
if err != nil {
return err
}
return nil
}
func (this *HTTPCacheTaskManager) PushTaskKeys(keys []string) {
select {
case this.taskQueue <- &pb.PurgeServerCacheRequest{
Keys: keys,
Prefixes: nil,
}:
default:
}
}
func (this *HTTPCacheTaskManager) processKey(key *pb.HTTPCacheTaskKey) error {
switch key.Type {
case "purge":
var storages = caches.SharedManager.FindAllStorages()
for _, storage := range storages {
switch key.KeyType {
case "key":
var cacheKeys = []string{key.Key}
if strings.HasPrefix(key.Key, "http://") {
cacheKeys = append(cacheKeys, strings.Replace(key.Key, "http://", "https://", 1))
} else if strings.HasPrefix(key.Key, "https://") {
cacheKeys = append(cacheKeys, strings.Replace(key.Key, "https://", "http://", 1))
}
// TODO 提升效率
for _, cacheKey := range cacheKeys {
var subKeys = []string{
cacheKey,
cacheKey + caches.SuffixMethod + "HEAD",
cacheKey + caches.SuffixWebP,
cacheKey + caches.SuffixPartial,
}
// TODO 根据实际缓存的内容进行组合
for _, encoding := range compressions.AllEncodings() {
subKeys = append(subKeys, cacheKey+caches.SuffixCompression+encoding)
subKeys = append(subKeys, cacheKey+caches.SuffixWebP+caches.SuffixCompression+encoding)
}
err := storage.Purge(subKeys, "file")
if err != nil {
return err
}
}
case "prefix":
var prefixes = []string{key.Key}
if strings.HasPrefix(key.Key, "http://") {
prefixes = append(prefixes, strings.Replace(key.Key, "http://", "https://", 1))
} else if strings.HasPrefix(key.Key, "https://") {
prefixes = append(prefixes, strings.Replace(key.Key, "https://", "http://", 1))
}
err := storage.Purge(prefixes, "dir")
if err != nil {
return err
}
}
}
case "fetch":
err := this.fetchKey(key)
if err != nil {
return err
}
default:
return errors.New("invalid operation type '" + key.Type + "'")
}
return nil
}
// TODO 增加失败重试
func (this *HTTPCacheTaskManager) fetchKey(key *pb.HTTPCacheTaskKey) error {
var fullKey = key.Key
if !this.protocolReg.MatchString(fullKey) {
fullKey = "https://" + fullKey
}
req, err := http.NewRequest(http.MethodGet, fullKey, nil)
if err != nil {
return fmt.Errorf("invalid url: '%s': %w", fullKey, err)
}
// TODO 可以在管理界面自定义Header
req.Header.Set("X-Edge-Cache-Action", "fetch")
req.Header.Set("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/85.0.4183.121 Safari/537.36") // TODO 可以定义
req.Header.Set("Accept-Encoding", "gzip, deflate, br")
resp, err := this.httpClient().Do(req)
if err != nil {
err = this.simplifyErr(err)
return fmt.Errorf("request failed: '%s': %w", fullKey, err)
}
defer func() {
_ = resp.Body.Close()
}()
// 处理502
if resp.StatusCode == http.StatusBadGateway {
return errors.New("read origin site timeout")
}
// 读取内容,以便于生成缓存
var buf = bytepool.Pool16k.Get()
_, err = io.CopyBuffer(io.Discard, resp.Body, buf.Bytes)
bytepool.Pool16k.Put(buf)
if err != nil {
if err != io.EOF {
err = this.simplifyErr(err)
return fmt.Errorf("request failed: '%s': %w", fullKey, err)
} else {
err = nil
}
}
return nil
}
func (this *HTTPCacheTaskManager) simplifyErr(err error) error {
if err == nil {
return nil
}
if os.IsTimeout(err) {
return errors.New("timeout to read origin site")
}
return err
}
func (this *HTTPCacheTaskManager) httpClient() *http.Client {
var timeout = serverconfigs.DefaultHTTPCachePolicyFetchTimeout
var nodeConfig = sharedNodeConfig // copy
if nodeConfig != nil {
var cachePolicies = nodeConfig.HTTPCachePolicies // copy
if len(cachePolicies) > 0 && cachePolicies[0].FetchTimeout != nil && cachePolicies[0].FetchTimeout.Count > 0 {
var fetchTimeout = cachePolicies[0].FetchTimeout.Duration()
if fetchTimeout > 0 {
timeout = fetchTimeout
}
}
}
this.locker.Lock()
defer this.locker.Unlock()
client, ok := this.timeoutClientMap[timeout]
if ok {
return client
}
client = &http.Client{
Timeout: timeout,
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
_, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
conn, err := net.Dial(network, "127.0.0.1:"+port)
if err != nil {
return nil, err
}
return connutils.NewNoStat(conn), nil
},
MaxIdleConns: 128,
MaxIdleConnsPerHost: 32,
MaxConnsPerHost: 32,
IdleConnTimeout: 2 * time.Minute,
ExpectContinueTimeout: 1 * time.Second,
TLSHandshakeTimeout: 0,
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
}
this.timeoutClientMap[timeout] = client
return client
}

View File

@@ -0,0 +1,25 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes_test
import (
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeNode/internal/caches"
"github.com/TeaOSLab/EdgeNode/internal/nodes"
"testing"
)
func TestHTTPCacheTaskManager_Loop(t *testing.T) {
// initialize cache policies
config, err := nodeconfigs.SharedNodeConfig()
if err != nil {
t.Fatal(err)
}
caches.SharedManager.UpdatePolicies(config.HTTPCachePolicies)
var manager = nodes.NewHTTPCacheTaskManager()
err = manager.Loop()
if err != nil {
t.Fatal(err)
}
}

View File

@@ -0,0 +1,47 @@
package nodes
import (
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
"net/http"
)
// HTTPClient HTTP客户端
type HTTPClient struct {
rawClient *http.Client
accessAt int64
isProxyProtocol bool
}
// NewHTTPClient 获取新客户端对象
func NewHTTPClient(rawClient *http.Client, isProxyProtocol bool) *HTTPClient {
return &HTTPClient{
rawClient: rawClient,
accessAt: fasttime.Now().Unix(),
isProxyProtocol: isProxyProtocol,
}
}
// RawClient 获取原始客户端对象
func (this *HTTPClient) RawClient() *http.Client {
return this.rawClient
}
// UpdateAccessTime 更新访问时间
func (this *HTTPClient) UpdateAccessTime() {
this.accessAt = fasttime.Now().Unix()
}
// AccessTime 获取访问时间
func (this *HTTPClient) AccessTime() int64 {
return this.accessAt
}
// IsProxyProtocol 判断是否为PROXY Protocol
func (this *HTTPClient) IsProxyProtocol() bool {
return this.isProxyProtocol
}
// Close 关闭
func (this *HTTPClient) Close() {
this.rawClient.CloseIdleConnections()
}

View File

@@ -0,0 +1,301 @@
package nodes
import (
"context"
"crypto/tls"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
"github.com/TeaOSLab/EdgeNode/internal/utils/goman"
"github.com/cespare/xxhash/v2"
"github.com/pires/go-proxyproto"
"golang.org/x/net/http2"
"net"
"net/http"
"runtime"
"strconv"
"strings"
"sync"
"time"
)
// SharedHTTPClientPool HTTP客户端池单例
var SharedHTTPClientPool = NewHTTPClientPool()
const httpClientProxyProtocolTag = "@ProxyProtocol@"
const maxHTTPRedirects = 8
// HTTPClientPool 客户端池
type HTTPClientPool struct {
clientsMap map[uint64]*HTTPClient // origin key => client
cleanTicker *time.Ticker
locker sync.RWMutex
}
// NewHTTPClientPool 获取新对象
func NewHTTPClientPool() *HTTPClientPool {
var pool = &HTTPClientPool{
cleanTicker: time.NewTicker(1 * time.Hour),
clientsMap: map[uint64]*HTTPClient{},
}
goman.New(func() {
pool.cleanClients()
})
return pool
}
// Client 根据地址获取客户端
func (this *HTTPClientPool) Client(req *HTTPRequest,
origin *serverconfigs.OriginConfig,
originAddr string,
proxyProtocol *serverconfigs.ProxyProtocolConfig,
followRedirects bool) (rawClient *http.Client, err error) {
if origin.Addr == nil {
return nil, errors.New("origin addr should not be empty (originId:" + strconv.FormatInt(origin.Id, 10) + ")")
}
if req == nil || req.RawReq == nil || req.RawReq.URL == nil {
err = errors.New("invalid request url")
return
}
var originHost = req.RawReq.URL.Host
var urlPort = req.RawReq.URL.Port()
if len(urlPort) == 0 {
if req.RawReq.URL.Scheme == "http" {
urlPort = "80"
} else {
urlPort = "443"
}
originHost += ":" + urlPort
}
var rawKey = origin.UniqueKey() + "@" + originAddr + "@" + originHost
// if we are under available ProxyProtocol, we add client ip to key to make every client unique
var isProxyProtocol = false
if proxyProtocol != nil && proxyProtocol.IsOn {
rawKey += httpClientProxyProtocolTag + req.requestRemoteAddr(true)
isProxyProtocol = true
}
// follow redirects
if followRedirects {
rawKey += "@follow"
}
var key = xxhash.Sum64String(rawKey)
var isLnRequest = origin.Id == 0
this.locker.RLock()
client, found := this.clientsMap[key]
this.locker.RUnlock()
if found {
client.UpdateAccessTime()
return client.RawClient(), nil
}
// 这里不能使用RLock避免因为并发生成多个同样的client实例
this.locker.Lock()
defer this.locker.Unlock()
// 再次查找
client, found = this.clientsMap[key]
if found {
client.UpdateAccessTime()
return client.RawClient(), nil
}
var maxConnections = origin.MaxConns
var connectionTimeout = origin.ConnTimeoutDuration()
var readTimeout = origin.ReadTimeoutDuration()
var idleTimeout = origin.IdleTimeoutDuration()
var idleConns = origin.MaxIdleConns
// 超时时间
if connectionTimeout <= 0 {
connectionTimeout = 15 * time.Second
}
if idleTimeout <= 0 {
idleTimeout = 2 * time.Minute
}
var numberCPU = runtime.NumCPU()
if numberCPU < 8 {
numberCPU = 8
}
if maxConnections <= 0 {
maxConnections = numberCPU * 64
}
if idleConns <= 0 {
idleConns = numberCPU * 16
}
if isProxyProtocol { // ProxyProtocol无需保持太多空闲连接
idleConns = 3
} else if isLnRequest { // 可以判断为Ln节点请求
maxConnections *= 8
idleConns *= 8
idleTimeout *= 4
} else if sharedNodeConfig != nil && sharedNodeConfig.Level > 1 {
// Ln节点可以适当增加连接数
maxConnections *= 2
idleConns *= 2
}
// TLS通讯
var tlsConfig = &tls.Config{
InsecureSkipVerify: true,
}
if origin.Cert != nil {
var obj = origin.Cert.CertObject()
if obj != nil {
tlsConfig.InsecureSkipVerify = false
tlsConfig.Certificates = []tls.Certificate{*obj}
if len(origin.Cert.ServerName) > 0 {
tlsConfig.ServerName = origin.Cert.ServerName
}
}
}
var transport = &HTTPClientTransport{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network string, addr string) (net.Conn, error) {
var realAddr = originAddr
// for redirections
if followRedirects && originHost != addr {
realAddr = addr
}
// connect
conn, dialErr := (&net.Dialer{
Timeout: connectionTimeout,
KeepAlive: 1 * time.Minute,
}).DialContext(ctx, network, realAddr)
if dialErr != nil {
return nil, dialErr
}
// handle PROXY protocol
proxyErr := this.handlePROXYProtocol(conn, req, proxyProtocol)
if proxyErr != nil {
return nil, proxyErr
}
return NewOriginConn(conn), nil
},
MaxIdleConns: 0,
MaxIdleConnsPerHost: idleConns,
MaxConnsPerHost: maxConnections,
IdleConnTimeout: idleTimeout,
ExpectContinueTimeout: 1 * time.Second,
TLSHandshakeTimeout: 5 * time.Second,
TLSClientConfig: tlsConfig,
ReadBufferSize: 8 * 1024,
Proxy: nil,
},
}
// support http/2
if origin.HTTP2Enabled && origin.Addr != nil && origin.Addr.Protocol == serverconfigs.ProtocolHTTPS {
_ = http2.ConfigureTransport(transport.Transport)
}
rawClient = &http.Client{
Timeout: readTimeout,
Transport: transport,
CheckRedirect: func(targetReq *http.Request, via []*http.Request) error {
// follow redirects
if followRedirects && len(via) <= maxHTTPRedirects {
return nil
}
return http.ErrUseLastResponse
},
}
this.clientsMap[key] = NewHTTPClient(rawClient, isProxyProtocol)
return rawClient, nil
}
// 清理不使用的Client
func (this *HTTPClientPool) cleanClients() {
for range this.cleanTicker.C {
var nowTime = fasttime.Now().Unix()
var expiredKeys []uint64
var expiredClients = []*HTTPClient{}
// lookup expired clients
this.locker.RLock()
for k, client := range this.clientsMap {
if client.AccessTime() < nowTime-86400 ||
(client.IsProxyProtocol() && client.AccessTime() < nowTime-3600) { // 超过 N 秒没有调用就关闭
expiredKeys = append(expiredKeys, k)
expiredClients = append(expiredClients, client)
}
}
this.locker.RUnlock()
// remove expired keys
if len(expiredKeys) > 0 {
this.locker.Lock()
for _, k := range expiredKeys {
delete(this.clientsMap, k)
}
this.locker.Unlock()
}
// close expired clients
if len(expiredClients) > 0 {
for _, client := range expiredClients {
client.Close()
}
}
}
}
// 支持PROXY Protocol
func (this *HTTPClientPool) handlePROXYProtocol(conn net.Conn, req *HTTPRequest, proxyProtocol *serverconfigs.ProxyProtocolConfig) error {
if proxyProtocol != nil &&
proxyProtocol.IsOn &&
(proxyProtocol.Version == serverconfigs.ProxyProtocolVersion1 || proxyProtocol.Version == serverconfigs.ProxyProtocolVersion2) {
var remoteAddr = req.requestRemoteAddr(true)
var transportProtocol = proxyproto.TCPv4
if strings.Contains(remoteAddr, ":") {
transportProtocol = proxyproto.TCPv6
}
var destAddr = conn.RemoteAddr()
var reqConn = req.RawReq.Context().Value(HTTPConnContextKey)
if reqConn != nil {
destAddr = reqConn.(net.Conn).LocalAddr()
}
var header = proxyproto.Header{
Version: byte(proxyProtocol.Version),
Command: proxyproto.PROXY,
TransportProtocol: transportProtocol,
SourceAddr: &net.TCPAddr{
IP: net.ParseIP(remoteAddr),
Port: req.requestRemotePort(),
},
DestinationAddr: destAddr,
}
_, err := header.WriteTo(conn)
if err != nil {
_ = conn.Close()
return err
}
return nil
}
return nil
}

View File

@@ -0,0 +1,92 @@
package nodes
import (
"context"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"net/http"
"runtime"
"testing"
"time"
)
func TestHTTPClientPool_Client(t *testing.T) {
var pool = NewHTTPClientPool()
rawReq, err := http.NewRequest(http.MethodGet, "https://example.com/", nil)
if err != nil {
t.Fatal(err)
}
var req = &HTTPRequest{RawReq: rawReq}
{
var origin = &serverconfigs.OriginConfig{
Id: 1,
Version: 2,
Addr: &serverconfigs.NetworkAddressConfig{Host: "127.0.0.1", PortRange: "1234"},
}
err := origin.Init(context.Background())
if err != nil {
t.Fatal(err)
}
{
client, err := pool.Client(req, origin, origin.Addr.PickAddress(), nil, false)
if err != nil {
t.Fatal(err)
}
t.Log("client:", client)
}
for i := 0; i < 10; i++ {
client, err := pool.Client(req, origin, origin.Addr.PickAddress(), nil, false)
if err != nil {
t.Fatal(err)
}
t.Log("client:", client)
}
}
}
func TestHTTPClientPool_cleanClients(t *testing.T) {
var origin = &serverconfigs.OriginConfig{
Id: 1,
Version: 2,
Addr: &serverconfigs.NetworkAddressConfig{Host: "127.0.0.1", PortRange: "1234"},
}
err := origin.Init(context.Background())
if err != nil {
t.Fatal(err)
}
var pool = NewHTTPClientPool()
for i := 0; i < 10; i++ {
t.Log("get", i)
_, _ = pool.Client(nil, origin, origin.Addr.PickAddress(), nil, false)
if testutils.IsSingleTesting() {
time.Sleep(1 * time.Second)
}
}
}
func BenchmarkHTTPClientPool_Client(b *testing.B) {
runtime.GOMAXPROCS(1)
var origin = &serverconfigs.OriginConfig{
Id: 1,
Version: 2,
Addr: &serverconfigs.NetworkAddressConfig{Host: "127.0.0.1", PortRange: "1234"},
}
err := origin.Init(context.Background())
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
var pool = NewHTTPClientPool()
for i := 0; i < b.N; i++ {
_, _ = pool.Client(nil, origin, origin.Addr.PickAddress(), nil, false)
}
}

View File

@@ -0,0 +1,26 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes
import (
"net/http"
)
const emptyHTTPLocation = "/$EmptyHTTPLocation$"
type HTTPClientTransport struct {
*http.Transport
}
func (this *HTTPClientTransport) RoundTrip(req *http.Request) (*http.Response, error) {
resp, err := this.Transport.RoundTrip(req)
if err != nil {
return resp, err
}
// 检查在跳转相关状态中Location是否存在
if httpStatusIsRedirect(resp.StatusCode) && len(resp.Header.Get("Location")) == 0 {
resp.Header.Set("Location", emptyHTTPLocation)
}
return resp, nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,39 @@
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"path/filepath"
)
func (this *HTTPRequest) doACME() (shouldStop bool) {
// TODO 对请求进行校验,防止恶意攻击
var token = filepath.Base(this.RawReq.URL.Path)
if token == "acme-challenge" || len(token) <= 32 {
return false
}
rpcClient, err := rpc.SharedRPC()
if err != nil {
remotelogs.Error("RPC", "[ACME]rpc failed: "+err.Error())
return false
}
keyResp, err := rpcClient.ACMEAuthenticationRPC.FindACMEAuthenticationKeyWithToken(rpcClient.Context(), &pb.FindACMEAuthenticationKeyWithTokenRequest{Token: token})
if err != nil {
remotelogs.Error("RPC", "[ACME]read key for token failed: "+err.Error())
return false
}
if len(keyResp.Key) == 0 {
return false
}
this.tags = append(this.tags, "ACME")
this.writer.Header().Set("Content-Type", "text/plain")
_, _ = this.writer.WriteString(keyResp.Key)
return true
}

View File

@@ -0,0 +1,72 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"bytes"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"io"
"net/http"
)
// 执行认证
func (this *HTTPRequest) doAuth() (shouldStop bool) {
if this.web.Auth == nil || !this.web.Auth.IsOn {
return
}
for _, ref := range this.web.Auth.PolicyRefs {
if !ref.IsOn || ref.AuthPolicy == nil || !ref.AuthPolicy.IsOn {
continue
}
if !ref.AuthPolicy.MatchRequest(this.RawReq) {
continue
}
ok, newURI, uriChanged, err := ref.AuthPolicy.Filter(this.RawReq, func(subReq *http.Request) (status int, err error) {
subReq.TLS = this.RawReq.TLS
subReq.RemoteAddr = this.RawReq.RemoteAddr
subReq.Host = this.RawReq.Host
subReq.Proto = this.RawReq.Proto
subReq.ProtoMinor = this.RawReq.ProtoMinor
subReq.ProtoMajor = this.RawReq.ProtoMajor
subReq.Body = io.NopCloser(bytes.NewReader([]byte{}))
subReq.Header.Set("Referer", this.URL())
var writer = NewEmptyResponseWriter(this.writer)
this.doSubRequest(writer, subReq)
return writer.StatusCode(), nil
}, this.Format)
if err != nil {
this.write50x(err, http.StatusInternalServerError, "Failed to execute the AuthPolicy", "认证策略执行失败", false)
return
}
if ok {
if uriChanged {
this.uri = newURI
}
this.tags = append(this.tags, "auth:"+ref.AuthPolicy.Type)
return
} else {
// Basic Auth比较特殊
if ref.AuthPolicy.Type == serverconfigs.HTTPAuthTypeBasicAuth {
method, ok := ref.AuthPolicy.Method().(*serverconfigs.HTTPAuthBasicMethod)
if ok {
var headerValue = "Basic realm=\""
if len(method.Realm) > 0 {
headerValue += method.Realm
} else {
headerValue += this.ReqHost
}
headerValue += "\""
if len(method.Charset) > 0 {
headerValue += ", charset=\"" + method.Charset + "\""
}
this.writer.Header()["WWW-Authenticate"] = []string{headerValue}
}
}
this.writer.WriteHeader(http.StatusUnauthorized)
this.tags = append(this.tags, "auth:"+ref.AuthPolicy.Type)
return true
}
}
return
}

View File

@@ -0,0 +1,760 @@
package nodes
import (
"bytes"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeNode/internal/caches"
"github.com/TeaOSLab/EdgeNode/internal/compressions"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
rangeutils "github.com/TeaOSLab/EdgeNode/internal/utils/ranges"
"github.com/iwind/TeaGo/types"
"io"
"net/http"
"path/filepath"
"strconv"
"strings"
"time"
)
// 读取缓存
func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
// 需要动态Upgrade的不缓存
if len(this.RawReq.Header.Get("Upgrade")) > 0 {
return
}
this.cacheCanTryStale = false
var cachePolicy = this.ReqServer.HTTPCachePolicy
if cachePolicy == nil || !cachePolicy.IsOn {
return
}
if this.web.Cache == nil || !this.web.Cache.IsOn || (len(cachePolicy.CacheRefs) == 0 && len(this.web.Cache.CacheRefs) == 0) {
return
}
// 添加 X-Cache Header
var addStatusHeader = this.web.Cache.AddStatusHeader
var cacheBypassDescription = ""
if addStatusHeader {
defer func() {
if len(cacheBypassDescription) > 0 {
this.writer.Header().Set("X-Cache", cacheBypassDescription)
return
}
var cacheStatus = this.varMapping["cache.status"]
if cacheStatus != "HIT" {
this.writer.Header().Set("X-Cache", cacheStatus)
}
}()
}
// 检查服务独立的缓存条件
var refType = ""
for _, cacheRef := range this.web.Cache.CacheRefs {
if !cacheRef.IsOn {
continue
}
if (cacheRef.Conds != nil && cacheRef.Conds.HasRequestConds() && cacheRef.Conds.MatchRequest(this.Format)) ||
(cacheRef.SimpleCond != nil && cacheRef.SimpleCond.Match(this.Format)) {
if cacheRef.IsReverse {
return
}
this.cacheRef = cacheRef
refType = "server"
break
}
}
if this.cacheRef == nil && !this.web.Cache.DisablePolicyRefs {
// 检查策略默认的缓存条件
for _, cacheRef := range cachePolicy.CacheRefs {
if !cacheRef.IsOn {
continue
}
if (cacheRef.Conds != nil && cacheRef.Conds.HasRequestConds() && cacheRef.Conds.MatchRequest(this.Format)) ||
(cacheRef.SimpleCond != nil && cacheRef.SimpleCond.Match(this.Format)) {
if cacheRef.IsReverse {
return
}
this.cacheRef = cacheRef
refType = "policy"
break
}
}
}
if this.cacheRef == nil {
return
}
// 是否强制Range回源
if this.cacheRef.AlwaysForwardRangeRequest && len(this.RawReq.Header.Get("Range")) > 0 {
this.cacheRef = nil
cacheBypassDescription = "BYPASS, forward range"
return
}
// 是否正在Purge
var isPurging = this.web.Cache.PurgeIsOn && strings.ToUpper(this.RawReq.Method) == "PURGE" && this.RawReq.Header.Get("X-Edge-Purge-Key") == this.web.Cache.PurgeKey
if isPurging {
this.RawReq.Method = http.MethodGet
}
// 校验请求
if !this.cacheRef.MatchRequest(this.RawReq) {
this.cacheRef = nil
cacheBypassDescription = "BYPASS, not match"
return
}
// 相关变量
this.varMapping["cache.policy.name"] = cachePolicy.Name
this.varMapping["cache.policy.id"] = strconv.FormatInt(cachePolicy.Id, 10)
this.varMapping["cache.policy.type"] = cachePolicy.Type
// Cache-Pragma
if this.cacheRef.EnableRequestCachePragma {
if this.RawReq.Header.Get("Cache-Control") == "no-cache" || this.RawReq.Header.Get("Pragma") == "no-cache" {
this.cacheRef = nil
cacheBypassDescription = "BYPASS, Cache-Control or Pragma"
return
}
}
// TODO 支持Vary Header
// 缓存标签
var tags = []string{}
// 检查是否有缓存
var key string
if this.web.Cache.Key != nil && this.web.Cache.Key.IsOn && len(this.web.Cache.Key.Host) > 0 {
key = configutils.ParseVariables(this.cacheRef.Key, func(varName string) (value string) {
switch varName {
case "scheme":
return this.web.Cache.Key.Scheme
case "host":
return this.web.Cache.Key.Host
default:
return this.Format("${" + varName + "}")
}
})
} else {
key = this.Format(this.cacheRef.Key)
}
if len(key) == 0 {
this.cacheRef = nil
cacheBypassDescription = "BYPASS, empty key"
return
}
var method = this.Method()
if method != http.MethodGet {
key += caches.SuffixMethod + method
tags = append(tags, strings.ToLower(method))
}
this.cacheKey = key
this.varMapping["cache.key"] = key
// 读取缓存
var storage = caches.SharedManager.FindStorageWithPolicy(cachePolicy.Id)
if storage == nil {
this.cacheRef = nil
cacheBypassDescription = "BYPASS, no policy found"
return
}
this.writer.cacheStorage = storage
// 如果正在预热,则不读取缓存,等待下一个步骤重新生成
if (strings.HasPrefix(this.RawReq.RemoteAddr, "127.") || strings.HasPrefix(this.RawReq.RemoteAddr, "[::1]")) && this.RawReq.Header.Get("X-Edge-Cache-Action") == "fetch" {
return
}
// 判断是否在Purge
if isPurging {
this.varMapping["cache.status"] = "PURGE"
var subKeys = []string{
key,
key + caches.SuffixMethod + "HEAD",
key + caches.SuffixWebP,
key + caches.SuffixPartial,
}
// TODO 根据实际缓存的内容进行组合
for _, encoding := range compressions.AllEncodings() {
subKeys = append(subKeys, key+caches.SuffixCompression+encoding)
subKeys = append(subKeys, key+caches.SuffixWebP+caches.SuffixCompression+encoding)
}
for _, subKey := range subKeys {
err := storage.Delete(subKey)
if err != nil {
remotelogs.ErrorServer("HTTP_REQUEST_CACHE", "purge failed: "+err.Error())
}
}
// 通过API节点清除别节点上的的Key
SharedHTTPCacheTaskManager.PushTaskKeys([]string{key})
return true
}
// 调用回调
this.onRequest()
if this.writer.isFinished {
return
}
var reader caches.Reader
var err error
var rangeHeader = this.RawReq.Header.Get("Range")
var isPartialRequest = len(rangeHeader) > 0
// 检查是否支持WebP
var webPIsEnabled = false
var isHeadMethod = method == http.MethodHead
if !isPartialRequest &&
!isHeadMethod &&
this.web.WebP != nil &&
this.web.WebP.IsOn &&
this.web.WebP.MatchRequest(filepath.Ext(this.Path()), this.Format) &&
this.web.WebP.MatchAccept(this.RawReq.Header.Get("Accept")) {
webPIsEnabled = true
}
// 检查WebP压缩缓存
if webPIsEnabled && !isPartialRequest && !isHeadMethod && reader == nil {
if this.web.Compression != nil && this.web.Compression.IsOn {
_, encoding, ok := this.web.Compression.MatchAcceptEncoding(this.RawReq.Header.Get("Accept-Encoding"))
if ok {
reader, err = storage.OpenReader(key+caches.SuffixWebP+caches.SuffixCompression+encoding, useStale, false)
if err != nil && caches.IsBusyError(err) {
this.varMapping["cache.status"] = "BUSY"
this.cacheRef = nil
return
}
if reader != nil {
tags = append(tags, "webp", encoding)
}
}
}
}
// 检查WebP
if webPIsEnabled && !isPartialRequest &&
!isHeadMethod &&
reader == nil {
reader, err = storage.OpenReader(key+caches.SuffixWebP, useStale, false)
if err != nil && caches.IsBusyError(err) {
this.varMapping["cache.status"] = "BUSY"
this.cacheRef = nil
return
}
if reader != nil {
this.writer.cacheReaderSuffix = caches.SuffixWebP
tags = append(tags, "webp")
}
}
// 检查普通压缩缓存
if !isPartialRequest && !isHeadMethod && reader == nil {
if this.web.Compression != nil && this.web.Compression.IsOn {
_, encoding, ok := this.web.Compression.MatchAcceptEncoding(this.RawReq.Header.Get("Accept-Encoding"))
if ok {
reader, err = storage.OpenReader(key+caches.SuffixCompression+encoding, useStale, false)
if err != nil && caches.IsBusyError(err) {
this.varMapping["cache.status"] = "BUSY"
this.cacheRef = nil
return
}
if reader != nil {
tags = append(tags, encoding)
}
}
}
}
// 检查正常的文件
var isPartialCache = false
var partialRanges []rangeutils.Range
var firstRangeEnd int64
if reader == nil {
reader, err = storage.OpenReader(key, useStale, false)
if err != nil && caches.IsBusyError(err) {
this.varMapping["cache.status"] = "BUSY"
this.cacheRef = nil
return
}
if err != nil && this.cacheRef.AllowPartialContent {
// 尝试读取分片的缓存内容
if len(rangeHeader) == 0 && this.cacheRef.ForcePartialContent {
// 默认读取开头
rangeHeader = "bytes=0-"
}
if len(rangeHeader) > 0 {
pReader, ranges, rangeEnd, goNext := this.tryPartialReader(storage, key, useStale, rangeHeader, this.cacheRef.ForcePartialContent)
if !goNext {
this.cacheRef = nil
return
}
if pReader != nil {
isPartialCache = true
reader = pReader
partialRanges = ranges
firstRangeEnd = rangeEnd
err = nil
}
}
}
if err != nil {
if errors.Is(err, caches.ErrNotFound) {
// 移除请求中的 If-None-Match 和 If-Modified-Since防止源站返回304而无法缓存
if this.reverseProxy != nil {
this.RawReq.Header.Del("If-None-Match")
this.RawReq.Header.Del("If-Modified-Since")
}
// cache相关变量
this.varMapping["cache.status"] = "MISS"
if !useStale && this.web.Cache.Stale != nil && this.web.Cache.Stale.IsOn {
this.cacheCanTryStale = true
}
return
}
if !this.canIgnore(err) {
remotelogs.WarnServer("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: open cache failed: "+err.Error())
}
return
}
}
defer func() {
if !this.writer.DelayRead() {
_ = reader.Close()
}
}()
if useStale {
this.varMapping["cache.status"] = "STALE"
this.logAttrs["cache.status"] = "STALE"
} else {
this.varMapping["cache.status"] = "HIT"
this.logAttrs["cache.status"] = "HIT"
}
// 准备Buffer
var fileSize = reader.BodySize()
var totalSizeString = types.String(fileSize)
if isPartialCache {
fileSize = reader.(*caches.PartialFileReader).MaxLength()
if totalSizeString == "0" {
totalSizeString = "*"
}
}
// 读取Header
var headerData = []byte{}
this.writer.SetSentHeaderBytes(reader.HeaderSize())
var headerPool = this.bytePool(reader.HeaderSize())
var headerBuf = headerPool.Get()
err = reader.ReadHeader(headerBuf.Bytes, func(n int) (goNext bool, readErr error) {
headerData = append(headerData, headerBuf.Bytes[:n]...)
for {
var nIndex = bytes.Index(headerData, []byte{'\n'})
if nIndex >= 0 {
var row = headerData[:nIndex]
var spaceIndex = bytes.Index(row, []byte{':'})
if spaceIndex <= 0 {
return false, errors.New("invalid header '" + string(row) + "'")
}
this.writer.Header().Set(string(row[:spaceIndex]), string(row[spaceIndex+1:]))
headerData = headerData[nIndex+1:]
} else {
break
}
}
return true, nil
})
headerPool.Put(headerBuf)
if err != nil {
if !this.canIgnore(err) {
remotelogs.WarnServer("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: read header failed: "+err.Error())
}
return
}
// 设置cache.age变量
var age = strconv.FormatInt(fasttime.Now().Unix()-reader.LastModified(), 10)
this.varMapping["cache.age"] = age
if addStatusHeader {
if useStale {
this.writer.Header().Set("X-Cache", "STALE, "+refType+", "+reader.TypeName())
} else {
this.writer.Header().Set("X-Cache", "HIT, "+refType+", "+reader.TypeName())
}
} else {
this.writer.Header().Del("X-Cache")
}
if this.web.Cache.AddAgeHeader {
this.writer.Header().Set("Age", age)
}
// ETag
var respHeader = this.writer.Header()
var eTag = respHeader.Get("ETag")
var lastModifiedAt = reader.LastModified()
if len(eTag) == 0 {
if lastModifiedAt > 0 {
if len(tags) > 0 {
eTag = "\"" + strconv.FormatInt(lastModifiedAt, 10) + "_" + strings.Join(tags, "_") + "\""
} else {
eTag = "\"" + strconv.FormatInt(lastModifiedAt, 10) + "\""
}
respHeader.Del("Etag")
// 修复:移除 isPartialCache 限制,确保所有缓存类型都返回 ETag
respHeader["ETag"] = []string{eTag}
}
}
// 支持 Last-Modified
var modifiedTime = ""
if lastModifiedAt > 0 {
modifiedTime = time.Unix(utils.GMTUnixTime(lastModifiedAt), 0).Format("Mon, 02 Jan 2006 15:04:05") + " GMT"
// 修复:移除 isPartialCache 限制,确保所有缓存类型都返回 Last-Modified
respHeader.Set("Last-Modified", modifiedTime)
}
// 支持 If-None-Match
// 修复:移除 isPartialCache 限制,允许分片缓存也支持 304 响应
if !this.isLnRequest && len(eTag) > 0 && this.requestHeader("If-None-Match") == eTag {
// 自定义Header
this.ProcessResponseHeaders(this.writer.Header(), http.StatusNotModified)
this.addExpiresHeader(reader.ExpiresAt())
this.writer.WriteHeader(http.StatusNotModified)
this.isCached = true
this.cacheRef = nil
this.writer.SetOk()
return true
}
// 支持 If-Modified-Since
// 修复:移除 isPartialCache 限制,允许分片缓存也支持 304 响应
if !this.isLnRequest && len(modifiedTime) > 0 && this.requestHeader("If-Modified-Since") == modifiedTime {
// 自定义Header
this.ProcessResponseHeaders(this.writer.Header(), http.StatusNotModified)
this.addExpiresHeader(reader.ExpiresAt())
this.writer.WriteHeader(http.StatusNotModified)
this.isCached = true
this.cacheRef = nil
this.writer.SetOk()
return true
}
this.ProcessResponseHeaders(this.writer.Header(), reader.Status())
this.addExpiresHeader(reader.ExpiresAt())
// 返回上级节点过期时间
if this.isLnRequest {
respHeader.Set(LNExpiresHeader, types.String(reader.ExpiresAt()))
}
// 输出Body
if this.RawReq.Method == http.MethodHead {
this.writer.WriteHeader(reader.Status())
} else {
ifRangeHeaders, ok := this.RawReq.Header["If-Range"]
var supportRange = true
if ok {
supportRange = false
for _, v := range ifRangeHeaders {
if v == this.writer.Header().Get("ETag") || v == this.writer.Header().Get("Last-Modified") {
supportRange = true
break
}
}
}
// 支持Range
var ranges = partialRanges
if supportRange {
if len(rangeHeader) > 0 {
if fileSize == 0 {
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
if len(ranges) == 0 {
ranges, ok = httpRequestParseRangeHeader(rangeHeader)
if !ok {
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
}
if len(ranges) > 0 {
for k, r := range ranges {
r2, ok := r.Convert(fileSize)
if !ok {
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
ranges[k] = r2
}
}
}
}
if len(ranges) == 1 {
respHeader.Set("Content-Range", ranges[0].ComposeContentRangeHeader(totalSizeString))
respHeader.Set("Content-Length", strconv.FormatInt(ranges[0].Length(), 10))
this.writer.WriteHeader(http.StatusPartialContent)
var pool = this.bytePool(fileSize)
var bodyBuf = pool.Get()
var rangeEnd = ranges[0].End()
if firstRangeEnd > 0 {
rangeEnd = firstRangeEnd
}
err = reader.ReadBodyRange(bodyBuf.Bytes, ranges[0].Start(), rangeEnd, func(n int) (goNext bool, readErr error) {
_, readErr = this.writer.Write(bodyBuf.Bytes[:n])
if readErr != nil {
return false, errWritingToClient
}
return true, nil
})
pool.Put(bodyBuf)
if err != nil {
this.varMapping["cache.status"] = "MISS"
if errors.Is(err, caches.ErrInvalidRange) {
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
if !this.canIgnore(err) {
remotelogs.WarnServer("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: "+err.Error())
}
return true
}
} else if len(ranges) > 1 {
var boundary = httpRequestGenBoundary()
respHeader.Set("Content-Type", "multipart/byteranges; boundary="+boundary)
respHeader.Del("Content-Length")
var contentType = respHeader.Get("Content-Type")
this.writer.WriteHeader(http.StatusPartialContent)
for index, r := range ranges {
if index == 0 {
_, err = this.writer.WriteString("--" + boundary + "\r\n")
} else {
_, err = this.writer.WriteString("\r\n--" + boundary + "\r\n")
}
if err != nil {
// 不提示写入客户端错误
return true
}
_, err = this.writer.WriteString("Content-Range: " + r.ComposeContentRangeHeader(totalSizeString) + "\r\n")
if err != nil {
// 不提示写入客户端错误
return true
}
if len(contentType) > 0 {
_, err = this.writer.WriteString("Content-Type: " + contentType + "\r\n\r\n")
if err != nil {
// 不提示写入客户端错误
return true
}
}
var pool = this.bytePool(fileSize)
var bodyBuf = pool.Get()
err = reader.ReadBodyRange(bodyBuf.Bytes, r.Start(), r.End(), func(n int) (goNext bool, readErr error) {
_, readErr = this.writer.Write(bodyBuf.Bytes[:n])
if readErr != nil {
return false, errWritingToClient
}
return true, nil
})
pool.Put(bodyBuf)
if err != nil {
if !this.canIgnore(err) {
remotelogs.WarnServer("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: "+err.Error())
}
return true
}
}
_, err = this.writer.WriteString("\r\n--" + boundary + "--\r\n")
if err != nil {
this.varMapping["cache.status"] = "MISS"
// 不提示写入客户端错误
return true
}
} else { // 没有Range
var resp = &http.Response{
Body: reader,
ContentLength: reader.BodySize(),
}
this.writer.Prepare(resp, fileSize, reader.Status(), false)
this.writer.WriteHeader(reader.Status())
if storage.CanSendfile() {
var pool = this.bytePool(fileSize)
var bodyBuf = pool.Get()
if fp, canSendFile := this.writer.canSendfile(); canSendFile {
this.writer.sentBodyBytes, err = io.CopyBuffer(this.writer.rawWriter, fp, bodyBuf.Bytes)
} else {
_, err = io.CopyBuffer(this.writer, resp.Body, bodyBuf.Bytes)
}
pool.Put(bodyBuf)
} else {
mmapReader, isMMAPReader := reader.(*caches.MMAPFileReader)
if isMMAPReader {
_, err = mmapReader.CopyBodyTo(this.writer)
} else {
var pool = this.bytePool(fileSize)
var bodyBuf = pool.Get()
_, err = io.CopyBuffer(this.writer, resp.Body, bodyBuf.Bytes)
pool.Put(bodyBuf)
}
}
if err == io.EOF {
err = nil
}
if err != nil {
this.varMapping["cache.status"] = "MISS"
if !this.canIgnore(err) {
remotelogs.WarnServer("HTTP_REQUEST_CACHE", this.URL()+": read from cache failed: read body failed: "+err.Error())
}
return
}
}
}
this.isCached = true
this.cacheRef = nil
this.writer.SetOk()
return true
}
// 设置Expires Header
func (this *HTTPRequest) addExpiresHeader(expiresAt int64) {
if this.cacheRef.ExpiresTime != nil && this.cacheRef.ExpiresTime.IsPrior && this.cacheRef.ExpiresTime.IsOn {
if this.cacheRef.ExpiresTime.Overwrite || len(this.writer.Header().Get("Expires")) == 0 {
if this.cacheRef.ExpiresTime.AutoCalculate {
this.writer.Header().Set("Expires", time.Unix(utils.GMTUnixTime(expiresAt), 0).Format("Mon, 2 Jan 2006 15:04:05")+" GMT")
this.writer.Header().Del("Cache-Control")
} else if this.cacheRef.ExpiresTime.Duration != nil {
var duration = this.cacheRef.ExpiresTime.Duration.Duration()
if duration > 0 {
this.writer.Header().Set("Expires", utils.GMTTime(time.Now().Add(duration)).Format("Mon, 2 Jan 2006 15:04:05")+" GMT")
this.writer.Header().Del("Cache-Control")
}
}
}
}
}
// 尝试读取区间缓存
func (this *HTTPRequest) tryPartialReader(storage caches.StorageInterface, key string, useStale bool, rangeHeader string, forcePartialContent bool) (resultReader caches.Reader, ranges []rangeutils.Range, firstRangeEnd int64, goNext bool) {
goNext = true
// 尝试读取Partial cache
if len(rangeHeader) == 0 {
return
}
ranges, ok := httpRequestParseRangeHeader(rangeHeader)
if !ok {
return
}
pReader, pErr := storage.OpenReader(key+caches.SuffixPartial, useStale, true)
if pErr != nil {
if caches.IsBusyError(pErr) {
this.varMapping["cache.status"] = "BUSY"
goNext = false
return
}
return
}
partialReader, ok := pReader.(*caches.PartialFileReader)
if !ok {
_ = pReader.Close()
return
}
var isOk = false
defer func() {
if !isOk {
_ = pReader.Close()
}
}()
// 检查是否已下载完整
if !forcePartialContent &&
len(ranges) > 0 &&
ranges[0][1] < 0 &&
!partialReader.IsCompleted() {
if partialReader.BodySize() > 0 {
var options = this.ReqServer.HTTPCachePolicy.Options
if options != nil {
fileStorage, isFileStorage := storage.(*caches.FileStorage)
if isFileStorage && fileStorage.Options() != nil && fileStorage.Options().EnableIncompletePartialContent {
var r = ranges[0]
r2, findOk := partialReader.Ranges().FindRangeAtPosition(r.Start())
if findOk && r2.Length() >= (256<<10) /* worth reading */ {
isOk = true
ranges[0] = [2]int64{r.Start(), partialReader.BodySize() - 1} // Content-Range: bytes 0-[CONTENT_LENGTH - 1]/CONTENT_LENGTH
pReader.SetNextReader(NewHTTPRequestPartialReader(this, r2.End(), partialReader))
return pReader, ranges, r2.End() - 1 /* not include last byte */, true
}
}
}
}
return
}
// 检查范围
// 这里 **切记不要** 为末尾位置指定一个中间值,因为部分软件客户端不支持
for index, r := range ranges {
r1, ok := r.Convert(partialReader.MaxLength())
if !ok {
return
}
r2, ok := partialReader.ContainsRange(r1)
if !ok {
return
}
ranges[index] = r2
}
isOk = true
return pReader, ranges, -1, true
}

View File

@@ -0,0 +1,103 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes
import (
"github.com/TeaOSLab/EdgeNode/internal/caches"
"github.com/iwind/TeaGo/types"
"io"
"net/http"
)
// HTTPRequestPartialReader 分区文件读取器
type HTTPRequestPartialReader struct {
req *HTTPRequest
offset int64
resp *http.Response
cacheReader caches.Reader
cacheWriter caches.Writer
}
// NewHTTPRequestPartialReader 构建新的分区文件读取器
// req 当前请求
// offset 读取位置
// reader 当前缓存读取器
func NewHTTPRequestPartialReader(req *HTTPRequest, offset int64, reader caches.Reader) *HTTPRequestPartialReader {
return &HTTPRequestPartialReader{
req: req,
offset: offset,
cacheReader: reader,
}
}
// 读取内容
func (this *HTTPRequestPartialReader) Read(p []byte) (n int, err error) {
if this.resp == nil {
_ = this.cacheReader.Close()
this.req.RawReq.Header.Set("Range", "bytes="+types.String(this.offset)+"-")
var resp = this.req.doReverseProxy(false)
if resp == nil {
err = io.ErrUnexpectedEOF
return
}
this.resp = resp
// 对比Content-MD5
partialReader, ok := this.cacheReader.(*caches.PartialFileReader)
if ok {
if partialReader.Ranges().Version >= 2 && resp.Header.Get("Content-MD5") != partialReader.Ranges().ContentMD5 {
err = io.ErrUnexpectedEOF
var storage = this.req.writer.cacheStorage
if storage != nil {
_ = storage.Delete(this.req.cacheKey + caches.SuffixPartial)
}
return
}
}
// 准备写入
this.prepareCacheWriter()
}
n, err = this.resp.Body.Read(p)
// 写入到缓存
if n > 0 && this.cacheWriter != nil {
_ = this.cacheWriter.WriteAt(this.offset, p[:n])
this.offset += int64(n)
}
return
}
// Close 关闭读取器
func (this *HTTPRequestPartialReader) Close() error {
if this.cacheWriter != nil {
_ = this.cacheWriter.Close()
}
if this.resp != nil && this.resp.Body != nil {
return this.resp.Body.Close()
}
return nil
}
// 准备缓存写入器
func (this *HTTPRequestPartialReader) prepareCacheWriter() {
var storage = this.req.writer.cacheStorage
if storage == nil {
return
}
var cacheKey = this.req.cacheKey + caches.SuffixPartial
writer, err := storage.OpenWriter(cacheKey, this.cacheReader.ExpiresAt(), this.cacheReader.Status(), int(this.cacheReader.HeaderSize()), this.cacheReader.BodySize(), -1, true)
if err == nil {
this.cacheWriter = writer
}
}

View File

@@ -0,0 +1,8 @@
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build !plus
package nodes
func (this *HTTPRequest) doCC() (block bool) {
return
}

View File

@@ -0,0 +1,355 @@
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package nodes
import (
"fmt"
iplib "github.com/TeaOSLab/EdgeCommon/pkg/iplibrary"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/conns"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/utils/agents"
"github.com/TeaOSLab/EdgeNode/internal/utils/counters"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
maputils "github.com/TeaOSLab/EdgeNode/internal/utils/maps"
"github.com/TeaOSLab/EdgeNode/internal/utils/ttlcache"
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/types"
stringutil "github.com/iwind/TeaGo/utils/string"
"net/http"
"net/url"
"path/filepath"
"strings"
)
type ccBlockedCounter struct {
count int32
updatedAt int64
}
var ccBlockedMap = maputils.NewFixedMap[string, *ccBlockedCounter](65535)
func (this *HTTPRequest) doCC() (block bool) {
if this.nodeConfig == nil || this.ReqServer == nil {
return
}
var validatePath = "/GE/CC/VALIDATOR"
var maxConnections = 30
// 策略
var httpCCPolicy = this.nodeConfig.FindHTTPCCPolicyWithClusterId(this.ReqServer.ClusterId)
var scope = firewallconfigs.FirewallScopeGlobal
if httpCCPolicy != nil {
if !httpCCPolicy.IsOn {
return
}
if len(httpCCPolicy.RedirectsChecking.ValidatePath) > 0 {
validatePath = httpCCPolicy.RedirectsChecking.ValidatePath
}
if httpCCPolicy.MaxConnectionsPerIP > 0 {
maxConnections = httpCCPolicy.MaxConnectionsPerIP
}
scope = httpCCPolicy.FirewallScope()
}
var ccConfig = this.web.CC
if ccConfig == nil || !ccConfig.IsOn || (this.RawReq.URL.Path != validatePath && !ccConfig.MatchURL(this.requestScheme()+"://"+this.ReqHost+this.Path())) {
return
}
// 忽略常用文件
if ccConfig.IgnoreCommonFiles {
if len(this.RawReq.Referer()) > 0 {
var ext = filepath.Ext(this.RawReq.URL.Path)
if len(ext) > 0 && utils.IsCommonFileExtension(ext) {
return
}
}
}
// 检查白名单
var remoteAddr = this.requestRemoteAddr(true)
// 检查是否为白名单直连
if !Tea.IsTesting() && this.nodeConfig.IPIsAutoAllowed(remoteAddr) {
return
}
// 是否在全局名单中
canGoNext, isInAllowedList, _ := iplibrary.AllowIP(remoteAddr, this.ReqServer.Id)
if !canGoNext {
this.disableLog = true
this.Close()
return true
}
if isInAllowedList {
return false
}
// WAF黑名单
if waf.SharedIPBlackList.Contains(waf.IPTypeAll, scope, this.ReqServer.Id, remoteAddr) {
this.disableLog = true
this.Close()
return true
}
// 检查是否启用QPS
if ccConfig.MinQPSPerIP > 0 && this.RawReq.URL.Path != validatePath {
var avgQPS = counters.SharedCounter.IncreaseKey("QPS:"+remoteAddr, 60) / 60
if avgQPS <= 0 {
avgQPS = 1
}
if avgQPS < types.Uint32(ccConfig.MinQPSPerIP) {
return false
}
}
// 检查连接数
if conns.SharedMap.CountIPConns(remoteAddr) >= maxConnections {
this.ccForbid(5)
var forbiddenTimes = this.increaseCCCounter(remoteAddr)
waf.SharedIPBlackList.RecordIP(waf.IPTypeAll, scope, this.ReqServer.Id, remoteAddr, fasttime.Now().Unix()+int64(forbiddenTimes*1800), 0, scope == firewallconfigs.FirewallScopeGlobal, 0, 0, "CC防护拦截并发连接数超出"+types.String(maxConnections)+"个")
return true
}
// GET302验证
if ccConfig.EnableGET302 &&
this.RawReq.Method == http.MethodGet &&
!agents.SharedManager.ContainsIP(remoteAddr) /** 搜索引擎亲和性 **/ &&
!strings.HasPrefix(this.RawReq.URL.Path, "/baidu_verify_") /** 百度验证 **/ &&
!strings.HasPrefix(this.RawReq.URL.Path, "/google") /** Google验证 **/ {
// 忽略搜索引擎
var canSkip302 = false
var ipResult = iplib.LookupIP(remoteAddr)
if ipResult != nil && ipResult.IsOk() {
var providerName = ipResult.ProviderName()
canSkip302 = strings.Contains(providerName, "百度") || strings.Contains(providerName, "谷歌") || strings.Contains(providerName, "baidu") || strings.Contains(providerName, "google")
}
if !canSkip302 {
// 检查参数
var ccWhiteListKey = "HTTP-CC-GET302-" + remoteAddr
var currentTime = fasttime.Now().Unix()
if this.RawReq.URL.Path == validatePath {
this.DisableAccessLog()
this.DisableStat()
// TODO 根据浏览器信息决定是否校验referer
if !this.checkCCRedirects(httpCCPolicy, remoteAddr) {
return true
}
var key = this.RawReq.URL.Query().Get("key")
var pieces = strings.Split(key, ".") // key1.key2.timestamp
if len(pieces) != 3 {
this.ccForbid(1)
return true
}
var urlKey = pieces[0]
var timestampKey = pieces[1]
var timestamp = pieces[2]
var targetURL = this.RawReq.URL.Query().Get("url")
var realURLKey = stringutil.Md5(sharedNodeConfig.Secret + "@" + targetURL + "@" + remoteAddr)
if urlKey != realURLKey {
this.ccForbid(2)
return true
}
// 校验时间
if timestampKey != stringutil.Md5(sharedNodeConfig.Secret+"@"+timestamp) {
this.ccForbid(3)
return true
}
var elapsedSeconds = currentTime - types.Int64(timestamp)
if elapsedSeconds > 10 /** 10秒钟 **/ { // 如果校验时间过长,则可能阻止当前访问
if elapsedSeconds > 300 /** 5分钟 **/ { // 如果超时时间过长就跳回原URL
httpRedirect(this.writer, this.RawReq, targetURL, http.StatusFound)
} else {
this.ccForbid(4)
}
return true
}
// 加入到临时白名单
ttlcache.SharedInt64Cache.Write(ccWhiteListKey, 1, currentTime+600 /** 10分钟 **/)
// 跳转回原来URL
httpRedirect(this.writer, this.RawReq, targetURL, http.StatusFound)
return true
} else {
// 检查临时白名单
if ttlcache.SharedInt64Cache.Read(ccWhiteListKey) == nil {
if !this.checkCCRedirects(httpCCPolicy, remoteAddr) {
return true
}
var urlKey = stringutil.Md5(sharedNodeConfig.Secret + "@" + this.URL() + "@" + remoteAddr)
var timestampKey = stringutil.Md5(sharedNodeConfig.Secret + "@" + types.String(currentTime))
// 跳转到验证URL
this.DisableStat()
httpRedirect(this.writer, this.RawReq, validatePath+"?key="+urlKey+"."+timestampKey+"."+types.String(currentTime)+"&url="+url.QueryEscape(this.URL()), http.StatusFound)
return true
}
}
}
} else if this.RawReq.URL.Path == validatePath {
// 直接跳回
var targetURL = this.RawReq.URL.Query().Get("url")
httpRedirect(this.writer, this.RawReq, targetURL, http.StatusFound)
return true
}
// Key
var ccKeys = []string{}
if ccConfig.WithRequestPath {
ccKeys = append(ccKeys, "HTTP-CC-"+remoteAddr+"-"+this.Path()) // 这里可以忽略域名,因为一个正常用户同时访问多个域名的可能性不大
} else {
ccKeys = append(ccKeys, "HTTP-CC-"+remoteAddr)
}
// 指纹
if this.IsHTTPS && ccConfig.EnableFingerprint {
var requestConn = this.RawReq.Context().Value(HTTPConnContextKey)
if requestConn != nil {
clientConn, ok := requestConn.(ClientConnInterface)
if ok {
var fingerprint = clientConn.Fingerprint()
if len(fingerprint) > 0 {
var fingerprintString = fmt.Sprintf("%x", fingerprint)
if ccConfig.WithRequestPath {
ccKeys = append(ccKeys, "HTTP-CC-"+fingerprintString+"-"+this.Path()) // 这里可以忽略域名,因为一个正常用户同时访问多个域名的可能性不大
} else {
ccKeys = append(ccKeys, "HTTP-CC-"+fingerprintString)
}
}
}
}
}
// 检查阈值
var thresholds = ccConfig.Thresholds
if len(thresholds) == 0 || ccConfig.UseDefaultThresholds {
if httpCCPolicy != nil && len(httpCCPolicy.Thresholds) > 0 {
thresholds = httpCCPolicy.Thresholds
} else {
thresholds = serverconfigs.DefaultHTTPCCThresholds
}
}
if len(thresholds) == 0 {
return
}
var currentTime = fasttime.Now().Unix()
for _, threshold := range thresholds {
if threshold.PeriodSeconds <= 0 || threshold.MaxRequests <= 0 {
continue
}
for _, ccKey := range ccKeys {
var count = counters.SharedCounter.IncreaseKey(ccKey+"-T"+types.String(threshold.PeriodSeconds), int(threshold.PeriodSeconds))
if count >= types.Uint32(threshold.MaxRequests) {
this.writeCode(http.StatusTooManyRequests, "Too many requests, please wait for a few minutes.", "访问过于频繁,请稍等片刻后再访问。")
// 记录到黑名单
if threshold.BlockSeconds > 0 {
// 如果被重复拦截,则加大惩罚力度
var forbiddenTimes = this.increaseCCCounter(remoteAddr)
waf.SharedIPBlackList.RecordIP(waf.IPTypeAll, scope, this.ReqServer.Id, remoteAddr, currentTime+int64(threshold.BlockSeconds*forbiddenTimes), 0, scope == firewallconfigs.FirewallScopeGlobal, 0, 0, "CC防护拦截在"+types.String(threshold.PeriodSeconds)+"秒内请求超过"+types.String(threshold.MaxRequests)+"次")
}
this.tags = append(this.tags, "CCProtection")
this.isAttack = true
// 关闭连接
this.writer.Flush()
this.Close()
// 关闭同一个IP其他连接
conns.SharedMap.CloseIPConns(remoteAddr)
return true
}
}
}
return
}
// TODO 对forbidden比较多的IP进行惩罚
func (this *HTTPRequest) ccForbid(code int) {
this.writeCode(http.StatusForbidden, "The request has been blocked by cc policy.", "当前请求已被CC策略拦截。")
}
// 检查跳转次数
func (this *HTTPRequest) checkCCRedirects(httpCCPolicy *nodeconfigs.HTTPCCPolicy, remoteAddr string) bool {
// 如果无效跳转次数太多,则拦截
var ccRedirectKey = "HTTP-CC-GET302-" + remoteAddr + "-REDIRECTS"
var maxRedirectDurationSeconds = 120
var maxRedirects uint32 = 30
var blockSeconds int64 = 3600
if httpCCPolicy != nil && httpCCPolicy.IsOn {
if httpCCPolicy.RedirectsChecking.DurationSeconds > 0 {
maxRedirectDurationSeconds = httpCCPolicy.RedirectsChecking.DurationSeconds
}
if httpCCPolicy.RedirectsChecking.MaxRedirects > 0 {
maxRedirects = types.Uint32(httpCCPolicy.RedirectsChecking.MaxRedirects)
}
if httpCCPolicy.RedirectsChecking.BlockSeconds > 0 {
blockSeconds = int64(httpCCPolicy.RedirectsChecking.BlockSeconds)
}
}
var countRedirects = counters.SharedCounter.IncreaseKey(ccRedirectKey, maxRedirectDurationSeconds)
if countRedirects >= maxRedirects {
// 加入黑名单
var scope = firewallconfigs.FirewallScopeGlobal
if httpCCPolicy != nil {
scope = httpCCPolicy.FirewallScope()
}
waf.SharedIPBlackList.RecordIP(waf.IPTypeAll, scope, this.ReqServer.Id, remoteAddr, fasttime.Now().Unix()+blockSeconds, 0, scope == firewallconfigs.FirewallScopeGlobal, 0, 0, "CC防护拦截在"+types.String(maxRedirectDurationSeconds)+"秒内无效跳转"+types.String(maxRedirects)+"次")
this.Close()
return false
}
return true
}
// 对CC禁用次数进行计数
func (this *HTTPRequest) increaseCCCounter(remoteAddr string) int32 {
counter, ok := ccBlockedMap.Get(remoteAddr)
if !ok {
counter = &ccBlockedCounter{
count: 1,
updatedAt: fasttime.Now().Unix(),
}
ccBlockedMap.Put(remoteAddr, counter)
} else {
if counter.updatedAt < fasttime.Now().Unix()-86400 /** 有效期间不要超过1天防止无限期封禁 **/ {
counter.count = 0
}
counter.updatedAt = fasttime.Now().Unix()
if counter.count < 32 {
counter.count++
}
}
return counter.count
}

View File

@@ -0,0 +1,531 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes
import (
"bytes"
"compress/gzip"
"encoding/base64"
"fmt"
"io"
"net/http"
"regexp"
"strings"
"time"
"github.com/TeaOSLab/EdgeNode/internal/encryption"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/andybalholm/brotli"
)
const encryptionCacheVersion = "xor-v1"
// processPageEncryption 处理页面加密
func (this *HTTPRequest) processPageEncryption(resp *http.Response) error {
// 首先检查是否是 /waf/loader.js如果是则直接跳过不应该被加密
// 这个检查必须在所有其他检查之前,确保 loader.js 永远不会被加密
if strings.Contains(this.URL(), "/waf/loader.js") {
remotelogs.Debug("HTTP_REQUEST_ENCRYPTION", "skipping /waf/loader.js, should not be encrypted, URL: "+this.URL())
return nil
}
if this.web.Encryption == nil {
remotelogs.Debug("HTTP_REQUEST_ENCRYPTION", "encryption config is nil for URL: "+this.URL())
return nil
}
if !this.web.Encryption.IsOn {
remotelogs.Debug("HTTP_REQUEST_ENCRYPTION", "encryption switch is off for URL: "+this.URL())
return nil
}
if !this.web.Encryption.IsEnabled() {
remotelogs.Debug("HTTP_REQUEST_ENCRYPTION", "encryption is not enabled for URL: "+this.URL())
return nil
}
// 检查 URL 白名单
if this.web.Encryption.MatchExcludeURL(this.URL()) {
remotelogs.Debug("HTTP_REQUEST_ENCRYPTION", "URL is in exclude list: "+this.URL())
return nil
}
// 检查 Content-Type 和 URL
contentType := resp.Header.Get("Content-Type")
contentTypeLower := strings.ToLower(contentType)
urlLower := strings.ToLower(this.URL())
var isHTML = strings.Contains(contentTypeLower, "text/html")
// 判断是否为 JavaScript 文件:通过 Content-Type 或 URL 后缀
var isJavaScript = strings.Contains(contentTypeLower, "text/javascript") ||
strings.Contains(contentTypeLower, "application/javascript") ||
strings.Contains(contentTypeLower, "application/x-javascript") ||
strings.Contains(contentTypeLower, "text/ecmascript") ||
strings.HasSuffix(urlLower, ".js") ||
strings.Contains(urlLower, ".js?") ||
strings.Contains(urlLower, ".js&")
if !isHTML && !isJavaScript {
remotelogs.Debug("HTTP_REQUEST_ENCRYPTION", "content type not match, URL: "+this.URL()+", Content-Type: "+contentType)
return nil
}
// 检查内容大小(仅处理小于 10MB 的内容)
if resp.ContentLength > 0 && resp.ContentLength > 10*1024*1024 {
return nil
}
// 读取响应体
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
_ = resp.Body.Close()
// 如果源站返回了压缩内容,先解压再处理
decodedBody, decoded, err := decodeResponseBody(bodyBytes, resp.Header.Get("Content-Encoding"))
if err == nil && decoded {
bodyBytes = decodedBody
// 已经解压,移除 Content-Encoding
resp.Header.Del("Content-Encoding")
}
// 检查实际大小
if len(bodyBytes) > 10*1024*1024 {
// 内容太大,恢复原始响应体
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
return nil
}
var encryptedBytes []byte
// 处理 JavaScript 文件
if isJavaScript {
remotelogs.Debug("HTTP_REQUEST_ENCRYPTION", "processing JavaScript file, URL: "+this.URL())
// 检查是否需要加密独立的 JavaScript 文件
if this.web.Encryption.Javascript == nil {
remotelogs.Debug("HTTP_REQUEST_ENCRYPTION", "Javascript config is nil for URL: "+this.URL())
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
return nil
}
if !this.web.Encryption.Javascript.IsOn {
remotelogs.Debug("HTTP_REQUEST_ENCRYPTION", "Javascript encryption is not enabled for URL: "+this.URL())
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
return nil
}
// 检查 URL 匹配
if !this.web.Encryption.Javascript.MatchURL(this.URL()) {
remotelogs.Debug("HTTP_REQUEST_ENCRYPTION", "URL does not match patterns for URL: "+this.URL())
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
return nil
}
// 跳过 Loader 文件(必须排除,否则 loader.js 会被错误加密)
if strings.Contains(this.URL(), "/waf/loader.js") ||
strings.Contains(this.URL(), "waf-loader.js") ||
strings.Contains(this.URL(), "__WAF_") {
remotelogs.Debug("HTTP_REQUEST_ENCRYPTION", "skipping loader file, URL: "+this.URL())
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
return nil
}
// 加密 JavaScript 文件
remotelogs.Println("HTTP_REQUEST_ENCRYPTION", "encrypting JavaScript file, URL: "+this.URL())
encryptedBytes, err = this.encryptJavaScriptFile(bodyBytes, resp)
if err != nil {
remotelogs.Warn("HTTP_REQUEST_ENCRYPTION", "encrypt JavaScript file failed: "+err.Error())
// 加密失败,恢复原始响应体
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
return nil
}
remotelogs.Println("HTTP_REQUEST_ENCRYPTION", "JavaScript file encrypted successfully, URL: "+this.URL())
} else if isHTML {
// 加密 HTML 内容
encryptedBytes, err = this.encryptHTMLScripts(bodyBytes, resp)
if err != nil {
remotelogs.Warn("HTTP_REQUEST_ENCRYPTION", "encrypt HTML failed: "+err.Error())
// 加密失败,恢复原始响应体
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
return nil
}
} else {
// 未知类型,恢复原始响应体
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
return nil
}
// 替换响应体
resp.Body = io.NopCloser(bytes.NewReader(encryptedBytes))
resp.ContentLength = int64(len(encryptedBytes))
resp.Header.Set("Content-Length", fmt.Sprintf("%d", len(encryptedBytes)))
// 避免旧缓存导致解密算法不匹配
resp.Header.Set("Cache-Control", "no-store, no-cache, must-revalidate")
// 删除 Content-Encoding如果存在因为我们修改了内容
resp.Header.Del("Content-Encoding")
return nil
}
// encryptHTMLScripts 加密 HTML 中的脚本
func (this *HTTPRequest) encryptHTMLScripts(htmlBytes []byte, resp *http.Response) ([]byte, error) {
html := string(htmlBytes)
// 检查是否需要加密 HTML 脚本
if this.web.Encryption.HTML == nil || !this.web.Encryption.HTML.IsOn {
return htmlBytes, nil
}
// 检查 URL 匹配
if !this.web.Encryption.HTML.MatchURL(this.URL()) {
return htmlBytes, nil
}
// 生成密钥
remoteIP := this.requestRemoteAddr(true)
userAgent := this.RawReq.UserAgent()
keyID, actualKey := encryption.GenerateEncryptionKey(remoteIP, userAgent, this.web.Encryption.KeyPolicy)
// 检查缓存
var cacheKey string
if this.web.Encryption.Cache != nil && this.web.Encryption.Cache.IsOn {
// 生成缓存键keyID + URL + contentHash
contentHash := fmt.Sprintf("%x", bytesHash(htmlBytes))
cacheKey = fmt.Sprintf("encrypt_%s_%s_%s_%s", encryptionCacheVersion, keyID, this.URL(), contentHash)
cache := encryption.SharedEncryptionCache(
int(this.web.Encryption.Cache.MaxSize),
time.Duration(this.web.Encryption.Cache.TTL)*time.Second,
)
if cached, ok := cache.Get(cacheKey); ok {
return cached, nil
}
}
// 提取并加密内联脚本
if this.web.Encryption.HTML.EncryptInlineScripts {
html = this.encryptInlineScripts(html, actualKey, keyID)
}
// 提取并加密外部脚本(通过 src 属性)
if this.web.Encryption.HTML.EncryptExternalScripts {
html = this.encryptExternalScripts(html, actualKey, keyID)
}
// 注入 Loader
html = this.injectLoader(html)
result := []byte(html)
// 保存到缓存
if this.web.Encryption.Cache != nil && this.web.Encryption.Cache.IsOn && len(cacheKey) > 0 {
cache := encryption.SharedEncryptionCache(
int(this.web.Encryption.Cache.MaxSize),
time.Duration(this.web.Encryption.Cache.TTL)*time.Second,
)
cache.Set(cacheKey, result, this.web.Encryption.Cache.TTL)
}
return result, nil
}
// encryptInlineScripts 加密内联脚本
func (this *HTTPRequest) encryptInlineScripts(html string, key string, keyID string) string {
// 匹配 <script>...</script>(不包含 src 属性)
scriptRegex := regexp.MustCompile(`(?i)<script(?:\s+[^>]*)?>([\s\S]*?)</script>`)
return scriptRegex.ReplaceAllStringFunc(html, func(match string) string {
// 检查是否有 src 属性(外部脚本)
if strings.Contains(strings.ToLower(match), "src=") {
return match // 跳过外部脚本
}
// 提取脚本内容
contentMatch := regexp.MustCompile(`(?i)<script(?:\s+[^>]*)?>([\s\S]*?)</script>`)
matches := contentMatch.FindStringSubmatch(match)
if len(matches) < 2 {
return match
}
scriptContent := matches[1]
// 跳过空脚本或仅包含空白字符的脚本
if strings.TrimSpace(scriptContent) == "" {
return match
}
// 跳过已加密的脚本(包含 __WAF_P__
if strings.Contains(scriptContent, "__WAF_P__") {
return match
}
// 加密脚本内容
encrypted, err := this.encryptScript(scriptContent, key)
if err != nil {
return match // 加密失败,返回原始内容
}
// 生成元数据k 是 keyID用于缓存key 是实际密钥,用于解密)
meta := fmt.Sprintf(`{"k":"%s","key":"%s","t":%d,"alg":"xor"}`, keyID, key, time.Now().Unix())
// 替换为加密后的脚本(同步解密执行,保证脚本顺序)
return fmt.Sprintf(
`<script>(function(){
function xorDecodeToString(b64,key){
var bin=atob(b64);
var out=new Uint8Array(bin.length);
for(var i=0;i<bin.length;i++){out[i]=bin.charCodeAt(i)^key.charCodeAt(i%%key.length);}
if (typeof TextDecoder !== 'undefined') {
return new TextDecoder().decode(out);
}
var s='';for(var j=0;j<out.length;j++){s+=String.fromCharCode(out[j]);}
return s;
}
try{var meta=%s;var code=xorDecodeToString("%s",meta.key);window.eval(code);}catch(e){console.error('WAF inline decrypt/execute failed',e);}
})();</script>`,
meta,
encrypted,
)
})
}
// encryptExternalScripts 加密外部脚本(通过替换 src 为加密后的内容)
// 注意:这里我们实际上是将外部脚本的内容内联并加密
func (this *HTTPRequest) encryptExternalScripts(html string, key string, keyID string) string {
// 匹配 <script src="..."></script>
scriptRegex := regexp.MustCompile(`(?i)<script\s+([^>]*src\s*=\s*["']([^"']+)["'][^>]*)>\s*</script>`)
return scriptRegex.ReplaceAllStringFunc(html, func(match string) string {
// 提取 src URL
srcMatch := regexp.MustCompile(`(?i)src\s*=\s*["']([^"']+)["']`)
srcMatches := srcMatch.FindStringSubmatch(match)
if len(srcMatches) < 2 {
return match
}
srcURL := srcMatches[1]
// 跳过已加密的脚本或 Loader
if strings.Contains(srcURL, "waf-loader.js") || strings.Contains(srcURL, "__WAF_") {
return match
}
// 注意:实际实现中,我们需要获取外部脚本的内容
// 这里为了简化,我们只是标记需要加密,实际内容获取需要在响应处理时进行
// 当前实现:将外部脚本转换为内联加密脚本的占位符
// 实际生产环境需要1. 获取外部脚本内容 2. 加密 3. 替换
return match // 暂时返回原始内容,后续可以扩展
})
}
// encryptScript 加密脚本内容
func (this *HTTPRequest) encryptScript(scriptContent string, key string) (string, error) {
// 1. XOR 加密(不压缩,避免浏览器解压依赖导致顺序问题)
encrypted := xorEncrypt([]byte(scriptContent), []byte(key))
// 2. Base64 编码
encoded := base64.StdEncoding.EncodeToString(encrypted)
return encoded, nil
}
// injectLoader 注入 Loader 脚本
func (this *HTTPRequest) injectLoader(html string) string {
// 检查是否已经注入
if strings.Contains(html, "waf-loader.js") {
return html
}
// 在 </head> 之前注入,如果没有 </head>,则在 </body> 之前注入
// 不使用 async确保 loader 在解析阶段优先加载并执行
loaderScript := `<script src="/waf/loader.js"></script>`
if strings.Contains(html, "</head>") {
return strings.Replace(html, "</head>", loaderScript+"</head>", 1)
} else if strings.Contains(html, "</body>") {
return strings.Replace(html, "</body>", loaderScript+"</body>", 1)
} else {
// 如果都没有,在开头注入
return loaderScript + html
}
}
// compressGzip 使用 Gzip 压缩(浏览器原生支持)
func compressGzip(data []byte) ([]byte, error) {
var buf bytes.Buffer
writer := gzip.NewWriter(&buf)
_, err := writer.Write(data)
if err != nil {
writer.Close()
return nil, err
}
err = writer.Close()
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// decodeResponseBody 根据 Content-Encoding 解压响应体
func decodeResponseBody(body []byte, encoding string) ([]byte, bool, error) {
enc := strings.ToLower(strings.TrimSpace(encoding))
if enc == "" || enc == "identity" {
return body, false, nil
}
switch enc {
case "gzip":
reader, err := gzip.NewReader(bytes.NewReader(body))
if err != nil {
return body, false, err
}
defer reader.Close()
decoded, err := io.ReadAll(reader)
if err != nil {
return body, false, err
}
return decoded, true, nil
case "br":
reader := brotli.NewReader(bytes.NewReader(body))
decoded, err := io.ReadAll(reader)
if err != nil {
return body, false, err
}
return decoded, true, nil
default:
// 未知编码,保持原样
return body, false, nil
}
}
// compressBrotli 使用 Brotli 压缩(保留用于其他用途)
func compressBrotli(data []byte, level int) ([]byte, error) {
var buf bytes.Buffer
writer := brotli.NewWriterOptions(&buf, brotli.WriterOptions{
Quality: level,
LGWin: 14,
})
_, err := writer.Write(data)
if err != nil {
writer.Close()
return nil, err
}
err = writer.Close()
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// xorEncrypt XOR 加密
func xorEncrypt(data []byte, key []byte) []byte {
result := make([]byte, len(data))
keyLen := len(key)
if keyLen == 0 {
return data
}
for i := 0; i < len(data); i++ {
result[i] = data[i] ^ key[i%keyLen]
}
return result
}
// encryptJavaScriptFile 加密独立的 JavaScript 文件
func (this *HTTPRequest) encryptJavaScriptFile(jsBytes []byte, resp *http.Response) ([]byte, error) {
jsContent := string(jsBytes)
// 跳过空文件
if strings.TrimSpace(jsContent) == "" {
return jsBytes, nil
}
// 跳过已加密的脚本(包含 __WAF_P__
if strings.Contains(jsContent, "__WAF_P__") {
return jsBytes, nil
}
// 生成密钥
remoteIP := this.requestRemoteAddr(true)
userAgent := this.RawReq.UserAgent()
keyID, actualKey := encryption.GenerateEncryptionKey(remoteIP, userAgent, this.web.Encryption.KeyPolicy)
// 检查缓存
var cacheKey string
if this.web.Encryption.Cache != nil && this.web.Encryption.Cache.IsOn {
// 生成缓存键keyID + URL + contentHash
contentHash := fmt.Sprintf("%x", bytesHash(jsBytes))
cacheKey = fmt.Sprintf("encrypt_js_%s_%s_%s_%s", encryptionCacheVersion, keyID, this.URL(), contentHash)
cache := encryption.SharedEncryptionCache(
int(this.web.Encryption.Cache.MaxSize),
time.Duration(this.web.Encryption.Cache.TTL)*time.Second,
)
if cached, ok := cache.Get(cacheKey); ok {
return cached, nil
}
}
// 加密脚本内容
encrypted, err := this.encryptScript(jsContent, actualKey)
if err != nil {
return nil, err
}
// 生成元数据k 是 keyID用于缓存key 是实际密钥,用于解密)
meta := fmt.Sprintf(`{"k":"%s","key":"%s","t":%d,"alg":"xor"}`, keyID, actualKey, time.Now().Unix())
// 生成加密后的 JavaScript 代码(同步解密执行,保证脚本顺序)
encryptedJS := fmt.Sprintf(`(function() {
try {
function xorDecodeToString(b64, key) {
var bin = atob(b64);
var out = new Uint8Array(bin.length);
for (var i = 0; i < bin.length; i++) {
out[i] = bin.charCodeAt(i) ^ key.charCodeAt(i %% key.length);
}
if (typeof TextDecoder !== 'undefined') {
return new TextDecoder().decode(out);
}
var s = '';
for (var j = 0; j < out.length; j++) {
s += String.fromCharCode(out[j]);
}
return s;
}
var meta = %s;
var code = xorDecodeToString("%s", meta.key);
// 使用全局 eval尽量保持和 <script> 一致的作用域
window.eval(code);
} catch (e) {
console.error('WAF JS decrypt/execute failed', e);
}
})();`, meta, encrypted)
result := []byte(encryptedJS)
// 保存到缓存
if this.web.Encryption.Cache != nil && this.web.Encryption.Cache.IsOn && len(cacheKey) > 0 {
cache := encryption.SharedEncryptionCache(
int(this.web.Encryption.Cache.MaxSize),
time.Duration(this.web.Encryption.Cache.TTL)*time.Second,
)
cache.Set(cacheKey, result, this.web.Encryption.Cache.TTL)
}
return result, nil
}
// bytesHash 计算字节数组的简单哈希(用于缓存键)
func bytesHash(data []byte) uint64 {
var hash uint64 = 5381
for _, b := range data {
hash = ((hash << 5) + hash) + uint64(b)
}
return hash
}

View File

@@ -0,0 +1,119 @@
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/types"
"net/http"
"strings"
)
const httpStatusPageTemplate = `<!DOCTYPE html>
<html lang="en">
<head>
<title>${status} ${statusMessage}</title>
<meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
<style>
address { line-height: 1.8; }
</style>
</head>
<body>
<h1>${status} ${statusMessage}</h1>
<p>${message}</p>
<address>Connection: ${remoteAddr} (Client) -&gt; ${serverAddr} (Server)</address>
<address>Request ID: ${requestId}.</address>
</body>
</html>`
func (this *HTTPRequest) write404() {
this.writeCode(http.StatusNotFound, "", "")
}
func (this *HTTPRequest) writeCode(statusCode int, enMessage string, zhMessage string) {
if this.doPage(statusCode) {
return
}
var pageContent = configutils.ParseVariables(httpStatusPageTemplate, func(varName string) (value string) {
switch varName {
case "status":
return types.String(statusCode)
case "statusMessage":
return http.StatusText(statusCode)
case "message":
var acceptLanguages = this.RawReq.Header.Get("Accept-Language")
if len(acceptLanguages) > 0 {
var index = strings.Index(acceptLanguages, ",")
if index > 0 {
var firstLanguage = acceptLanguages[:index]
if firstLanguage == "zh-CN" {
return zhMessage
}
}
}
return enMessage
}
return this.Format("${" + varName + "}")
})
this.ProcessResponseHeaders(this.writer.Header(), statusCode)
this.writer.WriteHeader(statusCode)
_, _ = this.writer.Write([]byte(pageContent))
}
func (this *HTTPRequest) write50x(err error, statusCode int, enMessage string, zhMessage string, canTryStale bool) {
if err != nil {
this.addError(err)
}
// 尝试从缓存中恢复
if canTryStale &&
this.cacheCanTryStale &&
this.web.Cache.Stale != nil &&
this.web.Cache.Stale.IsOn &&
(len(this.web.Cache.Stale.Status) == 0 || lists.ContainsInt(this.web.Cache.Stale.Status, statusCode)) {
var ok = this.doCacheRead(true)
if ok {
return
}
}
// 显示自定义页面
if this.doPage(statusCode) {
return
}
// 内置HTML模板
var pageContent = configutils.ParseVariables(httpStatusPageTemplate, func(varName string) (value string) {
switch varName {
case "status":
return types.String(statusCode)
case "statusMessage":
return http.StatusText(statusCode)
case "requestId":
return this.requestId
case "message":
var acceptLanguages = this.RawReq.Header.Get("Accept-Language")
if len(acceptLanguages) > 0 {
var index = strings.Index(acceptLanguages, ",")
if index > 0 {
var firstLanguage = acceptLanguages[:index]
if firstLanguage == "zh-CN" {
return "网站出了一点小问题,原因:" + zhMessage + "。"
}
}
}
return "The site is unavailable now, cause: " + enMessage + "."
}
return this.Format("${" + varName + "}")
})
this.ProcessResponseHeaders(this.writer.Header(), statusCode)
this.writer.WriteHeader(statusCode)
_, _ = this.writer.Write([]byte(pageContent))
}

View File

@@ -0,0 +1,11 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build !script
// +build !script
package nodes
func (this *HTTPRequest) onInit() {
}
func (this *HTTPRequest) onRequest() {
}

View File

@@ -0,0 +1,124 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build script
// +build script
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
"rogchap.com/v8go"
"strings"
)
func (this *HTTPRequest) onInit() {
this.fireScriptEvent("init")
}
func (this *HTTPRequest) onRequest() {
if this.isDone || this.writer.isFinished {
return
}
this.fireScriptEvent("request")
}
// 触发事件
func (this *HTTPRequest) fireScriptEvent(eventType string) {
if SharedJSPool == nil {
return
}
if this.web.RequestScripts == nil {
return
}
var group *serverconfigs.ScriptGroupConfig
if eventType == "init" && this.web.RequestScripts.InitGroup != nil && this.web.RequestScripts.InitGroup.IsOn {
group = this.web.RequestScripts.InitGroup
} else if eventType == "request" && this.web.RequestScripts.RequestGroup != nil && this.web.RequestScripts.RequestGroup.IsOn {
group = this.web.RequestScripts.RequestGroup
}
if group == nil {
return
}
for _, script := range group.Scripts {
if this.isDone {
return
}
if !script.IsOn {
continue
}
if len(script.RealCode()) > 0 {
this.runScript(eventType, script.RealCode())
}
}
}
func (this *HTTPRequest) runScript(eventType string, script string) {
ctx, err := SharedJSPool.GetContext()
if err != nil {
remotelogs.Error("SCRIPT", "get context failed: "+err.Error())
return
}
ctx.SetServerId(this.ReqServer.Id)
defer SharedJSPool.PutContext(ctx)
var reqObjectId = ctx.AddGoRequestObject(this)
var respObjectId = ctx.AddGoResponseObject(this.writer)
script = `(function () {
let req = new gojs.net.http.Request()
req.setGoObject(` + types.String(reqObjectId) + `)
let resp = new gojs.net.http.Response()
resp.setGoObject(` + types.String(respObjectId) + `)
` + script + `
})()`
_, err = ctx.Run(script, "request."+eventType+".js")
if err != nil {
var errString = ""
jsErr, ok := err.(*v8go.JSError)
if ok {
var pieces = strings.Split(jsErr.Location, ":")
if len(pieces) < 3 {
errString = err.Error()
} else {
if strings.HasPrefix(pieces[0], "gojs") {
var line = types.Int(pieces[len(pieces)-2])
var scriptLine = ctx.ReadLineFromLibrary(pieces[0], line)
if len(scriptLine) > 0 {
if len(scriptLine) > 256 {
scriptLine = scriptLine[:256] + "..."
}
errString = jsErr.Error() + ", location: " + strings.Join(pieces[:len(pieces)-2], ":") + ":" + types.String(line) + ":" + pieces[len(pieces)-1] + ", code: " + scriptLine
}
}
if len(errString) == 0 {
var line = types.Int(pieces[len(pieces)-2])
var scriptLines = strings.Split(script, "\n")
var scriptLine = ""
if len(scriptLines) > line-1 {
scriptLine = scriptLines[line-1]
if len(scriptLine) > 256 {
scriptLine = scriptLine[:256] + "..."
}
}
line -= 6 /* 6是req和resp构造用的行数 */
errString = jsErr.Error() + ", location: " + strings.Join(pieces[:len(pieces)-2], ":") + ":" + types.String(line) + ":" + pieces[len(pieces)-1] + ", code: " + scriptLine
}
}
} else {
errString = err.Error()
}
remotelogs.ServerError(this.ReqServer.Id, "SCRIPT", "run "+eventType+" script failed: "+errString, nodeconfigs.NodeLogTypeRunScriptFailed, maps.Map{})
return
}
}

View File

@@ -0,0 +1,230 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"errors"
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types"
"github.com/iwind/gofcgi/pkg/fcgi"
"io"
"net"
"net/http"
"net/url"
"path/filepath"
"strings"
)
func (this *HTTPRequest) doFastcgi() (shouldStop bool) {
fastcgiList := []*serverconfigs.HTTPFastcgiConfig{}
for _, fastcgi := range this.web.FastcgiList {
if !fastcgi.IsOn {
continue
}
fastcgiList = append(fastcgiList, fastcgi)
}
if len(fastcgiList) == 0 {
return false
}
shouldStop = true
fastcgi := fastcgiList[rands.Int(0, len(fastcgiList)-1)]
env := fastcgi.FilterParams()
if !env.Has("DOCUMENT_ROOT") {
env["DOCUMENT_ROOT"] = ""
}
if !env.Has("REMOTE_ADDR") {
env["REMOTE_ADDR"] = this.requestRemoteAddr(true)
}
if !env.Has("QUERY_STRING") {
u, err := url.ParseRequestURI(this.uri)
if err == nil {
env["QUERY_STRING"] = u.RawQuery
} else {
env["QUERY_STRING"] = this.RawReq.URL.RawQuery
}
}
if !env.Has("SERVER_NAME") {
env["SERVER_NAME"] = this.ReqHost
}
if !env.Has("REQUEST_URI") {
env["REQUEST_URI"] = this.uri
}
if !env.Has("HOST") {
env["HOST"] = this.ReqHost
}
if len(this.ServerAddr) > 0 {
if !env.Has("SERVER_ADDR") {
env["SERVER_ADDR"] = this.ServerAddr
}
if !env.Has("SERVER_PORT") {
_, port, err := net.SplitHostPort(this.ServerAddr)
if err == nil {
env["SERVER_PORT"] = port
}
}
}
// 设置为持久化连接
var requestConn = this.RawReq.Context().Value(HTTPConnContextKey)
if requestConn == nil {
return
}
requestClientConn, ok := requestConn.(ClientConnInterface)
if ok {
requestClientConn.SetIsPersistent(true)
}
// 连接池配置
poolSize := fastcgi.PoolSize
if poolSize <= 0 {
poolSize = 32
}
client, err := fcgi.SharedPool(fastcgi.Network(), fastcgi.RealAddress(), uint(poolSize)).Client()
if err != nil {
this.write50x(err, http.StatusInternalServerError, "Failed to create Fastcgi pool", "Fastcgi池生成失败", false)
return
}
// 请求相关
if !env.Has("REQUEST_METHOD") {
env["REQUEST_METHOD"] = this.RawReq.Method
}
if !env.Has("CONTENT_LENGTH") {
env["CONTENT_LENGTH"] = fmt.Sprintf("%d", this.RawReq.ContentLength)
}
if !env.Has("CONTENT_TYPE") {
env["CONTENT_TYPE"] = this.RawReq.Header.Get("Content-Type")
}
if !env.Has("SERVER_SOFTWARE") {
env["SERVER_SOFTWARE"] = teaconst.ProductName + "/v" + teaconst.Version
}
// 处理SCRIPT_FILENAME
scriptPath := env.GetString("SCRIPT_FILENAME")
if len(scriptPath) > 0 && !strings.Contains(scriptPath, "/") && !strings.Contains(scriptPath, "\\") {
env["SCRIPT_FILENAME"] = env.GetString("DOCUMENT_ROOT") + Tea.DS + scriptPath
}
scriptFilename := filepath.Base(this.RawReq.URL.Path)
// PATH_INFO
pathInfoReg := fastcgi.PathInfoRegexp()
pathInfo := ""
if pathInfoReg != nil {
matches := pathInfoReg.FindStringSubmatch(this.RawReq.URL.Path)
countMatches := len(matches)
if countMatches == 1 {
pathInfo = matches[0]
} else if countMatches == 2 {
pathInfo = matches[1]
} else if countMatches > 2 {
scriptFilename = matches[1]
pathInfo = matches[2]
}
if !env.Has("PATH_INFO") {
env["PATH_INFO"] = pathInfo
}
}
this.addVarMapping(map[string]string{
"fastcgi.documentRoot": env.GetString("DOCUMENT_ROOT"),
"fastcgi.filename": scriptFilename,
"fastcgi.pathInfo": pathInfo,
})
params := map[string]string{}
for key, value := range env {
params[key] = this.Format(types.String(value))
}
this.processRequestHeaders(this.RawReq.Header)
for k, v := range this.RawReq.Header {
if k == "Connection" {
continue
}
for _, subV := range v {
params["HTTP_"+strings.ToUpper(strings.Replace(k, "-", "_", -1))] = subV
}
}
host, found := params["HTTP_HOST"]
if !found || len(host) == 0 {
params["HTTP_HOST"] = this.ReqHost
}
fcgiReq := fcgi.NewRequest()
fcgiReq.SetTimeout(fastcgi.ReadTimeoutDuration())
fcgiReq.SetParams(params)
fcgiReq.SetBody(this.RawReq.Body, uint32(this.requestLength()))
resp, stderr, err := client.Call(fcgiReq)
if err != nil {
this.write50x(err, http.StatusInternalServerError, "Failed to read Fastcgi", "读取Fastcgi失败", false)
return
}
if len(stderr) > 0 {
err := errors.New("Fastcgi Error: " + strings.TrimSpace(string(stderr)) + " script: " + maps.NewMap(params).GetString("SCRIPT_FILENAME"))
this.write50x(err, http.StatusInternalServerError, "Failed to read Fastcgi", "读取Fastcgi失败", false)
return
}
defer func() {
_ = resp.Body.Close()
}()
// 设置Charset
// TODO 这里应该可以设置文本类型的列表,以及是否强制覆盖所有文本类型的字符集
if this.web.Charset != nil && this.web.Charset.IsOn && len(this.web.Charset.Charset) > 0 {
contentTypes, ok := resp.Header["Content-Type"]
if ok && len(contentTypes) > 0 {
contentType := contentTypes[0]
if _, found := textMimeMap[contentType]; found {
resp.Header["Content-Type"][0] = contentType + "; charset=" + this.web.Charset.Charset
}
}
}
// 响应Header
this.writer.AddHeaders(resp.Header)
this.ProcessResponseHeaders(this.writer.Header(), resp.StatusCode)
// 准备
this.writer.Prepare(resp, resp.ContentLength, resp.StatusCode, true)
// 设置响应代码
this.writer.WriteHeader(resp.StatusCode)
// 输出到客户端
var pool = this.bytePool(resp.ContentLength)
var buf = pool.Get()
_, err = io.CopyBuffer(this.writer, resp.Body, buf.Bytes)
pool.Put(buf)
closeErr := resp.Body.Close()
if closeErr != nil {
remotelogs.Warn("HTTP_REQUEST_FASTCGI", closeErr.Error())
}
if err != nil && err != io.EOF {
remotelogs.Warn("HTTP_REQUEST_FASTCGI", err.Error())
this.addError(err)
}
// 是否成功结束
if err == nil && closeErr == nil {
this.writer.SetOk()
}
return
}

View File

@@ -0,0 +1,35 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/nodeutils"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
)
// 健康检查
func (this *HTTPRequest) doHealthCheck(key string, isHealthCheck *bool) (stop bool) {
this.tags = append(this.tags, "healthCheck")
this.RawReq.Header.Del(serverconfigs.HealthCheckHeaderName)
data, err := nodeutils.Base64DecodeMap(key)
if err != nil {
remotelogs.Error("HTTP_REQUEST_HEALTH_CHECK", "decode key failed: "+err.Error())
return
}
*isHealthCheck = true
this.web.StatRef = nil
if !data.GetBool("accessLogIsOn") {
this.disableLog = true
}
if data.GetBool("onlyBasicRequest") {
return true
}
return
}

View File

@@ -0,0 +1,16 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build !plus
package nodes
import "net/http"
func (this *HTTPRequest) processHLSBefore() (blocked bool) {
// stub
return false
}
func (this *HTTPRequest) processM3u8Response(resp *http.Response) error {
// stub
return nil
}

View File

@@ -0,0 +1,190 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package nodes
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"github.com/TeaOSLab/EdgeNode/internal/compressions"
"github.com/TeaOSLab/EdgeNode/internal/utils/readers"
"github.com/iwind/TeaGo/types"
"io"
"net/http"
"net/url"
"strings"
)
const (
defaultM3u8EncryptParam = "ge_m3u8_token"
m3u8EncryptExtension = ".ge-m3u8-key"
)
// process hls
func (this *HTTPRequest) processHLSBefore() (blocked bool) {
var requestPath = this.RawReq.URL.Path
// .ge-m3u8-key
if !strings.HasSuffix(requestPath, m3u8EncryptExtension) {
return
}
blocked = true
tokenBytes, base64Err := base64.StdEncoding.DecodeString(this.RawReq.URL.Query().Get(defaultM3u8EncryptParam))
if base64Err != nil {
this.writeCode(http.StatusBadRequest, "invalid m3u8 token: bad format", "invalid m3u8 token: bad format")
return
}
if len(tokenBytes) != 32 {
this.writeCode(http.StatusBadRequest, "invalid m3u8 token: bad length", "invalid m3u8 token: bad length")
return
}
_, _ = this.writer.Write(tokenBytes[:16])
return
}
// process m3u8 file
func (this *HTTPRequest) processM3u8Response(resp *http.Response) error {
var requestPath = this.RawReq.URL.Path
// .m3u8
if strings.HasSuffix(requestPath, ".m3u8") &&
this.web.HLS.Encrypting.MatchURL(this.URL()) {
return this.processM3u8File(resp)
}
// .ts
if strings.HasSuffix(requestPath, ".ts") {
var token = this.RawReq.URL.Query().Get(defaultM3u8EncryptParam)
if len(token) > 0 {
return this.processTSFile(resp, token)
}
}
return nil
}
func (this *HTTPRequest) processTSFile(resp *http.Response, token string) error {
rawToken, err := base64.StdEncoding.DecodeString(token)
if err != nil {
return err
}
if len(rawToken) != 32 {
return errors.New("invalid token length")
}
var key = rawToken[:16]
var iv = rawToken[16:]
block, err := aes.NewCipher(key)
if err != nil {
return fmt.Errorf("create cipher failed: %w", err)
}
var stream = cipher.NewCBCEncrypter(block, iv)
var reader = readers.NewFilterReaderCloser(resp.Body)
var blockSize = stream.BlockSize()
reader.Add(func(p []byte, readErr error) error {
var l = len(p)
if l == 0 {
return nil
}
var mod = l % blockSize
if mod != 0 {
p = append(p, bytes.Repeat([]byte{'0'}, blockSize-mod)...)
}
stream.CryptBlocks(p, p)
return readErr
})
resp.Body = reader
return nil
}
func (this *HTTPRequest) processM3u8File(resp *http.Response) error {
const maxSize = 1 << 20
// 检查内容长度
if resp.Body == nil || resp.ContentLength == 0 || resp.ContentLength > maxSize {
return nil
}
// 解压缩
compressions.WrapHTTPResponse(resp)
// 读取内容
data, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
if err != nil {
_ = resp.Body.Close()
return err
}
// 超出尺寸直接返回
if len(data) == maxSize {
resp.Body = io.NopCloser(io.MultiReader(bytes.NewBuffer(data), resp.Body))
return nil
}
var lines = bytes.Split(data, []byte{'\n'})
var addedKey = false
var ivBytes = make([]byte, 16)
var keyBytes = make([]byte, 16)
_, ivErr := rand.Read(ivBytes)
_, keyErr := rand.Read(keyBytes)
if ivErr != nil || keyErr != nil {
resp.Body = io.NopCloser(bytes.NewBuffer(data))
return nil
}
var ivString = fmt.Sprintf("%x", ivBytes)
var token = url.QueryEscape(base64.StdEncoding.EncodeToString(append(keyBytes, ivBytes...)))
for index, line := range lines {
if len(line) == 0 {
continue
}
if bytes.HasPrefix(line, []byte("#EXT-X-KEY")) {
goto returnData
}
if !addedKey && bytes.HasPrefix(line, []byte("#EXTINF")) {
this.URL()
var keyPath = strings.TrimSuffix(this.RawReq.URL.Path, ".m3u8") + m3u8EncryptExtension + "?" + defaultM3u8EncryptParam + "=" + token
lines[index] = append([]byte("#EXT-X-KEY:METHOD=AES-128,URI=\""+this.requestScheme()+"://"+this.ReqHost+keyPath+"\",IV=0x"+ivString+
"\n"), line...)
addedKey = true
continue
}
if line[0] != '#' && bytes.Contains(line, []byte(".ts")) {
if bytes.ContainsRune(line, '?') {
lines[index] = append(line, []byte("&"+defaultM3u8EncryptParam+"="+token)...)
} else {
lines[index] = append(line, []byte("?"+defaultM3u8EncryptParam+"="+token)...)
}
}
}
if addedKey {
this.tags = append(this.tags, "m3u8")
}
returnData:
data = bytes.Join(lines, []byte{'\n'})
resp.Body = io.NopCloser(bytes.NewBuffer(data))
resp.ContentLength = int64(len(data))
resp.Header.Set("Content-Length", types.String(resp.ContentLength))
return nil
}

View File

@@ -0,0 +1,228 @@
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/iwind/TeaGo/types"
"net"
"net/http"
"strconv"
"strings"
)
// 主机地址快速跳转
func (this *HTTPRequest) doHostRedirect() (blocked bool) {
var urlPath = this.RawReq.URL.Path
if this.web.MergeSlashes {
urlPath = utils.CleanPath(urlPath)
}
for _, u := range this.web.HostRedirects {
if !u.IsOn {
continue
}
if !u.MatchRequest(this.Format) {
continue
}
if len(u.ExceptDomains) > 0 && configutils.MatchDomains(u.ExceptDomains, this.ReqHost) {
continue
}
if len(u.OnlyDomains) > 0 && !configutils.MatchDomains(u.OnlyDomains, this.ReqHost) {
continue
}
var status = u.Status
if status <= 0 {
if searchEngineRegex.MatchString(this.RawReq.UserAgent()) {
status = http.StatusMovedPermanently
} else {
status = http.StatusTemporaryRedirect
}
}
var fullURL string
if u.BeforeHasQuery() {
fullURL = this.URL()
} else {
fullURL = this.requestScheme() + "://" + this.ReqHost + urlPath
}
if len(u.Type) == 0 || u.Type == serverconfigs.HTTPHostRedirectTypeURL {
if u.MatchPrefix { // 匹配前缀
if strings.HasPrefix(fullURL, u.BeforeURL) {
var afterURL = u.AfterURL
if u.KeepRequestURI {
afterURL += this.RawReq.URL.RequestURI()
}
// 前后是否一致
if fullURL == afterURL {
return false
}
this.ProcessResponseHeaders(this.writer.Header(), status)
httpRedirect(this.writer, this.RawReq, afterURL, status)
return true
}
} else if u.MatchRegexp { // 正则匹配
var reg = u.BeforeURLRegexp()
if reg == nil {
continue
}
var matches = reg.FindStringSubmatch(fullURL)
if len(matches) == 0 {
continue
}
var afterURL = u.AfterURL
for i, match := range matches {
afterURL = strings.ReplaceAll(afterURL, "${"+strconv.Itoa(i)+"}", match)
}
var subNames = reg.SubexpNames()
if len(subNames) > 0 {
for _, subName := range subNames {
if len(subName) > 0 {
index := reg.SubexpIndex(subName)
if index > -1 {
afterURL = strings.ReplaceAll(afterURL, "${"+subName+"}", matches[index])
}
}
}
}
// 前后是否一致
if fullURL == afterURL {
return false
}
if u.KeepArgs {
var qIndex = strings.Index(this.uri, "?")
if qIndex >= 0 {
afterURL += this.uri[qIndex:]
}
}
this.ProcessResponseHeaders(this.writer.Header(), status)
httpRedirect(this.writer, this.RawReq, afterURL, status)
return true
} else { // 精准匹配
if fullURL == u.RealBeforeURL() {
// 前后是否一致
if fullURL == u.AfterURL {
return false
}
var afterURL = u.AfterURL
if u.KeepArgs {
var qIndex = strings.Index(this.uri, "?")
if qIndex >= 0 {
var afterQIndex = strings.Index(u.AfterURL, "?")
if afterQIndex >= 0 {
afterURL = u.AfterURL[:afterQIndex] + this.uri[qIndex:]
} else {
afterURL += this.uri[qIndex:]
}
}
}
this.ProcessResponseHeaders(this.writer.Header(), status)
httpRedirect(this.writer, this.RawReq, afterURL, status)
return true
}
}
} else if u.Type == serverconfigs.HTTPHostRedirectTypeDomain {
if len(u.DomainAfter) == 0 {
continue
}
var reqHost = this.ReqHost
// 忽略跳转前端口
if u.DomainBeforeIgnorePorts {
h, _, err := net.SplitHostPort(reqHost)
if err == nil && len(h) > 0 {
reqHost = h
}
}
var scheme = u.DomainAfterScheme
if len(scheme) == 0 {
scheme = this.requestScheme()
}
if u.DomainsAll || configutils.MatchDomains(u.DomainsBefore, reqHost) {
var afterURL = scheme + "://" + u.DomainAfter + urlPath
if fullURL == afterURL {
// 终止匹配
return false
}
// 如果跳转前后域名一致,则终止
if u.DomainAfter == reqHost {
return false
}
this.ProcessResponseHeaders(this.writer.Header(), status)
// 参数
var qIndex = strings.Index(this.uri, "?")
if qIndex >= 0 {
afterURL += this.uri[qIndex:]
}
httpRedirect(this.writer, this.RawReq, afterURL, status)
return true
}
} else if u.Type == serverconfigs.HTTPHostRedirectTypePort {
if u.PortAfter <= 0 {
continue
}
var scheme = u.PortAfterScheme
if len(scheme) == 0 {
scheme = this.requestScheme()
}
reqHost, reqPort, _ := net.SplitHostPort(this.ReqHost)
if len(reqHost) == 0 {
reqHost = this.ReqHost
}
if len(reqPort) == 0 {
switch this.requestScheme() {
case "http":
reqPort = "80"
case "https":
reqPort = "443"
}
}
// 如果跳转前后端口一致,则终止
if reqPort == types.String(u.PortAfter) {
return false
}
var containsPort bool
if u.PortsAll {
containsPort = true
} else {
containsPort = u.ContainsPort(types.Int(reqPort))
}
if containsPort {
var newReqHost = reqHost
if !((scheme == "http" && u.PortAfter == 80) || (scheme == "https" && u.PortAfter == 443)) {
newReqHost += ":" + types.String(u.PortAfter)
}
var afterURL = scheme + "://" + newReqHost + urlPath
if fullURL == afterURL {
// 终止匹配
return false
}
this.ProcessResponseHeaders(this.writer.Header(), status)
httpRedirect(this.writer, this.RawReq, afterURL, status)
return true
}
}
}
return
}

View File

@@ -0,0 +1,10 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build !plus
package nodes
import "net/http"
func (this *HTTPRequest) processHTTP3Headers(respHeader http.Header) {
// stub
}

View File

@@ -0,0 +1,13 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package nodes
import "net/http"
func (this *HTTPRequest) processHTTP3Headers(respHeader http.Header) {
if this.ReqServer == nil || this.ReqServer.ClusterId <= 0 {
return
}
sharedHTTP3Manager.ProcessHTTP3Headers(this.RawReq.UserAgent(), respHeader, this.ReqServer.ClusterId)
}

View File

@@ -0,0 +1,41 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"net/http"
)
func (this *HTTPRequest) doRequestLimit() (shouldStop bool) {
// 是否在全局名单中
_, isInAllowedList, _ := iplibrary.AllowIP(this.RemoteAddr(), this.ReqServer.Id)
if isInAllowedList {
return false
}
// 检查请求Body尺寸
// TODO 处理分片提交的内容
if this.web.RequestLimit.MaxBodyBytes() > 0 &&
this.RawReq.ContentLength > this.web.RequestLimit.MaxBodyBytes() {
this.writeCode(http.StatusRequestEntityTooLarge, "", "")
return true
}
// 设置连接相关参数
if this.web.RequestLimit.MaxConns > 0 || this.web.RequestLimit.MaxConnsPerIP > 0 {
var requestConn = this.RawReq.Context().Value(HTTPConnContextKey)
if requestConn != nil {
clientConn, ok := requestConn.(ClientConnInterface)
if ok && !clientConn.IsBound() {
if !clientConn.Bind(this.ReqServer.Id, this.requestRemoteAddr(true), this.web.RequestLimit.MaxConns, this.web.RequestLimit.MaxConnsPerIP) {
this.writeCode(http.StatusTooManyRequests, "", "")
this.Close()
return true
}
}
}
}
return false
}

View File

@@ -0,0 +1,25 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build !plus
// +build !plus
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
)
const (
LNExpiresHeader = "X-Edge-Ln-Expires"
)
func existsLnNodeIP(nodeIP string) bool {
return false
}
func (this *HTTPRequest) checkLnRequest() bool {
return false
}
func (this *HTTPRequest) getLnOrigin(excludingNodeIds []int64, urlHash uint64) (originConfig *serverconfigs.OriginConfig, lnNodeId int64, hasMultipleNodes bool) {
return nil, 0, false
}

View File

@@ -0,0 +1,199 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/runtime/protoimpl"
"reflect"
"sync"
)
func UnmarshalLnRequestKey(data []byte, key proto.Message) error {
return proto.Unmarshal(data, key)
}
// LnRequestKey
// 使用数字作为json标识意在减少编码后的字节数
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type LnRequestKey struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Timestamp int64 `protobuf:"varint,1,opt,name=Timestamp,proto3" json:"1,omitempty"`
NodeId int64 `protobuf:"varint,2,opt,name=NodeId,proto3" json:"2,omitempty"`
RequestId string `protobuf:"bytes,3,opt,name=RequestId,proto3" json:"3,omitempty"`
RemoteAddr string `protobuf:"bytes,4,opt,name=RemoteAddr,proto3" json:"4,omitempty"`
URLMd5 string `protobuf:"bytes,5,opt,name=URLMd5,proto3" json:"5,omitempty"`
Method string `protobuf:"bytes,6,opt,name=Method,proto3" json:"6,omitempty"`
}
func (x *LnRequestKey) Reset() {
*x = LnRequestKey{}
if protoimpl.UnsafeEnabled {
mi := &file_models_model_ln_request_key_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *LnRequestKey) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*LnRequestKey) ProtoMessage() {}
func (x *LnRequestKey) ProtoReflect() protoreflect.Message {
mi := &file_models_model_ln_request_key_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use LnRequestKey.ProtoReflect.Descriptor instead.
func (*LnRequestKey) Descriptor() ([]byte, []int) {
return file_models_model_ln_request_key_proto_rawDescGZIP(), []int{0}
}
func (x *LnRequestKey) GetTimestamp() int64 {
if x != nil {
return x.Timestamp
}
return 0
}
func (x *LnRequestKey) GetNodeId() int64 {
if x != nil {
return x.NodeId
}
return 0
}
func (x *LnRequestKey) GetRequestId() string {
if x != nil {
return x.RequestId
}
return ""
}
func (x *LnRequestKey) GetRemoteAddr() string {
if x != nil {
return x.RemoteAddr
}
return ""
}
func (x *LnRequestKey) GetURLMd5() string {
if x != nil {
return x.URLMd5
}
return ""
}
func (x *LnRequestKey) GetMethod() string {
if x != nil {
return x.Method
}
return ""
}
var File_models_model_ln_request_key_proto protoreflect.FileDescriptor
var file_models_model_ln_request_key_proto_rawDesc = []byte{
0x0a, 0x21, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x73, 0x2f, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x5f, 0x6c,
0x6e, 0x5f, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x6b, 0x65, 0x79, 0x2e, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x12, 0x02, 0x70, 0x62, 0x22, 0xb2, 0x01, 0x0a, 0x0c, 0x4c, 0x6e, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x4b, 0x65, 0x79, 0x12, 0x1c, 0x0a, 0x09, 0x54, 0x69, 0x6d, 0x65,
0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x09, 0x54, 0x69, 0x6d,
0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x6f, 0x64, 0x65, 0x49, 0x64,
0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x6f, 0x64, 0x65, 0x49, 0x64, 0x12, 0x1c,
0x0a, 0x09, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x49, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28,
0x09, 0x52, 0x09, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x49, 0x64, 0x12, 0x1e, 0x0a, 0x0a,
0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x41, 0x64, 0x64, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09,
0x52, 0x0a, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x41, 0x64, 0x64, 0x72, 0x12, 0x16, 0x0a, 0x06,
0x55, 0x52, 0x4c, 0x4d, 0x64, 0x35, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x55, 0x52,
0x4c, 0x4d, 0x64, 0x35, 0x12, 0x16, 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x18, 0x06,
0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x42, 0x06, 0x5a, 0x04,
0x2e, 0x2f, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_models_model_ln_request_key_proto_rawDescOnce sync.Once
file_models_model_ln_request_key_proto_rawDescData = file_models_model_ln_request_key_proto_rawDesc
)
func file_models_model_ln_request_key_proto_rawDescGZIP() []byte {
file_models_model_ln_request_key_proto_rawDescOnce.Do(func() {
file_models_model_ln_request_key_proto_rawDescData = protoimpl.X.CompressGZIP(file_models_model_ln_request_key_proto_rawDescData)
})
return file_models_model_ln_request_key_proto_rawDescData
}
var file_models_model_ln_request_key_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
var file_models_model_ln_request_key_proto_goTypes = []interface{}{
(*LnRequestKey)(nil), // 0: pb.LnRequestKey
}
var file_models_model_ln_request_key_proto_depIdxs = []int32{
0, // [0:0] is the sub-list for method output_type
0, // [0:0] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_models_model_ln_request_key_proto_init() }
func file_models_model_ln_request_key_proto_init() {
if File_models_model_ln_request_key_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_models_model_ln_request_key_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*LnRequestKey); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_models_model_ln_request_key_proto_rawDesc,
NumEnums: 0,
NumMessages: 1,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_models_model_ln_request_key_proto_goTypes,
DependencyIndexes: file_models_model_ln_request_key_proto_depIdxs,
MessageInfos: file_models_model_ln_request_key_proto_msgTypes,
}.Build()
File_models_model_ln_request_key_proto = out.File
file_models_model_ln_request_key_proto_rawDesc = nil
file_models_model_ln_request_key_proto_goTypes = nil
file_models_model_ln_request_key_proto_depIdxs = nil
}
func (this *LnRequestKey) AsPB() ([]byte, error) {
return proto.Marshal(this)
}

View File

@@ -0,0 +1,72 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes_test
import (
"encoding/json"
"github.com/TeaOSLab/EdgeNode/internal/nodes"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
stringutil "github.com/iwind/TeaGo/utils/string"
"net/http"
"runtime"
"testing"
)
func TestMarshalLnRequestKey(t *testing.T) {
var key = &nodes.LnRequestKey{
Timestamp: fasttime.Now().Unix(),
NodeId: 1024,
RequestId: "abc",
RemoteAddr: "1.2.3.4",
URLMd5: stringutil.Md5("123456"),
Method: http.MethodPost,
}
pbData, err := key.AsPB()
if err != nil {
t.Fatal(err)
}
t.Log(len(pbData), "bytes")
var key2 = &nodes.LnRequestKey{}
err = nodes.UnmarshalLnRequestKey(pbData, key2)
if err != nil {
t.Fatal(err)
}
t.Log(key2)
}
func TestMarshalLnRequestKey_JSON(t *testing.T) {
var key = &nodes.LnRequestKey{
Timestamp: fasttime.Now().Unix(),
NodeId: 1024,
RequestId: "abc",
RemoteAddr: "1.2.3.4",
URLMd5: stringutil.Md5("123456"),
Method: http.MethodPost,
}
data, err := json.Marshal(key)
if err != nil {
t.Fatal(err)
}
t.Log(string(data))
}
func BenchmarkMarshalLnRequestKey(b *testing.B) {
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
var key = &nodes.LnRequestKey{
Timestamp: fasttime.Now().Unix(),
NodeId: 1024,
RequestId: "abc",
RemoteAddr: "1.2.3.4",
URLMd5: stringutil.Md5("123456"),
Method: http.MethodPost,
}
_, _ = key.AsPB()
}
})
}

View File

@@ -0,0 +1,271 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus
package nodes
import (
"context"
"encoding/base64"
"github.com/TeaOSLab/EdgeCommon/pkg/iputils"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils/encrypt"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
setutils "github.com/TeaOSLab/EdgeNode/internal/utils/sets"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types"
stringutil "github.com/iwind/TeaGo/utils/string"
"net"
"sync"
"time"
)
const (
LNKeyHeader = "X-Edge-Ln-Key"
LNExpiresHeader = "X-Edge-Ln-Expires"
LnExpiresSeconds int64 = 30
)
var lnOriginConfigsMap = map[string]*serverconfigs.OriginConfig{}
var lnEncryptMethodMap = map[string]encrypt.MethodInterface{}
var lnRequestIdSet = setutils.NewFixedSet(32 * 1024)
var lnNodeIPMap = map[string]bool{} // node ip => bool
var lnNodeIPLocker = &sync.RWMutex{}
var lnLocker = &sync.RWMutex{}
func existsLnNodeIP(nodeIP string) bool {
lnNodeIPLocker.RLock()
defer lnNodeIPLocker.RUnlock()
return lnNodeIPMap[nodeIP]
}
func (this *HTTPRequest) checkLnRequest() bool {
var keyString = this.RawReq.Header.Get(LNKeyHeader)
if len(keyString) <= 20 {
return false
}
// 删除
this.RawReq.Header.Del(LNKeyHeader)
// 当前连接IP
connIP, _, _ := net.SplitHostPort(this.RawReq.RemoteAddr)
var realIP = this.RawReq.Header.Get("X-Real-Ip")
if len(realIP) > 0 && !iputils.IsValid(realIP) {
realIP = ""
}
// 如果是在已经识别的L[n-1]节点IP列表中无需再次验证
if existsLnNodeIP(connIP) && len(realIP) > 0 {
this.tags = append(this.tags, "L"+types.String(this.nodeConfig.Level))
this.lnRemoteAddr = realIP
return true
}
// 如果已经在允许的IP中则直接允许通过无需再次验证
// 这个需要放在 existsLnNodeIP() 检查之后
if this.nodeConfig != nil && this.nodeConfig.IPIsAutoAllowed(connIP) && len(realIP) > 0 {
this.tags = append(this.tags, "L"+types.String(this.nodeConfig.Level))
this.lnRemoteAddr = realIP
lnNodeIPLocker.Lock()
lnNodeIPMap[connIP] = true
lnNodeIPLocker.Unlock()
return true
}
// 检查Key
keyEncodedData, err := base64.StdEncoding.DecodeString(keyString)
if err != nil {
return false
}
var secret = this.nodeConfig.SecretHash()
lnLocker.Lock()
method, ok := lnEncryptMethodMap[secret]
if !ok {
method, err = encrypt.NewMethodInstance("aes-192-cfb", secret, secret)
if err != nil {
lnLocker.Unlock()
return false
}
}
lnLocker.Unlock()
keyData, err := method.Decrypt(keyEncodedData)
if err != nil {
return false
}
var key = &LnRequestKey{}
err = UnmarshalLnRequestKey(keyData, key)
if err != nil {
return false
}
// Method和URL需要一致
if key.URLMd5 != stringutil.Md5(this.URL()) || key.Method != this.Method() {
return false
}
// N秒钟过期这里要求两个节点时间相差不能超过此时间
var currentUnixTime = fasttime.Now().Unix()
if key.Timestamp < currentUnixTime-LnExpiresSeconds {
return false
}
// 检查请求ID唯一性
// TODO 因为FixedSet是有限的这里仍然无法避免重放攻击
// TODO 而RequestId是并发的并不能简单的对比大小
// TODO 所以为了绝对的安全需要使用HTTPS并检查子节点的IP
if lnRequestIdSet.Has(key.RequestId) {
return false
}
lnRequestIdSet.Push(key.RequestId)
this.lnRemoteAddr = key.RemoteAddr
this.tags = append(this.tags, "L"+types.String(this.nodeConfig.Level))
// 当前连接IP
if len(connIP) > 0 {
lnNodeIPLocker.Lock()
lnNodeIPMap[connIP] = true
lnNodeIPLocker.Unlock()
}
return true
}
func (this *HTTPRequest) getLnOrigin(excludingNodeIds []int64, urlHash uint64) (originConfig *serverconfigs.OriginConfig, lnNodeId int64, hasMultipleNodes bool) {
var parentNodesMap = this.nodeConfig.ParentNodes // 需要复制,防止运行过程中修改
if len(parentNodesMap) == 0 {
return nil, 0, false
}
parentNodes, ok := parentNodesMap[this.ReqServer.ClusterId]
var countParentNodes = len(parentNodes)
if ok && countParentNodes > 0 {
var parentNode *nodeconfigs.ParentNodeConfig
// 尝试顺序读取
if len(excludingNodeIds) > 0 {
for _, node := range parentNodes {
if !lists.ContainsInt64(excludingNodeIds, node.Id) {
parentNode = node
break
}
}
}
// 尝试随机读取
if parentNode == nil {
if countParentNodes == 1 {
parentNode = parentNodes[0]
} else {
var method = serverconfigs.LnRequestSchedulingMethodURLMapping
if this.nodeConfig != nil {
var globalServerConfig = this.nodeConfig.GlobalServerConfig // copy
if globalServerConfig != nil {
method = globalServerConfig.HTTPAll.LnRequestSchedulingMethod
}
}
switch method {
case serverconfigs.LnRequestSchedulingMethodRandom:
// 随机选取一个L2节点有利于流量均衡但同一份缓存可能会存在多个L2节点上占用更多的空间
parentNode = parentNodes[rands.Int(0, countParentNodes-1)]
default:
// 从固定的L2节点读取内容优点是能够提升缓存命中率缺点是可能会导致多个L2节点流量不均衡
parentNode = parentNodes[urlHash%uint64(countParentNodes)]
}
}
}
lnNodeId = parentNode.Id
var countAddrs = len(parentNode.Addrs)
if countAddrs == 0 {
return nil, 0, false
}
var addr = parentNode.Addrs[rands.Int(0, countAddrs-1)]
var protocol = this.requestScheme() // http|https TODO 需要可以设置是否强制HTTPS回二级节点
var portString = types.String(this.requestServerPort())
var originKey = protocol + "@" + addr + "@" + portString
lnLocker.RLock()
config, ok := lnOriginConfigsMap[originKey]
lnLocker.RUnlock()
if !ok {
config = &serverconfigs.OriginConfig{
Id: 0,
IsOn: true,
Addr: &serverconfigs.NetworkAddressConfig{
Protocol: serverconfigs.Protocol(protocol),
Host: addr,
PortRange: portString,
},
IsOk: true,
}
err := config.Init(context.Background())
if err != nil {
remotelogs.Error("HTTP_REQUEST", "create ln origin config failed: "+err.Error())
return nil, 0, false
}
lnLocker.Lock()
lnOriginConfigsMap[originKey] = config
lnLocker.Unlock()
}
// 添加Header
this.RawReq.Header.Set(LNKeyHeader, this.encodeLnKey(parentNode.SecretHash))
return config, lnNodeId, len(parentNodes) > 0
}
return nil, 0, false
}
func (this *HTTPRequest) encodeLnKey(secretHash string) string {
var key = &LnRequestKey{
NodeId: this.nodeConfig.Id,
RequestId: this.requestId,
Timestamp: time.Now().Unix(),
RemoteAddr: this.RemoteAddr(),
URLMd5: stringutil.Md5(this.URL()),
Method: this.Method(),
}
data, err := key.AsPB()
if err != nil {
remotelogs.Error("HTTP_REQUEST", "ln request: encode key failed: "+err.Error())
return ""
}
lnLocker.Lock()
method, ok := lnEncryptMethodMap[secretHash]
if !ok {
method, err = encrypt.NewMethodInstance("aes-192-cfb", secretHash, secretHash)
if err != nil {
remotelogs.Error("HTTP_REQUEST", "ln request: create encrypt method failed: "+err.Error())
lnLocker.Unlock()
return ""
}
lnEncryptMethodMap[secretHash] = method
}
lnLocker.Unlock()
dst, err := method.Encrypt(data)
if err != nil {
remotelogs.Error("HTTP_REQUEST", "ln request: encode key failed: "+err.Error())
return ""
}
return base64.StdEncoding.EncodeToString(dst)
}

View File

@@ -0,0 +1,211 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes
import (
"net/http"
)
// serveWAFLoader 提供 WAF Loader JavaScript 文件
func (this *HTTPRequest) serveWAFLoader() {
loaderJS := `(function() {
'use strict';
// 全局队列与执行器
window.__WAF_Q__ = window.__WAF_Q__ || [];
if (!window.__WAF_LOADER__) {
window.__WAF_LOADER__ = {
executing: false,
execute: function() {
if (this.executing) {
return;
}
this.executing = true;
var self = this;
var queue = window.__WAF_Q__ || [];
var runNext = function() {
if (!queue.length) {
self.executing = false;
return;
}
var item = queue.shift();
executeDecryptedCode(item.p, item.m, runNext);
};
runNext();
}
};
}
// 1. XOR 解码为字符串(不压缩,避免顺序问题)
function xorDecodeToString(b64, key) {
try {
var bin = atob(b64);
var out = new Uint8Array(bin.length);
for (var i = 0; i < bin.length; i++) {
out[i] = bin.charCodeAt(i) ^ key.charCodeAt(i % key.length);
}
if (typeof TextDecoder !== 'undefined') {
return new TextDecoder().decode(out);
}
var s = '';
for (var j = 0; j < out.length; j++) {
s += String.fromCharCode(out[j]);
}
return s;
} catch (e) {
console.error('WAF Loader: xor decode failed', e);
return '';
}
}
// 2. XOR 解密
function decryptXOR(payload, key) {
try {
var binary = atob(payload);
var output = [];
var keyLen = key.length;
if (keyLen === 0) {
return '';
}
for (var i = 0; i < binary.length; i++) {
var charCode = binary.charCodeAt(i) ^ key.charCodeAt(i % keyLen);
output.push(String.fromCharCode(charCode));
}
return output.join('');
} catch (e) {
console.error('WAF Loader: Decrypt failed', e);
return '';
}
}
// 3. 执行解密后的代码
function executeDecryptedCode(cipher, meta, done) {
var finish = function() {
if (typeof done === 'function') {
done();
}
};
try {
if (!cipher || !meta || !meta.key) {
console.error('WAF Loader: Missing cipher or meta.key');
finish();
return;
}
if (meta.alg !== 'xor') {
console.error('WAF Loader: Unsupported alg', meta.alg);
finish();
return;
}
// 1. XOR 解码为字符串
var plainJS = xorDecodeToString(cipher, meta.key);
if (!plainJS) {
console.error('WAF Loader: XOR decode failed');
finish();
return;
}
// 2. 执行解密后的代码(同步)
try {
// 使用全局 eval尽量保持和 <script> 一致的作用域
window.eval(plainJS);
// 3. 计算 Token 并握手
calculateAndHandshake(plainJS);
} catch (e) {
console.error('WAF Loader: Execute failed', e);
}
finish();
} catch (e) {
console.error('WAF Loader: Security check failed', e);
finish();
}
}
// 7. Token 计算和握手
function base64EncodeUtf8(str) {
try {
return btoa(unescape(encodeURIComponent(str)));
} catch (e) {
console.error('WAF Loader: base64 encode failed', e);
return '';
}
}
function calculateAndHandshake(decryptedJS) {
try {
// 计算 Token简化示例
var tokenData = decryptedJS.substring(0, Math.min(100, decryptedJS.length)) + Date.now();
var token = base64EncodeUtf8(tokenData).substring(0, 64); // 限制长度
// 发送握手请求
fetch('/waf/handshake', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ token: token })
}).then(function(resp) {
if (resp.ok) {
// 握手成功,设置全局 Token
window.__WAF_TOKEN__ = token;
// 为后续请求设置 Header
if (window.XMLHttpRequest) {
var originalOpen = XMLHttpRequest.prototype.open;
XMLHttpRequest.prototype.open = function() {
originalOpen.apply(this, arguments);
if (window.__WAF_TOKEN__) {
this.setRequestHeader('X-WAF-TOKEN', window.__WAF_TOKEN__);
}
};
}
// 为 fetch 设置 Header
var originalFetch = window.fetch;
window.fetch = function() {
var args = Array.prototype.slice.call(arguments);
if (args.length > 1 && typeof args[1] === 'object') {
if (!args[1].headers) {
args[1].headers = {};
}
if (window.__WAF_TOKEN__) {
args[1].headers['X-WAF-TOKEN'] = window.__WAF_TOKEN__;
}
} else {
args[1] = {
headers: {
'X-WAF-TOKEN': window.__WAF_TOKEN__ || ''
}
};
}
return originalFetch.apply(this, args);
};
}
}).catch(function(err) {
console.error('WAF Loader: Handshake failed', err);
});
} catch (e) {
console.error('WAF Loader: Calculate token failed', e);
}
}
// 8. 主逻辑
if (window.__WAF_Q__ && window.__WAF_Q__.length) {
window.__WAF_LOADER__.execute();
} else {
// 如果没有加密内容,等待一下再检查(可能是异步加载)
setTimeout(function() {
if (window.__WAF_Q__ && window.__WAF_Q__.length) {
window.__WAF_LOADER__.execute();
}
}, 100);
}
})();`
this.writer.Header().Set("Content-Type", "application/javascript; charset=utf-8")
this.writer.Header().Set("Cache-Control", "public, max-age=3600")
this.writer.WriteHeader(http.StatusOK)
_, _ = this.writer.WriteString(loaderJS)
this.writer.SetOk()
}

View File

@@ -0,0 +1,188 @@
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"strings"
"time"
)
const (
// AccessLogMaxRequestBodySize 访问日志存储的请求内容最大尺寸 TODO 此值应该可以在访问日志页设置
AccessLogMaxRequestBodySize = 2 << 20
)
// 日志
func (this *HTTPRequest) log() {
// 检查全局配置
if this.nodeConfig != nil && this.nodeConfig.GlobalServerConfig != nil && !this.nodeConfig.GlobalServerConfig.HTTPAccessLog.IsOn {
return
}
var ref *serverconfigs.HTTPAccessLogRef
if !this.forceLog {
if this.disableLog {
return
}
// 计算请求时间
this.requestCost = time.Since(this.requestFromTime).Seconds()
ref = this.web.AccessLogRef
if ref == nil {
ref = serverconfigs.DefaultHTTPAccessLogRef
}
if !ref.IsOn {
return
}
if !ref.Match(this.writer.StatusCode()) {
return
}
if ref.FirewallOnly && this.firewallPolicyId == 0 {
return
}
// 是否记录499
if !ref.EnableClientClosed && this.writer.StatusCode() == 499 {
return
}
}
var addr = this.RawReq.RemoteAddr
var index = strings.LastIndex(addr, ":")
if index > 0 {
addr = addr[:index]
}
var serverGlobalConfig = this.nodeConfig.GlobalServerConfig
// 请求Cookie
var cookies = map[string]string{}
var enableCookies = false
if serverGlobalConfig == nil || serverGlobalConfig.HTTPAccessLog.EnableCookies {
enableCookies = true
if ref == nil || ref.ContainsField(serverconfigs.HTTPAccessLogFieldCookie) {
for _, cookie := range this.RawReq.Cookies() {
cookies[cookie.Name] = cookie.Value
}
}
}
// 请求Header
var pbReqHeader = map[string]*pb.Strings{}
if serverGlobalConfig == nil || serverGlobalConfig.HTTPAccessLog.EnableRequestHeaders {
if ref == nil || ref.ContainsField(serverconfigs.HTTPAccessLogFieldHeader) {
// 是否只记录通用Header
var commonHeadersOnly = serverGlobalConfig != nil && serverGlobalConfig.HTTPAccessLog.CommonRequestHeadersOnly
for k, v := range this.RawReq.Header {
if commonHeadersOnly && !serverconfigs.IsCommonRequestHeader(k) {
continue
}
if !enableCookies && k == "Cookie" {
continue
}
pbReqHeader[k] = &pb.Strings{Values: v}
}
}
}
// 响应Header
var pbResHeader = map[string]*pb.Strings{}
if serverGlobalConfig == nil || serverGlobalConfig.HTTPAccessLog.EnableResponseHeaders {
if ref == nil || ref.ContainsField(serverconfigs.HTTPAccessLogFieldSentHeader) {
for k, v := range this.writer.Header() {
pbResHeader[k] = &pb.Strings{Values: v}
}
}
}
// 参数列表
var queryString = ""
if ref == nil || ref.ContainsField(serverconfigs.HTTPAccessLogFieldArg) {
queryString = this.requestQueryString()
}
// 浏览器
var userAgent = ""
if ref == nil || ref.ContainsField(serverconfigs.HTTPAccessLogFieldUserAgent) || ref.ContainsField(serverconfigs.HTTPAccessLogFieldExtend) {
userAgent = this.RawReq.UserAgent()
}
// 请求来源
var referer = ""
if ref == nil || ref.ContainsField(serverconfigs.HTTPAccessLogFieldReferer) {
referer = this.RawReq.Referer()
}
var accessLog = &pb.HTTPAccessLog{
RequestId: this.requestId,
NodeId: this.nodeConfig.Id,
ServerId: this.ReqServer.Id,
RemoteAddr: this.requestRemoteAddr(true),
RawRemoteAddr: addr,
RemotePort: int32(this.requestRemotePort()),
RemoteUser: this.requestRemoteUser(),
RequestURI: this.rawURI,
RequestPath: this.Path(),
RequestLength: this.requestLength(),
RequestTime: this.requestCost,
RequestMethod: this.RawReq.Method,
RequestFilename: this.requestFilename(),
Scheme: this.requestScheme(),
Proto: this.RawReq.Proto,
BytesSent: this.writer.SentBodyBytes(), // TODO 加上Header Size
BodyBytesSent: this.writer.SentBodyBytes(),
Status: int32(this.writer.StatusCode()),
StatusMessage: "",
TimeISO8601: this.requestFromTime.Format("2006-01-02T15:04:05.000Z07:00"),
TimeLocal: this.requestFromTime.Format("2/Jan/2006:15:04:05 -0700"),
Msec: float64(this.requestFromTime.Unix()) + float64(this.requestFromTime.Nanosecond())/1000000000,
Timestamp: this.requestFromTime.Unix(),
Host: this.ReqHost,
Referer: referer,
UserAgent: userAgent,
Request: this.requestString(),
ContentType: this.writer.Header().Get("Content-Type"),
Cookie: cookies,
Args: queryString,
QueryString: queryString,
Header: pbReqHeader,
ServerName: this.ServerName,
ServerPort: int32(this.requestServerPort()),
ServerProtocol: this.RawReq.Proto,
SentHeader: pbResHeader,
Errors: this.errors,
Hostname: HOSTNAME,
FirewallPolicyId: this.firewallPolicyId,
FirewallRuleGroupId: this.firewallRuleGroupId,
FirewallRuleSetId: this.firewallRuleSetId,
FirewallRuleId: this.firewallRuleId,
FirewallActions: this.firewallActions,
Tags: this.tags,
Attrs: this.logAttrs,
}
if this.origin != nil {
accessLog.OriginId = this.origin.Id
accessLog.OriginAddress = this.originAddr
accessLog.OriginStatus = this.originStatus
}
// 请求Body
if (ref != nil && ref.ContainsField(serverconfigs.HTTPAccessLogFieldRequestBody)) || this.wafHasRequestBody {
accessLog.RequestBody = this.requestBodyData
if len(accessLog.RequestBody) > AccessLogMaxRequestBodySize {
accessLog.RequestBody = accessLog.RequestBody[:AccessLogMaxRequestBodySize]
}
}
// TODO 记录匹配的 locationId和rewriteId非必要需求
sharedHTTPAccessLogQueue.Push(accessLog)
}

View File

@@ -0,0 +1,48 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/metrics"
)
// 指标统计 - 响应
// 只需要在结束时调用指标进行统计
func (this *HTTPRequest) doMetricsResponse() {
metrics.SharedManager.Add(this)
}
func (this *HTTPRequest) MetricKey(key string) string {
return this.Format(key)
}
func (this *HTTPRequest) MetricValue(value string) (result int64, ok bool) {
// TODO 需要忽略健康检查的请求,但是同时也要防止攻击者模拟健康检查
switch value {
case "${countRequest}":
return 1, true
case "${countTrafficOut}":
// 这里不包括Header长度
return this.writer.SentBodyBytes(), true
case "${countTrafficIn}":
var hl int64 = 0 // header length
for k, values := range this.RawReq.Header {
for _, v := range values {
hl += int64(len(k) + len(v) + 2 /** k: v **/)
}
}
return this.RawReq.ContentLength + hl, true
case "${countConnection}":
return 1, true
}
return 0, false
}
func (this *HTTPRequest) MetricServerId() int64 {
return this.ReqServer.Id
}
func (this *HTTPRequest) MetricCategory() string {
return serverconfigs.MetricItemCategoryHTTP
}

View File

@@ -0,0 +1,116 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/nodeutils"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/utils/ttlcache"
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/iwind/TeaGo/types"
"net/http"
"time"
)
// 域名无匹配情况处理
func (this *HTTPRequest) doMismatch() {
// 是否为健康检查
var healthCheckKey = this.RawReq.Header.Get(serverconfigs.HealthCheckHeaderName)
if len(healthCheckKey) > 0 {
_, err := nodeutils.Base64DecodeMap(healthCheckKey)
if err == nil {
this.writer.WriteHeader(http.StatusOK)
return
}
}
// 是否已经在黑名单
var remoteIP = this.RemoteAddr()
if waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, remoteIP) {
this.Close()
return
}
// 根据配置进行相应的处理
var nodeConfig = sharedNodeConfig // copy
if nodeConfig != nil {
var globalServerConfig = nodeConfig.GlobalServerConfig
if globalServerConfig != nil && globalServerConfig.HTTPAll.MatchDomainStrictly {
var statusCode = 404
var httpAllConfig = globalServerConfig.HTTPAll
var mismatchAction = httpAllConfig.DomainMismatchAction
if mismatchAction != nil && mismatchAction.Options != nil {
var mismatchStatusCode = mismatchAction.Options.GetInt("statusCode")
if mismatchStatusCode > 0 && mismatchStatusCode >= 100 && mismatchStatusCode < 1000 {
statusCode = mismatchStatusCode
}
}
// 是否正在访问IP
if globalServerConfig.HTTPAll.NodeIPShowPage && utils.IsWildIP(this.ReqHost) {
this.writer.statusCode = statusCode
var contentHTML = this.Format(globalServerConfig.HTTPAll.NodeIPPageHTML)
this.writer.Header().Set("Content-Type", "text/html; charset=utf-8")
this.writer.Header().Set("Content-Length", types.String(len(contentHTML)))
this.writer.WriteHeader(statusCode)
_, _ = this.writer.WriteString(contentHTML)
return
}
// 检查cc
// TODO 可以在管理端配置是否开启以及最多尝试次数
// 要考虑到服务在切换集群时,域名未生效状态时,用户访问的仍然是老集群中的节点,就会产生找不到域名的情况
if len(remoteIP) > 0 {
const maxAttempts = 100
if ttlcache.SharedInt64Cache.IncreaseInt64("MISMATCH_DOMAIN:"+remoteIP, int64(1), time.Now().Unix()+60, false) > maxAttempts {
// 在加入之前再次检查黑名单
if !waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, remoteIP) {
waf.SharedIPBlackList.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, remoteIP, time.Now().Unix()+3600)
}
}
}
// 处理当前连接
if mismatchAction != nil {
if mismatchAction.Code == serverconfigs.DomainMismatchActionPage {
if mismatchAction.Options != nil {
this.writer.statusCode = statusCode
var contentHTML = this.Format(mismatchAction.Options.GetString("contentHTML"))
this.writer.Header().Set("Content-Type", "text/html; charset=utf-8")
this.writer.Header().Set("Content-Length", types.String(len(contentHTML)))
this.writer.WriteHeader(statusCode)
_, _ = this.writer.Write([]byte(contentHTML))
} else {
http.Error(this.writer, "404 page not found: '"+this.URL()+"'", http.StatusNotFound)
}
return
}
if mismatchAction.Code == serverconfigs.DomainMismatchActionRedirect {
var url = this.Format(mismatchAction.Options.GetString("url"))
if len(url) > 0 {
httpRedirect(this.writer, this.RawReq, url, http.StatusTemporaryRedirect)
} else {
http.Error(this.writer, "404 page not found: '"+this.URL()+"'", http.StatusNotFound)
}
return
}
if mismatchAction.Code == serverconfigs.DomainMismatchActionClose {
http.Error(this.writer, "404 page not found: '"+this.URL()+"'", http.StatusNotFound)
this.Close()
return
}
}
http.Error(this.writer, "404 page not found: '"+this.URL()+"'", http.StatusNotFound)
this.Close()
return
}
}
http.Error(this.writer, "404 page not found: '"+this.URL()+"'", http.StatusNotFound)
}

View File

@@ -0,0 +1,15 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build !plus
package nodes
import (
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"net/http"
)
func (this *HTTPRequest) doOSSOrigin(origin *serverconfigs.OriginConfig) (resp *http.Response, goNext bool, errorCode string, ossBucketName string, err error) {
// stub
return nil, false, "", "", errors.New("not implemented")
}

View File

@@ -0,0 +1,78 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package nodes
import (
"bytes"
"errors"
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/oss"
"github.com/iwind/TeaGo/types"
"io"
"net/http"
)
// 请求OSS源站
func (this *HTTPRequest) doOSSOrigin(origin *serverconfigs.OriginConfig) (resp *http.Response, goNext bool, errorCode string, ossBucketName string, err error) {
if origin == nil || origin.OSS == nil {
err = errors.New("'origin' or 'origin.OSS' should not be nil")
return
}
// 只支持少数方法
var isHeadRequest = this.RawReq.Method != http.MethodGet &&
this.RawReq.Method != http.MethodPost &&
this.RawReq.Method != http.MethodPut
var rangeBytes = this.RawReq.Header.Get("Range")
if isHeadRequest && len(rangeBytes) == 0 {
resp, errorCode, ossBucketName, err = oss.SharedManager.Head(this.RawReq, this.ReqHost, origin.OSS)
} else {
if len(rangeBytes) > 0 {
resp, errorCode, ossBucketName, err = oss.SharedManager.GetRange(this.RawReq, this.ReqHost, rangeBytes, origin.OSS)
} else {
resp, errorCode, ossBucketName, err = oss.SharedManager.Get(this.RawReq, this.ReqHost, origin.OSS)
}
}
if len(ossBucketName) == 0 {
this.originAddr = origin.OSS.Type
} else {
this.originAddr = origin.OSS.Type + "/" + ossBucketName
}
this.tags = append(this.tags, "oss")
if err != nil {
if oss.IsNotFound(err) {
this.write404()
return nil, false, errorCode, ossBucketName, nil
}
if oss.IsTimeout(err) {
this.writeCode(http.StatusGatewayTimeout, "Read object timeout.", "读取对象超时。")
return nil, false, errorCode, ossBucketName, nil
}
return nil, false, errorCode, ossBucketName, fmt.Errorf("OSS: [%s]: %s: %w", origin.OSS.Type, errorCode, err)
}
if isHeadRequest {
_ = resp.Body.Close()
resp.Body = io.NopCloser(&bytes.Buffer{})
}
// fix Content-Length
if resp.Header == nil {
resp.Header = http.Header{}
}
if resp.ContentLength > 0 {
_, ok := resp.Header["Content-Length"]
if !ok {
resp.Header.Set("Content-Length", types.String(resp.ContentLength))
}
}
return resp, true, "", ossBucketName, nil
}

View File

@@ -0,0 +1,170 @@
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/utils/bytepool"
"github.com/iwind/TeaGo/Tea"
"net/http"
"os"
"path"
"strings"
)
const defaultPageContentType = "text/html; charset=utf-8"
// 请求特殊页面
func (this *HTTPRequest) doPage(status int) (shouldStop bool) {
if len(this.web.Pages) == 0 {
// 集群自定义页面
if this.nodeConfig != nil && this.ReqServer != nil && this.web.EnableGlobalPages {
var httpPagesPolicy = this.nodeConfig.FindHTTPPagesPolicyWithClusterId(this.ReqServer.ClusterId)
if httpPagesPolicy != nil && httpPagesPolicy.IsOn && len(httpPagesPolicy.Pages) > 0 {
return this.doPageLookup(httpPagesPolicy.Pages, status)
}
}
return false
}
// 查找当前网站自定义页面
shouldStop = this.doPageLookup(this.web.Pages, status)
if shouldStop {
return
}
// 集群自定义页面
if this.nodeConfig != nil && this.ReqServer != nil && this.web.EnableGlobalPages {
var httpPagesPolicy = this.nodeConfig.FindHTTPPagesPolicyWithClusterId(this.ReqServer.ClusterId)
if httpPagesPolicy != nil && httpPagesPolicy.IsOn && len(httpPagesPolicy.Pages) > 0 {
return this.doPageLookup(httpPagesPolicy.Pages, status)
}
}
return
}
func (this *HTTPRequest) doPageLookup(pages []*serverconfigs.HTTPPageConfig, status int) (shouldStop bool) {
var url = this.URL()
for _, page := range pages {
if !page.MatchURL(url) {
continue
}
if page.Match(status) {
if len(page.BodyType) == 0 || page.BodyType == serverconfigs.HTTPPageBodyTypeURL {
if urlSchemeRegexp.MatchString(page.URL) {
var newStatus = page.NewStatus
if newStatus <= 0 {
newStatus = status
}
this.doURL(http.MethodGet, page.URL, "", newStatus, true)
return true
} else {
var realpath = path.Clean(page.URL)
if !strings.HasPrefix(realpath, "/pages/") && !strings.HasPrefix(realpath, "pages/") { // only files under "/pages/" can be used
var msg = "404 page not found: '" + page.URL + "'"
this.writer.Header().Set("Content-Type", defaultPageContentType)
this.writer.WriteHeader(http.StatusNotFound)
_, _ = this.writer.Write([]byte(msg))
return true
}
var file = Tea.Root + Tea.DS + realpath
fp, err := os.Open(file)
if err != nil {
var msg = "404 page not found: '" + page.URL + "'"
this.writer.Header().Set("Content-Type", defaultPageContentType)
this.writer.WriteHeader(http.StatusNotFound)
_, _ = this.writer.Write([]byte(msg))
return true
}
defer func() {
_ = fp.Close()
}()
stat, err := fp.Stat()
if err != nil {
var msg = "404 could not read page content: '" + page.URL + "'"
this.writer.Header().Set("Content-Type", defaultPageContentType)
this.writer.WriteHeader(http.StatusNotFound)
_, _ = this.writer.Write([]byte(msg))
return true
}
// 修改状态码
if page.NewStatus > 0 {
// 自定义响应Headers
this.writer.Header().Set("Content-Type", defaultPageContentType)
this.ProcessResponseHeaders(this.writer.Header(), page.NewStatus)
this.writer.Prepare(nil, stat.Size(), page.NewStatus, true)
this.writer.WriteHeader(page.NewStatus)
} else {
this.writer.Header().Set("Content-Type", defaultPageContentType)
this.ProcessResponseHeaders(this.writer.Header(), status)
this.writer.Prepare(nil, stat.Size(), status, true)
this.writer.WriteHeader(status)
}
var buf = bytepool.Pool1k.Get()
_, err = utils.CopyWithFilter(this.writer, fp, buf.Bytes, func(p []byte) []byte {
return []byte(this.Format(string(p)))
})
bytepool.Pool1k.Put(buf)
if err != nil {
if !this.canIgnore(err) {
remotelogs.Warn("HTTP_REQUEST_PAGE", "write to client failed: "+err.Error())
}
} else {
this.writer.SetOk()
}
}
return true
} else if page.BodyType == serverconfigs.HTTPPageBodyTypeHTML {
// 这里需要实现设置Status因为在Format()中可以获取${status}等变量
if page.NewStatus > 0 {
this.writer.statusCode = page.NewStatus
} else {
this.writer.statusCode = status
}
var content = this.Format(page.Body)
// 修改状态码
if page.NewStatus > 0 {
// 自定义响应Headers
this.writer.Header().Set("Content-Type", defaultPageContentType)
this.ProcessResponseHeaders(this.writer.Header(), page.NewStatus)
this.writer.Prepare(nil, int64(len(content)), page.NewStatus, true)
this.writer.WriteHeader(page.NewStatus)
} else {
this.writer.Header().Set("Content-Type", defaultPageContentType)
this.ProcessResponseHeaders(this.writer.Header(), status)
this.writer.Prepare(nil, int64(len(content)), status, true)
this.writer.WriteHeader(status)
}
_, err := this.writer.WriteString(content)
if err != nil {
if !this.canIgnore(err) {
remotelogs.Warn("HTTP_REQUEST_PAGE", "write to client failed: "+err.Error())
}
} else {
this.writer.SetOk()
}
return true
} else if page.BodyType == serverconfigs.HTTPPageBodyTypeRedirectURL {
var newURL = this.Format(page.URL)
if len(newURL) == 0 {
newURL = "/"
}
if page.NewStatus > 0 && httpStatusIsRedirect(page.NewStatus) {
httpRedirect(this.writer, this.RawReq, newURL, page.NewStatus)
} else {
httpRedirect(this.writer, this.RawReq, newURL, http.StatusTemporaryRedirect)
}
this.writer.SetOk()
return true
}
}
}
return false
}

View File

@@ -0,0 +1,10 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build !plus
package nodes
// 检查套餐
func (this *HTTPRequest) doPlanBefore() (blocked bool) {
// stub
return false
}

View File

@@ -0,0 +1,38 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"net/http"
)
// 检查套餐
func (this *HTTPRequest) doPlanBefore() (blocked bool) {
// check date
if !this.ReqServer.UserPlan.IsAvailable() {
this.tags = append(this.tags, "plan")
var statusCode = http.StatusNotFound
this.ProcessResponseHeaders(this.writer.Header(), statusCode)
this.writer.WriteHeader(statusCode)
_, _ = this.writer.WriteString(this.Format(serverconfigs.DefaultPlanExpireNoticePageBody))
return true
}
// check max upload size
if this.RawReq.ContentLength > 0 {
var plan = sharedNodeConfig.FindPlan(this.ReqServer.UserPlan.PlanId)
if plan != nil && plan.MaxUploadSize != nil && plan.MaxUploadSize.Count > 0 {
if this.RawReq.ContentLength > plan.MaxUploadSize.Bytes() {
this.writeCode(http.StatusRequestEntityTooLarge, "Reached max upload size limitation in your plan.", "触发套餐中最大文件上传尺寸限制。")
this.tags = append(this.tags, "plan")
return true
}
}
}
return false
}

View File

@@ -0,0 +1,49 @@
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"net/http"
"strconv"
"strings"
)
func (this *HTTPRequest) doRedirectToHTTPS(redirectToHTTPSConfig *serverconfigs.HTTPRedirectToHTTPSConfig) (shouldBreak bool) {
var host = this.RawReq.Host
// 检查域名是否匹配
if !redirectToHTTPSConfig.MatchDomain(host) {
return false
}
if len(redirectToHTTPSConfig.Host) > 0 {
if redirectToHTTPSConfig.Port > 0 && redirectToHTTPSConfig.Port != 443 {
host = redirectToHTTPSConfig.Host + ":" + strconv.Itoa(redirectToHTTPSConfig.Port)
} else {
host = redirectToHTTPSConfig.Host
}
} else if redirectToHTTPSConfig.Port > 0 {
var lastIndex = strings.LastIndex(host, ":")
if lastIndex > 0 {
host = host[:lastIndex]
}
if redirectToHTTPSConfig.Port != 443 {
host = host + ":" + strconv.Itoa(redirectToHTTPSConfig.Port)
}
} else {
var lastIndex = strings.LastIndex(host, ":")
if lastIndex > 0 {
host = host[:lastIndex]
}
}
var statusCode = http.StatusMovedPermanently
if redirectToHTTPSConfig.Status > 0 {
statusCode = redirectToHTTPSConfig.Status
}
var newURL = "https://" + host + this.RawReq.RequestURI
this.ProcessResponseHeaders(this.writer.Header(), statusCode)
httpRedirect(this.writer, this.RawReq, newURL, statusCode)
return true
}

View File

@@ -0,0 +1,78 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes
import (
"net/http"
"net/url"
)
func (this *HTTPRequest) doCheckReferers() (shouldStop bool) {
if this.web.Referers == nil {
return
}
// 检查URL
if !this.web.Referers.MatchURL(this.URL()) {
return
}
var origin = this.RawReq.Header.Get("Origin")
const cacheSeconds = "3600" // 时间不能过长,防止修改设置后长期无法生效
// 处理用到Origin的特殊功能
if this.web.Referers.CheckOrigin && len(origin) > 0 {
// 处理Websocket
if this.web.Websocket != nil && this.web.Websocket.IsOn && this.RawReq.Header.Get("Upgrade") == "websocket" {
originHost, _ := httpParseHost(origin)
if len(originHost) > 0 && this.web.Websocket.MatchOrigin(originHost) {
return
}
}
}
var refererURL = this.RawReq.Header.Get("Referer")
if len(refererURL) == 0 && this.web.Referers.CheckOrigin {
if len(origin) > 0 && origin != "null" {
if urlSchemeRegexp.MatchString(origin) {
refererURL = origin
} else {
refererURL = "https://" + origin
}
}
}
if len(refererURL) == 0 {
if this.web.Referers.MatchDomain(this.ReqHost, "") {
return
}
this.tags = append(this.tags, "refererCheck")
this.writer.Header().Set("Cache-Control", "max-age="+cacheSeconds)
this.writeCode(http.StatusForbidden, "The referer has been blocked.", "当前访问已被防盗链系统拦截。")
return true
}
u, err := url.Parse(refererURL)
if err != nil {
if this.web.Referers.MatchDomain(this.ReqHost, "") {
return
}
this.tags = append(this.tags, "refererCheck")
this.writer.Header().Set("Cache-Control", "max-age="+cacheSeconds)
this.writeCode(http.StatusForbidden, "The referer has been blocked.", "当前访问已被防盗链系统拦截。")
return true
}
if !this.web.Referers.MatchDomain(this.ReqHost, u.Host) {
this.tags = append(this.tags, "refererCheck")
this.writer.Header().Set("Cache-Control", "max-age="+cacheSeconds)
this.writeCode(http.StatusForbidden, "The referer has been blocked.", "当前访问已被防盗链系统拦截。")
return true
}
return
}

View File

@@ -0,0 +1,684 @@
package nodes
import (
"context"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
"github.com/TeaOSLab/EdgeNode/internal/compressions"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/utils/bytepool"
"github.com/TeaOSLab/EdgeNode/internal/utils/fnv"
"github.com/TeaOSLab/EdgeNode/internal/utils/minifiers"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/types"
"io"
"net/http"
"net/url"
"strconv"
"strings"
)
// 处理反向代理
// writeToClient 读取响应并发送到客户端
func (this *HTTPRequest) doReverseProxy(writeToClient bool) (resultResp *http.Response) {
if this.reverseProxy == nil {
return
}
var retries = 3
var failedOriginIds []int64
var failedLnNodeIds []int64
var failStatusCode int
for i := 0; i < retries; i++ {
originId, lnNodeId, shouldRetry, resp := this.doOriginRequest(failedOriginIds, failedLnNodeIds, i == 0, i == retries-1, &failStatusCode, writeToClient)
if !shouldRetry {
resultResp = resp
break
}
if originId > 0 {
failedOriginIds = append(failedOriginIds, originId)
}
if lnNodeId > 0 {
failedLnNodeIds = append(failedLnNodeIds, lnNodeId)
}
}
return
}
// 请求源站
func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeIds []int64, isFirstTry bool, isLastRetry bool, failStatusCode *int, writeToClient bool) (originId int64, lnNodeId int64, shouldRetry bool, resultResp *http.Response) {
// 对URL的处理
var stripPrefix = this.reverseProxy.StripPrefix
var requestURI = this.reverseProxy.RequestURI
var requestURIHasVariables = this.reverseProxy.RequestURIHasVariables()
var oldURI = this.uri
var requestHost = ""
if this.reverseProxy.RequestHostType == serverconfigs.RequestHostTypeCustomized {
requestHost = this.reverseProxy.RequestHost
}
var requestHostHasVariables = this.reverseProxy.RequestHostHasVariables()
// 源站
var requestCall = shared.NewRequestCall()
requestCall.Request = this.RawReq
requestCall.Formatter = this.Format
requestCall.Domain = this.ReqHost
var origin *serverconfigs.OriginConfig
// 二级节点
var hasMultipleLnNodes = false
if this.cacheRef != nil || (this.nodeConfig != nil && this.nodeConfig.GlobalServerConfig != nil && this.nodeConfig.GlobalServerConfig.HTTPAll.ForceLnRequest) {
origin, lnNodeId, hasMultipleLnNodes = this.getLnOrigin(failedLnNodeIds, fnv.HashString(this.URL()))
if origin != nil {
// 强制变更原来访问的域名
requestHost = this.ReqHost
}
if this.cacheRef != nil {
// 回源Header中去除If-None-Match和If-Modified-Since
if !this.cacheRef.EnableIfNoneMatch {
this.DeleteHeader("If-None-Match")
}
if !this.cacheRef.EnableIfModifiedSince {
this.DeleteHeader("If-Modified-Since")
}
}
}
// 自定义源站
if origin == nil {
if !isFirstTry {
origin = this.reverseProxy.AnyOrigin(requestCall, failedOriginIds)
}
if origin == nil {
origin = this.reverseProxy.NextOrigin(requestCall)
if origin != nil && origin.Id > 0 && (*failStatusCode >= 403 && *failStatusCode <= 404) && lists.ContainsInt64(failedOriginIds, origin.Id) {
shouldRetry = false
isLastRetry = true
}
}
requestCall.CallResponseCallbacks(this.writer)
if origin == nil {
err := errors.New(this.URL() + ": no available origin sites for reverse proxy")
remotelogs.ServerError(this.ReqServer.Id, "HTTP_REQUEST_REVERSE_PROXY", err.Error(), "", nil)
this.write50x(err, http.StatusBadGateway, "No origin site yet", "尚未配置源站", true)
return
}
originId = origin.Id
if len(origin.StripPrefix) > 0 {
stripPrefix = origin.StripPrefix
}
if len(origin.RequestURI) > 0 {
requestURI = origin.RequestURI
requestURIHasVariables = origin.RequestURIHasVariables()
}
}
this.origin = origin // 设置全局变量是为了日志等处理
if len(origin.RequestHost) > 0 {
requestHost = origin.RequestHost
requestHostHasVariables = origin.RequestHostHasVariables()
}
// 处理OSS
var isHTTPOrigin = origin.OSS == nil
// 处理Scheme
if isHTTPOrigin && origin.Addr == nil {
err := errors.New(this.URL() + ": Origin '" + strconv.FormatInt(origin.Id, 10) + "' does not has a address")
remotelogs.ErrorServer("HTTP_REQUEST_REVERSE_PROXY", err.Error())
this.write50x(err, http.StatusBadGateway, "Origin site did not has a valid address", "源站尚未配置地址", true)
return
}
if isHTTPOrigin {
this.RawReq.URL.Scheme = origin.Addr.Protocol.Primary().Scheme()
}
// StripPrefix
if len(stripPrefix) > 0 {
if stripPrefix[0] != '/' {
stripPrefix = "/" + stripPrefix
}
this.uri = strings.TrimPrefix(this.uri, stripPrefix)
if len(this.uri) == 0 || this.uri[0] != '/' {
this.uri = "/" + this.uri
}
}
// RequestURI
if len(requestURI) > 0 {
if requestURIHasVariables {
this.uri = this.Format(requestURI)
} else {
this.uri = requestURI
}
if len(this.uri) == 0 || this.uri[0] != '/' {
this.uri = "/" + this.uri
}
// 处理RequestURI中的问号
var questionMark = strings.LastIndex(this.uri, "?")
if questionMark > 0 {
var path = this.uri[:questionMark]
if strings.Contains(path, "?") {
this.uri = path + "&" + this.uri[questionMark+1:]
}
}
// 去除多个/
this.uri = utils.CleanPath(this.uri)
}
var originAddr = ""
if isHTTPOrigin {
// 获取源站地址
originAddr = origin.Addr.PickAddress()
if origin.Addr.HostHasVariables() {
originAddr = this.Format(originAddr)
}
// 端口跟随
if origin.FollowPort {
var originHostIndex = strings.Index(originAddr, ":")
if originHostIndex < 0 {
var originErr = errors.New(this.URL() + ": Invalid origin address '" + originAddr + "', lacking port")
remotelogs.ErrorServer("HTTP_REQUEST_REVERSE_PROXY", originErr.Error())
this.write50x(originErr, http.StatusBadGateway, "No port in origin site address", "源站地址中没有配置端口", true)
return
}
originAddr = originAddr[:originHostIndex+1] + types.String(this.requestServerPort())
}
this.originAddr = originAddr
// RequestHost
if len(requestHost) > 0 {
if requestHostHasVariables {
this.RawReq.Host = this.Format(requestHost)
} else {
this.RawReq.Host = requestHost
}
// 是否移除端口
if this.reverseProxy.RequestHostExcludingPort {
this.RawReq.Host = utils.ParseAddrHost(this.RawReq.Host)
}
this.RawReq.URL.Host = this.RawReq.Host
} else if this.reverseProxy.RequestHostType == serverconfigs.RequestHostTypeOrigin {
// 源站主机名
var hostname = originAddr
if origin.Addr.Protocol.IsHTTPFamily() {
hostname = strings.TrimSuffix(hostname, ":80")
} else if origin.Addr.Protocol.IsHTTPSFamily() {
hostname = strings.TrimSuffix(hostname, ":443")
}
this.RawReq.Host = hostname
// 是否移除端口
if this.reverseProxy.RequestHostExcludingPort {
this.RawReq.Host = utils.ParseAddrHost(this.RawReq.Host)
}
this.RawReq.URL.Host = this.RawReq.Host
} else {
this.RawReq.URL.Host = this.ReqHost
// 是否移除端口
if this.reverseProxy.RequestHostExcludingPort {
this.RawReq.Host = utils.ParseAddrHost(this.RawReq.Host)
this.RawReq.URL.Host = utils.ParseAddrHost(this.RawReq.URL.Host)
}
}
}
// 重组请求URL
var questionMark = strings.Index(this.uri, "?")
if questionMark > -1 {
this.RawReq.URL.Path = this.uri[:questionMark]
this.RawReq.URL.RawQuery = this.uri[questionMark+1:]
} else {
this.RawReq.URL.Path = this.uri
this.RawReq.URL.RawQuery = ""
}
this.RawReq.RequestURI = ""
// 处理Header
this.setForwardHeaders(this.RawReq.Header)
this.processRequestHeaders(this.RawReq.Header)
// 调用回调
this.onRequest()
if this.writer.isFinished {
return
}
// 判断是否为Websocket请求
if isHTTPOrigin && this.RawReq.Header.Get("Upgrade") == "websocket" {
shouldRetry = this.doWebsocket(requestHost, isLastRetry)
return
}
var resp *http.Response
var respBodyIsClosed bool
var requestErr error
var requestErrCode string
if isHTTPOrigin { // 普通HTTP(S)源站
// 修复空User-Agent问题
_, existsUserAgent := this.RawReq.Header["User-Agent"]
if !existsUserAgent {
this.RawReq.Header["User-Agent"] = []string{""}
}
// 获取请求客户端
client, err := SharedHTTPClientPool.Client(this, origin, originAddr, this.reverseProxy.ProxyProtocol, this.reverseProxy.FollowRedirects)
if err != nil {
remotelogs.ErrorServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Create client failed: "+err.Error())
this.write50x(err, http.StatusBadGateway, "Failed to create origin site client", "构造源站客户端失败", true)
return
}
// 尝试自动纠正源站地址中的scheme
if this.RawReq.URL.Scheme == "http" && strings.HasSuffix(originAddr, ":443") {
this.RawReq.URL.Scheme = "https"
} else if this.RawReq.URL.Scheme == "https" && strings.HasSuffix(originAddr, ":80") {
this.RawReq.URL.Scheme = "http"
}
// request origin with Accept-Encoding: gzip, ...
var rawAcceptEncoding string
var acceptEncodingChanged bool
if this.nodeConfig != nil &&
this.nodeConfig.GlobalServerConfig != nil &&
this.nodeConfig.GlobalServerConfig.HTTPAll.RequestOriginsWithEncodings &&
this.RawReq.ProtoAtLeast(1, 1) &&
this.RawReq.Header != nil {
rawAcceptEncoding = this.RawReq.Header.Get("Accept-Encoding")
if len(rawAcceptEncoding) == 0 {
this.RawReq.Header.Set("Accept-Encoding", "gzip")
acceptEncodingChanged = true
} else if strings.Index(rawAcceptEncoding, "gzip") < 0 {
this.RawReq.Header.Set("Accept-Encoding", rawAcceptEncoding+", gzip")
acceptEncodingChanged = true
}
}
// 开始请求
resp, requestErr = client.Do(this.RawReq)
// recover Accept-Encoding
if acceptEncodingChanged {
if len(rawAcceptEncoding) > 0 {
this.RawReq.Header.Set("Accept-Encoding", rawAcceptEncoding)
} else {
this.RawReq.Header.Del("Accept-Encoding")
}
if resp != nil && resp.Header != nil && resp.Header.Get("Content-Encoding") == "gzip" {
bodyReader, gzipErr := compressions.NewGzipReader(resp.Body)
if gzipErr == nil {
resp.Body = bodyReader
}
resp.TransferEncoding = nil
resp.Header.Del("Content-Encoding")
}
}
} else if origin.OSS != nil { // OSS源站
var goNext bool
resp, goNext, requestErrCode, _, requestErr = this.doOSSOrigin(origin)
if requestErr == nil {
if resp == nil || !goNext {
return
}
}
} else {
this.writeCode(http.StatusBadGateway, "The type of origin site has not been supported", "设置的源站类型尚未支持")
return
}
if resp != nil && resp.Body != nil {
defer func() {
if !respBodyIsClosed {
if writeToClient {
_ = resp.Body.Close()
}
}
}()
}
if requestErr != nil {
// 客户端取消请求,则不提示
var httpErr *url.Error
var ok = errors.As(requestErr, &httpErr)
if !ok {
if isHTTPOrigin {
SharedOriginStateManager.Fail(origin, requestHost, this.reverseProxy, func() {
this.reverseProxy.ResetScheduling()
})
}
if len(requestErrCode) > 0 {
this.write50x(requestErr, http.StatusBadGateway, "Failed to read origin site (error code: "+requestErrCode+")", "源站读取失败(错误代号:"+requestErrCode+"", true)
} else {
this.write50x(requestErr, http.StatusBadGateway, "Failed to read origin site", "源站读取失败", true)
}
remotelogs.WarnServer("HTTP_REQUEST_REVERSE_PROXY", this.RawReq.URL.String()+": Request origin server failed: "+requestErr.Error())
} else if !errors.Is(httpErr, context.Canceled) {
if isHTTPOrigin {
SharedOriginStateManager.Fail(origin, requestHost, this.reverseProxy, func() {
this.reverseProxy.ResetScheduling()
})
}
// 是否需要重试
if (originId > 0 || (lnNodeId > 0 && hasMultipleLnNodes)) && !isLastRetry {
shouldRetry = true
this.uri = oldURI // 恢复备份
if httpErr.Err != io.EOF && !errors.Is(httpErr.Err, http.ErrBodyReadAfterClose) {
remotelogs.WarnServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Request origin server failed: "+requestErr.Error())
}
return
}
if httpErr.Timeout() {
this.write50x(requestErr, http.StatusGatewayTimeout, "Read origin site timeout", "源站读取超时", true)
} else if httpErr.Temporary() {
this.write50x(requestErr, http.StatusServiceUnavailable, "Origin site unavailable now", "源站当前不可用", true)
} else {
this.write50x(requestErr, http.StatusBadGateway, "Failed to read origin site", "源站读取失败", true)
}
if httpErr.Err != io.EOF && !errors.Is(httpErr.Err, http.ErrBodyReadAfterClose) {
remotelogs.WarnServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Request origin server failed: "+requestErr.Error())
}
} else {
// 是否为客户端方面的错误
var isClientError = false
if errors.Is(httpErr, context.Canceled) {
// 如果是服务器端主动关闭,则无需提示
if this.isConnClosed() {
this.disableLog = true
return
}
isClientError = true
this.addError(errors.New(httpErr.Op + " " + httpErr.URL + ": client closed the connection"))
this.writer.WriteHeader(499) // 仿照nginx
}
if !isClientError {
this.write50x(requestErr, http.StatusBadGateway, "Failed to read origin site", "源站读取失败", true)
}
}
return
}
if resp == nil {
this.write50x(requestErr, http.StatusBadGateway, "Failed to read origin site", "源站读取失败", true)
return
}
if !writeToClient {
resultResp = resp
return
}
// fix Content-Type
if resp.Header["Content-Type"] == nil {
resp.Header["Content-Type"] = []string{}
}
// 40x && 50x
*failStatusCode = resp.StatusCode
if ((resp.StatusCode >= 500 && resp.StatusCode < 510 && this.reverseProxy.Retry50X) ||
(resp.StatusCode >= 403 && resp.StatusCode <= 404 && this.reverseProxy.Retry40X)) &&
(originId > 0 || (lnNodeId > 0 && hasMultipleLnNodes)) &&
!isLastRetry {
if resp.Body != nil {
_ = resp.Body.Close()
}
shouldRetry = true
return
}
// 尝试从缓存中恢复
if resp.StatusCode >= 500 && // support 50X only
resp.StatusCode < 510 &&
this.cacheCanTryStale &&
this.web.Cache.Stale != nil &&
this.web.Cache.Stale.IsOn &&
(len(this.web.Cache.Stale.Status) == 0 || lists.ContainsInt(this.web.Cache.Stale.Status, resp.StatusCode)) {
var ok = this.doCacheRead(true)
if ok {
return
}
}
// 记录相关数据
this.originStatus = int32(resp.StatusCode)
// 恢复源站状态
if !origin.IsOk && isHTTPOrigin {
SharedOriginStateManager.Success(origin, func() {
this.reverseProxy.ResetScheduling()
})
}
// WAF对出站进行检查
if this.web.FirewallRef != nil && this.web.FirewallRef.IsOn {
if this.doWAFResponse(resp) {
return
}
}
// 特殊页面
if this.doPage(resp.StatusCode) {
return
}
// Page encryption (必须在页面优化之前)
if this.web.Encryption != nil && resp.Body != nil {
err := this.processPageEncryption(resp)
if err != nil {
remotelogs.Warn("HTTP_REQUEST_ENCRYPTION", "encrypt page failed: "+err.Error())
// 加密失败不影响正常响应,继续处理
}
}
// Page optimization (如果已加密,跳过优化)
if this.web.Optimization != nil && resp.Body != nil && this.cacheRef != nil /** must under cache **/ {
// 如果已加密,跳过优化
if this.web.Encryption == nil || !this.web.Encryption.IsOn || !this.web.Encryption.IsEnabled() {
err := minifiers.MinifyResponse(this.web.Optimization, this.URL(), resp)
if err != nil {
this.write50x(err, http.StatusBadGateway, "Page Optimization: fail to read content from origin", "内容优化:从源站读取内容失败", false)
return
}
}
}
// HLS
if this.web.HLS != nil &&
this.web.HLS.Encrypting != nil &&
this.web.HLS.Encrypting.IsOn &&
resp.StatusCode == http.StatusOK {
m3u8Err := this.processM3u8Response(resp)
if m3u8Err != nil {
this.write50x(m3u8Err, http.StatusBadGateway, "m3u8 encrypt: fail to read content from origin", "m3u8加密从源站读取内容失败", false)
return
}
}
// 设置Charset
// TODO 这里应该可以设置文本类型的列表
if this.web.Charset != nil && this.web.Charset.IsOn && len(this.web.Charset.Charset) > 0 {
contentTypes, ok := resp.Header["Content-Type"]
if ok && len(contentTypes) > 0 {
var contentType = contentTypes[0]
if this.web.Charset.Force {
var semiIndex = strings.Index(contentType, ";")
if semiIndex > 0 {
contentType = contentType[:semiIndex]
}
}
if _, found := textMimeMap[contentType]; found {
var newCharset = this.web.Charset.Charset
if this.web.Charset.IsUpper {
newCharset = strings.ToUpper(newCharset)
}
resp.Header["Content-Type"][0] = contentType + "; charset=" + newCharset
}
}
}
// 替换Location中的源站地址
var locationHeader = resp.Header.Get("Location")
if len(locationHeader) > 0 {
// 空Location处理
if locationHeader == emptyHTTPLocation {
resp.Header.Del("Location")
} else {
// 自动修正Location中的源站地址
locationURL, err := url.Parse(locationHeader)
if err == nil && locationURL.Host != this.ReqHost && (locationURL.Host == originAddr || strings.HasPrefix(originAddr, locationURL.Host+":")) {
locationURL.Host = this.ReqHost
var oldScheme = locationURL.Scheme
// 尝试和当前Scheme一致
if this.IsHTTP {
locationURL.Scheme = "http"
} else if this.IsHTTPS {
locationURL.Scheme = "https"
}
// 如果和当前URL一样则可能是http -> https防止无限循环
if locationURL.String() == this.URL() {
locationURL.Scheme = oldScheme
resp.Header.Set("Location", locationURL.String())
} else {
resp.Header.Set("Location", locationURL.String())
}
}
}
}
// 响应Header
this.writer.AddHeaders(resp.Header)
this.ProcessResponseHeaders(this.writer.Header(), resp.StatusCode)
// 是否需要刷新
var shouldAutoFlush = this.reverseProxy.AutoFlush || (resp.Header != nil && strings.Contains(resp.Header.Get("Content-Type"), "stream"))
// 设置当前连接为Persistence
if shouldAutoFlush && this.nodeConfig != nil && this.nodeConfig.HasConnTimeoutSettings() {
var requestConn = this.RawReq.Context().Value(HTTPConnContextKey)
if requestConn == nil {
return
}
requestClientConn, ok := requestConn.(ClientConnInterface)
if ok {
requestClientConn.SetIsPersistent(true)
}
}
// 准备
var delayHeaders = this.writer.Prepare(resp, resp.ContentLength, resp.StatusCode, true)
// 设置响应代码
if !delayHeaders {
this.writer.WriteHeader(resp.StatusCode)
}
// 是否有内容
if resp.ContentLength == 0 && len(resp.TransferEncoding) == 0 {
// 即使内容为0也需要读取一次以便于触发相关事件
var buf = bytepool.Pool4k.Get()
_, _ = io.CopyBuffer(this.writer, resp.Body, buf.Bytes)
bytepool.Pool4k.Put(buf)
_ = resp.Body.Close()
respBodyIsClosed = true
this.writer.SetOk()
return
}
// 输出到客户端
var pool = this.bytePool(resp.ContentLength)
var buf = pool.Get()
var err error
if shouldAutoFlush {
for {
n, readErr := resp.Body.Read(buf.Bytes)
if n > 0 {
_, err = this.writer.Write(buf.Bytes[:n])
this.writer.Flush()
if err != nil {
break
}
}
if readErr != nil {
err = readErr
break
}
}
} else {
if this.cacheRef != nil &&
this.cacheRef.EnableReadingOriginAsync &&
resp.ContentLength > 0 &&
resp.ContentLength < (128<<20) { // TODO configure max content-length in cache policy OR CacheRef
var requestIsCanceled = false
for {
n, readErr := resp.Body.Read(buf.Bytes)
if n > 0 && !requestIsCanceled {
_, err = this.writer.Write(buf.Bytes[:n])
if err != nil {
requestIsCanceled = true
}
}
if readErr != nil {
err = readErr
break
}
}
} else {
_, err = io.CopyBuffer(this.writer, resp.Body, buf.Bytes)
}
}
pool.Put(buf)
var closeErr = resp.Body.Close()
respBodyIsClosed = true
if closeErr != nil {
if !this.canIgnore(closeErr) {
remotelogs.WarnServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Closing error: "+closeErr.Error())
}
}
if err != nil && err != io.EOF {
if !this.canIgnore(err) {
remotelogs.WarnServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Writing error: "+err.Error())
this.addError(err)
}
}
// 是否成功结束
if (err == nil || err == io.EOF) && (closeErr == nil || closeErr == io.EOF) {
this.writer.SetOk()
}
return
}

View File

@@ -0,0 +1,43 @@
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"net/http"
)
// 调用Rewrite
func (this *HTTPRequest) doRewrite() (shouldShop bool) {
if this.rewriteRule == nil {
return
}
// 代理
if this.rewriteRule.Mode == serverconfigs.HTTPRewriteModeProxy {
// 外部URL
if this.rewriteIsExternalURL {
host := this.ReqHost
if len(this.rewriteRule.ProxyHost) > 0 {
host = this.rewriteRule.ProxyHost
}
this.doURL(this.RawReq.Method, this.rewriteReplace, host, 0, false)
return true
}
// 内部URL继续
return false
}
// 跳转
if this.rewriteRule.Mode == serverconfigs.HTTPRewriteModeRedirect {
if this.rewriteRule.RedirectStatus > 0 {
this.ProcessResponseHeaders(this.writer.Header(), this.rewriteRule.RedirectStatus)
httpRedirect(this.writer, this.RawReq, this.rewriteReplace, this.rewriteRule.RedirectStatus)
} else {
this.ProcessResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect)
httpRedirect(this.writer, this.RawReq, this.rewriteReplace, http.StatusTemporaryRedirect)
}
return true
}
return true
}

View File

@@ -0,0 +1,460 @@
package nodes
import (
"fmt"
rangeutils "github.com/TeaOSLab/EdgeNode/internal/utils/ranges"
"github.com/TeaOSLab/EdgeNode/internal/utils/zero"
"github.com/cespare/xxhash/v2"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/types"
"io"
"io/fs"
"mime"
"net/http"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
)
// 文本mime-type列表
var textMimeMap = map[string]zero.Zero{
"application/atom+xml": {},
"application/javascript": {},
"application/x-javascript": {},
"application/json": {},
"application/rss+xml": {},
"application/x-web-app-manifest+json": {},
"application/xhtml+xml": {},
"application/xml": {},
"image/svg+xml": {},
"text/css": {},
"text/plain": {},
"text/javascript": {},
"text/xml": {},
"text/html": {},
"text/xhtml": {},
"text/sgml": {},
}
// 调用本地静态资源
// 如果返回true则终止请求
func (this *HTTPRequest) doRoot() (isBreak bool) {
if this.web.Root == nil || !this.web.Root.IsOn {
return
}
if len(this.uri) == 0 {
this.write404()
return true
}
var rootDir = this.web.Root.Dir
if this.web.Root.HasVariables() {
rootDir = this.Format(rootDir)
}
if !filepath.IsAbs(rootDir) {
rootDir = Tea.Root + Tea.DS + rootDir
}
var requestPath = this.uri
var questionMarkIndex = strings.Index(this.uri, "?")
if questionMarkIndex > -1 {
requestPath = this.uri[:questionMarkIndex]
}
// except hidden files
if this.web.Root.ExceptHiddenFiles &&
(strings.Contains(requestPath, "/.") || strings.Contains(requestPath, "\\.")) {
this.write404()
return true
}
// except and only files
if !this.web.Root.MatchURL(this.URL()) {
this.write404()
return true
}
// 去掉其中的奇怪的路径
requestPath = strings.Replace(requestPath, "..\\", "", -1)
// 进行URL Decode
if this.web.Root.DecodePath {
p, err := url.QueryUnescape(requestPath)
if err == nil {
requestPath = p
} else {
if !this.canIgnore(err) {
logs.Error(err)
}
}
}
// 去掉前缀
stripPrefix := this.web.Root.StripPrefix
if len(stripPrefix) > 0 {
if stripPrefix[0] != '/' {
stripPrefix = "/" + stripPrefix
}
requestPath = strings.TrimPrefix(requestPath, stripPrefix)
if len(requestPath) == 0 || requestPath[0] != '/' {
requestPath = "/" + requestPath
}
}
var filename = strings.Replace(requestPath, "/", Tea.DS, -1)
var filePath string
if len(filename) > 0 && filename[0:1] == Tea.DS {
filePath = rootDir + filename
} else {
filePath = rootDir + Tea.DS + filename
}
this.filePath = filePath // 用来记录日志
stat, err := os.Stat(filePath)
if err != nil {
_, isPathError := err.(*fs.PathError)
if os.IsNotExist(err) || isPathError {
if this.web.Root.IsBreak {
this.write404()
return true
}
return
} else {
this.write50x(err, http.StatusInternalServerError, "Failed to stat the file", "查看文件统计信息失败", true)
if !this.canIgnore(err) {
logs.Error(err)
}
return true
}
}
if stat.IsDir() {
indexFile, indexStat := this.findIndexFile(filePath)
if len(indexFile) > 0 {
filePath += Tea.DS + indexFile
} else {
if this.web.Root.IsBreak {
this.write404()
return true
}
return
}
this.filePath = filePath
// stat again
if indexStat == nil {
stat, err = os.Stat(filePath)
if err != nil {
if os.IsNotExist(err) {
if this.web.Root.IsBreak {
this.write404()
return true
}
return
} else {
this.write50x(err, http.StatusInternalServerError, "Failed to stat the file", "查看文件统计信息失败", true)
if !this.canIgnore(err) {
logs.Error(err)
}
return true
}
}
} else {
stat = indexStat
}
}
// 响应header
var respHeader = this.writer.Header()
// mime type
var contentType = ""
if this.web.ResponseHeaderPolicy == nil || !this.web.ResponseHeaderPolicy.IsOn || !this.web.ResponseHeaderPolicy.ContainsHeader("CONTENT-TYPE") {
var ext = filepath.Ext(filePath)
if len(ext) > 0 {
mimeType := mime.TypeByExtension(ext)
if len(mimeType) > 0 {
var semicolonIndex = strings.Index(mimeType, ";")
var mimeTypeKey = mimeType
if semicolonIndex > 0 {
mimeTypeKey = mimeType[:semicolonIndex]
}
if _, found := textMimeMap[mimeTypeKey]; found {
if this.web.Charset != nil && this.web.Charset.IsOn && len(this.web.Charset.Charset) > 0 {
var charset = this.web.Charset.Charset
if this.web.Charset.IsUpper {
charset = strings.ToUpper(charset)
}
contentType = mimeTypeKey + "; charset=" + charset
respHeader.Set("Content-Type", mimeTypeKey+"; charset="+charset)
} else {
contentType = mimeType
respHeader.Set("Content-Type", mimeType)
}
} else {
contentType = mimeType
respHeader.Set("Content-Type", mimeType)
}
}
}
}
// length
var fileSize = stat.Size()
// 支持 Last-Modified
modifiedTime := stat.ModTime().Format("Mon, 02 Jan 2006 15:04:05 GMT")
if len(respHeader.Get("Last-Modified")) == 0 {
respHeader.Set("Last-Modified", modifiedTime)
}
// 支持 ETag
var eTag = "\"e" + fmt.Sprintf("%0x", xxhash.Sum64String(filename+strconv.FormatInt(stat.ModTime().UnixNano(), 10)+strconv.FormatInt(stat.Size(), 10))) + "\""
if len(respHeader.Get("ETag")) == 0 {
respHeader.Set("ETag", eTag)
}
// 调用回调
this.onRequest()
if this.writer.isFinished {
return
}
// 支持 If-None-Match
if this.requestHeader("If-None-Match") == eTag {
// 自定义Header
this.ProcessResponseHeaders(this.writer.Header(), http.StatusNotModified)
this.writer.WriteHeader(http.StatusNotModified)
return true
}
// 支持 If-Modified-Since
if this.requestHeader("If-Modified-Since") == modifiedTime {
// 自定义Header
this.ProcessResponseHeaders(this.writer.Header(), http.StatusNotModified)
this.writer.WriteHeader(http.StatusNotModified)
return true
}
// 支持Range
respHeader.Set("Accept-Ranges", "bytes")
ifRangeHeaders, ok := this.RawReq.Header["If-Range"]
var supportRange = true
if ok {
supportRange = false
for _, v := range ifRangeHeaders {
if v == eTag || v == modifiedTime {
supportRange = true
break
}
}
if !supportRange {
respHeader.Del("Accept-Ranges")
}
}
// 支持Range
var ranges = []rangeutils.Range{}
if supportRange {
var contentRange = this.RawReq.Header.Get("Range")
if len(contentRange) > 0 {
if fileSize == 0 {
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
set, ok := httpRequestParseRangeHeader(contentRange)
if !ok {
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
if len(set) > 0 {
ranges = set
for k, r := range ranges {
r2, ok := r.Convert(fileSize)
if !ok {
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
ranges[k] = r2
}
}
} else {
respHeader.Set("Content-Length", strconv.FormatInt(fileSize, 10))
}
} else {
respHeader.Set("Content-Length", strconv.FormatInt(fileSize, 10))
}
fileReader, err := os.OpenFile(filePath, os.O_RDONLY, 0444)
if err != nil {
this.write50x(err, http.StatusInternalServerError, "Failed to open the file", "试图打开文件失败", true)
return true
}
// 自定义Header
this.ProcessResponseHeaders(this.writer.Header(), http.StatusOK)
// 在Range请求中不能缓存
if len(ranges) > 0 {
this.cacheRef = nil // 不支持缓存
}
var resp = &http.Response{
ContentLength: fileSize,
Body: fileReader,
StatusCode: http.StatusOK,
}
this.writer.Prepare(resp, fileSize, http.StatusOK, true)
var pool = this.bytePool(fileSize)
var buf = pool.Get()
defer func() {
pool.Put(buf)
}()
if len(ranges) == 1 {
respHeader.Set("Content-Range", ranges[0].ComposeContentRangeHeader(types.String(fileSize)))
this.writer.WriteHeader(http.StatusPartialContent)
ok, err := httpRequestReadRange(resp.Body, buf.Bytes, ranges[0].Start(), ranges[0].End(), func(buf []byte, n int) error {
_, err := this.writer.Write(buf[:n])
return err
})
if err != nil {
if !this.canIgnore(err) {
logs.Error(err)
}
return true
}
if !ok {
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
} else if len(ranges) > 1 {
var boundary = httpRequestGenBoundary()
respHeader.Set("Content-Type", "multipart/byteranges; boundary="+boundary)
this.writer.WriteHeader(http.StatusPartialContent)
for index, r := range ranges {
if index == 0 {
_, err = this.writer.WriteString("--" + boundary + "\r\n")
} else {
_, err = this.writer.WriteString("\r\n--" + boundary + "\r\n")
}
if err != nil {
if !this.canIgnore(err) {
logs.Error(err)
}
return true
}
_, err = this.writer.WriteString("Content-Range: " + r.ComposeContentRangeHeader(types.String(fileSize)) + "\r\n")
if err != nil {
if !this.canIgnore(err) {
logs.Error(err)
}
return true
}
if len(contentType) > 0 {
_, err = this.writer.WriteString("Content-Type: " + contentType + "\r\n\r\n")
if err != nil {
if !this.canIgnore(err) {
logs.Error(err)
}
return true
}
}
ok, err := httpRequestReadRange(resp.Body, buf.Bytes, r.Start(), r.End(), func(buf []byte, n int) error {
_, err := this.writer.Write(buf[:n])
return err
})
if err != nil {
if !this.canIgnore(err) {
logs.Error(err)
}
return true
}
if !ok {
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
}
_, err = this.writer.WriteString("\r\n--" + boundary + "--\r\n")
if err != nil {
if !this.canIgnore(err) {
logs.Error(err)
}
return true
}
} else {
_, err = io.CopyBuffer(this.writer, resp.Body, buf.Bytes)
if err != nil {
if !this.canIgnore(err) {
logs.Error(err)
}
return true
}
}
// 设置成功
this.writer.SetOk()
return true
}
// 查找首页文件
func (this *HTTPRequest) findIndexFile(dir string) (filename string, stat os.FileInfo) {
if this.web.Root == nil || !this.web.Root.IsOn {
return "", nil
}
if len(this.web.Root.Indexes) == 0 {
return "", nil
}
for _, index := range this.web.Root.Indexes {
if len(index) == 0 {
continue
}
// 模糊查找
if strings.Contains(index, "*") {
indexFiles, err := filepath.Glob(dir + Tea.DS + index)
if err != nil {
if !this.canIgnore(err) {
logs.Error(err)
}
this.addError(err)
continue
}
if len(indexFiles) > 0 {
return filepath.Base(indexFiles[0]), nil
}
continue
}
// 精确查找
filePath := dir + Tea.DS + index
stat, err := os.Stat(filePath)
if err != nil || !stat.Mode().IsRegular() {
continue
}
return index, stat
}
return "", nil
}

View File

@@ -0,0 +1,115 @@
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/utils/bytepool"
"github.com/iwind/TeaGo/Tea"
"net/http"
"os"
"path"
"strings"
)
// 调用临时关闭页面
func (this *HTTPRequest) doShutdown() {
var shutdown = this.web.Shutdown
if shutdown == nil {
return
}
if len(shutdown.BodyType) == 0 || shutdown.BodyType == serverconfigs.HTTPPageBodyTypeURL {
// URL
if urlSchemeRegexp.MatchString(shutdown.URL) {
this.doURL(http.MethodGet, shutdown.URL, "", shutdown.Status, true)
return
}
// URL为空则显示文本
if len(shutdown.URL) == 0 {
// 自定义响应Headers
if shutdown.Status > 0 {
this.ProcessResponseHeaders(this.writer.Header(), shutdown.Status)
this.writer.WriteHeader(shutdown.Status)
} else {
this.ProcessResponseHeaders(this.writer.Header(), http.StatusOK)
this.writer.WriteHeader(http.StatusOK)
}
_, _ = this.writer.WriteString("The site have been shutdown.")
return
}
// 从本地文件中读取
var realpath = path.Clean(shutdown.URL)
if !strings.HasPrefix(realpath, "/pages/") && !strings.HasPrefix(realpath, "pages/") { // only files under "/pages/" can be used
var msg = "404 page not found: '" + shutdown.URL + "'"
this.writer.WriteHeader(http.StatusNotFound)
_, _ = this.writer.Write([]byte(msg))
return
}
var file = Tea.Root + Tea.DS + shutdown.URL
fp, err := os.Open(file)
if err != nil {
var msg = "404 page not found: '" + shutdown.URL + "'"
this.writer.WriteHeader(http.StatusNotFound)
_, _ = this.writer.Write([]byte(msg))
return
}
defer func() {
_ = fp.Close()
}()
// 自定义响应Headers
if shutdown.Status > 0 {
this.ProcessResponseHeaders(this.writer.Header(), shutdown.Status)
this.writer.WriteHeader(shutdown.Status)
} else {
this.ProcessResponseHeaders(this.writer.Header(), http.StatusOK)
this.writer.WriteHeader(http.StatusOK)
}
var buf = bytepool.Pool1k.Get()
_, err = utils.CopyWithFilter(this.writer, fp, buf.Bytes, func(p []byte) []byte {
return []byte(this.Format(string(p)))
})
bytepool.Pool1k.Put(buf)
if err != nil {
if !this.canIgnore(err) {
remotelogs.Warn("HTTP_REQUEST_SHUTDOWN", "write to client failed: "+err.Error())
}
} else {
this.writer.SetOk()
}
} else if shutdown.BodyType == serverconfigs.HTTPPageBodyTypeHTML {
// 自定义响应Headers
if shutdown.Status > 0 {
this.ProcessResponseHeaders(this.writer.Header(), shutdown.Status)
this.writer.WriteHeader(shutdown.Status)
} else {
this.ProcessResponseHeaders(this.writer.Header(), http.StatusOK)
this.writer.WriteHeader(http.StatusOK)
}
_, err := this.writer.WriteString(this.Format(shutdown.Body))
if err != nil {
if !this.canIgnore(err) {
remotelogs.Warn("HTTP_REQUEST_SHUTDOWN", "write to client failed: "+err.Error())
}
} else {
this.writer.SetOk()
}
} else if shutdown.BodyType == serverconfigs.HTTPPageBodyTypeRedirectURL {
var newURL = shutdown.URL
if len(newURL) == 0 {
newURL = "/"
}
if shutdown.Status > 0 && httpStatusIsRedirect(shutdown.Status) {
httpRedirect(this.writer, this.RawReq, newURL, shutdown.Status)
} else {
httpRedirect(this.writer, this.RawReq, newURL, http.StatusTemporaryRedirect)
}
this.writer.SetOk()
}
}

View File

@@ -0,0 +1,16 @@
package nodes
import (
"github.com/TeaOSLab/EdgeNode/internal/stats"
)
// 统计
func (this *HTTPRequest) doStat() {
if this.ReqServer == nil || this.web == nil || this.web.StatRef == nil {
return
}
// 内置的统计
stats.SharedHTTPRequestStatManager.AddRemoteAddr(this.ReqServer.Id, this.requestRemoteAddr(true), this.writer.SentBodyBytes(), this.isAttack)
stats.SharedHTTPRequestStatManager.AddUserAgent(this.ReqServer.Id, this.requestHeader("User-Agent"), this.remoteAddr)
}

View File

@@ -0,0 +1,22 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import "net/http"
// 执行子请求
func (this *HTTPRequest) doSubRequest(writer http.ResponseWriter, rawReq *http.Request) {
// 包装新请求对象
req := &HTTPRequest{
RawReq: rawReq,
RawWriter: writer,
ReqServer: this.ReqServer,
ReqHost: this.ReqHost,
ServerName: this.ServerName,
ServerAddr: this.ServerAddr,
IsHTTP: this.IsHTTP,
IsHTTPS: this.IsHTTPS,
}
req.isSubRequest = true
req.Do()
}

View File

@@ -0,0 +1,71 @@
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/iwind/TeaGo/assert"
"net/http"
"runtime"
"testing"
)
func TestHTTPRequest_RedirectToHTTPS(t *testing.T) {
var a = assert.NewAssertion(t)
{
rawReq, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
t.Fatal(err)
}
var req = &HTTPRequest{
RawReq: rawReq,
RawWriter: NewEmptyResponseWriter(nil),
ReqServer: &serverconfigs.ServerConfig{
IsOn: true,
Web: &serverconfigs.HTTPWebConfig{
IsOn: true,
RedirectToHttps: &serverconfigs.HTTPRedirectToHTTPSConfig{},
},
},
}
req.init()
req.Do()
a.IsBool(req.web.RedirectToHttps.IsOn == false)
}
{
rawReq, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
t.Fatal(err)
}
var req = &HTTPRequest{
RawReq: rawReq,
RawWriter: NewEmptyResponseWriter(nil),
ReqServer: &serverconfigs.ServerConfig{
IsOn: true,
Web: &serverconfigs.HTTPWebConfig{
IsOn: true,
RedirectToHttps: &serverconfigs.HTTPRedirectToHTTPSConfig{
IsOn: true,
},
},
},
}
req.init()
req.Do()
a.IsBool(req.web.RedirectToHttps.IsOn == true)
}
}
func TestHTTPRequest_Memory(t *testing.T) {
var stat1 = &runtime.MemStats{}
runtime.ReadMemStats(stat1)
var requests = []*HTTPRequest{}
for i := 0; i < 1_000_000; i++ {
requests = append(requests, &HTTPRequest{})
}
var stat2 = &runtime.MemStats{}
runtime.ReadMemStats(stat2)
t.Log((stat2.HeapInuse-stat1.HeapInuse)/1024/1024, "MB,")
t.Log(len(requests))
}

View File

@@ -0,0 +1,120 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes
import (
"encoding/json"
"net/http"
"sync"
"time"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
"github.com/TeaOSLab/EdgeNode/internal/utils/ttlcache"
)
// TokenCache Token 缓存
type TokenCache struct {
cache *ttlcache.Cache[string] // IP => Token
}
var sharedTokenCache *TokenCache
var sharedTokenCacheOnce sync.Once
// SharedTokenCache 获取共享 Token 缓存实例
func SharedTokenCache() *TokenCache {
sharedTokenCacheOnce.Do(func() {
sharedTokenCache = NewTokenCache()
})
return sharedTokenCache
}
// NewTokenCache 创建新 Token 缓存
func NewTokenCache() *TokenCache {
cache := ttlcache.NewCache[string](
ttlcache.NewMaxItemsOption(10000),
ttlcache.NewGCConfigOption().
WithBaseInterval(5*time.Minute).
WithMinInterval(2*time.Minute).
WithMaxInterval(10*time.Minute).
WithAdaptive(true),
)
return &TokenCache{
cache: cache,
}
}
// Set 设置 Token
func (this *TokenCache) Set(ip string, token string) {
expiresAt := fasttime.Now().Unix() + 300 // 5 分钟过期
this.cache.Write(ip, token, expiresAt)
}
// Get 获取 Token
func (this *TokenCache) Get(ip string) (string, bool) {
item := this.cache.Read(ip)
if item == nil {
return "", false
}
if item.ExpiresAt() < fasttime.Now().Unix() {
return "", false
}
return item.Value, true
}
// ValidateToken 验证 Token
func (this *HTTPRequest) validateEncryptionToken(token string) bool {
if this.web.Encryption == nil || !this.web.Encryption.IsOn || !this.web.Encryption.IsEnabled() {
return true // 未启用加密,直接通过
}
if len(token) == 0 {
return false
}
remoteIP := this.requestRemoteAddr(true)
cache := SharedTokenCache()
storedToken, ok := cache.Get(remoteIP)
if !ok {
return false
}
return storedToken == token
}
// handleTokenHandshake 处理 Token 握手请求
func (this *HTTPRequest) handleTokenHandshake() {
if this.RawReq.Method != http.MethodPost {
this.writeCode(http.StatusMethodNotAllowed, "Method Not Allowed", "方法不允许")
return
}
// 解析请求体
var requestBody struct {
Token string `json:"token"`
}
err := json.NewDecoder(this.RawReq.Body).Decode(&requestBody)
if err != nil {
this.writeCode(http.StatusBadRequest, "Invalid Request", "请求格式错误")
return
}
if len(requestBody.Token) == 0 {
this.writeCode(http.StatusBadRequest, "Token Required", "Token 不能为空")
return
}
// 保存 Token
remoteIP := this.requestRemoteAddr(true)
cache := SharedTokenCache()
cache.Set(remoteIP, requestBody.Token)
// 返回成功响应
this.writer.Header().Set("Content-Type", "application/json")
this.writer.WriteHeader(http.StatusOK)
_, _ = this.writer.WriteString(`{"success":true}`)
this.writer.SetOk()
}

View File

@@ -0,0 +1,47 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
)
// 流量限制
func (this *HTTPRequest) doTrafficLimit(status *serverconfigs.TrafficLimitStatus) (blocked bool) {
if status == nil {
return false
}
// 如果是网站单独设置的流量限制,则检查是否已关闭
var config = this.ReqServer.TrafficLimit
if (config == nil || !config.IsOn) && status.PlanId == 0 {
return false
}
// 如果是套餐设置的流量限制,即使套餐变更了(变更套餐或者变更套餐的限制),仍然会提示流量超限
this.tags = append(this.tags, "trafficLimit")
var statusCode = 509
this.writer.statusCode = statusCode
this.ProcessResponseHeaders(this.writer.Header(), statusCode)
this.writer.Header().Set("Content-Type", "text/html; charset=utf-8")
this.writer.WriteHeader(statusCode)
// check plan traffic limit
if (config == nil || !config.IsOn) && this.ReqServer.PlanId() > 0 && this.nodeConfig != nil {
var planConfig = this.nodeConfig.FindPlan(this.ReqServer.PlanId())
if planConfig != nil && planConfig.TrafficLimit != nil && planConfig.TrafficLimit.IsOn {
config = planConfig.TrafficLimit
}
}
if config != nil && len(config.NoticePageBody) != 0 {
_, _ = this.writer.WriteString(this.Format(config.NoticePageBody))
} else {
_, _ = this.writer.WriteString(this.Format(serverconfigs.DefaultTrafficLimitNoticePageBody))
}
return true
}

View File

@@ -0,0 +1,16 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build !plus
// +build !plus
package nodes
func (this *HTTPRequest) isUAMRequest() bool {
// stub
return false
}
// UAM
func (this *HTTPRequest) doUAM() (block bool) {
// stub
return false
}

View File

@@ -0,0 +1,210 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"github.com/TeaOSLab/EdgeNode/internal/uam"
"github.com/TeaOSLab/EdgeNode/internal/utils/agents"
"github.com/TeaOSLab/EdgeNode/internal/utils/counters"
"github.com/TeaOSLab/EdgeNode/internal/utils/ttlcache"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/types"
"net/http"
"strings"
)
var sharedUAMManager *uam.Manager
func init() {
if !teaconst.IsMain {
return
}
events.On(events.EventLoaded, func() {
if sharedUAMManager != nil {
return
}
manager, _ := uam.NewManager(sharedNodeConfig.NodeId, sharedNodeConfig.Secret)
if manager != nil {
sharedUAMManager = manager
}
})
events.On(events.EventReload, func() {
if sharedUAMManager != nil {
return
}
manager, _ := uam.NewManager(sharedNodeConfig.NodeId, sharedNodeConfig.Secret)
if manager != nil {
sharedUAMManager = manager
}
})
}
func (this *HTTPRequest) isUAMRequest() bool {
if this.web.UAM == nil || !this.web.UAM.IsOn || this.RawReq.Method != http.MethodPost {
return false
}
var cookiesString = this.RawReq.Header.Get("Cookie")
if len(cookiesString) == 0 {
return false
}
return strings.HasPrefix(cookiesString, uam.CookiePrevKey+"=") ||
strings.Contains(cookiesString, " "+uam.CookiePrevKey+"=")
}
// UAM
// TODO 需要检查是否为plus
func (this *HTTPRequest) doUAM() (block bool) {
var serverId int64
if this.ReqServer != nil {
serverId = this.ReqServer.Id
}
var uamConfig = this.web.UAM
if uamConfig == nil ||
!uamConfig.IsOn ||
!uamConfig.MatchURL(this.requestScheme()+"://"+this.ReqHost+this.Path()) ||
!uamConfig.MatchRequest(this.Format) {
return
}
var policy = this.nodeConfig.FindUAMPolicyWithClusterId(this.ReqServer.ClusterId)
if policy == nil {
policy = nodeconfigs.NewUAMPolicy()
}
if policy == nil || !policy.IsOn {
return
}
// 获取UAM管理器
var manager = sharedUAMManager
if manager == nil {
return false
}
// 忽略URL白名单
if this.RawReq.URL.Path == "/favicon.ico" || this.RawReq.URL.Path == "/favicon.png" {
return false
}
// 检查白名单
var remoteAddr = this.requestRemoteAddr(true)
// 检查UAM白名单
if uamConfig.AddToWhiteList && ttlcache.SharedInt64Cache.Read("UAM:WHITE:"+remoteAddr) != nil {
return false
}
// 检查是否为白名单直连
if !Tea.IsTesting() && this.nodeConfig.IPIsAutoAllowed(remoteAddr) {
return
}
// 是否在全局名单中
canGoNext, isInAllowedList, _ := iplibrary.AllowIP(remoteAddr, serverId)
if !canGoNext {
this.disableLog = true
this.Close()
return true
}
if isInAllowedList {
return false
}
// 如果是搜索引擎直接通过
var userAgent = this.RawReq.UserAgent()
if len(userAgent) == 0 {
// 如果User-Agent为空则直接阻止
this.writer.WriteHeader(http.StatusForbidden)
// 增加失败次数
if manager.IncreaseFails(policy, remoteAddr, serverId) {
this.isAttack = true
}
return true
}
// 不管是否开启允许搜索引擎,这里都放行,避免收录了拦截的代码
if agents.SharedManager.ContainsIP(remoteAddr) {
return false
}
if policy.AllowSearchEngines {
if searchEngineRegex.MatchString(userAgent) {
return false
}
}
// 如果是python之类的直接拦截
if policy.DenySpiders && spiderRegexp.MatchString(userAgent) {
this.writer.WriteHeader(http.StatusForbidden)
// 增加失败次数
if manager.IncreaseFails(policy, remoteAddr, serverId) {
this.isAttack = true
}
return true
}
// 检查预生成Key
var step = this.Header().Get("X-GE-UA-Step")
if step == uam.StepPrev {
if this.Method() != http.MethodPost {
this.writer.WriteHeader(http.StatusForbidden)
return true
}
if manager.CheckPrevKey(policy, uamConfig, this.RawReq, remoteAddr, this.writer) {
_, _ = this.writer.Write([]byte(`{"ok": true}`))
} else {
_, _ = this.writer.Write([]byte(`{"ok": false}`))
}
// 增加失败次数
if manager.IncreaseFails(policy, remoteAddr, serverId) {
this.isAttack = true
}
return true
}
// 检查是否启用QPS
if uamConfig.MinQPSPerIP > 0 && len(step) == 0 {
var avgQPS = counters.SharedCounter.IncreaseKey("UAM:"+remoteAddr, 60) / 60
if avgQPS <= 0 {
avgQPS = 1
}
if avgQPS < types.Uint32(uamConfig.MinQPSPerIP) {
return false
}
}
// 检查Cookie中的Key
isVerified, isAttack, _ := manager.CheckKey(policy, this.RawReq, this.writer, remoteAddr, serverId, uamConfig.KeyLife)
if isVerified {
return false
}
this.isAttack = isAttack
// 检查是否已生成有效的prev key如果已经生成则表示当前请求是附带请求比如favicon.ico不再重新生成新的
// TODO 考虑这里的必要性
if manager.ExistsActivePreKey(this.RawReq) {
return true
}
// 显示加载页面
err := manager.LoadPage(policy, this.RawReq, this.Format, remoteAddr, this.writer)
if err != nil {
return false
}
return true
}

View File

@@ -0,0 +1,88 @@
package nodes
import (
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/iwind/TeaGo/logs"
"io"
"net/http"
"time"
)
// 请求某个URL
func (this *HTTPRequest) doURL(method string, url string, host string, statusCode int, supportVariables bool) {
req, err := http.NewRequest(method, url, this.RawReq.Body)
if err != nil {
logs.Error(err)
return
}
// 修改Host
if len(host) > 0 {
req.Host = this.Format(host)
}
// 添加当前Header
req.Header = this.RawReq.Header
// 代理头部
this.setForwardHeaders(req.Header)
// 自定义请求Header
this.processRequestHeaders(req.Header)
var client = utils.SharedHttpClient(60 * time.Second)
resp, err := client.Do(req)
if err != nil {
remotelogs.Error("HTTP_REQUEST_URL", req.URL.String()+": "+err.Error())
this.write50x(err, http.StatusInternalServerError, "Failed to read url", "读取URL失败", false)
return
}
defer func() {
_ = resp.Body.Close()
}()
// Header
if statusCode <= 0 {
this.ProcessResponseHeaders(this.writer.Header(), resp.StatusCode)
} else {
this.ProcessResponseHeaders(this.writer.Header(), statusCode)
}
if supportVariables {
resp.Header.Del("Content-Length")
}
this.writer.AddHeaders(resp.Header)
if statusCode <= 0 {
this.writer.Prepare(resp, resp.ContentLength, resp.StatusCode, true)
} else {
this.writer.Prepare(resp, resp.ContentLength, statusCode, true)
}
// 设置响应代码
if statusCode <= 0 {
this.writer.WriteHeader(resp.StatusCode)
} else {
this.writer.WriteHeader(statusCode)
}
// 输出内容
var pool = this.bytePool(resp.ContentLength)
var buf = pool.Get()
if supportVariables {
_, err = utils.CopyWithFilter(this.writer, resp.Body, buf.Bytes, func(p []byte) []byte {
return []byte(this.Format(string(p)))
})
} else {
_, err = io.CopyBuffer(this.writer, resp.Body, buf.Bytes)
}
pool.Put(buf)
if err != nil {
if !this.canIgnore(err) {
remotelogs.Warn("HTTP_REQUEST_URL", "write to client failed: "+err.Error())
}
} else {
this.writer.SetOk()
}
}

View File

@@ -0,0 +1,24 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes
import (
"net/http"
)
func (this *HTTPRequest) doCheckUserAgent() (shouldStop bool) {
if this.web.UserAgent == nil || !this.web.UserAgent.IsOn {
return
}
const cacheSeconds = "3600" // 时间不能过长,防止修改设置后长期无法生效
if this.web.UserAgent.MatchURL(this.URL()) && !this.web.UserAgent.AllowRequest(this.RawReq) {
this.tags = append(this.tags, "userAgentCheck")
this.writer.Header().Set("Cache-Control", "max-age="+cacheSeconds)
this.writeCode(http.StatusForbidden, "The User-Agent has been blocked.", "当前访问已被UA名单拦截。")
return true
}
return
}

View File

@@ -0,0 +1,241 @@
package nodes
import (
"crypto/rand"
"fmt"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
"github.com/TeaOSLab/EdgeNode/internal/utils/ranges"
"github.com/iwind/TeaGo/types"
"io"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"sync/atomic"
)
// 搜索引擎和爬虫正则
var searchEngineRegex = regexp.MustCompile(`(?i)(60spider|adldxbot|adsbot-google|applebot|admantx|alexa|baidu|bingbot|bingpreview|facebookexternalhit|googlebot|proximic|slurp|sogou|twitterbot|yandex)`)
var spiderRegexp = regexp.MustCompile(`(?i)(python|pycurl|http-client|httpclient|apachebench|nethttp|http_request|java|perl|ruby|scrapy|php|rust)`)
// 内容范围正则,其中的每个括号里的内容都在被引用,不能轻易修改
var contentRangeRegexp = regexp.MustCompile(`^bytes (\d+)-(\d+)/(\d+|\*)`)
// URL协议前缀
var urlSchemeRegexp = regexp.MustCompile("^(?i)(http|https|ftp)://")
// 分解Range
func httpRequestParseRangeHeader(rangeValue string) (result []rangeutils.Range, ok bool) {
// 参考RFChttps://tools.ietf.org/html/rfc7233
index := strings.Index(rangeValue, "=")
if index == -1 {
return
}
unit := rangeValue[:index]
if unit != "bytes" {
return
}
var rangeSetString = rangeValue[index+1:]
if len(rangeSetString) == 0 {
ok = true
return
}
var pieces = strings.Split(rangeSetString, ", ")
for _, piece := range pieces {
index = strings.Index(piece, "-")
if index == -1 {
return
}
first := piece[:index]
firstInt := int64(-1)
var err error
last := piece[index+1:]
var lastInt = int64(-1)
if len(first) > 0 {
firstInt, err = strconv.ParseInt(first, 10, 64)
if err != nil {
return
}
if len(last) > 0 {
lastInt, err = strconv.ParseInt(last, 10, 64)
if err != nil {
return
}
if lastInt < firstInt {
return
}
}
} else {
if len(last) == 0 {
return
}
lastInt, err = strconv.ParseInt(last, 10, 64)
if err != nil {
return
}
lastInt = -lastInt
}
result = append(result, [2]int64{firstInt, lastInt})
}
ok = true
return
}
// 读取内容Range
func httpRequestReadRange(reader io.Reader, buf []byte, start int64, end int64, callback func(buf []byte, n int) error) (ok bool, err error) {
if start < 0 || end < 0 {
return
}
seeker, ok := reader.(io.Seeker)
if !ok {
return
}
_, err = seeker.Seek(start, io.SeekStart)
if err != nil {
return false, nil
}
offset := start
for {
n, err := reader.Read(buf)
if n > 0 {
offset += int64(n)
if end < offset {
err = callback(buf, n-int(offset-end-1))
if err != nil {
return false, err
}
return true, nil
} else {
err = callback(buf, n)
if err != nil {
return false, err
}
}
}
if err != nil {
if err == io.EOF {
return true, nil
}
return false, err
}
}
}
// 分解Content-Range
func httpRequestParseContentRangeHeader(contentRange string) (start int64, total int64) {
var matches = contentRangeRegexp.FindStringSubmatch(contentRange)
if len(matches) < 4 {
return -1, -1
}
start = types.Int64(matches[1])
var sizeString = matches[3]
if sizeString != "*" {
total = types.Int64(sizeString)
}
return
}
// 生成boundary
// 仿照Golang自带的函数multipart包
func httpRequestGenBoundary() string {
var buf [8]byte
_, err := io.ReadFull(rand.Reader, buf[:])
if err != nil {
panic(err)
}
return fmt.Sprintf("%x", buf[:])
}
// 从content-type中读取boundary
func httpRequestParseBoundary(contentType string) string {
var delim = "boundary="
var boundaryIndex = strings.Index(contentType, delim)
if boundaryIndex < 0 {
return ""
}
var boundary = contentType[boundaryIndex+len(delim):]
semicolonIndex := strings.Index(boundary, ";")
if semicolonIndex >= 0 {
return boundary[:semicolonIndex]
}
return boundary
}
// 判断状态是否为跳转
func httpStatusIsRedirect(statusCode int) bool {
return statusCode == http.StatusPermanentRedirect ||
statusCode == http.StatusTemporaryRedirect ||
statusCode == http.StatusMovedPermanently ||
statusCode == http.StatusSeeOther ||
statusCode == http.StatusFound
}
// 生成请求ID
var httpRequestTimestamp int64
var httpRequestId int32 = 1_000_000
func httpRequestNextId() string {
unixTime, unixTimeString := fasttime.Now().UnixMilliString()
if unixTime > httpRequestTimestamp {
atomic.StoreInt32(&httpRequestId, 1_000_000)
httpRequestTimestamp = unixTime
}
// timestamp + nodeId + requestId
return unixTimeString + teaconst.NodeIdString + strconv.Itoa(int(atomic.AddInt32(&httpRequestId, 1)))
}
// 检查是否可以接受某个编码
func httpAcceptEncoding(acceptEncodings string, encoding string) bool {
if len(acceptEncodings) == 0 {
return false
}
var pieces = strings.Split(acceptEncodings, ",")
for _, piece := range pieces {
var qualityIndex = strings.Index(piece, ";")
if qualityIndex >= 0 {
piece = piece[:qualityIndex]
}
if strings.TrimSpace(piece) == encoding {
return true
}
}
return false
}
// 跳转到某个URL
func httpRedirect(writer http.ResponseWriter, req *http.Request, url string, code int) {
if len(writer.Header().Get("Content-Type")) == 0 {
// 设置Content-Type是为了让页面不输出链接
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
}
http.Redirect(writer, req, url, code)
}
// 分析URL中的Host部分
func httpParseHost(urlString string) (host string, err error) {
if !urlSchemeRegexp.MatchString(urlString) {
urlString = "https://" + urlString
}
u, err := url.Parse(urlString)
if err != nil && u != nil {
return "", err
}
return u.Host, nil
}

View File

@@ -0,0 +1,173 @@
package nodes
import (
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/utils/zero"
"github.com/iwind/TeaGo/assert"
"runtime"
"sync"
"testing"
"time"
)
func TestHTTPRequest_httpRequestGenBoundary(t *testing.T) {
for i := 0; i < 10; i++ {
var boundary = httpRequestGenBoundary()
t.Log(boundary, "[", len(boundary), "bytes", "]")
}
}
func TestHTTPRequest_httpRequestParseBoundary(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsTrue(httpRequestParseBoundary("multipart/byteranges") == "")
a.IsTrue(httpRequestParseBoundary("multipart/byteranges; boundary=123") == "123")
a.IsTrue(httpRequestParseBoundary("multipart/byteranges; boundary=123; 456") == "123")
}
func TestHTTPRequest_httpRequestParseRangeHeader(t *testing.T) {
var a = assert.NewAssertion(t)
{
_, ok := httpRequestParseRangeHeader("")
a.IsFalse(ok)
}
{
_, ok := httpRequestParseRangeHeader("byte=")
a.IsFalse(ok)
}
{
_, ok := httpRequestParseRangeHeader("byte=")
a.IsFalse(ok)
}
{
set, ok := httpRequestParseRangeHeader("bytes=")
a.IsTrue(ok)
a.IsTrue(len(set) == 0)
}
{
_, ok := httpRequestParseRangeHeader("bytes=60-50")
a.IsFalse(ok)
}
{
set, ok := httpRequestParseRangeHeader("bytes=0-50")
a.IsTrue(ok)
a.IsTrue(len(set) > 0)
t.Log(set)
}
{
set, ok := httpRequestParseRangeHeader("bytes=0-")
a.IsTrue(ok)
a.IsTrue(len(set) > 0)
if len(set) > 0 {
a.IsTrue(set[0][0] == 0)
}
t.Log(set)
}
{
set, ok := httpRequestParseRangeHeader("bytes=-50")
a.IsTrue(ok)
a.IsTrue(len(set) > 0)
t.Log(set)
}
{
set, ok := httpRequestParseRangeHeader("bytes=0-50, 60-100")
a.IsTrue(ok)
a.IsTrue(len(set) > 0)
t.Log(set)
}
}
func TestHTTPRequest_httpRequestParseContentRangeHeader(t *testing.T) {
{
var c1 = "bytes 0-100/*"
t.Log(httpRequestParseContentRangeHeader(c1))
}
{
var c1 = "bytes 30-100/*"
t.Log(httpRequestParseContentRangeHeader(c1))
}
{
var c1 = "bytes1 0-100/*"
t.Log(httpRequestParseContentRangeHeader(c1))
}
}
func BenchmarkHTTPRequest_httpRequestParseContentRangeHeader(b *testing.B) {
for i := 0; i < b.N; i++ {
var c1 = "bytes 0-100/*"
httpRequestParseContentRangeHeader(c1)
}
}
func TestHTTPRequest_httpRequestNextId(t *testing.T) {
teaconst.NodeId = 123
teaconst.NodeIdString = "123"
t.Log(httpRequestNextId())
t.Log(httpRequestNextId())
t.Log(httpRequestNextId())
time.Sleep(1 * time.Second)
t.Log(httpRequestNextId())
t.Log(httpRequestNextId())
time.Sleep(1 * time.Second)
t.Log(httpRequestNextId())
}
func TestHTTPRequest_httpRequestNextId_Concurrent(t *testing.T) {
var m = map[string]zero.Zero{}
var locker = sync.Mutex{}
var count = 4000
var wg = &sync.WaitGroup{}
wg.Add(count)
var countDuplicated = 0
for i := 0; i < count; i++ {
go func() {
defer wg.Done()
var requestId = httpRequestNextId()
locker.Lock()
_, ok := m[requestId]
if ok {
t.Log("duplicated:", requestId)
countDuplicated++
}
m[requestId] = zero.New()
locker.Unlock()
}()
}
wg.Wait()
t.Log("ok", countDuplicated, "duplicated")
var a = assert.NewAssertion(t)
a.IsTrue(countDuplicated == 0)
}
func TestHTTPParseURL(t *testing.T) {
for _, s := range []string{
"",
"null",
"example.com",
"https://example.com",
"https://example.com/hello",
} {
host, err := httpParseHost(s)
if err == nil {
t.Log(s, "=>", host)
} else {
t.Log(s, "=>")
}
}
}
func BenchmarkHTTPRequest_httpRequestNextId(b *testing.B) {
runtime.GOMAXPROCS(1)
teaconst.NodeIdString = "123"
for i := 0; i < b.N; i++ {
_ = httpRequestNextId()
}
}

View File

@@ -0,0 +1,553 @@
package nodes
import (
"bytes"
"io"
"net/http"
"time"
iplib "github.com/TeaOSLab/EdgeCommon/pkg/iplibrary"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/stats"
"github.com/TeaOSLab/EdgeNode/internal/waf"
wafutils "github.com/TeaOSLab/EdgeNode/internal/waf/utils"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/types"
)
// 调用WAF
func (this *HTTPRequest) doWAFRequest() (blocked bool) {
if this.web.FirewallRef == nil || !this.web.FirewallRef.IsOn {
return
}
var remoteAddr = this.requestRemoteAddr(true)
// 检查是否为白名单直连
if !Tea.IsTesting() && this.nodeConfig.IPIsAutoAllowed(remoteAddr) {
return
}
// 当前连接是否已关闭
if this.isConnClosed() {
this.disableLog = true
return true
}
// 是否在全局名单中
canGoNext, isInAllowedList, _ := iplibrary.AllowIP(remoteAddr, this.ReqServer.Id)
if !canGoNext {
this.disableLog = true
this.Close()
return true
}
if isInAllowedList {
return false
}
// 检查是否在临时黑名单中
if waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeServer, this.ReqServer.Id, remoteAddr) || waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, remoteAddr) {
this.disableLog = true
this.Close()
return true
}
var forceLog = false
var forceLogRequestBody = false
var forceLogRegionDenying = false
if this.ReqServer.HTTPFirewallPolicy != nil &&
this.ReqServer.HTTPFirewallPolicy.IsOn &&
this.ReqServer.HTTPFirewallPolicy.Log != nil &&
this.ReqServer.HTTPFirewallPolicy.Log.IsOn {
forceLog = true
forceLogRequestBody = this.ReqServer.HTTPFirewallPolicy.Log.RequestBody
forceLogRegionDenying = this.ReqServer.HTTPFirewallPolicy.Log.RegionDenying
}
// 检查IP名单
{
// 当前服务的独立设置
if this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn {
blockedRequest, breakChecking := this.checkWAFRemoteAddr(this.web.FirewallPolicy)
if blockedRequest {
return true
}
if breakChecking {
return false
}
}
// 公用的防火墙设置
if this.ReqServer.HTTPFirewallPolicy != nil && this.ReqServer.HTTPFirewallPolicy.IsOn {
blockedRequest, breakChecking := this.checkWAFRemoteAddr(this.ReqServer.HTTPFirewallPolicy)
if blockedRequest {
return true
}
if breakChecking {
return false
}
}
}
// 检查WAF规则
{
// 当前服务的独立设置
if this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn {
blockedRequest, breakChecking := this.checkWAFRequest(this.web.FirewallPolicy, forceLog, forceLogRequestBody, forceLogRegionDenying, false)
if blockedRequest {
return true
}
if breakChecking {
return false
}
}
// 公用的防火墙设置
if this.ReqServer.HTTPFirewallPolicy != nil && this.ReqServer.HTTPFirewallPolicy.IsOn {
blockedRequest, breakChecking := this.checkWAFRequest(this.ReqServer.HTTPFirewallPolicy, forceLog, forceLogRequestBody, forceLogRegionDenying, this.web.FirewallRef.IgnoreGlobalRules)
if blockedRequest {
return true
}
if breakChecking {
return false
}
}
}
return
}
// check client remote address
func (this *HTTPRequest) checkWAFRemoteAddr(firewallPolicy *firewallconfigs.HTTPFirewallPolicy) (blocked bool, breakChecking bool) {
if firewallPolicy == nil {
return
}
var isDefendMode = firewallPolicy.Mode == firewallconfigs.FirewallModeDefend
// 检查IP白名单
var remoteAddrs []string
if len(this.remoteAddr) > 0 {
remoteAddrs = []string{this.remoteAddr}
} else {
remoteAddrs = this.requestRemoteAddrs()
}
var inbound = firewallPolicy.Inbound
if inbound == nil {
return
}
for _, ref := range inbound.AllAllowListRefs() {
if ref.IsOn && ref.ListId > 0 {
var list = iplibrary.SharedIPListManager.FindList(ref.ListId)
if list != nil {
_, found := list.ContainsIPStrings(remoteAddrs)
if found {
breakChecking = true
return
}
}
}
}
// 检查IP黑名单
if isDefendMode {
for _, ref := range inbound.AllDenyListRefs() {
if ref.IsOn && ref.ListId > 0 {
var list = iplibrary.SharedIPListManager.FindList(ref.ListId)
if list != nil {
item, found := list.ContainsIPStrings(remoteAddrs)
if found {
// 触发事件
if item != nil && len(item.EventLevel) > 0 {
actions := iplibrary.SharedActionManager.FindEventActions(item.EventLevel)
for _, action := range actions {
goNext, err := action.DoHTTP(this.RawReq, this.RawWriter)
if err != nil {
remotelogs.Error("HTTP_REQUEST_WAF", "do action '"+err.Error()+"' failed: "+err.Error())
return true, false
}
if !goNext {
this.disableLog = true
return true, false
}
}
}
// TODO 考虑是否需要记录日志信息吗,可能数据量非常庞大,所以暂时不记录
this.writer.WriteHeader(http.StatusForbidden)
this.writer.Close()
// 停止日志
this.disableLog = true
return true, false
}
}
}
}
}
return
}
// check waf inbound rules
func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFirewallPolicy, forceLog bool, logRequestBody bool, logDenying bool, ignoreRules bool) (blocked bool, breakChecking bool) {
// 检查配置是否为空
if firewallPolicy == nil || !firewallPolicy.IsOn || firewallPolicy.Inbound == nil || !firewallPolicy.Inbound.IsOn || firewallPolicy.Mode == firewallconfigs.FirewallModeBypass {
return
}
var isDefendMode = firewallPolicy.Mode == firewallconfigs.FirewallModeDefend
// 检查IP白名单
var remoteAddrs []string
if len(this.remoteAddr) > 0 {
remoteAddrs = []string{this.remoteAddr}
} else {
remoteAddrs = this.requestRemoteAddrs()
}
var inbound = firewallPolicy.Inbound
if inbound == nil {
return
}
// 检查地区封禁
if firewallPolicy.Inbound.Region != nil && firewallPolicy.Inbound.Region.IsOn {
var regionConfig = firewallPolicy.Inbound.Region
if regionConfig.IsNotEmpty() {
for _, remoteAddr := range remoteAddrs {
var result = iplib.LookupIP(remoteAddr)
if result != nil && result.IsOk() {
var currentURL = this.URL()
if regionConfig.MatchCountryURL(currentURL) {
// 检查国家/地区级别封禁
if !regionConfig.IsAllowedCountry(result.CountryId(), result.ProvinceId()) && (!regionConfig.AllowSearchEngine || wafutils.CheckSearchEngine(remoteAddr)) {
this.firewallPolicyId = firewallPolicy.Id
if isDefendMode {
var promptHTML string
if len(regionConfig.CountryHTML) > 0 {
promptHTML = regionConfig.CountryHTML
} else if this.ReqServer != nil && this.ReqServer.HTTPFirewallPolicy != nil && len(this.ReqServer.HTTPFirewallPolicy.DenyCountryHTML) > 0 {
promptHTML = this.ReqServer.HTTPFirewallPolicy.DenyCountryHTML
}
if len(promptHTML) > 0 {
var formattedHTML = this.Format(promptHTML)
this.writer.Header().Set("Content-Type", "text/html; charset=utf-8")
this.writer.Header().Set("Content-Length", types.String(len(formattedHTML)))
this.writer.WriteHeader(http.StatusForbidden)
_, _ = this.writer.Write([]byte(formattedHTML))
} else {
this.writeCode(http.StatusForbidden, "The region has been denied.", "当前区域禁止访问")
}
// 延时返回,避免攻击
time.Sleep(1 * time.Second)
}
// 停止日志
if !logDenying {
this.disableLog = true
} else {
this.tags = append(this.tags, "denyCountry")
}
if isDefendMode {
return true, false
}
}
}
if regionConfig.MatchProvinceURL(currentURL) {
// 检查省份封禁
if !regionConfig.IsAllowedProvince(result.CountryId(), result.ProvinceId()) {
this.firewallPolicyId = firewallPolicy.Id
if isDefendMode {
var promptHTML string
if len(regionConfig.ProvinceHTML) > 0 {
promptHTML = regionConfig.ProvinceHTML
} else if this.ReqServer != nil && this.ReqServer.HTTPFirewallPolicy != nil && len(this.ReqServer.HTTPFirewallPolicy.DenyProvinceHTML) > 0 {
promptHTML = this.ReqServer.HTTPFirewallPolicy.DenyProvinceHTML
}
if len(promptHTML) > 0 {
var formattedHTML = this.Format(promptHTML)
this.writer.Header().Set("Content-Type", "text/html; charset=utf-8")
this.writer.Header().Set("Content-Length", types.String(len(formattedHTML)))
this.writer.WriteHeader(http.StatusForbidden)
_, _ = this.writer.Write([]byte(formattedHTML))
} else {
this.writeCode(http.StatusForbidden, "The region has been denied.", "当前区域禁止访问")
}
// 延时返回,避免攻击
time.Sleep(1 * time.Second)
}
// 停止日志
if !logDenying {
this.disableLog = true
} else {
this.tags = append(this.tags, "denyProvince")
}
if isDefendMode {
return true, false
}
}
}
}
}
}
}
// 是否执行规则
if ignoreRules {
return
}
// 规则测试
var w = waf.SharedWAFManager.FindWAF(firewallPolicy.Id)
if w == nil {
return
}
result, err := w.MatchRequest(this, this.writer, this.web.FirewallRef.DefaultCaptchaType)
if err != nil {
if !this.canIgnore(err) {
remotelogs.Warn("HTTP_REQUEST_WAF", this.rawURI+": "+err.Error())
}
return
}
if result.IsAllowed && (len(result.AllowScope) == 0 || result.AllowScope == waf.AllowScopeGlobal) {
breakChecking = true
}
if forceLog && logRequestBody && result.HasRequestBody && result.Set != nil && result.Set.HasAttackActions() {
this.wafHasRequestBody = true
}
if result.Set != nil {
if forceLog {
this.forceLog = true
}
if result.Set.HasSpecialActions() {
this.firewallPolicyId = firewallPolicy.Id
this.firewallRuleGroupId = types.Int64(result.Group.Id)
this.firewallRuleSetId = types.Int64(result.Set.Id)
if result.Set.HasAttackActions() {
this.isAttack = true
}
// 添加统计
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.ReqServer.Id, this.firewallRuleGroupId, result.Set.Actions)
}
this.firewallActions = append(result.Set.ActionCodes(), firewallPolicy.Mode)
}
return !result.GoNext, breakChecking
}
// call response waf
func (this *HTTPRequest) doWAFResponse(resp *http.Response) (blocked bool) {
if this.web.FirewallRef == nil || !this.web.FirewallRef.IsOn {
return
}
// 当前服务的独立设置
var forceLog = false
var forceLogRequestBody = false
if this.ReqServer.HTTPFirewallPolicy != nil && this.ReqServer.HTTPFirewallPolicy.IsOn && this.ReqServer.HTTPFirewallPolicy.Log != nil && this.ReqServer.HTTPFirewallPolicy.Log.IsOn {
forceLog = true
forceLogRequestBody = this.ReqServer.HTTPFirewallPolicy.Log.RequestBody
}
if this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn {
blockedRequest, breakChecking := this.checkWAFResponse(this.web.FirewallPolicy, resp, forceLog, forceLogRequestBody, false)
if blockedRequest {
return true
}
if breakChecking {
return
}
}
// 公用的防火墙设置
if this.ReqServer.HTTPFirewallPolicy != nil && this.ReqServer.HTTPFirewallPolicy.IsOn {
blockedRequest, _ := this.checkWAFResponse(this.ReqServer.HTTPFirewallPolicy, resp, forceLog, forceLogRequestBody, this.web.FirewallRef.IgnoreGlobalRules)
if blockedRequest {
return true
}
}
return
}
// check waf outbound rules
func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFirewallPolicy, resp *http.Response, forceLog bool, logRequestBody bool, ignoreRules bool) (blocked bool, breakChecking bool) {
if firewallPolicy == nil || !firewallPolicy.IsOn || !firewallPolicy.Outbound.IsOn || firewallPolicy.Mode == firewallconfigs.FirewallModeBypass {
return
}
// 是否执行规则
if ignoreRules {
return
}
var w = waf.SharedWAFManager.FindWAF(firewallPolicy.Id)
if w == nil {
return
}
result, err := w.MatchResponse(this, resp, this.writer)
if err != nil {
if !this.canIgnore(err) {
remotelogs.Warn("HTTP_REQUEST_WAF", this.rawURI+": "+err.Error())
}
return
}
if result.IsAllowed && (len(result.AllowScope) == 0 || result.AllowScope == waf.AllowScopeGlobal) {
breakChecking = true
}
if forceLog && logRequestBody && result.HasRequestBody && result.Set != nil && result.Set.HasAttackActions() {
this.wafHasRequestBody = true
}
if result.Set != nil {
if forceLog {
this.forceLog = true
}
if result.Set.HasSpecialActions() {
this.firewallPolicyId = firewallPolicy.Id
this.firewallRuleGroupId = types.Int64(result.Group.Id)
this.firewallRuleSetId = types.Int64(result.Set.Id)
if result.Set.HasAttackActions() {
this.isAttack = true
}
// 添加统计
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.ReqServer.Id, this.firewallRuleGroupId, result.Set.Actions)
}
this.firewallActions = append(result.Set.ActionCodes(), firewallPolicy.Mode)
}
return !result.GoNext, breakChecking
}
// WAFRaw 原始请求
func (this *HTTPRequest) WAFRaw() *http.Request {
return this.RawReq
}
// WAFRemoteIP 客户端IP
func (this *HTTPRequest) WAFRemoteIP() string {
return this.requestRemoteAddr(true)
}
// WAFGetCacheBody 获取缓存中的Body
func (this *HTTPRequest) WAFGetCacheBody() []byte {
return this.requestBodyData
}
// WAFSetCacheBody 设置Body
func (this *HTTPRequest) WAFSetCacheBody(body []byte) {
this.requestBodyData = body
}
// WAFReadBody 读取Body
func (this *HTTPRequest) WAFReadBody(max int64) (data []byte, err error) {
if this.RawReq.ContentLength > 0 {
data, err = io.ReadAll(io.LimitReader(this.RawReq.Body, max))
}
return
}
// WAFRestoreBody 恢复Body
func (this *HTTPRequest) WAFRestoreBody(data []byte) {
if len(data) > 0 {
this.RawReq.Body = io.NopCloser(io.MultiReader(bytes.NewBuffer(data), this.RawReq.Body))
}
}
// WAFServerId 服务ID
func (this *HTTPRequest) WAFServerId() int64 {
return this.ReqServer.Id
}
// WAFClose 关闭连接
func (this *HTTPRequest) WAFClose() {
this.Close()
// 这里不要强关IP所有连接避免因为单个服务而影响所有
}
func (this *HTTPRequest) WAFOnAction(action interface{}) (goNext bool) {
if action == nil {
return true
}
instance, ok := action.(waf.ActionInterface)
if !ok {
return true
}
switch instance.Code() {
case waf.ActionTag:
this.tags = append(this.tags, action.(*waf.TagAction).Tags...)
}
return true
}
func (this *HTTPRequest) WAFFingerprint() []byte {
// 目前只有HTTPS请求才有指纹
if !this.IsHTTPS {
return nil
}
var requestConn = this.RawReq.Context().Value(HTTPConnContextKey)
if requestConn == nil {
return nil
}
clientConn, ok := requestConn.(ClientConnInterface)
if ok {
return clientConn.Fingerprint()
}
return nil
}
func (this *HTTPRequest) WAFMaxRequestSize() int64 {
var maxRequestSize = firewallconfigs.DefaultMaxRequestBodySize
if this.ReqServer.HTTPFirewallPolicy != nil && this.ReqServer.HTTPFirewallPolicy.MaxRequestBodySize > 0 {
maxRequestSize = this.ReqServer.HTTPFirewallPolicy.MaxRequestBodySize
}
return maxRequestSize
}
// DisableAccessLog 在当前请求中不使用访问日志
func (this *HTTPRequest) DisableAccessLog() {
this.disableLog = true
}
// DisableStat 停用统计
func (this *HTTPRequest) DisableStat() {
if this.web != nil {
this.web.StatRef = nil
}
this.disableMetrics = true
}

View File

@@ -0,0 +1,205 @@
package nodes
import (
"bufio"
"bytes"
"errors"
"github.com/TeaOSLab/EdgeNode/internal/utils/bytepool"
"io"
"net/http"
"net/url"
)
// WebsocketResponseReader Websocket响应Reader
type WebsocketResponseReader struct {
rawReader io.Reader
buf []byte
}
func NewWebsocketResponseReader(rawReader io.Reader) *WebsocketResponseReader {
return &WebsocketResponseReader{
rawReader: rawReader,
}
}
func (this *WebsocketResponseReader) Read(p []byte) (n int, err error) {
n, err = this.rawReader.Read(p)
if n > 0 {
if len(this.buf) == 0 {
this.buf = make([]byte, n)
copy(this.buf, p[:n])
} else {
this.buf = append(this.buf, p[:n]...)
}
}
return
}
// 处理Websocket请求
func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shouldRetry bool) {
// 设置不缓存
this.web.Cache = nil
if this.web.WebsocketRef == nil || !this.web.WebsocketRef.IsOn || this.web.Websocket == nil || !this.web.Websocket.IsOn {
this.writer.WriteHeader(http.StatusForbidden)
this.addError(errors.New("websocket have not been enabled yet"))
return
}
// TODO 实现handshakeTimeout
// 校验来源
var requestOrigin = this.RawReq.Header.Get("Origin")
if len(requestOrigin) > 0 {
u, err := url.Parse(requestOrigin)
if err == nil {
if !this.web.Websocket.MatchOrigin(u.Host) {
this.writer.WriteHeader(http.StatusForbidden)
this.addError(errors.New("websocket origin '" + requestOrigin + "' not been allowed"))
return
}
}
}
// 标记
this.isWebsocketResponse = true
// 设置指定的来源域
if !this.web.Websocket.RequestSameOrigin && len(this.web.Websocket.RequestOrigin) > 0 {
var newRequestOrigin = this.web.Websocket.RequestOrigin
if this.web.Websocket.RequestOriginHasVariables() {
newRequestOrigin = this.Format(newRequestOrigin)
}
this.RawReq.Header.Set("Origin", newRequestOrigin)
}
// 获取当前连接
var requestConn = this.RawReq.Context().Value(HTTPConnContextKey)
if requestConn == nil {
return
}
// 连接源站
originConn, _, err := OriginConnect(this.origin, this.requestServerPort(), this.RawReq.RemoteAddr, requestHost)
if err != nil {
if isLastRetry {
this.write50x(err, http.StatusBadGateway, "Failed to connect origin site", "源站连接失败", false)
}
// 增加失败次数
SharedOriginStateManager.Fail(this.origin, requestHost, this.reverseProxy, func() {
this.reverseProxy.ResetScheduling()
})
shouldRetry = true
return
}
if !this.origin.IsOk {
SharedOriginStateManager.Success(this.origin, func() {
this.reverseProxy.ResetScheduling()
})
}
defer func() {
_ = originConn.Close()
}()
err = this.RawReq.Write(originConn)
if err != nil {
this.write50x(err, http.StatusBadGateway, "Failed to write request to origin site", "源站请求初始化失败", false)
return
}
requestClientConn, ok := requestConn.(ClientConnInterface)
if ok {
requestClientConn.SetIsPersistent(true)
}
clientConn, _, err := this.writer.Hijack()
if err != nil || clientConn == nil {
this.write50x(err, http.StatusInternalServerError, "Failed to get origin site connection", "获取源站连接失败", false)
return
}
defer func() {
_ = clientConn.Close()
}()
go func() {
// 读取第一个响应
var respReader = NewWebsocketResponseReader(originConn)
resp, respErr := http.ReadResponse(bufio.NewReader(respReader), this.RawReq)
if respErr != nil || resp == nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
_ = clientConn.Close()
_ = originConn.Close()
return
}
this.ProcessResponseHeaders(resp.Header, resp.StatusCode)
this.writer.statusCode = resp.StatusCode
// 将响应写回客户端
err = resp.Write(clientConn)
if err != nil {
if resp.Body != nil {
_ = resp.Body.Close()
}
_ = clientConn.Close()
_ = originConn.Close()
return
}
// 剩余已经从源站读取的内容
var headerBytes = respReader.buf
var headerIndex = bytes.Index(headerBytes, []byte{'\r', '\n', '\r', '\n'}) // CRLF
if headerIndex > 0 {
var leftBytes = headerBytes[headerIndex+4:]
if len(leftBytes) > 0 {
_, writeErr := clientConn.Write(leftBytes)
if writeErr != nil {
if resp.Body != nil {
_ = resp.Body.Close()
}
_ = clientConn.Close()
_ = originConn.Close()
return
}
}
}
if resp.Body != nil {
_ = resp.Body.Close()
}
// 复制剩余的数据
var buf = bytepool.Pool4k.Get()
defer bytepool.Pool4k.Put(buf)
for {
n, readErr := originConn.Read(buf.Bytes)
if n > 0 {
this.writer.sentBodyBytes += int64(n)
_, writeErr := clientConn.Write(buf.Bytes[:n])
if writeErr != nil {
break
}
}
if readErr != nil {
break
}
}
_ = clientConn.Close()
_ = originConn.Close()
}()
var buf = bytepool.Pool4k.Get()
_, _ = io.CopyBuffer(originConn, clientConn, buf.Bytes)
bytepool.Pool4k.Put(buf)
return
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,63 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"bufio"
"net"
"net/http"
)
// EmptyResponseWriter 空的响应Writer
type EmptyResponseWriter struct {
header http.Header
parentWriter http.ResponseWriter
statusCode int
}
func NewEmptyResponseWriter(parentWriter http.ResponseWriter) *EmptyResponseWriter {
return &EmptyResponseWriter{
header: http.Header{},
parentWriter: parentWriter,
}
}
func (this *EmptyResponseWriter) Header() http.Header {
return this.header
}
func (this *EmptyResponseWriter) Write(data []byte) (int, error) {
if this.statusCode > 300 && this.parentWriter != nil {
return this.parentWriter.Write(data)
}
return 0, nil
}
func (this *EmptyResponseWriter) WriteHeader(statusCode int) {
this.statusCode = statusCode
if this.statusCode > 300 && this.parentWriter != nil {
var parentHeader = this.parentWriter.Header()
for k, v := range this.header {
parentHeader[k] = v
}
this.parentWriter.WriteHeader(this.statusCode)
}
}
func (this *EmptyResponseWriter) StatusCode() int {
return this.statusCode
}
// Hijack Hijack
func (this *EmptyResponseWriter) Hijack() (conn net.Conn, buf *bufio.ReadWriter, err error) {
if this.parentWriter == nil {
return
}
hijack, ok := this.parentWriter.(http.Hijacker)
if ok {
return hijack.Hijack()
}
return
}

View File

@@ -0,0 +1,17 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build !plus
// +build !plus
package nodes
import (
"os"
)
func (this *HTTPWriter) canSendfile() (*os.File, bool) {
return nil, false
}
func (this *HTTPWriter) checkPlanBandwidth(n int) {
// stub
}

View File

@@ -0,0 +1,50 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus
// +build plus
package nodes
import (
"github.com/TeaOSLab/EdgeNode/internal/caches"
"github.com/TeaOSLab/EdgeNode/internal/utils/writers"
"io"
"os"
)
func (this *HTTPWriter) canSendfile() (*os.File, bool) {
if this.cacheWriter != nil || this.webpIsEncoding || this.isPartial || this.delayRead || this.compressionCacheWriter != nil {
return nil, false
}
if this.rawReader == nil {
return nil, false
}
fileReader, ok := this.rawReader.(*caches.FileReader)
if !ok {
return nil, false
}
_, ok = this.rawWriter.(io.ReaderFrom)
if !ok {
return nil, false
}
counterWriter, ok := this.writer.(*writers.BytesCounterWriter)
if !ok {
return nil, false
}
if counterWriter.RawWriter() != this.rawWriter {
return nil, false
}
return fileReader.FP(), true
}
// 检查套餐带宽限速
func (this *HTTPWriter) checkPlanBandwidth(n int) {
if this.req.ReqServer != nil && this.req.ReqServer.PlanId() > 0 {
sharedPlanBandwidthLimiter.Ack(this.req.RawReq.Context(), this.req.ReqServer.PlanId(), n)
}
}

View File

@@ -0,0 +1,230 @@
package nodes
import (
"context"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils/goman"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"net"
"strings"
"sync"
)
type Listener struct {
group *serverconfigs.ServerAddressGroup
listener ListenerInterface // 监听器
locker sync.RWMutex
}
func NewListener() *Listener {
return &Listener{}
}
func (this *Listener) Reload(group *serverconfigs.ServerAddressGroup) {
this.locker.Lock()
this.group = group
if this.listener != nil {
this.listener.Reload(group)
}
this.locker.Unlock()
}
func (this *Listener) FullAddr() string {
if this.group != nil {
return this.group.FullAddr()
}
return ""
}
func (this *Listener) Listen() error {
if this.group == nil {
return nil
}
var protocol = this.group.Protocol()
if protocol.IsUDPFamily() {
return this.listenUDP()
}
return this.listenTCP()
}
func (this *Listener) listenTCP() error {
if this.group == nil {
return nil
}
var protocol = this.group.Protocol()
tcpListener, err := this.createTCPListener()
if err != nil {
return err
}
var netListener = NewClientListener(tcpListener, protocol.IsHTTPFamily() || protocol.IsHTTPSFamily())
events.OnKey(events.EventQuit, this, func() {
remotelogs.Println("LISTENER", "quit "+this.group.FullAddr())
_ = netListener.Close()
})
switch protocol {
case serverconfigs.ProtocolHTTP, serverconfigs.ProtocolHTTP4, serverconfigs.ProtocolHTTP6:
this.listener = &HTTPListener{
BaseListener: BaseListener{Group: this.group},
Listener: netListener,
}
case serverconfigs.ProtocolHTTPS, serverconfigs.ProtocolHTTPS4, serverconfigs.ProtocolHTTPS6:
netListener.SetIsTLS(true)
this.listener = &HTTPListener{
BaseListener: BaseListener{Group: this.group},
Listener: netListener,
}
case serverconfigs.ProtocolTCP, serverconfigs.ProtocolTCP4, serverconfigs.ProtocolTCP6:
this.listener = &TCPListener{
BaseListener: BaseListener{Group: this.group},
Listener: netListener,
}
case serverconfigs.ProtocolTLS, serverconfigs.ProtocolTLS4, serverconfigs.ProtocolTLS6:
netListener.SetIsTLS(true)
this.listener = &TCPListener{
BaseListener: BaseListener{Group: this.group},
Listener: netListener,
}
default:
return errors.New("unknown protocol '" + protocol.String() + "'")
}
this.listener.Init()
goman.New(func() {
err := this.listener.Serve()
if err != nil {
// 在这里屏蔽accept错误防止在优雅关闭的时候有多余的提示
opErr, ok := err.(*net.OpError)
if ok && opErr.Op == "accept" {
return
}
// 打印其他错误
remotelogs.Error("LISTENER", err.Error())
}
})
return nil
}
func (this *Listener) listenUDP() error {
var addr = this.group.Addr()
var ipv4PacketListener *ipv4.PacketConn
var ipv6PacketListener *ipv6.PacketConn
host, _, err := net.SplitHostPort(addr)
if err != nil {
return err
}
if len(host) == 0 {
// ipv4
ipv4Listener, err := this.createUDPIPv4Listener()
if err == nil {
ipv4PacketListener = ipv4.NewPacketConn(ipv4Listener)
} else {
remotelogs.Error("LISTENER", "create udp ipv4 listener '"+addr+"': "+err.Error())
}
// ipv6
ipv6Listener, err := this.createUDPIPv6Listener()
if err == nil {
ipv6PacketListener = ipv6.NewPacketConn(ipv6Listener)
} else {
remotelogs.Error("LISTENER", "create udp ipv6 listener '"+addr+"': "+err.Error())
}
} else if strings.Contains(host, ":") { // ipv6
ipv6Listener, err := this.createUDPIPv6Listener()
if err == nil {
ipv6PacketListener = ipv6.NewPacketConn(ipv6Listener)
} else {
remotelogs.Error("LISTENER", "create udp ipv6 listener '"+addr+"': "+err.Error())
}
} else { // ipv4
ipv4Listener, err := this.createUDPIPv4Listener()
if err == nil {
ipv4PacketListener = ipv4.NewPacketConn(ipv4Listener)
} else {
remotelogs.Error("LISTENER", "create udp ipv4 listener '"+addr+"': "+err.Error())
}
}
events.OnKey(events.EventQuit, this, func() {
remotelogs.Println("LISTENER", "quit "+this.group.FullAddr())
if ipv4PacketListener != nil {
_ = ipv4PacketListener.Close()
}
if ipv6PacketListener != nil {
_ = ipv6PacketListener.Close()
}
})
this.listener = &UDPListener{
BaseListener: BaseListener{Group: this.group},
IPv4Listener: ipv4PacketListener,
IPv6Listener: ipv6PacketListener,
}
goman.New(func() {
err := this.listener.Serve()
if err != nil {
remotelogs.Error("LISTENER", err.Error())
}
})
return nil
}
func (this *Listener) Close() error {
events.Remove(this)
if this.listener == nil {
return nil
}
return this.listener.Close()
}
// 创建TCP监听器
func (this *Listener) createTCPListener() (net.Listener, error) {
var listenConfig = net.ListenConfig{
Control: nil,
KeepAlive: 0,
}
switch this.group.Protocol() {
case serverconfigs.ProtocolHTTP4, serverconfigs.ProtocolHTTPS4, serverconfigs.ProtocolTLS4:
return listenConfig.Listen(context.Background(), "tcp4", this.group.Addr())
case serverconfigs.ProtocolHTTP6, serverconfigs.ProtocolHTTPS6, serverconfigs.ProtocolTLS6:
return listenConfig.Listen(context.Background(), "tcp6", this.group.Addr())
}
return listenConfig.Listen(context.Background(), "tcp", this.group.Addr())
}
// 创建UDP IPv4监听器
func (this *Listener) createUDPIPv4Listener() (*net.UDPConn, error) {
addr, err := net.ResolveUDPAddr("udp", this.group.Addr())
if err != nil {
return nil, err
}
return net.ListenUDP("udp4", addr)
}
// 创建UDP监听器
func (this *Listener) createUDPIPv6Listener() (*net.UDPConn, error) {
addr, err := net.ResolveUDPAddr("udp", this.group.Addr())
if err != nil {
return nil, err
}
return net.ListenUDP("udp6", addr)
}

View File

@@ -0,0 +1,276 @@
package nodes
import (
"crypto/tls"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/iwind/TeaGo/types"
"net"
)
type BaseListener struct {
Group *serverconfigs.ServerAddressGroup
countActiveConnections int64 // 当前活跃的连接数
}
// Init 初始化
func (this *BaseListener) Init() {
}
// Reset 清除既有配置
func (this *BaseListener) Reset() {
}
// CountActiveConnections 获取当前活跃连接数
func (this *BaseListener) CountActiveConnections() int {
return types.Int(this.countActiveConnections)
}
// 构造TLS配置
func (this *BaseListener) buildTLSConfig() *tls.Config {
return &tls.Config{
Certificates: nil,
GetConfigForClient: func(clientInfo *tls.ClientHelloInfo) (config *tls.Config, e error) {
// 指纹信息
var fingerprint = this.calculateFingerprint(clientInfo)
if len(fingerprint) > 0 && clientInfo.Conn != nil {
clientConn, ok := clientInfo.Conn.(ClientConnInterface)
if ok {
clientConn.SetFingerprint(fingerprint)
}
}
tlsPolicy, _, err := this.matchSSL(this.helloServerNames(clientInfo))
if err != nil {
return nil, err
}
if tlsPolicy == nil {
return nil, nil
}
tlsPolicy.CheckOCSP()
return tlsPolicy.TLSConfig(), nil
},
GetCertificate: func(clientInfo *tls.ClientHelloInfo) (certificate *tls.Certificate, e error) {
// 指纹信息
var fingerprint = this.calculateFingerprint(clientInfo)
if len(fingerprint) > 0 && clientInfo.Conn != nil {
clientConn, ok := clientInfo.Conn.(ClientConnInterface)
if ok {
clientConn.SetFingerprint(fingerprint)
}
}
tlsPolicy, cert, err := this.matchSSL(this.helloServerNames(clientInfo))
if err != nil {
return nil, err
}
if cert == nil {
return nil, errors.New("no ssl certs found for '" + clientInfo.ServerName + "'")
}
tlsPolicy.CheckOCSP()
return cert, nil
},
}
}
// 根据域名匹配证书
func (this *BaseListener) matchSSL(domains []string) (*sslconfigs.SSLPolicy, *tls.Certificate, error) {
var group = this.Group
if group == nil {
return nil, nil, errors.New("no configure found")
}
var globalServerConfig *serverconfigs.GlobalServerConfig
if sharedNodeConfig != nil {
globalServerConfig = sharedNodeConfig.GlobalServerConfig
}
// 如果域名为空,则取第一个
// 通常域名为空是因为是直接通过IP访问的
if len(domains) == 0 {
if group.IsHTTPS() && globalServerConfig != nil && globalServerConfig.HTTPAll.MatchDomainStrictly {
return nil, nil, errors.New("no tls server name matched")
}
firstServer := group.FirstTLSServer()
if firstServer == nil {
return nil, nil, errors.New("no tls server available")
}
sslConfig := firstServer.SSLPolicy()
if sslConfig != nil {
return sslConfig, sslConfig.FirstCert(), nil
}
return nil, nil, errors.New("no tls server name found")
}
var firstDomain = domains[0]
// 通过网站域名配置匹配
var server *serverconfigs.ServerConfig
var matchedDomain string
for _, domain := range domains {
server, _ = this.findNamedServer(domain, true)
if server != nil {
matchedDomain = domain
break
}
}
if server == nil {
server, _ = this.findNamedServer(firstDomain, false)
if server != nil {
matchedDomain = firstDomain
}
}
if server == nil {
// 找不到或者此时的服务没有配置证书需要搜索所有的Server通过SSL证书内容中的DNSName匹配
// 此功能仅为了兼容以往版本v1.0.4),不应该作为常态启用
if globalServerConfig != nil && globalServerConfig.HTTPAll.MatchCertFromAllServers {
for _, searchingServer := range group.Servers() {
if searchingServer.SSLPolicy() == nil || !searchingServer.SSLPolicy().IsOn {
continue
}
cert, ok := searchingServer.SSLPolicy().MatchDomain(firstDomain)
if ok {
return searchingServer.SSLPolicy(), cert, nil
}
}
}
return nil, nil, errors.New("no server found for '" + firstDomain + "'")
}
if server.SSLPolicy() == nil || !server.SSLPolicy().IsOn {
// 找不到或者此时的服务没有配置证书需要搜索所有的Server通过SSL证书内容中的DNSName匹配
// 此功能仅为了兼容以往版本v1.0.4),不应该作为常态启用
if globalServerConfig != nil && globalServerConfig.HTTPAll.MatchCertFromAllServers {
for _, searchingServer := range group.Servers() {
if searchingServer.SSLPolicy() == nil || !searchingServer.SSLPolicy().IsOn {
continue
}
cert, ok := searchingServer.SSLPolicy().MatchDomain(matchedDomain)
if ok {
return searchingServer.SSLPolicy(), cert, nil
}
}
}
return nil, nil, errors.New("no cert found for '" + matchedDomain + "'")
}
// 证书是否匹配
var sslConfig = server.SSLPolicy()
cert, ok := sslConfig.MatchDomain(matchedDomain)
if ok {
return sslConfig, cert, nil
}
if len(sslConfig.Certs) == 0 {
remotelogs.ServerError(server.Id, "BASE_LISTENER", "no ssl certs found for '"+matchedDomain+"', server id: "+types.String(server.Id), "", nil)
}
return sslConfig, sslConfig.FirstCert(), nil
}
// 根据域名来查找匹配的域名
func (this *BaseListener) findNamedServer(name string, exactly bool) (serverConfig *serverconfigs.ServerConfig, serverName string) {
serverConfig, serverName = this.findNamedServerMatched(name)
if serverConfig != nil {
return
}
var globalServerConfig = sharedNodeConfig.GlobalServerConfig
var matchDomainStrictly = globalServerConfig != nil && globalServerConfig.HTTPAll.MatchDomainStrictly
if globalServerConfig != nil &&
len(globalServerConfig.HTTPAll.DefaultDomain) > 0 &&
(!matchDomainStrictly || configutils.MatchDomains(globalServerConfig.HTTPAll.AllowMismatchDomains, name) || (globalServerConfig.HTTPAll.AllowNodeIP && utils.IsWildIP(name))) {
if globalServerConfig.HTTPAll.AllowNodeIP &&
globalServerConfig.HTTPAll.NodeIPShowPage &&
utils.IsWildIP(name) {
return
} else {
var defaultDomain = globalServerConfig.HTTPAll.DefaultDomain
serverConfig, serverName = this.findNamedServerMatched(defaultDomain)
if serverConfig != nil {
return
}
}
}
if matchDomainStrictly && !configutils.MatchDomains(globalServerConfig.HTTPAll.AllowMismatchDomains, name) && (!globalServerConfig.HTTPAll.AllowNodeIP || (!utils.IsWildIP(name) || globalServerConfig.HTTPAll.NodeIPShowPage)) {
return
}
if !exactly {
// 如果没有找到,则匹配到第一个
var group = this.Group
var currentServers = group.Servers()
var countServers = len(currentServers)
if countServers == 0 {
return nil, ""
}
return currentServers[0], name
}
return
}
// 严格查找域名
func (this *BaseListener) findNamedServerMatched(name string) (serverConfig *serverconfigs.ServerConfig, serverName string) {
var group = this.Group
if group == nil {
return nil, ""
}
server := group.MatchServerName(name)
if server != nil {
return server, name
}
// 是否严格匹配域名
var matchDomainStrictly = sharedNodeConfig.GlobalServerConfig != nil && sharedNodeConfig.GlobalServerConfig.HTTPAll.MatchDomainStrictly
// 如果只有一个server则默认为这个
var currentServers = group.Servers()
var countServers = len(currentServers)
if countServers == 1 && !matchDomainStrictly {
return currentServers[0], name
}
return nil, name
}
// 从Hello信息中获取服务名称
func (this *BaseListener) helloServerNames(clientInfo *tls.ClientHelloInfo) (serverNames []string) {
if len(clientInfo.ServerName) != 0 {
serverNames = append(serverNames, clientInfo.ServerName)
return
}
if clientInfo.Conn != nil {
var localAddr = clientInfo.Conn.LocalAddr()
if localAddr != nil {
tcpAddr, ok := localAddr.(*net.TCPAddr)
if ok {
serverNames = append(serverNames, tcpAddr.IP.String())
}
}
}
serverNames = append(serverNames, sharedNodeConfig.IPAddresses...)
return
}

View File

@@ -0,0 +1,10 @@
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build !plus
package nodes
import "crypto/tls"
func (this *BaseListener) calculateFingerprint(clientInfo *tls.ClientHelloInfo) []byte {
return nil
}

View File

@@ -0,0 +1,50 @@
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package nodes
import (
"crypto/md5"
"crypto/tls"
"encoding/binary"
"hash"
"sync"
)
var md5Pool = &sync.Pool{
New: func() any {
return md5.New()
},
}
func md5Bytes(b []byte) []byte {
var h = md5Pool.Get().(hash.Hash)
h.Write(b)
var sum = h.Sum(nil)
h.Reset()
md5Pool.Put(h)
return sum
}
func (this *BaseListener) calculateFingerprint(clientInfo *tls.ClientHelloInfo) []byte {
var b = []byte{}
for _, c := range clientInfo.CipherSuites {
b = binary.BigEndian.AppendUint16(b, c)
}
b = append(b, 0)
for _, c := range clientInfo.SupportedCurves {
b = binary.BigEndian.AppendUint16(b, uint16(c))
}
b = append(b, 0)
b = append(b, clientInfo.SupportedPoints...)
b = append(b, 0)
for _, s := range clientInfo.SignatureSchemes {
b = binary.BigEndian.AppendUint16(b, uint16(s))
}
b = append(b, 0)
return md5Bytes(b)
}

View File

@@ -0,0 +1,37 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"context"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/iwind/TeaGo/types"
"testing"
"time"
)
func TestBaseListener_FindServer(t *testing.T) {
sharedNodeConfig = &nodeconfigs.NodeConfig{}
var listener = &BaseListener{}
listener.Group = serverconfigs.NewServerAddressGroup("https://*:443")
for i := 0; i < 1_000_000; i++ {
var server = &serverconfigs.ServerConfig{
IsOn: true,
Name: types.String(i) + ".hello.com",
ServerNames: []*serverconfigs.ServerNameConfig{
{Name: types.String(i) + ".hello.com"},
},
}
_ = server.Init(context.Background())
listener.Group.Add(server)
}
var before = time.Now()
defer func() {
t.Log(time.Since(before).Seconds()*1000, "ms")
}()
t.Log(listener.findNamedServerMatched("855555.hello.com"))
}

View File

@@ -0,0 +1,278 @@
package nodes
import (
"context"
"crypto/tls"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/iwind/TeaGo/Tea"
"io"
"log"
"net"
"net/http"
"strings"
"sync/atomic"
"time"
)
var httpErrorLogger = log.New(io.Discard, "", 0)
const HTTPIdleTimeout = 75 * time.Second
type contextKey struct {
key string
}
var HTTPConnContextKey = &contextKey{key: "http-conn"}
type HTTPListener struct {
BaseListener
Listener net.Listener
addr string
isHTTP bool
isHTTPS bool
isHTTP3 bool
httpServer *http.Server
}
func (this *HTTPListener) Serve() error {
this.addr = this.Group.Addr()
this.isHTTP = this.Group.IsHTTP()
this.isHTTPS = this.Group.IsHTTPS()
this.httpServer = &http.Server{
Addr: this.addr,
Handler: this,
ReadHeaderTimeout: 3 * time.Second, // TODO 改成可以配置
IdleTimeout: HTTPIdleTimeout, // TODO 改成可以配置
ConnState: func(conn net.Conn, state http.ConnState) {
switch state {
case http.StateNew:
atomic.AddInt64(&this.countActiveConnections, 1)
case http.StateClosed:
atomic.AddInt64(&this.countActiveConnections, -1)
default:
// do nothing
}
},
ConnContext: func(ctx context.Context, conn net.Conn) context.Context {
tlsConn, ok := conn.(*tls.Conn)
if ok {
conn = NewClientTLSConn(tlsConn)
}
return context.WithValue(ctx, HTTPConnContextKey, conn)
},
}
if !Tea.IsTesting() {
this.httpServer.ErrorLog = httpErrorLogger
}
this.httpServer.SetKeepAlivesEnabled(true)
// HTTP协议
if this.isHTTP {
err := this.httpServer.Serve(this.Listener)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
}
}
// HTTPS协议
if this.isHTTPS {
this.httpServer.TLSConfig = this.buildTLSConfig()
err := this.httpServer.ServeTLS(this.Listener, "", "")
if err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
}
}
return nil
}
func (this *HTTPListener) Close() error {
if this.httpServer != nil {
_ = this.httpServer.Close()
}
return this.Listener.Close()
}
func (this *HTTPListener) Reload(group *serverconfigs.ServerAddressGroup) {
this.Group = group
this.Reset()
}
// ServeHTTPWithAddr 处理HTTP请求
func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.Request) {
this.ServeHTTPWithAddr(rawWriter, rawReq, this.addr)
}
// ServeHTTPWithAddr 处理HTTP请求并指定服务地址
func (this *HTTPListener) ServeHTTPWithAddr(rawWriter http.ResponseWriter, rawReq *http.Request, serverAddr string) {
if len(rawReq.Host) > 253 {
http.Error(rawWriter, "Host too long.", http.StatusBadRequest)
time.Sleep(1 * time.Second) // make connection slow down
return
}
var globalServerConfig = sharedNodeConfig.GlobalServerConfig
if globalServerConfig != nil && !globalServerConfig.HTTPAll.SupportsLowVersionHTTP && (rawReq.ProtoMajor < 1 /** 0.x **/ || (rawReq.ProtoMajor == 1 && rawReq.ProtoMinor == 0 /** 1.0 **/)) {
http.Error(rawWriter, rawReq.Proto+" request is not supported.", http.StatusHTTPVersionNotSupported)
time.Sleep(1 * time.Second) // make connection slow down
return
}
// 不支持Connect
if rawReq.Method == http.MethodConnect {
http.Error(rawWriter, "Method Not Allowed", http.StatusMethodNotAllowed)
time.Sleep(1 * time.Second) // make connection slow down
return
}
// 域名
var reqHost = strings.ToLower(strings.TrimRight(rawReq.Host, "."))
// TLS域名
if this.isIP(reqHost) {
if rawReq.TLS != nil {
var serverName = rawReq.TLS.ServerName
if len(serverName) > 0 {
// 端口
var index = strings.LastIndex(reqHost, ":")
if index >= 0 {
reqHost = serverName + reqHost[index:]
} else {
reqHost = serverName
}
}
}
}
// 防止空Host
if len(reqHost) == 0 {
var ctx = rawReq.Context()
if ctx != nil {
addr := ctx.Value(http.LocalAddrContextKey)
if addr != nil {
reqHost = addr.(net.Addr).String()
}
}
}
domain, _, err := net.SplitHostPort(reqHost)
if err != nil {
domain = reqHost
}
server, serverName := this.findNamedServer(domain, false)
if server == nil {
server = this.emptyServer()
} else if !server.CNameAsDomain && server.CNameDomain == domain {
server = this.emptyServer()
} else {
serverName = domain
}
// 绑定连接
var clientConn ClientConnInterface
if server != nil && server.Id > 0 {
var requestConn = rawReq.Context().Value(HTTPConnContextKey)
if requestConn != nil {
var ok bool
clientConn, ok = requestConn.(ClientConnInterface)
if ok {
var goNext = clientConn.SetServerId(server.Id)
if !goNext {
return
}
clientConn.SetUserId(server.UserId)
var userPlanId int64
if server.UserPlan != nil && server.UserPlan.Id > 0 {
userPlanId = server.UserPlan.Id
}
clientConn.SetUserPlanId(userPlanId)
}
}
}
// 检查用户
if server != nil && server.UserId > 0 {
if !SharedUserManager.CheckUserServersIsEnabled(server.UserId) {
rawWriter.WriteHeader(http.StatusNotFound)
_, _ = rawWriter.Write([]byte("The site owner is unavailable."))
return
}
}
// 包装新请求对象
var req = &HTTPRequest{
RawReq: rawReq,
RawWriter: rawWriter,
ReqServer: server,
ReqHost: reqHost,
ServerName: serverName,
ServerAddr: serverAddr,
IsHTTP: this.isHTTP,
IsHTTPS: this.isHTTPS,
IsHTTP3: this.isHTTP3,
nodeConfig: sharedNodeConfig,
}
req.Do()
// fix hijacked connection state
if req.isHijacked && clientConn != nil && this.httpServer.ConnState != nil {
netConn, ok := clientConn.(net.Conn)
if ok {
this.httpServer.ConnState(netConn, http.StateClosed)
}
}
}
// 检查host是否为IP
func (this *HTTPListener) isIP(host string) bool {
// IPv6
if strings.Contains(host, "[") {
return true
}
for _, b := range host {
if b >= 'a' && b <= 'z' {
return false
}
}
return true
}
// 默认的访问日志
func (this *HTTPListener) emptyServer() *serverconfigs.ServerConfig {
var server = &serverconfigs.ServerConfig{
Type: serverconfigs.ServerTypeHTTPProxy,
}
// 检查是否开启访问日志
if sharedNodeConfig != nil {
var globalServerConfig = sharedNodeConfig.GlobalServerConfig
if globalServerConfig != nil && globalServerConfig.HTTPAccessLog.EnableServerNotFound {
var accessLogRef = serverconfigs.NewHTTPAccessLogRef()
accessLogRef.IsOn = true
accessLogRef.Fields = append([]int{}, serverconfigs.HTTPAccessLogDefaultFieldsCodes...)
server.Web = &serverconfigs.HTTPWebConfig{
IsOn: true,
AccessLogRef: accessLogRef,
}
}
}
// TODO 需要对访问频率过多的IP进行惩罚
return server
}

View File

@@ -0,0 +1,21 @@
package nodes
import "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
// ListenerInterface 各协议监听器的接口
type ListenerInterface interface {
// Init 初始化
Init()
// Serve 监听
Serve() error
// Close 关闭
Close() error
// Reload 重载配置
Reload(serverGroup *serverconfigs.ServerAddressGroup)
// CountActiveConnections 获取当前活跃的连接数
CountActiveConnections() int
}

View File

@@ -0,0 +1,369 @@
package nodes
import (
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
"github.com/TeaOSLab/EdgeNode/internal/utils/goman"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
"net/url"
"regexp"
"runtime"
"sort"
"strings"
"sync"
"time"
)
var sharedListenerManager *ListenerManager
func init() {
if !teaconst.IsMain {
return
}
sharedListenerManager = NewListenerManager()
}
// ListenerManager 端口监听管理器
type ListenerManager struct {
listenersMap map[string]*Listener // addr => *Listener
http3Listener *HTTPListener
locker sync.Mutex
lastConfig *nodeconfigs.NodeConfig
retryListenerMap map[string]*Listener // 需要重试的监听器 addr => Listener
ticker *time.Ticker
firewalld *firewalls.Firewalld
lastPortStrings string
lastTCPPortRanges [][2]int
lastUDPPortRanges [][2]int
}
// NewListenerManager 获取新对象
func NewListenerManager() *ListenerManager {
var manager = &ListenerManager{
listenersMap: map[string]*Listener{},
retryListenerMap: map[string]*Listener{},
ticker: time.NewTicker(1 * time.Minute),
firewalld: firewalls.NewFirewalld(),
}
// 提升测试效率
if Tea.IsTesting() {
manager.ticker = time.NewTicker(5 * time.Second)
}
goman.New(func() {
for range manager.ticker.C {
manager.retryListeners()
}
})
return manager
}
// Start 启动监听
func (this *ListenerManager) Start(nodeConfig *nodeconfigs.NodeConfig) error {
this.locker.Lock()
defer this.locker.Unlock()
// 重置数据
this.retryListenerMap = map[string]*Listener{}
// 检查是否有变化
/**if this.lastConfig != nil && this.lastConfig.Version == node.Version {
return nil
}**/
this.lastConfig = nodeConfig
// 所有的新地址
var groupAddrs = []string{}
var availableServerGroups = nodeConfig.AvailableGroups()
if !nodeConfig.IsOn {
availableServerGroups = []*serverconfigs.ServerAddressGroup{}
}
if len(availableServerGroups) == 0 {
remotelogs.Println("LISTENER_MANAGER", "no available servers to startup")
}
for _, group := range availableServerGroups {
var addr = group.FullAddr()
groupAddrs = append(groupAddrs, addr)
}
// 停掉老的
for listenerKey, listener := range this.listenersMap {
var addr = listener.FullAddr()
if !lists.ContainsString(groupAddrs, addr) {
remotelogs.Println("LISTENER_MANAGER", "close '"+addr+"'")
_ = listener.Close()
delete(this.listenersMap, listenerKey)
}
}
// 启动新的或修改老的
for _, group := range availableServerGroups {
var addr = group.FullAddr()
listener, ok := this.listenersMap[addr]
if ok {
// 不需要打印reload信息防止日志数量过多
listener.Reload(group)
} else {
remotelogs.Println("LISTENER_MANAGER", "listen '"+this.prettyAddress(addr)+"'")
listener = NewListener()
listener.Reload(group)
err := listener.Listen()
if err != nil {
// 放入到重试队列中
this.retryListenerMap[addr] = listener
var firstServer = group.FirstServer()
if firstServer == nil {
remotelogs.Error("LISTENER_MANAGER", err.Error())
} else {
// 当前占用的进程名
if strings.Contains(err.Error(), "in use") {
portIndex := strings.LastIndex(addr, ":")
if portIndex > 0 {
var port = addr[portIndex+1:]
var processName = this.findProcessNameWithPort(group.IsUDP(), port)
if len(processName) > 0 {
err = fmt.Errorf("%w (the process using port: '%s')", err, processName)
}
}
}
remotelogs.ServerError(firstServer.Id, "LISTENER_MANAGER", "listen '"+addr+"' failed: "+err.Error(), nodeconfigs.NodeLogTypeListenAddressFailed, maps.Map{"address": addr})
}
continue
} else {
// TODO 是否是从错误中恢复
}
this.listenersMap[addr] = listener
}
}
// 加入到firewalld
go this.addToFirewalld(groupAddrs)
return nil
}
// TotalActiveConnections 获取总的活跃连接数
func (this *ListenerManager) TotalActiveConnections() int {
this.locker.Lock()
defer this.locker.Unlock()
var total = 0
for _, listener := range this.listenersMap {
total += listener.listener.CountActiveConnections()
}
if this.http3Listener != nil {
total += this.http3Listener.CountActiveConnections()
}
return total
}
// 返回更加友好格式的地址
func (this *ListenerManager) prettyAddress(addr string) string {
u, err := url.Parse(addr)
if err != nil {
return addr
}
if regexp.MustCompile(`^:\d+$`).MatchString(u.Host) {
u.Host = "*" + u.Host
}
return u.String()
}
// 重试失败的Listener
func (this *ListenerManager) retryListeners() {
this.locker.Lock()
defer this.locker.Unlock()
for addr, listener := range this.retryListenerMap {
err := listener.Listen()
if err == nil {
delete(this.retryListenerMap, addr)
this.listenersMap[addr] = listener
remotelogs.ServerSuccess(listener.group.FirstServer().Id, "LISTENER_MANAGER", "retry to listen '"+addr+"' successfully", nodeconfigs.NodeLogTypeListenAddressFailed, maps.Map{"address": addr})
}
}
}
func (this *ListenerManager) findProcessNameWithPort(isUdp bool, port string) string {
if runtime.GOOS != "linux" {
return ""
}
path, err := executils.LookPath("ss")
if err != nil {
return ""
}
var option = "t"
if isUdp {
option = "u"
}
var cmd = executils.NewTimeoutCmd(10*time.Second, path, "-"+option+"lpn", "sport = :"+port)
cmd.WithStdout()
err = cmd.Run()
if err != nil {
return ""
}
var matches = regexp.MustCompile(`(?U)\(\("(.+)",pid=\d+,fd=\d+\)\)`).FindStringSubmatch(cmd.Stdout())
if len(matches) > 1 {
return matches[1]
}
return ""
}
func (this *ListenerManager) addToFirewalld(groupAddrs []string) {
if !sharedNodeConfig.AutoOpenPorts {
return
}
if this.firewalld == nil || !this.firewalld.IsReady() {
return
}
// HTTP/3相关端口
var http3Ports = sharedNodeConfig.FindHTTP3Ports()
if len(http3Ports) > 0 {
for _, port := range http3Ports {
var groupAddr = "udp://:" + types.String(port)
if !lists.ContainsString(groupAddrs, groupAddr) {
groupAddrs = append(groupAddrs, groupAddr)
}
}
}
// 组合端口号
var portStrings = []string{}
var udpPorts = []int{}
var tcpPorts = []int{}
for _, addr := range groupAddrs {
var protocol = "tcp"
if strings.HasPrefix(addr, "udp") {
protocol = "udp"
}
var lastIndex = strings.LastIndex(addr, ":")
if lastIndex > 0 {
var portString = addr[lastIndex+1:]
portStrings = append(portStrings, portString+"/"+protocol)
switch protocol {
case "tcp":
tcpPorts = append(tcpPorts, types.Int(portString))
case "udp":
udpPorts = append(udpPorts, types.Int(portString))
}
}
}
if len(portStrings) == 0 {
return
}
// 检查是否有变化
sort.Strings(portStrings)
var newPortStrings = strings.Join(portStrings, ",")
if newPortStrings == this.lastPortStrings {
return
}
this.locker.Lock()
this.lastPortStrings = newPortStrings
this.locker.Unlock()
remotelogs.Println("FIREWALLD", "opening ports automatically ...")
defer func() {
remotelogs.Println("FIREWALLD", "open ports successfully")
}()
// 合并端口
var tcpPortRanges = utils.MergePorts(tcpPorts)
var udpPortRanges = utils.MergePorts(udpPorts)
defer func() {
this.locker.Lock()
this.lastTCPPortRanges = tcpPortRanges
this.lastUDPPortRanges = udpPortRanges
this.locker.Unlock()
}()
// 删除老的不存在的端口
var tcpPortRangesMap = map[string]bool{}
var udpPortRangesMap = map[string]bool{}
for _, portRange := range tcpPortRanges {
tcpPortRangesMap[this.firewalld.PortRangeString(portRange, "tcp")] = true
}
for _, portRange := range udpPortRanges {
udpPortRangesMap[this.firewalld.PortRangeString(portRange, "udp")] = true
}
for _, portRange := range this.lastTCPPortRanges {
var s = this.firewalld.PortRangeString(portRange, "tcp")
_, ok := tcpPortRangesMap[s]
if ok {
continue
}
remotelogs.Println("FIREWALLD", "remove port '"+s+"'")
_ = this.firewalld.RemovePortRangePermanently(portRange, "tcp")
}
for _, portRange := range this.lastUDPPortRanges {
var s = this.firewalld.PortRangeString(portRange, "udp")
_, ok := udpPortRangesMap[s]
if ok {
continue
}
remotelogs.Println("FIREWALLD", "remove port '"+s+"'")
_ = this.firewalld.RemovePortRangePermanently(portRange, "udp")
}
// 添加新的
_ = this.firewalld.AllowPortRangesPermanently(tcpPortRanges, "tcp")
_ = this.firewalld.AllowPortRangesPermanently(udpPortRanges, "udp")
}
func (this *ListenerManager) reloadFirewalld() {
this.locker.Lock()
defer this.locker.Unlock()
var nodeConfig = sharedNodeConfig
// 所有的新地址
var groupAddrs = []string{}
var availableServerGroups = nodeConfig.AvailableGroups()
if !nodeConfig.IsOn {
availableServerGroups = []*serverconfigs.ServerAddressGroup{}
}
if len(availableServerGroups) == 0 {
remotelogs.Println("LISTENER_MANAGER", "no available servers to startup")
}
for _, group := range availableServerGroups {
var addr = group.FullAddr()
groupAddrs = append(groupAddrs, addr)
}
go this.addToFirewalld(groupAddrs)
}

View File

@@ -0,0 +1,84 @@
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"testing"
)
func TestListenerManager_Listen(t *testing.T) {
manager := NewListenerManager()
err := manager.Start(&nodeconfigs.NodeConfig{
Servers: []*serverconfigs.ServerConfig{
{
IsOn: true,
HTTP: &serverconfigs.HTTPProtocolConfig{
BaseProtocol: serverconfigs.BaseProtocol{
IsOn: true,
Listen: []*serverconfigs.NetworkAddressConfig{
{
Protocol: serverconfigs.ProtocolHTTP,
PortRange: "1234",
},
},
},
},
},
{
IsOn: true,
HTTP: &serverconfigs.HTTPProtocolConfig{
BaseProtocol: serverconfigs.BaseProtocol{
IsOn: true,
Listen: []*serverconfigs.NetworkAddressConfig{
{
Protocol: serverconfigs.ProtocolHTTP,
PortRange: "1235",
},
},
},
},
},
},
})
if err != nil {
t.Fatal(err)
}
err = manager.Start(&nodeconfigs.NodeConfig{
Servers: []*serverconfigs.ServerConfig{
{
IsOn: true,
HTTP: &serverconfigs.HTTPProtocolConfig{
BaseProtocol: serverconfigs.BaseProtocol{
IsOn: true,
Listen: []*serverconfigs.NetworkAddressConfig{
{
Protocol: serverconfigs.ProtocolHTTP,
PortRange: "1234",
},
},
},
},
},
{
IsOn: true,
HTTP: &serverconfigs.HTTPProtocolConfig{
BaseProtocol: serverconfigs.BaseProtocol{
IsOn: true,
Listen: []*serverconfigs.NetworkAddressConfig{
{
Protocol: serverconfigs.ProtocolHTTP,
PortRange: "1236",
},
},
},
},
},
},
})
if err != nil {
t.Fatal(err)
}
t.Log("all ok")
}

View File

@@ -0,0 +1,318 @@
package nodes
import (
"crypto/tls"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/stats"
"github.com/TeaOSLab/EdgeNode/internal/utils/bytepool"
"github.com/TeaOSLab/EdgeNode/internal/utils/goman"
"github.com/iwind/TeaGo/types"
"github.com/pires/go-proxyproto"
"net"
"strings"
"sync/atomic"
)
type TCPListener struct {
BaseListener
Listener net.Listener
port int
}
func (this *TCPListener) Serve() error {
var listener = this.Listener
if this.Group.IsTLS() {
listener = tls.NewListener(listener, this.buildTLSConfig())
}
// 获取分组端口
var groupAddr = this.Group.Addr()
var portIndex = strings.LastIndex(groupAddr, ":")
if portIndex >= 0 {
var port = groupAddr[portIndex+1:]
this.port = types.Int(port)
}
for {
conn, err := listener.Accept()
if err != nil {
break
}
atomic.AddInt64(&this.countActiveConnections, 1)
go func(conn net.Conn) {
var server = this.Group.FirstServer()
if server == nil {
return
}
err = this.handleConn(server, conn)
if err != nil {
remotelogs.ServerError(server.Id, "TCP_LISTENER", err.Error(), "", nil)
}
atomic.AddInt64(&this.countActiveConnections, -1)
}(conn)
}
return nil
}
func (this *TCPListener) Reload(group *serverconfigs.ServerAddressGroup) {
this.Group = group
this.Reset()
}
func (this *TCPListener) handleConn(server *serverconfigs.ServerConfig, conn net.Conn) error {
if server == nil {
return errors.New("no server available")
}
if server.ReverseProxy == nil {
return errors.New("no ReverseProxy configured for the server")
}
// 绑定连接和服务
clientConn, ok := conn.(ClientConnInterface)
if ok {
var goNext = clientConn.SetServerId(server.Id)
if !goNext {
return nil
}
clientConn.SetUserId(server.UserId)
var userPlanId int64
if server.UserPlan != nil && server.UserPlan.Id > 0 {
userPlanId = server.UserPlan.Id
}
clientConn.SetUserPlanId(userPlanId)
} else {
tlsConn, ok := conn.(*tls.Conn)
if ok {
var internalConn = tlsConn.NetConn()
if internalConn != nil {
clientConn, ok = internalConn.(ClientConnInterface)
if ok {
var goNext = clientConn.SetServerId(server.Id)
if !goNext {
return nil
}
clientConn.SetUserId(server.UserId)
var userPlanId int64
if server.UserPlan != nil && server.UserPlan.Id > 0 {
userPlanId = server.UserPlan.Id
}
clientConn.SetUserPlanId(userPlanId)
}
}
}
}
// 是否已达到流量限制
if this.reachedTrafficLimit() || (server.UserId > 0 && !SharedUserManager.CheckUserServersIsEnabled(server.UserId)) {
// 关闭连接
tcpConn, ok := conn.(LingerConn)
if ok {
_ = tcpConn.SetLinger(0)
}
_ = conn.Close()
// TODO 使用系统防火墙drop当前端口的数据包一段时间1分钟
// 不能使用阻止IP的方法因为边缘节点只上有可能还有别的代理服务
return nil
}
// 记录域名排行
tlsConn, ok := conn.(*tls.Conn)
var recordStat = false
var serverName = ""
if ok {
serverName = tlsConn.ConnectionState().ServerName
if len(serverName) > 0 {
// 统计
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, serverName, 0, 0, 1, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
recordStat = true
}
}
// 统计
if !recordStat {
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", 0, 0, 1, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
}
// DAU统计
clientIP, _, parseErr := net.SplitHostPort(conn.RemoteAddr().String())
if parseErr == nil {
stats.SharedDAUManager.AddIP(server.Id, clientIP)
}
originConn, err := this.connectOrigin(server.Id, serverName, server.ReverseProxy, conn.RemoteAddr().String())
if err != nil {
_ = conn.Close()
return err
}
var closer = func() {
_ = conn.Close()
_ = originConn.Close()
}
// PROXY Protocol
if server.ReverseProxy != nil &&
server.ReverseProxy.ProxyProtocol != nil &&
server.ReverseProxy.ProxyProtocol.IsOn &&
(server.ReverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion1 || server.ReverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion2) {
var remoteAddr = conn.RemoteAddr()
var transportProtocol = proxyproto.TCPv4
if strings.Contains(remoteAddr.String(), "[") {
transportProtocol = proxyproto.TCPv6
}
var header = proxyproto.Header{
Version: byte(server.ReverseProxy.ProxyProtocol.Version),
Command: proxyproto.PROXY,
TransportProtocol: transportProtocol,
SourceAddr: remoteAddr,
DestinationAddr: conn.LocalAddr(),
}
_, err = header.WriteTo(originConn)
if err != nil {
closer()
return err
}
}
// 从源站读取
goman.New(func() {
var originBuf = bytepool.Pool16k.Get()
defer func() {
bytepool.Pool16k.Put(originBuf)
}()
for {
n, err := originConn.Read(originBuf.Bytes)
if n > 0 {
_, err = conn.Write(originBuf.Bytes[:n])
if err != nil {
closer()
break
}
// 记录流量
if server != nil {
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", int64(n), 0, 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
}
}
if err != nil {
closer()
break
}
}
})
// 从客户端读取
var clientBuf = bytepool.Pool16k.Get()
defer func() {
bytepool.Pool16k.Put(clientBuf)
}()
for {
// 是否已达到流量限制
if this.reachedTrafficLimit() {
closer()
return nil
}
n, err := conn.Read(clientBuf.Bytes)
if n > 0 {
_, err = originConn.Write(clientBuf.Bytes[:n])
if err != nil {
break
}
}
if err != nil {
break
}
}
// 关闭连接
closer()
return nil
}
func (this *TCPListener) Close() error {
return this.Listener.Close()
}
// 连接源站
func (this *TCPListener) connectOrigin(serverId int64, requestHost string, reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) {
if reverseProxy == nil {
return nil, errors.New("no reverse proxy config")
}
var requestCall = shared.NewRequestCall()
requestCall.Domain = requestHost
var retries = 3
var addr string
var failedOriginIds []int64
for i := 0; i < retries; i++ {
var origin *serverconfigs.OriginConfig
if len(failedOriginIds) > 0 {
origin = reverseProxy.AnyOrigin(requestCall, failedOriginIds)
}
if origin == nil {
origin = reverseProxy.NextOrigin(requestCall)
}
if origin == nil {
continue
}
// 回源主机名
if len(origin.RequestHost) > 0 {
requestHost = origin.RequestHost
} else if len(reverseProxy.RequestHost) > 0 {
requestHost = reverseProxy.RequestHost
}
conn, addr, err = OriginConnect(origin, this.port, remoteAddr, requestHost)
if err != nil {
failedOriginIds = append(failedOriginIds, origin.Id)
remotelogs.ServerError(serverId, "TCP_LISTENER", "unable to connect origin server: "+addr+": "+err.Error(), "", nil)
SharedOriginStateManager.Fail(origin, requestHost, reverseProxy, func() {
reverseProxy.ResetScheduling()
})
continue
} else {
if !origin.IsOk {
SharedOriginStateManager.Success(origin, func() {
reverseProxy.ResetScheduling()
})
}
return
}
}
if err == nil {
err = errors.New("server '" + types.String(serverId) + "': no available origin server can be used")
}
return
}
// 检查是否已经达到流量限制
func (this *TCPListener) reachedTrafficLimit() bool {
var server = this.Group.FirstServer()
if server == nil {
return true
}
return server.TrafficLimitStatus != nil && server.TrafficLimitStatus.IsValid()
}

View File

@@ -0,0 +1,18 @@
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"testing"
)
func TestListener_Listen(t *testing.T) {
listener := NewListener()
group := serverconfigs.NewServerAddressGroup("https://:1234")
listener.Reload(group)
err := listener.Listen()
if err != nil {
t.Fatal(err)
}
}

View File

@@ -0,0 +1,465 @@
package nodes
import (
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/stats"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/utils/bytepool"
"github.com/TeaOSLab/EdgeNode/internal/utils/goman"
"github.com/iwind/TeaGo/types"
"github.com/pires/go-proxyproto"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"net"
"strings"
"sync"
"time"
)
const (
UDPConnLifeSeconds = 30
)
type UDPPacketListener interface {
ReadFrom(b []byte) (n int, cm any, src net.Addr, err error)
WriteTo(b []byte, cm any, dst net.Addr) (n int, err error)
LocalAddr() net.Addr
}
type UDPIPv4Listener struct {
rawListener *ipv4.PacketConn
}
func NewUDPIPv4Listener(rawListener *ipv4.PacketConn) *UDPIPv4Listener {
return &UDPIPv4Listener{rawListener: rawListener}
}
func (this *UDPIPv4Listener) ReadFrom(b []byte) (n int, cm any, src net.Addr, err error) {
return this.rawListener.ReadFrom(b)
}
func (this *UDPIPv4Listener) WriteTo(b []byte, cm any, dst net.Addr) (n int, err error) {
return this.rawListener.WriteTo(b, cm.(*ipv4.ControlMessage), dst)
}
func (this *UDPIPv4Listener) LocalAddr() net.Addr {
return this.rawListener.LocalAddr()
}
type UDPIPv6Listener struct {
rawListener *ipv6.PacketConn
}
func NewUDPIPv6Listener(rawListener *ipv6.PacketConn) *UDPIPv6Listener {
return &UDPIPv6Listener{rawListener: rawListener}
}
func (this *UDPIPv6Listener) ReadFrom(b []byte) (n int, cm any, src net.Addr, err error) {
return this.rawListener.ReadFrom(b)
}
func (this *UDPIPv6Listener) WriteTo(b []byte, cm any, dst net.Addr) (n int, err error) {
return this.rawListener.WriteTo(b, cm.(*ipv6.ControlMessage), dst)
}
func (this *UDPIPv6Listener) LocalAddr() net.Addr {
return this.rawListener.LocalAddr()
}
type UDPListener struct {
BaseListener
IPv4Listener *ipv4.PacketConn
IPv6Listener *ipv6.PacketConn
connMap map[string]*UDPConn
connLocker sync.Mutex
connTicker *utils.Ticker
reverseProxy *serverconfigs.ReverseProxyConfig
port int
isClosed bool
}
func (this *UDPListener) Serve() error {
if this.Group == nil {
return nil
}
var server = this.Group.FirstServer()
if server == nil {
return nil
}
var serverId = server.Id
var wg = &sync.WaitGroup{}
wg.Add(2) // 2 = ipv4 + ipv6
go func() {
defer wg.Done()
if this.IPv4Listener != nil {
err := this.IPv4Listener.SetControlMessage(ipv4.FlagDst, true)
if err != nil {
remotelogs.ServerError(serverId, "UDP_LISTENER", "can not serve ipv4 listener: "+err.Error(), "", nil)
return
}
err = this.servePacketListener(NewUDPIPv4Listener(this.IPv4Listener))
if err != nil {
remotelogs.ServerError(serverId, "UDP_LISTENER", "can not serve ipv4 listener: "+err.Error(), "", nil)
return
}
}
}()
go func() {
defer wg.Done()
if this.IPv6Listener != nil {
err := this.IPv6Listener.SetControlMessage(ipv6.FlagDst, true)
if err != nil {
remotelogs.ServerError(serverId, "UDP_LISTENER", "can not serve ipv6 listener: "+err.Error(), "", nil)
return
}
err = this.servePacketListener(NewUDPIPv6Listener(this.IPv6Listener))
if err != nil {
remotelogs.ServerError(serverId, "UDP_LISTENER", "can not serve ipv6 listener: "+err.Error(), "", nil)
return
}
}
}()
wg.Wait()
return nil
}
func (this *UDPListener) servePacketListener(listener UDPPacketListener) error {
// 获取分组端口
var groupAddr = this.Group.Addr()
var portIndex = strings.LastIndex(groupAddr, ":")
if portIndex >= 0 {
var port = groupAddr[portIndex+1:]
this.port = types.Int(port)
}
var firstServer = this.Group.FirstServer()
if firstServer == nil {
return errors.New("no server available")
}
this.reverseProxy = firstServer.ReverseProxy
if this.reverseProxy == nil {
return errors.New("no ReverseProxy configured for the server '" + firstServer.Name + "'")
}
this.connMap = map[string]*UDPConn{}
this.connTicker = utils.NewTicker(1 * time.Minute)
goman.New(func() {
for this.connTicker.Next() {
this.gcConns()
}
})
var buffer = make([]byte, 4<<10)
for {
if this.isClosed {
return nil
}
// 检查用户状态
if firstServer.UserId > 0 && !SharedUserManager.CheckUserServersIsEnabled(firstServer.UserId) {
return nil
}
n, cm, clientAddr, err := listener.ReadFrom(buffer)
if err != nil {
if this.isClosed {
return nil
}
return err
}
// 检查IP名单
clientIP, _, parseHostErr := net.SplitHostPort(clientAddr.String())
if parseHostErr == nil {
ok, _, expiresAt := iplibrary.AllowIP(clientIP, firstServer.Id)
if !ok {
firewalls.DropTemporaryTo(clientIP, expiresAt)
continue
}
}
if n > 0 {
this.connLocker.Lock()
conn, ok := this.connMap[clientAddr.String()]
this.connLocker.Unlock()
if ok && !conn.IsOk() {
_ = conn.Close()
ok = false
}
if !ok {
originConn, err := this.connectOrigin(firstServer.Id, this.reverseProxy, listener.LocalAddr(), clientAddr)
if err != nil {
remotelogs.Error("UDP_LISTENER", "unable to connect to origin server: "+err.Error())
continue
}
if originConn == nil {
remotelogs.Error("UDP_LISTENER", "unable to find a origin server")
continue
}
conn = NewUDPConn(firstServer, clientAddr, listener, cm, originConn.(*net.UDPConn))
this.connLocker.Lock()
this.connMap[clientAddr.String()] = conn
this.connLocker.Unlock()
}
_, _ = conn.Write(buffer[:n])
}
}
}
func (this *UDPListener) Close() error {
this.isClosed = true
if this.connTicker != nil {
this.connTicker.Stop()
}
// 关闭所有连接
this.connLocker.Lock()
for _, conn := range this.connMap {
_ = conn.Close()
}
this.connLocker.Unlock()
var errorStrings = []string{}
if this.IPv4Listener != nil {
err := this.IPv4Listener.Close()
if err != nil {
errorStrings = append(errorStrings, err.Error())
}
}
if this.IPv6Listener != nil {
err := this.IPv6Listener.Close()
if err != nil {
errorStrings = append(errorStrings, err.Error())
}
}
if len(errorStrings) > 0 {
return errors.New(errorStrings[0])
}
return nil
}
func (this *UDPListener) Reload(group *serverconfigs.ServerAddressGroup) {
this.Group = group
this.Reset()
// 重置配置
var firstServer = this.Group.FirstServer()
if firstServer == nil {
return
}
this.reverseProxy = firstServer.ReverseProxy
}
func (this *UDPListener) connectOrigin(serverId int64, reverseProxy *serverconfigs.ReverseProxyConfig, localAddr net.Addr, remoteAddr net.Addr) (conn net.Conn, err error) {
if reverseProxy == nil {
return nil, errors.New("no reverse proxy config")
}
var retries = 3
var addr string
var failedOriginIds []int64
for i := 0; i < retries; i++ {
var origin *serverconfigs.OriginConfig
if len(failedOriginIds) > 0 {
origin = reverseProxy.AnyOrigin(nil, failedOriginIds)
}
if origin == nil {
origin = reverseProxy.NextOrigin(nil)
}
if origin == nil {
continue
}
conn, addr, err = OriginConnect(origin, this.port, remoteAddr.String(), "")
if err != nil {
failedOriginIds = append(failedOriginIds, origin.Id)
remotelogs.ServerError(serverId, "UDP_LISTENER", "unable to connect origin server: "+addr+": "+err.Error(), "", nil)
SharedOriginStateManager.Fail(origin, "", reverseProxy, func() {
reverseProxy.ResetScheduling()
})
continue
} else {
if !origin.IsOk {
SharedOriginStateManager.Success(origin, func() {
reverseProxy.ResetScheduling()
})
}
// PROXY Protocol
if reverseProxy != nil &&
reverseProxy.ProxyProtocol != nil &&
reverseProxy.ProxyProtocol.IsOn &&
(reverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion1 || reverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion2) {
var transportProtocol = proxyproto.UDPv4
if strings.Contains(remoteAddr.String(), "[") {
transportProtocol = proxyproto.UDPv6
}
var header = proxyproto.Header{
Version: byte(reverseProxy.ProxyProtocol.Version),
Command: proxyproto.PROXY,
TransportProtocol: transportProtocol,
SourceAddr: remoteAddr,
DestinationAddr: localAddr,
}
_, err = header.WriteTo(conn)
if err != nil {
_ = conn.Close()
return nil, err
}
}
return
}
}
if err == nil {
err = errors.New("server '" + types.String(serverId) + "': no available origin server can be used")
}
return
}
// 回收连接
func (this *UDPListener) gcConns() {
this.connLocker.Lock()
var closingConns = []*UDPConn{}
for addr, conn := range this.connMap {
if !conn.IsOk() {
closingConns = append(closingConns, conn)
delete(this.connMap, addr)
}
}
this.connLocker.Unlock()
for _, conn := range closingConns {
_ = conn.Close()
}
}
// UDPConn 自定义的UDP连接管理
type UDPConn struct {
addr net.Addr
proxyListener UDPPacketListener
serverConn net.Conn
activatedAt int64
isOk bool
isClosed bool
}
func NewUDPConn(server *serverconfigs.ServerConfig, clientAddr net.Addr, proxyListener UDPPacketListener, cm any, serverConn *net.UDPConn) *UDPConn {
var conn = &UDPConn{
addr: clientAddr,
proxyListener: proxyListener,
serverConn: serverConn,
activatedAt: time.Now().Unix(),
isOk: true,
}
// 统计
if server != nil {
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", 0, 0, 1, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
// DAU统计
clientIP, _, parseErr := net.SplitHostPort(clientAddr.String())
if parseErr == nil {
stats.SharedDAUManager.AddIP(server.Id, clientIP)
}
}
// 处理ControlMessage
switch controlMessage := cm.(type) {
case *ipv4.ControlMessage:
controlMessage.Src = controlMessage.Dst
case *ipv6.ControlMessage:
controlMessage.Src = controlMessage.Dst
}
goman.New(func() {
var buf = bytepool.Pool4k.Get()
defer func() {
bytepool.Pool4k.Put(buf)
}()
for {
n, err := serverConn.Read(buf.Bytes)
if n > 0 {
conn.activatedAt = time.Now().Unix()
_, writingErr := proxyListener.WriteTo(buf.Bytes[:n], cm, clientAddr)
if writingErr != nil {
conn.isOk = false
break
}
// 记录流量和带宽
if server != nil {
// 流量
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", int64(n), 0, 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
// 带宽
var userPlanId int64
if server.UserPlan != nil && server.UserPlan.Id > 0 {
userPlanId = server.UserPlan.Id
}
stats.SharedBandwidthStatManager.AddBandwidth(server.UserId, userPlanId, server.Id, int64(n), int64(n))
}
}
if err != nil {
conn.isOk = false
break
}
}
})
return conn
}
func (this *UDPConn) Write(b []byte) (n int, err error) {
this.activatedAt = time.Now().Unix()
n, err = this.serverConn.Write(b)
if err != nil {
this.isOk = false
}
return
}
func (this *UDPConn) Close() error {
this.isOk = false
if this.isClosed {
return nil
}
this.isClosed = true
return this.serverConn.Close()
}
func (this *UDPConn) IsOk() bool {
if !this.isOk {
return false
}
return time.Now().Unix()-this.activatedAt < UDPConnLifeSeconds // 如果超过 N 秒没有活动我们认为是超时
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,21 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build !script
// +build !script
package nodes
func (this *Node) reloadCommonScripts() error {
return nil
}
func (this *Node) reloadIPLibrary() {
}
func (this *Node) notifyPlusChange() error {
return nil
}
func (this *Node) execTOAChangedTask() error {
return nil
}

View File

@@ -0,0 +1,142 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus && script
package nodes
import (
"encoding/json"
"fmt"
iplib "github.com/TeaOSLab/EdgeCommon/pkg/iplibrary"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/utils/goman"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/lists"
"time"
)
func init() {
if !teaconst.IsMain {
return
}
events.On(events.EventLoaded, func() {
var plusTicker = time.NewTicker(1 * time.Hour)
if Tea.IsTesting() {
plusTicker = time.NewTicker(1 * time.Minute)
}
goman.New(func() {
for range plusTicker.C {
_ = nodeInstance.notifyPlusChange()
}
})
})
}
var lastEdition string
func (this *Node) reloadCommonScripts() error {
// 下载配置
rpcClient, err := rpc.SharedRPC()
if err != nil {
return err
}
configsResp, err := rpcClient.ScriptRPC.ComposeScriptConfigs(rpcClient.Context(), &pb.ComposeScriptConfigsRequest{})
if err != nil {
return err
}
if len(configsResp.ScriptConfigsJSON) == 0 {
sharedNodeConfig.CommonScripts = nil
} else {
var configs = []*serverconfigs.CommonScript{}
err = json.Unmarshal(configsResp.ScriptConfigsJSON, &configs)
if err != nil {
return fmt.Errorf("decode script configs failed: %w", err)
}
sharedNodeConfig.CommonScripts = configs
}
// 通知更新
select {
case commonScriptsChangesChan <- true:
default:
}
return nil
}
func (this *Node) reloadIPLibrary() {
if sharedNodeConfig.Edition == lastEdition {
return
}
go func() {
var err error
lastEdition = sharedNodeConfig.Edition
if len(lastEdition) > 0 && (lists.ContainsString([]string{"pro", "ent", "max", "ultra"}, lastEdition)) {
err = iplib.InitPlus()
} else {
err = iplib.InitDefault()
}
if err != nil {
remotelogs.Error("IP_LIBRARY", "load ip library failed: "+err.Error())
}
}()
}
func (this *Node) notifyPlusChange() error {
rpcClient, err := rpc.SharedRPC()
if err != nil {
return err
}
resp, err := rpcClient.AuthorityKeyRPC.CheckAuthority(rpcClient.Context(), &pb.CheckAuthorityRequest{})
if err != nil {
return err
}
var isChanged = resp.Edition != sharedNodeConfig.Edition
if resp.IsPlus {
sharedNodeConfig.Edition = resp.Edition
} else {
sharedNodeConfig.Edition = ""
}
if isChanged {
this.reloadIPLibrary()
}
return nil
}
func (this *Node) execTOAChangedTask() error {
if sharedTOAManager == nil {
return nil
}
rpcClient, err := rpc.SharedRPC()
if err != nil {
return err
}
resp, err := rpcClient.NodeRPC.FindNodeTOAConfig(rpcClient.Context(), &pb.FindNodeTOAConfigRequest{})
if err != nil {
return err
}
if len(resp.ToaJSON) == 0 {
return sharedTOAManager.Apply(&nodeconfigs.TOAConfig{IsOn: false})
}
var config = nodeconfigs.NewTOAConfig()
err = json.Unmarshal(resp.ToaJSON, config)
if err != nil {
return err
}
return sharedTOAManager.Apply(config)
}

View File

@@ -0,0 +1,44 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build !arm64
// +build !arm64
package nodes
import (
"bytes"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/logs"
"os"
"syscall"
)
// 处理异常
func (this *Node) handlePanic() {
// 如果是在前台运行就直接返回
backgroundEnv, _ := os.LookupEnv("EdgeBackground")
if backgroundEnv != "on" {
return
}
var panicFile = Tea.Root + "/logs/panic.log"
// 分析panic
data, err := os.ReadFile(panicFile)
if err == nil {
var index = bytes.Index(data, []byte("panic:"))
if index >= 0 {
remotelogs.Error("NODE", "系统错误,请上报给开发者: "+string(data[index:]))
}
}
fp, err := os.OpenFile(panicFile, os.O_CREATE|os.O_TRUNC|os.O_WRONLY|os.O_APPEND, 0777)
if err != nil {
logs.Println("NODE", "open 'panic.log' failed: "+err.Error())
return
}
err = syscall.Dup2(int(fp.Fd()), int(os.Stderr.Fd()))
if err != nil {
logs.Println("NODE", "write to 'panic.log' failed: "+err.Error())
}
}

View File

@@ -0,0 +1,9 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build arm64
package nodes
// 处理异常
func (this *Node) handlePanic() {
}

View File

@@ -0,0 +1,379 @@
package nodes
import (
"encoding/json"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/caches"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
"github.com/TeaOSLab/EdgeNode/internal/monitor"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/utils"
fsutils "github.com/TeaOSLab/EdgeNode/internal/utils/fs"
"github.com/TeaOSLab/EdgeNode/internal/utils/trackers"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/maps"
"github.com/shirou/gopsutil/v3/cpu"
"github.com/shirou/gopsutil/v3/disk"
"github.com/shirou/gopsutil/v3/net"
"math"
"os"
"runtime"
"strings"
"time"
)
type NodeStatusExecutor struct {
isFirstTime bool
lastUpdatedTime time.Time
cpuLogicalCount int
cpuPhysicalCount int
// 流量统计
lastIOCounterStat net.IOCountersStat
lastUDPInDatagrams int64
lastUDPOutDatagrams int64
apiCallStat *rpc.CallStat
ticker *time.Ticker
}
func NewNodeStatusExecutor() *NodeStatusExecutor {
return &NodeStatusExecutor{
ticker: time.NewTicker(30 * time.Second),
apiCallStat: rpc.NewCallStat(10),
lastUDPInDatagrams: -1,
lastUDPOutDatagrams: -1,
}
}
func (this *NodeStatusExecutor) Listen() {
this.isFirstTime = true
this.lastUpdatedTime = time.Now()
this.update()
events.OnKey(events.EventQuit, this, func() {
remotelogs.Println("NODE_STATUS", "quit executor")
this.ticker.Stop()
})
for range this.ticker.C {
this.isFirstTime = false
this.update()
}
}
func (this *NodeStatusExecutor) update() {
if sharedNodeConfig == nil {
return
}
var tr = trackers.Begin("UPLOAD_NODE_STATUS")
defer tr.End()
var status = &nodeconfigs.NodeStatus{}
status.BuildVersion = teaconst.Version
status.BuildVersionCode = utils.VersionToLong(teaconst.Version)
status.OS = runtime.GOOS
status.Arch = runtime.GOARCH
status.ExePath, _ = os.Executable()
status.ConfigVersion = sharedNodeConfig.Version
status.IsActive = true
status.ConnectionCount = sharedListenerManager.TotalActiveConnections()
status.CacheTotalDiskSize = caches.SharedManager.TotalDiskSize()
status.CacheTotalMemorySize = caches.SharedManager.TotalMemorySize()
status.TrafficInBytes = teaconst.InTrafficBytes
status.TrafficOutBytes = teaconst.OutTrafficBytes
apiSuccessPercent, apiAvgCostSeconds := this.apiCallStat.Sum()
status.APISuccessPercent = apiSuccessPercent
status.APIAvgCostSeconds = apiAvgCostSeconds
var localFirewall = firewalls.Firewall()
if localFirewall != nil && !localFirewall.IsMock() {
status.LocalFirewallName = localFirewall.Name()
}
// 记录监控数据
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemConnections, maps.Map{
"total": status.ConnectionCount,
})
hostname, _ := os.Hostname()
status.Hostname = hostname
var cpuTR = tr.Begin("cpu")
this.updateCPU(status)
cpuTR.End()
var memTR = tr.Begin("memory")
this.updateMem(status)
memTR.End()
var loadTR = tr.Begin("load")
this.updateLoad(status)
loadTR.End()
var diskTR = tr.Begin("disk")
this.updateDisk(status)
diskTR.End()
var cacheSpaceTR = tr.Begin("cache space")
this.updateCacheSpace(status)
cacheSpaceTR.End()
this.updateAllTraffic(status)
// 修改更新时间
this.lastUpdatedTime = time.Now()
status.UpdatedAt = time.Now().Unix()
status.Timestamp = status.UpdatedAt
// 发送数据
jsonData, err := json.Marshal(status)
if err != nil {
remotelogs.Error("NODE_STATUS", "serial NodeStatus fail: "+err.Error())
return
}
rpcClient, err := rpc.SharedRPC()
if err != nil {
remotelogs.Error("NODE_STATUS", "failed to open rpc: "+err.Error())
return
}
var before = time.Now()
_, err = rpcClient.NodeRPC.UpdateNodeStatus(rpcClient.Context(), &pb.UpdateNodeStatusRequest{
StatusJSON: jsonData,
})
var costSeconds = time.Since(before).Seconds()
this.apiCallStat.Add(err == nil, costSeconds)
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Warn("NODE_STATUS", "rpc UpdateNodeStatus() failed: "+err.Error())
} else {
remotelogs.Error("NODE_STATUS", "rpc UpdateNodeStatus() failed: "+err.Error())
}
return
}
}
// 更新CPU
func (this *NodeStatusExecutor) updateCPU(status *nodeconfigs.NodeStatus) {
var duration = time.Duration(0)
if this.isFirstTime {
duration = 100 * time.Millisecond
}
percents, err := cpu.Percent(duration, false)
if err != nil {
status.Error = "cpu.Percent(): " + err.Error()
return
}
if len(percents) == 0 {
return
}
status.CPUUsage = percents[0] / 100
// 记录监控数据
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemCPU, maps.Map{
"usage": status.CPUUsage,
"cores": runtime.NumCPU(),
})
if this.cpuLogicalCount == 0 && this.cpuPhysicalCount == 0 {
status.CPULogicalCount, err = cpu.Counts(true)
if err != nil {
status.Error = "cpu.Counts(): " + err.Error()
return
}
status.CPUPhysicalCount, err = cpu.Counts(false)
if err != nil {
status.Error = "cpu.Counts(): " + err.Error()
return
}
this.cpuLogicalCount = status.CPULogicalCount
this.cpuPhysicalCount = status.CPUPhysicalCount
} else {
status.CPULogicalCount = this.cpuLogicalCount
status.CPUPhysicalCount = this.cpuPhysicalCount
}
}
// 更新硬盘
func (this *NodeStatusExecutor) updateDisk(status *nodeconfigs.NodeStatus) {
status.DiskWritingSpeedMB = int(fsutils.DiskSpeedMB)
partitions, err := disk.Partitions(false)
if err != nil {
remotelogs.Error("NODE_STATUS", err.Error())
return
}
lists.Sort(partitions, func(i int, j int) bool {
p1 := partitions[i]
p2 := partitions[j]
return p1.Mountpoint > p2.Mountpoint
})
// 当前TeaWeb所在的fs
var rootFS = ""
var rootTotal = uint64(0)
var totalUsed = uint64(0)
if lists.ContainsString([]string{"darwin", "linux", "freebsd"}, runtime.GOOS) {
for _, p := range partitions {
if p.Mountpoint == "/" {
rootFS = p.Fstype
usage, _ := disk.Usage(p.Mountpoint)
if usage != nil {
rootTotal = usage.Total
totalUsed = usage.Used
}
break
}
}
}
var total = rootTotal
var maxUsage = float64(0)
for _, partition := range partitions {
if runtime.GOOS != "windows" && !strings.Contains(partition.Device, "/") && !strings.Contains(partition.Device, "\\") {
continue
}
// 跳过不同fs的
if len(rootFS) > 0 && rootFS != partition.Fstype {
continue
}
usage, err := disk.Usage(partition.Mountpoint)
if err != nil {
continue
}
if partition.Mountpoint != "/" && (usage.Total != rootTotal || total == 0) {
total += usage.Total
totalUsed += usage.Used
if usage.UsedPercent >= maxUsage {
maxUsage = usage.UsedPercent
status.DiskMaxUsagePartition = partition.Mountpoint
}
}
}
status.DiskTotal = total
if total > 0 {
status.DiskUsage = float64(totalUsed) / float64(total)
}
status.DiskMaxUsage = maxUsage / 100
// 记录监控数据
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemDisk, maps.Map{
"total": status.DiskTotal,
"usage": status.DiskUsage,
"maxUsage": status.DiskMaxUsage,
})
}
// 缓存空间
func (this *NodeStatusExecutor) updateCacheSpace(status *nodeconfigs.NodeStatus) {
var result = []maps.Map{}
var cachePaths = caches.SharedManager.FindAllCachePaths()
for _, path := range cachePaths {
stat, err := fsutils.StatDevice(path)
if err != nil {
return
}
result = append(result, maps.Map{
"path": path,
"total": stat.TotalSize(),
"avail": stat.FreeSize(),
"used": stat.UsedSize(),
})
}
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemCacheDir, maps.Map{
"dirs": result,
})
}
// 流量
func (this *NodeStatusExecutor) updateAllTraffic(status *nodeconfigs.NodeStatus) {
trafficCounters, err := net.IOCounters(true)
if err != nil {
remotelogs.Warn("NODE_STATUS_EXECUTOR", err.Error())
return
}
var allCounter = net.IOCountersStat{}
for _, counter := range trafficCounters {
// 跳过lo
if counter.Name == "lo" {
continue
}
allCounter.BytesRecv += counter.BytesRecv
allCounter.BytesSent += counter.BytesSent
}
if allCounter.BytesSent == 0 && allCounter.BytesRecv == 0 {
return
}
if this.lastIOCounterStat.BytesSent > 0 {
// 记录监控数据
if allCounter.BytesSent >= this.lastIOCounterStat.BytesSent && allCounter.BytesRecv >= this.lastIOCounterStat.BytesRecv {
var costSeconds = int(math.Ceil(time.Since(this.lastUpdatedTime).Seconds()))
if costSeconds > 0 {
var bytesSent = allCounter.BytesSent - this.lastIOCounterStat.BytesSent
var bytesRecv = allCounter.BytesRecv - this.lastIOCounterStat.BytesRecv
// UDP
var udpInDatagrams int64 = 0
var udpOutDatagrams int64 = 0
protoStats, protoErr := net.ProtoCounters([]string{"udp"})
if protoErr == nil {
for _, protoStat := range protoStats {
if protoStat.Protocol == "udp" {
udpInDatagrams = protoStat.Stats["InDatagrams"]
udpOutDatagrams = protoStat.Stats["OutDatagrams"]
if udpInDatagrams < 0 {
udpInDatagrams = 0
}
if udpOutDatagrams < 0 {
udpOutDatagrams = 0
}
}
}
}
var avgUDPInDatagrams int64 = 0
var avgUDPOutDatagrams int64 = 0
if this.lastUDPInDatagrams >= 0 && this.lastUDPOutDatagrams >= 0 {
avgUDPInDatagrams = (udpInDatagrams - this.lastUDPInDatagrams) / int64(costSeconds)
avgUDPOutDatagrams = (udpOutDatagrams - this.lastUDPOutDatagrams) / int64(costSeconds)
if avgUDPInDatagrams < 0 {
avgUDPInDatagrams = 0
}
if avgUDPOutDatagrams < 0 {
avgUDPOutDatagrams = 0
}
}
this.lastUDPInDatagrams = udpInDatagrams
this.lastUDPOutDatagrams = udpOutDatagrams
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemAllTraffic, maps.Map{
"inBytes": bytesRecv,
"outBytes": bytesSent,
"avgInBytes": bytesRecv / uint64(costSeconds),
"avgOutBytes": bytesSent / uint64(costSeconds),
"avgUDPInDatagrams": avgUDPInDatagrams,
"avgUDPOutDatagrams": avgUDPOutDatagrams,
})
}
}
}
this.lastIOCounterStat = allCounter
}

View File

@@ -0,0 +1,27 @@
package nodes
import (
"github.com/shirou/gopsutil/v3/cpu"
"testing"
"time"
)
func TestNodeStatusExecutor_CPU(t *testing.T) {
countLogicCPU, err := cpu.Counts(true)
if err != nil {
t.Fatal(err)
}
t.Log("logic count:", countLogicCPU)
countPhysicalCPU, err := cpu.Counts(false)
if err != nil {
t.Fatal(err)
}
t.Log("physical count:", countPhysicalCPU)
percents, err := cpu.Percent(100*time.Millisecond, false)
if err != nil {
t.Fatal(err)
}
t.Log(percents)
}

View File

@@ -0,0 +1,71 @@
//go:build !windows
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeNode/internal/monitor"
"github.com/iwind/TeaGo/maps"
"github.com/shirou/gopsutil/v3/load"
"github.com/shirou/gopsutil/v3/mem"
"runtime"
"runtime/debug"
)
// 更新内存
func (this *NodeStatusExecutor) updateMem(status *nodeconfigs.NodeStatus) {
stat, err := mem.VirtualMemory()
if err != nil {
return
}
// 重新计算内存
if stat.Total > 0 {
stat.Used = stat.Total - stat.Free - stat.Buffers - stat.Cached
status.MemoryUsage = float64(stat.Used) / float64(stat.Total)
}
status.MemoryTotal = stat.Total
// 记录监控数据
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemMemory, maps.Map{
"usage": status.MemoryUsage,
"total": status.MemoryTotal,
"used": stat.Used,
})
// 内存严重不足时自动释放内存
if stat.Total > 0 {
var minFreeMemory = stat.Total / 8
if minFreeMemory > 1<<30 {
minFreeMemory = 1 << 30
}
if stat.Available > 0 && stat.Available < minFreeMemory {
runtime.GC()
debug.FreeOSMemory()
}
}
}
// 更新负载
func (this *NodeStatusExecutor) updateLoad(status *nodeconfigs.NodeStatus) {
stat, err := load.Avg()
if err != nil {
status.Error = err.Error()
return
}
if stat == nil {
status.Error = "load is nil"
return
}
status.Load1m = stat.Load1
status.Load5m = stat.Load5
status.Load15m = stat.Load15
// 记录监控数据
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemLoad, maps.Map{
"load1m": status.Load1m,
"load5m": status.Load5m,
"load15m": status.Load15m,
})
}

Some files were not shown because too many files have changed in this diff Show More