v1.5.1 增强程序稳定性

This commit is contained in:
robin
2026-03-22 17:37:40 +08:00
parent afbaaa869c
commit 17e182b413
652 changed files with 22949 additions and 34397 deletions

View File

@@ -39,10 +39,11 @@ type IngestLog struct {
FirewallRuleGroupId int64 `json:"firewall_rule_group_id,omitempty"`
FirewallRuleSetId int64 `json:"firewall_rule_set_id,omitempty"`
FirewallRuleId int64 `json:"firewall_rule_id,omitempty"`
RequestHeaders string `json:"request_headers,omitempty"`
RequestBody string `json:"request_body,omitempty"`
ResponseHeaders string `json:"response_headers,omitempty"`
ResponseBody string `json:"response_body,omitempty"`
RequestHeaders string `json:"request_headers,omitempty"`
RequestBody string `json:"request_body,omitempty"`
ResponseHeaders string `json:"response_headers,omitempty"`
ResponseBody string `json:"response_body,omitempty"`
Attrs map[string]string `json:"attrs,omitempty"`
}
// stringsMapToJSON 将 map[string]*Strings 转为 JSON 字符串,便于落盘与 ClickHouse 存储
@@ -117,6 +118,7 @@ func FromHTTPAccessLog(l *pb.HTTPAccessLog, clusterId int64) (ingest IngestLog,
RequestHeaders: stringsMapToJSON(l.GetHeader()),
RequestBody: buildRequestBody(l),
ResponseHeaders: stringsMapToJSON(l.GetSentHeader()),
Attrs: l.GetAttrs(),
}
if ingest.IP == "" {
ingest.IP = l.GetRemoteAddr()

View File

@@ -3,6 +3,7 @@
package caches
import (
"database/sql"
"errors"
"fmt"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
@@ -141,7 +142,7 @@ func (this *SQLiteFileListDB) Init() error {
return err
}
this.selectByHashStmt, err = this.readDB.Prepare(`SELECT "key", "headerSize", "bodySize", "metaSize", "expiredAt" FROM "` + this.itemsTableName + `" WHERE "hash"=? LIMIT 1`)
this.selectByHashStmt, err = this.readDB.Prepare(`SELECT "key", "headerSize", "bodySize", "metaSize", "expiredAt", "staleAt", "host", "serverId", "createdAt" FROM "` + this.itemsTableName + `" WHERE "hash"=? LIMIT 1`)
if err != nil {
return err
}
@@ -302,6 +303,28 @@ func (this *SQLiteFileListDB) ListHashes(lastId int64) (hashList []string, maxId
return
}
func (this *SQLiteFileListDB) ReadItem(hash string) (*Item, error) {
if len(hash) == 0 {
return nil, nil
}
row := this.selectByHashStmt.QueryRow(hash)
if row == nil {
return nil, nil
}
var item = &Item{Type: ItemTypeFile}
err := row.Scan(&item.Key, &item.HeaderSize, &item.BodySize, &item.MetaSize, &item.ExpiresAt, &item.StaleAt, &item.Host, &item.ServerId, &item.CreatedAt)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, err
}
return item, nil
}
func (this *SQLiteFileListDB) IncreaseHitAsync(hash string) error {
// do nothing
return nil

View File

@@ -0,0 +1,96 @@
package caches
import (
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"os"
)
func MigrateSQLiteFileListDir(sqliteDir string, kvDir string) error {
if len(sqliteDir) == 0 || len(kvDir) == 0 {
return nil
}
_, err := os.Stat(sqliteDir)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
remotelogs.Println("CACHE", "migrating sqlite indexes from '"+sqliteDir+"' to '"+kvDir+"' ...")
src := NewSQLiteFileList(sqliteDir).(*SQLiteFileList)
err = src.Init()
if err != nil {
return err
}
defer func() {
_ = src.Close()
}()
dst := NewKVFileList(kvDir)
err = dst.Init()
if err != nil {
return err
}
defer func() {
_ = dst.Close()
}()
err = dst.CleanAll()
if err != nil {
return err
}
for _, db := range src.dbList {
if db == nil {
continue
}
var lastId int64
for {
hashes, maxId, listErr := db.ListHashes(lastId)
if listErr != nil {
return listErr
}
if len(hashes) == 0 {
break
}
for _, hash := range hashes {
item, readErr := db.ReadItem(hash)
if readErr != nil {
return readErr
}
if item == nil {
continue
}
addErr := dst.Add(hash, item)
if addErr != nil {
return addErr
}
}
lastId = maxId
}
}
for _, store := range dst.stores {
if store != nil && store.rawStore != nil {
err = store.rawStore.Flush()
if err != nil {
return err
}
}
}
err = os.RemoveAll(sqliteDir)
if err != nil {
return err
}
remotelogs.Println("CACHE", "migrated sqlite indexes to pebble")
return nil
}

View File

@@ -330,16 +330,25 @@ func (this *FileStorage) Init() error {
// read list
var list ListInterface
var sqliteIndexesDir = dir + "/p" + types.String(this.policy.Id) + "/.indexes"
var kvStoresDir = dir + "/p" + types.String(this.policy.Id) + "/.stores"
_, sqliteIndexesDirErr := os.Stat(sqliteIndexesDir)
if sqliteIndexesDirErr == nil || !teaconst.EnableKVCacheStore {
var useSQLite bool
if sqliteIndexesDirErr == nil {
err = MigrateSQLiteFileListDir(sqliteIndexesDir, kvStoresDir)
if err != nil {
remotelogs.Error("CACHE", "migrate sqlite indexes failed: "+err.Error())
useSQLite = true
}
}
if useSQLite {
list = NewSQLiteFileList(sqliteIndexesDir)
err = list.Init()
if err != nil {
return err
}
list.(*SQLiteFileList).SetOldDir(dir + "/p" + types.String(this.policy.Id))
} else {
list = NewKVFileList(dir + "/p" + types.String(this.policy.Id) + "/.stores")
list = NewKVFileList(kvStoresDir)
err = list.Init()
if err != nil {
return err

View File

@@ -1,7 +1,7 @@
package teaconst
const (
Version = "1.5.0" //1.3.8.2
Version = "1.5.1" //1.3.8.2
ProductName = "Edge Node"
ProcessName = "edge-node"

View File

@@ -220,6 +220,6 @@ func (this *Firewalld) pushCmd(cmd *executils.Cmd, denyIP string) {
select {
case this.cmdQueue <- &firewalldCmd{cmd: cmd, denyIP: denyIP}:
default:
// we discard the command
remotelogs.Warn("FIREWALL", "command queue full, discarding firewall command for IP: "+denyIP)
}
}

View File

@@ -5,7 +5,6 @@ package nftables
import (
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
@@ -28,15 +27,6 @@ func init() {
return
}
nodeConfig, err := nodeconfigs.SharedNodeConfig()
if err != nil {
return
}
if nodeConfig == nil || !nodeConfig.AutoInstallNftables {
return
}
if os.Getgid() == 0 { // root user only
if len(NftExePath()) > 0 {
return

View File

@@ -6,8 +6,10 @@ package http3
import (
"context"
"errors"
"fmt"
"github.com/quic-go/quic-go"
http3quic "github.com/quic-go/quic-go/http3"
"log"
"net"
"net/http"
)
@@ -45,6 +47,12 @@ func (this *Server) Serve(listener Listener) error {
continue
}
go func() {
defer func() {
if r := recover(); r != nil {
log.Println(fmt.Sprintf("[HTTP3]goroutine panic: %v", r))
}
}()
// 通知ConnState
if this.ConnState != nil {
netConn, isNetConn := conn.(net.Conn)

View File

@@ -25,8 +25,10 @@ func (this *IPItemEncoder[T]) EncodeField(value T, fieldName string) ([]byte, er
switch fieldName {
case "expiresAt":
var expiresAt = any(value).(*pb.IPItem).ExpiredAt
if expiresAt < 0 || expiresAt > int64(math.MaxUint32) {
if expiresAt < 0 {
expiresAt = 0
} else if expiresAt > int64(math.MaxUint32) {
expiresAt = int64(math.MaxUint32)
}
var b = make([]byte, 4)
binary.BigEndian.PutUint32(b, uint32(expiresAt))

View File

@@ -0,0 +1,94 @@
package iplibrary
import (
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"os"
)
func MigrateSQLiteIPListToKV(sqlitePath string) error {
_, err := os.Stat(sqlitePath)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
remotelogs.Println("IP_LIST_DB", "migrating sqlite data to kvstore ...")
src, err := NewSQLiteIPList()
if err != nil {
return err
}
defer func() {
_ = src.Close()
}()
dst, err := NewKVIPList()
if err != nil {
return err
}
defer func() {
_ = dst.Close()
}()
err = dst.ipTable.DB().Truncate()
if err != nil {
return err
}
var offset int64
const size int64 = 1000
for {
items, goNext, readErr := src.ReadItems(offset, size)
if readErr != nil {
return readErr
}
for _, item := range items {
addErr := dst.AddItem(item)
if addErr != nil {
return addErr
}
}
if !goNext {
break
}
offset += size
}
version, err := src.ReadMaxVersion()
if err != nil {
return err
}
if version > 0 {
err = dst.UpdateMaxVersion(version)
if err != nil {
return err
}
}
err = dst.Flush()
if err != nil {
return err
}
err = removeSQLiteFiles(sqlitePath)
if err != nil {
return err
}
remotelogs.Println("IP_LIST_DB", "migrated sqlite data to kvstore")
return nil
}
func removeSQLiteFiles(path string) error {
for _, filename := range []string{path, path + "-shm", path + "-wal"} {
err := os.Remove(filename)
if err != nil && !os.IsNotExist(err) {
return err
}
}
return nil
}

View File

@@ -117,11 +117,16 @@ func (this *IPListManager) Init() {
// 检查sqlite文件是否存在以便决定使用sqlite还是kv
var sqlitePath = Tea.Root + "/data/ip_list.db"
_, sqliteErr := os.Stat(sqlitePath)
var db IPListDB
var err error
if sqliteErr == nil || !teaconst.EnableKVCacheStore {
db, err = NewSQLiteIPList()
if sqliteErr == nil {
err = MigrateSQLiteIPListToKV(sqlitePath)
if err != nil {
remotelogs.Error("IP_LIST_MANAGER", "migrate sqlite data failed: "+err.Error())
db, err = NewSQLiteIPList()
} else {
db, err = NewKVIPList()
}
} else {
db, err = NewKVIPList()
}

View File

@@ -92,8 +92,14 @@ func (this *Manager) Update(items []*serverconfigs.MetricItemConfig) {
remotelogs.Println("METRIC_MANAGER", "start task '"+strconv.FormatInt(newItem.Id, 10)+"'")
var task Task
if CheckSQLiteDB(newItem.Id) || !teaconst.EnableKVCacheStore {
task = NewSQLiteTask(newItem)
if CheckSQLiteDB(newItem.Id) {
migrateErr := MigrateSQLiteTaskToKV(newItem)
if migrateErr != nil {
remotelogs.Error("METRIC_MANAGER", "migrate sqlite task failed: "+migrateErr.Error())
task = NewSQLiteTask(newItem)
} else {
task = NewKVTask(newItem)
}
} else {
task = NewKVTask(newItem)
}

View File

@@ -0,0 +1,151 @@
package metrics
import (
"encoding/json"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils/dbs"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/types"
"os"
)
func MigrateSQLiteTaskToKV(item *serverconfigs.MetricItemConfig) error {
var itemIdString = types.String(item.Id)
var sqlitePath = Tea.Root + "/data/metric." + itemIdString + ".db"
_, err := os.Stat(sqlitePath)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
dst := NewKVTask(item)
err = dst.Init()
if err != nil {
return err
}
err = dst.Truncate()
if err != nil {
return err
}
remotelogs.Println("METRIC", "migrating sqlite task '"+itemIdString+"' to kvstore ...")
src := NewSQLiteTask(item)
err = src.Init()
if err != nil {
return err
}
defer closeSQLiteTask(src)
var offset = 0
const size = 1000
for {
rows, queryErr := src.db.Query(`SELECT "hash", "keys", "value", "time", "serverId" FROM "`+src.statTableName+`" WHERE "version"=? ORDER BY "id" ASC LIMIT ?, ?`, item.Version, offset, size)
if queryErr != nil {
return queryErr
}
var countRows int
for rows.Next() {
countRows++
var hash string
var keysData []byte
var value float64
var timeString string
var serverId int64
scanErr := rows.Scan(&hash, &keysData, &value, &timeString, &serverId)
if scanErr != nil {
_ = rows.Close()
return scanErr
}
var keys []string
if len(keysData) > 0 {
unmarshalErr := json.Unmarshal(keysData, &keys)
if unmarshalErr != nil {
_ = rows.Close()
return unmarshalErr
}
}
insertErr := dst.InsertStat(&Stat{
ServerId: serverId,
Keys: keys,
Hash: hash,
Value: int64(value),
Time: timeString,
})
if insertErr != nil {
_ = rows.Close()
return insertErr
}
}
err = rows.Err()
if err != nil {
_ = rows.Close()
return err
}
closeErr := rows.Close()
if closeErr != nil {
return closeErr
}
if countRows < size {
break
}
offset += size
}
err = dst.Flush()
if err != nil {
return err
}
err = removeMetricSQLiteFiles(sqlitePath)
if err != nil {
return err
}
remotelogs.Println("METRIC", "migrated sqlite task '"+itemIdString+"' to kvstore")
return nil
}
func closeSQLiteTask(task *SQLiteTask) {
if task == nil {
return
}
for _, stmt := range []*dbs.Stmt{
task.insertStatStmt,
task.deleteByVersionStmt,
task.deleteByExpiresTimeStmt,
task.selectTopStmt,
task.sumStmt,
} {
if stmt != nil {
_ = stmt.Close()
}
}
if task.db != nil {
_ = task.db.Close()
}
}
func removeMetricSQLiteFiles(path string) error {
for _, filename := range []string{path, path + "-shm", path + "-wal"} {
err := os.Remove(filename)
if err != nil && !os.IsNotExist(err) {
return err
}
}
return nil
}

View File

@@ -407,7 +407,7 @@ func (this *APIStream) handleCheckLocalFirewall(message *pb.NodeStreamMessage) e
"version": version,
}
var protectionConfig = sharedNodeConfig.DDoSProtection
var protectionConfig = nodeConfig().DDoSProtection
err = firewalls.SharedDDoSProtectionManager.Apply(protectionConfig)
if err != nil {
this.replyFail(message.RequestId, dataMessage.Name+" was installed, but apply DDoS protection config failed: "+err.Error())

View File

@@ -71,7 +71,7 @@ func NewClientConn(rawConn net.Conn, isHTTP bool, isTLS bool, isInAllowList bool
}
// 超时等设置
var globalServerConfig = sharedNodeConfig.GlobalServerConfig
var globalServerConfig = nodeConfig().GlobalServerConfig
if globalServerConfig != nil {
var performanceConfig = globalServerConfig.Performance
conn.isDebugging = performanceConfig.Debug
@@ -136,7 +136,7 @@ func (this *ClientConn) Read(b []byte) (n int, err error) {
if !this.isPersistent && this.isHTTP && !this.isInAllowList && !utils.IsLocalIP(this.RawIP()) {
// SYN Flood检测
if this.serverId == 0 || !this.hasResetSYNFlood {
var synFloodConfig = sharedNodeConfig.SYNFloodConfig()
var synFloodConfig = nodeConfig().SYNFloodConfig()
if synFloodConfig != nil && synFloodConfig.IsOn {
if isHandshakeError {
this.increaseSYNFlood(synFloodConfig)

View File

@@ -7,6 +7,7 @@ import (
"context"
"crypto/tls"
"errors"
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/events"
@@ -41,14 +42,14 @@ func init() {
eventLocker.Lock()
defer eventLocker.Unlock()
if sharedNodeConfig == nil {
if nodeConfig() == nil {
return
}
_ = sharedHTTP3Manager.Update(sharedNodeConfig.HTTP3Policies)
_ = sharedHTTP3Manager.Update(nodeConfig().HTTP3Policies)
sharedHTTP3Manager.UpdateHTTPListener(listener)
listener.Reload(sharedNodeConfig.HTTP3Group())
listener.Reload(nodeConfig().HTTP3Group())
}()
})
}
@@ -219,6 +220,12 @@ func (this *HTTP3Manager) createServer(port int) (*http3.Server, error) {
},
}
go func() {
defer func() {
if r := recover(); r != nil {
remotelogs.Error("HTTP3_MANAGER", fmt.Sprintf("goroutine panic: %v", r))
}
}()
err = server.Serve(listener)
if err != nil {
remotelogs.Error("HTTP3_MANAGER", "serve '"+addr+"' failed: "+err.Error())

View File

@@ -95,8 +95,8 @@ Loop:
}
var writeTargets *serverconfigs.AccessLogWriteTargets
if sharedNodeConfig != nil && sharedNodeConfig.GlobalServerConfig != nil {
writeTargets = sharedNodeConfig.GlobalServerConfig.HTTPAccessLog.WriteTargets
if nodeConfig() != nil && nodeConfig().GlobalServerConfig != nil {
writeTargets = nodeConfig().GlobalServerConfig.HTTPAccessLog.WriteTargets
}
needWriteFile := writeTargets == nil || writeTargets.NeedWriteFile()
needReportAPI := writeTargets == nil || writeTargets.NeedReportToAPI()
@@ -104,8 +104,8 @@ Loop:
// 落盘 JSON LinesFluent Bit 采集 → ClickHouse
if needWriteFile {
var clusterId int64
if sharedNodeConfig != nil {
clusterId = sharedNodeConfig.GroupId
if nodeConfig() != nil {
clusterId = nodeConfig().GroupId
}
accesslogs.SharedFileWriter().WriteBatch(accessLogs, clusterId)
}

View File

@@ -63,6 +63,11 @@ func (this *HTTPAccessLogViewer) Start() error {
var connId = this.nextConnId()
this.connMap[connId] = conn
go func() {
defer func() {
if r := recover(); r != nil {
remotelogs.Error("ACCESS_LOG", fmt.Sprintf("goroutine panic: %v", r))
}
}()
this.startReading(conn, connId)
}()
this.locker.Unlock()

View File

@@ -275,9 +275,9 @@ func (this *HTTPCacheTaskManager) simplifyErr(err error) error {
func (this *HTTPCacheTaskManager) httpClient() *http.Client {
var timeout = serverconfigs.DefaultHTTPCachePolicyFetchTimeout
var nodeConfig = sharedNodeConfig // copy
if nodeConfig != nil {
var cachePolicies = nodeConfig.HTTPCachePolicies // copy
var cfg = nodeConfig()
if cfg != nil {
var cachePolicies = cfg.HTTPCachePolicies // copy
if len(cachePolicies) > 0 && cachePolicies[0].FetchTimeout != nil && cachePolicies[0].FetchTimeout.Count > 0 {
var fetchTimeout = cachePolicies[0].FetchTimeout.Duration()
if fetchTimeout > 0 {

View File

@@ -144,7 +144,7 @@ func (this *HTTPClientPool) Client(req *HTTPRequest,
maxConnections *= 8
idleConns *= 8
idleTimeout *= 4
} else if sharedNodeConfig != nil && sharedNodeConfig.Level > 1 {
} else if nodeConfig() != nil && nodeConfig().Level > 1 {
// Ln节点可以适当增加连接数
maxConnections *= 2
idleConns *= 2

View File

@@ -21,7 +21,6 @@ import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/metrics"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/stats"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/utils/bytepool"
@@ -168,12 +167,6 @@ func (this *HTTPRequest) Do() {
return
}
// 调试JS 请求是否命中加密配置
if strings.HasSuffix(strings.ToLower(this.Path()), ".js") {
encryptionOn := this.web.Encryption != nil && this.web.Encryption.IsOn && this.web.Encryption.IsEnabled()
remotelogs.Println("HTTP_REQUEST_ENCRYPTION", fmt.Sprintf("JS request matched - URL: %s, Host: %s, ServerID: %d, encryptionOn=%v", this.URL(), this.ReqHost, this.ReqServer.Id, encryptionOn))
}
// 是否为低级别节点
this.isLnRequest = this.checkLnRequest()

View File

@@ -160,14 +160,14 @@ func (this *HTTPRequest) doCC() (block bool) {
var targetURL = this.RawReq.URL.Query().Get("url")
var realURLKey = stringutil.Md5(sharedNodeConfig.Secret + "@" + targetURL + "@" + remoteAddr)
var realURLKey = stringutil.Md5(nodeConfig().Secret + "@" + targetURL + "@" + remoteAddr)
if urlKey != realURLKey {
this.ccForbid(2)
return true
}
// 校验时间
if timestampKey != stringutil.Md5(sharedNodeConfig.Secret+"@"+timestamp) {
if timestampKey != stringutil.Md5(nodeConfig().Secret+"@"+timestamp) {
this.ccForbid(3)
return true
}
@@ -196,8 +196,8 @@ func (this *HTTPRequest) doCC() (block bool) {
return true
}
var urlKey = stringutil.Md5(sharedNodeConfig.Secret + "@" + this.URL() + "@" + remoteAddr)
var timestampKey = stringutil.Md5(sharedNodeConfig.Secret + "@" + types.String(currentTime))
var urlKey = stringutil.Md5(nodeConfig().Secret + "@" + this.URL() + "@" + remoteAddr)
var timestampKey = stringutil.Md5(nodeConfig().Secret + "@" + types.String(currentTime))
// 跳转到验证URL
this.DisableStat()

View File

@@ -25,28 +25,23 @@ func (this *HTTPRequest) processPageEncryption(resp *http.Response) error {
// 首先检查是否是 /waf/loader.js如果是则直接跳过不应该被加密
// 这个检查必须在所有其他检查之前,确保 loader.js 永远不会被加密
if strings.Contains(this.URL(), "/waf/loader.js") {
remotelogs.Debug("HTTP_REQUEST_ENCRYPTION", "skipping /waf/loader.js, should not be encrypted, URL: "+this.URL())
return nil
}
if this.web.Encryption == nil {
remotelogs.Debug("HTTP_REQUEST_ENCRYPTION", "encryption config is nil for URL: "+this.URL())
return nil
}
if !this.web.Encryption.IsOn {
remotelogs.Debug("HTTP_REQUEST_ENCRYPTION", "encryption switch is off for URL: "+this.URL())
return nil
}
if !this.web.Encryption.IsEnabled() {
remotelogs.Debug("HTTP_REQUEST_ENCRYPTION", "encryption is not enabled for URL: "+this.URL())
return nil
}
// 检查 URL 白名单
if this.web.Encryption.MatchExcludeURL(this.URL()) {
remotelogs.Debug("HTTP_REQUEST_ENCRYPTION", "URL is in exclude list: "+this.URL())
return nil
}
@@ -66,7 +61,6 @@ func (this *HTTPRequest) processPageEncryption(resp *http.Response) error {
strings.Contains(urlLower, ".js&")
if !isHTML && !isJavaScript {
remotelogs.Debug("HTTP_REQUEST_ENCRYPTION", "content type not match, URL: "+this.URL()+", Content-Type: "+contentType)
return nil
}
@@ -101,47 +95,37 @@ func (this *HTTPRequest) processPageEncryption(resp *http.Response) error {
// 处理 JavaScript 文件
if isJavaScript {
remotelogs.Debug("HTTP_REQUEST_ENCRYPTION", "processing JavaScript file, URL: "+this.URL())
// 检查是否需要加密独立的 JavaScript 文件
if this.web.Encryption.Javascript == nil {
remotelogs.Debug("HTTP_REQUEST_ENCRYPTION", "Javascript config is nil for URL: "+this.URL())
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
return nil
}
if !this.web.Encryption.Javascript.IsOn {
remotelogs.Debug("HTTP_REQUEST_ENCRYPTION", "Javascript encryption is not enabled for URL: "+this.URL())
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
return nil
}
// 检查 URL 匹配
if !this.web.Encryption.Javascript.MatchURL(this.URL()) {
remotelogs.Debug("HTTP_REQUEST_ENCRYPTION", "URL does not match patterns for URL: "+this.URL())
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
return nil
}
// 跳过 Loader 文件(必须排除,否则 loader.js 会被错误加密)
// 跳过 Loader 文件
if strings.Contains(this.URL(), "/waf/loader.js") ||
strings.Contains(this.URL(), "waf-loader.js") ||
strings.Contains(this.URL(), "__WAF_") {
remotelogs.Debug("HTTP_REQUEST_ENCRYPTION", "skipping loader file, URL: "+this.URL())
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
return nil
}
// 加密 JavaScript 文件
remotelogs.Println("HTTP_REQUEST_ENCRYPTION", "encrypting JavaScript file, URL: "+this.URL())
encryptedBytes, err = this.encryptJavaScriptFile(bodyBytes, resp)
if err != nil {
remotelogs.Warn("HTTP_REQUEST_ENCRYPTION", "encrypt JavaScript file failed: "+err.Error())
// 加密失败,恢复原始响应体
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
return nil
}
remotelogs.Println("HTTP_REQUEST_ENCRYPTION", "JavaScript file encrypted successfully, URL: "+this.URL())
} else if isHTML {
// 加密 HTML 内容
encryptedBytes, err = this.encryptHTMLScripts(bodyBytes, resp)

View File

@@ -34,9 +34,9 @@ func (this *HTTPRequest) doMismatch() {
}
// 根据配置进行相应的处理
var nodeConfig = sharedNodeConfig // copy
if nodeConfig != nil {
var globalServerConfig = nodeConfig.GlobalServerConfig
var cfg = nodeConfig()
if cfg != nil {
var globalServerConfig = cfg.GlobalServerConfig
if globalServerConfig != nil && globalServerConfig.HTTPAll.MatchDomainStrictly {
var statusCode = 404
var httpAllConfig = globalServerConfig.HTTPAll

View File

@@ -24,7 +24,7 @@ func (this *HTTPRequest) doPlanBefore() (blocked bool) {
// check max upload size
if this.RawReq.ContentLength > 0 {
var plan = sharedNodeConfig.FindPlan(this.ReqServer.UserPlan.PlanId)
var plan = nodeConfig().FindPlan(this.ReqServer.UserPlan.PlanId)
if plan != nil && plan.MaxUploadSize != nil && plan.MaxUploadSize.Count > 0 {
if this.RawReq.ContentLength > plan.MaxUploadSize.Bytes() {
this.writeCode(http.StatusRequestEntityTooLarge, "Reached max upload size limitation in your plan.", "触发套餐中最大文件上传尺寸限制。")

View File

@@ -29,7 +29,7 @@ func init() {
if sharedUAMManager != nil {
return
}
manager, _ := uam.NewManager(sharedNodeConfig.NodeId, sharedNodeConfig.Secret)
manager, _ := uam.NewManager(nodeConfig().NodeId, nodeConfig().Secret)
if manager != nil {
sharedUAMManager = manager
}
@@ -39,7 +39,7 @@ func init() {
if sharedUAMManager != nil {
return
}
manager, _ := uam.NewManager(sharedNodeConfig.NodeId, sharedNodeConfig.Secret)
manager, _ := uam.NewManager(nodeConfig().NodeId, nodeConfig().Secret)
if manager != nil {
sharedUAMManager = manager
}

View File

@@ -4,8 +4,10 @@ import (
"bufio"
"bytes"
"errors"
"fmt"
"github.com/TeaOSLab/EdgeNode/internal/utils/bytepool"
"io"
"log"
"net/http"
"net/url"
)
@@ -126,6 +128,12 @@ func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shou
}()
go func() {
defer func() {
if r := recover(); r != nil {
log.Println(fmt.Sprintf("[WEBSOCKET]goroutine panic: %v", r))
}
}()
// 读取第一个响应
var respReader = NewWebsocketResponseReader(originConn)
resp, respErr := http.ReadResponse(bufio.NewReader(respReader), this.RawReq)

View File

@@ -93,8 +93,8 @@ func (this *BaseListener) matchSSL(domains []string) (*sslconfigs.SSLPolicy, *tl
}
var globalServerConfig *serverconfigs.GlobalServerConfig
if sharedNodeConfig != nil {
globalServerConfig = sharedNodeConfig.GlobalServerConfig
if nodeConfig() != nil {
globalServerConfig = nodeConfig().GlobalServerConfig
}
// 如果域名为空,则取第一个
@@ -191,7 +191,7 @@ func (this *BaseListener) findNamedServer(name string, exactly bool) (serverConf
return
}
var globalServerConfig = sharedNodeConfig.GlobalServerConfig
var globalServerConfig = nodeConfig().GlobalServerConfig
var matchDomainStrictly = globalServerConfig != nil && globalServerConfig.HTTPAll.MatchDomainStrictly
if globalServerConfig != nil &&
@@ -241,7 +241,7 @@ func (this *BaseListener) findNamedServerMatched(name string) (serverConfig *ser
}
// 是否严格匹配域名
var matchDomainStrictly = sharedNodeConfig.GlobalServerConfig != nil && sharedNodeConfig.GlobalServerConfig.HTTPAll.MatchDomainStrictly
var matchDomainStrictly = nodeConfig().GlobalServerConfig != nil && nodeConfig().GlobalServerConfig.HTTPAll.MatchDomainStrictly
// 如果只有一个server则默认为这个
var currentServers = group.Servers()
@@ -270,7 +270,7 @@ func (this *BaseListener) helloServerNames(clientInfo *tls.ClientHelloInfo) (ser
}
}
serverNames = append(serverNames, sharedNodeConfig.IPAddresses...)
serverNames = append(serverNames, nodeConfig().IPAddresses...)
return
}

View File

@@ -12,7 +12,7 @@ import (
)
func TestBaseListener_FindServer(t *testing.T) {
sharedNodeConfig = &nodeconfigs.NodeConfig{}
sharedNodeConfig.Store(&nodeconfigs.NodeConfig{})
var listener = &BaseListener{}
listener.Group = serverconfigs.NewServerAddressGroup("https://*:443")

View File

@@ -121,7 +121,7 @@ func (this *HTTPListener) ServeHTTPWithAddr(rawWriter http.ResponseWriter, rawRe
return
}
var globalServerConfig = sharedNodeConfig.GlobalServerConfig
var globalServerConfig = nodeConfig().GlobalServerConfig
if globalServerConfig != nil && !globalServerConfig.HTTPAll.SupportsLowVersionHTTP && (rawReq.ProtoMajor < 1 /** 0.x **/ || (rawReq.ProtoMajor == 1 && rawReq.ProtoMinor == 0 /** 1.0 **/)) {
http.Error(rawWriter, rawReq.Proto+" request is not supported.", http.StatusHTTPVersionNotSupported)
time.Sleep(1 * time.Second) // make connection slow down
@@ -223,7 +223,7 @@ func (this *HTTPListener) ServeHTTPWithAddr(rawWriter http.ResponseWriter, rawRe
IsHTTPS: this.isHTTPS,
IsHTTP3: this.isHTTP3,
nodeConfig: sharedNodeConfig,
nodeConfig: nodeConfig(),
}
req.Do()
@@ -259,8 +259,8 @@ func (this *HTTPListener) emptyServer() *serverconfigs.ServerConfig {
}
// 检查是否开启访问日志
if sharedNodeConfig != nil {
var globalServerConfig = sharedNodeConfig.GlobalServerConfig
if nodeConfig() != nil {
var globalServerConfig = nodeConfig().GlobalServerConfig
if globalServerConfig != nil && globalServerConfig.HTTPAccessLog.EnableServerNotFound {
var accessLogRef = serverconfigs.NewHTTPAccessLogRef()
accessLogRef.IsOn = true

View File

@@ -237,7 +237,7 @@ func (this *ListenerManager) findProcessNameWithPort(isUdp bool, port string) st
}
func (this *ListenerManager) addToFirewalld(groupAddrs []string) {
if !sharedNodeConfig.AutoOpenPorts {
if !nodeConfig().AutoOpenPorts {
return
}
@@ -246,7 +246,7 @@ func (this *ListenerManager) addToFirewalld(groupAddrs []string) {
}
// HTTP/3相关端口
var http3Ports = sharedNodeConfig.FindHTTP3Ports()
var http3Ports = nodeConfig().FindHTTP3Ports()
if len(http3Ports) > 0 {
for _, port := range http3Ports {
var groupAddr = "udp://:" + types.String(port)
@@ -347,12 +347,12 @@ func (this *ListenerManager) reloadFirewalld() {
this.locker.Lock()
defer this.locker.Unlock()
var nodeConfig = sharedNodeConfig
var cfg = nodeConfig()
// 所有的新地址
var groupAddrs = []string{}
var availableServerGroups = nodeConfig.AvailableGroups()
if !nodeConfig.IsOn {
var availableServerGroups = cfg.AvailableGroups()
if !cfg.IsOn {
availableServerGroups = []*serverconfigs.ServerAddressGroup{}
}

View File

@@ -49,11 +49,17 @@ import (
"sort"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
)
var sharedNodeConfig *nodeconfigs.NodeConfig
var sharedNodeConfig atomic.Pointer[nodeconfigs.NodeConfig]
// nodeConfig 返回当前节点配置(并发安全)
func nodeConfig() *nodeconfigs.NodeConfig {
return sharedNodeConfig.Load()
}
var nodeTaskNotify = make(chan bool, 8)
var nodeConfigChangedNotify = make(chan bool, 8)
var nodeConfigUpdatedAt int64
@@ -63,7 +69,7 @@ var nodeInstance *Node
// Node 节点
type Node struct {
isLoaded bool
isLoaded atomic.Bool
sock *gosock.Sock
locker sync.Mutex
@@ -200,11 +206,13 @@ func (this *Node) Start() {
remotelogs.ServerError(serverErr.Id, "NODE", serverErr.Message, nodeconfigs.NodeLogTypeServerConfigInitFailed, maps.Map{})
}
}
sharedNodeConfig = nodeConfig
sharedNodeConfig.Store(nodeConfig)
this.onReload(nodeConfig, true)
// 调整系统参数
go this.tuneSystemParameters()
goman.New(func() {
this.tuneSystemParameters()
})
// 发送事件
events.Notify(events.EventLoaded)
@@ -406,7 +414,7 @@ func (this *Node) syncConfig(taskVersion int64) error {
}
// 刷新配置
if this.isLoaded {
if this.isLoaded.Load() {
remotelogs.Println("NODE", "reloading node config ...")
} else {
remotelogs.Println("NODE", "loading node config ...")
@@ -417,11 +425,11 @@ func (this *Node) syncConfig(taskVersion int64) error {
// 发送事件
events.Notify(events.EventReload)
if this.isLoaded {
if this.isLoaded.Load() {
return sharedListenerManager.Start(nodeConfig)
}
this.isLoaded = true
this.isLoaded.Store(true)
// 预创建本地日志目录与空文件,便于 Fluent Bit 立即 tail无需等首条访问日志
_ = accesslogs.SharedFileWriter().EnsureInit()
@@ -885,7 +893,7 @@ func (this *Node) listenSock() error {
// 重载配置调用
func (this *Node) onReload(config *nodeconfigs.NodeConfig, reloadAll bool) {
nodeconfigs.ResetNodeConfig(config)
sharedNodeConfig = config
sharedNodeConfig.Store(config)
var accessLogFilePath string
if config != nil && config.GlobalServerConfig != nil {
@@ -1069,7 +1077,7 @@ func (this *Node) reloadServer() {
if countUpdatingServers > 0 {
var updatingServerMap = this.updatingServerMap
this.updatingServerMap = map[int64]*serverconfigs.ServerConfig{}
newNodeConfig, err := nodeconfigs.CloneNodeConfig(sharedNodeConfig)
newNodeConfig, err := nodeconfigs.CloneNodeConfig(nodeConfig())
if err != nil {
remotelogs.Error("NODE", "apply server config error: "+err.Error())
return
@@ -1121,7 +1129,7 @@ func (this *Node) tuneSystemParameters() {
return
}
if sharedNodeConfig == nil || !sharedNodeConfig.AutoSystemTuning {
if nodeConfig() == nil || !nodeConfig().AutoSystemTuning {
return
}

View File

@@ -51,14 +51,14 @@ func (this *Node) reloadCommonScripts() error {
return err
}
if len(configsResp.ScriptConfigsJSON) == 0 {
sharedNodeConfig.CommonScripts = nil
nodeConfig().CommonScripts = nil
} else {
var configs = []*serverconfigs.CommonScript{}
err = json.Unmarshal(configsResp.ScriptConfigsJSON, &configs)
if err != nil {
return fmt.Errorf("decode script configs failed: %w", err)
}
sharedNodeConfig.CommonScripts = configs
nodeConfig().CommonScripts = configs
}
// 通知更新
@@ -71,13 +71,20 @@ func (this *Node) reloadCommonScripts() error {
}
func (this *Node) reloadIPLibrary() {
if sharedNodeConfig.Edition == lastEdition {
var cfg = nodeConfig()
if cfg.Edition == lastEdition {
return
}
go func() {
defer func() {
if r := recover(); r != nil {
remotelogs.Error("IP_LIBRARY", fmt.Sprintf("goroutine panic: %v", r))
}
}()
var err error
lastEdition = sharedNodeConfig.Edition
lastEdition = cfg.Edition
if len(lastEdition) > 0 && (lists.ContainsString([]string{"pro", "ent", "max", "ultra"}, lastEdition)) {
err = iplib.InitPlus()
} else {
@@ -100,11 +107,12 @@ func (this *Node) notifyPlusChange() error {
return err
}
var isChanged = resp.Edition != sharedNodeConfig.Edition
var cfg = nodeConfig()
var isChanged = resp.Edition != cfg.Edition
if resp.IsPlus {
sharedNodeConfig.Edition = resp.Edition
cfg.Edition = resp.Edition
} else {
sharedNodeConfig.Edition = ""
cfg.Edition = ""
}
if isChanged {

View File

@@ -70,7 +70,8 @@ func (this *NodeStatusExecutor) Listen() {
}
func (this *NodeStatusExecutor) update() {
if sharedNodeConfig == nil {
var cfg = nodeConfig()
if cfg == nil {
return
}
@@ -83,7 +84,7 @@ func (this *NodeStatusExecutor) update() {
status.OS = runtime.GOOS
status.Arch = runtime.GOARCH
status.ExePath, _ = os.Executable()
status.ConfigVersion = sharedNodeConfig.Version
status.ConfigVersion = cfg.Version
status.IsActive = true
status.ConnectionCount = sharedListenerManager.TotalActiveConnections()
status.CacheTotalDiskSize = caches.SharedManager.TotalDiskSize()

View File

@@ -156,8 +156,9 @@ func (this *Node) execNodeLevelChangedTask(rpcClient *rpc.RPCClient) error {
return err
}
if sharedNodeConfig != nil {
sharedNodeConfig.Level = levelInfoResp.Level
var cfg = nodeConfig()
if cfg != nil {
cfg.Level = levelInfoResp.Level
}
var parentNodes = map[int64][]*nodeconfigs.ParentNodeConfig{}
@@ -168,8 +169,8 @@ func (this *Node) execNodeLevelChangedTask(rpcClient *rpc.RPCClient) error {
}
}
if sharedNodeConfig != nil {
sharedNodeConfig.ParentNodes = parentNodes
if cfg != nil {
cfg.ParentNodes = parentNodes
}
return nil
@@ -181,9 +182,10 @@ func (this *Node) execDDoSProtectionChangedTask(rpcClient *rpc.RPCClient) error
if err != nil {
return err
}
var cfg = nodeConfig()
if len(resp.DdosProtectionJSON) == 0 {
if sharedNodeConfig != nil {
sharedNodeConfig.DDoSProtection = nil
if cfg != nil {
cfg.DDoSProtection = nil
}
return nil
}
@@ -194,8 +196,8 @@ func (this *Node) execDDoSProtectionChangedTask(rpcClient *rpc.RPCClient) error
return fmt.Errorf("decode DDoS protection config failed: %w", err)
}
if ddosProtectionConfig != nil && sharedNodeConfig != nil {
sharedNodeConfig.DDoSProtection = ddosProtectionConfig
if ddosProtectionConfig != nil && cfg != nil {
cfg.DDoSProtection = ddosProtectionConfig
}
go func() {
@@ -227,8 +229,9 @@ func (this *Node) execGlobalServerConfigChangedTask(rpcClient *rpc.RPCClient) er
if err != nil {
return fmt.Errorf("validate global server config failed: %w", err)
}
if sharedNodeConfig != nil {
sharedNodeConfig.GlobalServerConfig = globalServerConfig
var cfg = nodeConfig()
if cfg != nil {
cfg.GlobalServerConfig = globalServerConfig
}
}
}
@@ -258,7 +261,7 @@ func (this *Node) execUserServersStateChangedTask(rpcClient *rpc.RPCClient, task
// 更新一组服务列表
func (this *Node) execUpdatingServersTask(rpcClient *rpc.RPCClient) error {
if this.lastUpdatingServerListId <= 0 {
this.lastUpdatingServerListId = sharedNodeConfig.UpdatingServerListId
this.lastUpdatingServerListId = nodeConfig().UpdatingServerListId
}
resp, err := rpcClient.UpdatingServerListRPC.FindUpdatingServerLists(rpcClient.Context(), &pb.FindUpdatingServerListsRequest{LastId: this.lastUpdatingServerListId})
@@ -353,7 +356,7 @@ func (this *Node) execWebPPolicyChangedTask(rpcClient *rpc.RPCClient) error {
webPPolicyMap[policy.NodeClusterId] = webPPolicy
}
}
sharedNodeConfig.UpdateWebPImagePolicies(webPPolicyMap)
nodeConfig().UpdateWebPImagePolicies(webPPolicyMap)
return nil
}

View File

@@ -47,7 +47,7 @@ func (this *Node) execUAMPolicyChangedTask(rpcClient *rpc.RPCClient) error {
uamPolicyMap[policy.NodeClusterId] = uamPolicy
}
}
sharedNodeConfig.UpdateUAMPolicies(uamPolicyMap)
nodeConfig().UpdateUAMPolicies(uamPolicyMap)
return nil
}
@@ -75,7 +75,7 @@ func (this *Node) execHTTPCCPolicyChangedTask(rpcClient *rpc.RPCClient) error {
httpCCPolicyMap[policy.NodeClusterId] = httpCCPolicy
}
}
sharedNodeConfig.UpdateHTTPCCPolicies(httpCCPolicyMap)
nodeConfig().UpdateHTTPCCPolicies(httpCCPolicyMap)
return nil
}
@@ -104,7 +104,7 @@ func (this *Node) execHTTP3PolicyChangedTask(rpcClient *rpc.RPCClient) error {
}
}
sharedNodeConfig.UpdateHTTP3Policies(http3PolicyMap)
nodeConfig().UpdateHTTP3Policies(http3PolicyMap)
// 加入端口到防火墙
sharedListenerManager.reloadFirewalld()
@@ -143,7 +143,7 @@ func (this *Node) execHTTPPagesPolicyChangedTask(rpcClient *rpc.RPCClient) error
httpPagesPolicyMap[policy.NodeClusterId] = httpPagesPolicy
}
}
sharedNodeConfig.UpdateHTTPPagesPolicies(httpPagesPolicyMap)
nodeConfig().UpdateHTTPPagesPolicies(httpPagesPolicyMap)
return nil
}
@@ -191,7 +191,7 @@ func (this *Node) execPlanChangedTask(rpcClient *rpc.RPCClient) error {
}
}
sharedNodeConfig.UpdatePlans(planMap)
nodeConfig().UpdatePlans(planMap)
sharedPlanBandwidthLimiter.UpdatePlans(planMap)
return nil

View File

@@ -72,8 +72,8 @@ func (this *OriginStateManager) Stop() {
// Loop 单次循环检查
func (this *OriginStateManager) Loop() error {
var nodeConfig = sharedNodeConfig // 复制
if nodeConfig == nil {
var cfg = nodeConfig() // 复制
if cfg == nil {
return nil
}
@@ -84,7 +84,7 @@ func (this *OriginStateManager) Loop() error {
this.locker.Lock()
for originId, state := range this.stateMap {
// 检查Origin是否正在使用
var originConfig = nodeConfig.FindOrigin(originId)
var originConfig = cfg.FindOrigin(originId)
if originConfig == nil || !originConfig.IsOn {
delete(this.stateMap, originId)
continue

View File

@@ -18,10 +18,11 @@ func init() {
}
events.On(events.EventLoaded, func() {
if sharedNodeConfig == nil {
var cfg = nodeConfig()
if cfg == nil {
return
}
sharedPlanBandwidthLimiter.UpdatePlans(sharedNodeConfig.FindAllPlans())
sharedPlanBandwidthLimiter.UpdatePlans(cfg.FindAllPlans())
})
}

View File

@@ -43,8 +43,9 @@ func init() {
}
events.On(events.EventReload, func() {
if sharedNodeConfig != nil {
var scripts = sharedNodeConfig.CommonScripts // 拷贝为了安全操作
var cfg = nodeConfig()
if cfg != nil {
var scripts = cfg.CommonScripts // 拷贝为了安全操作
if js.IsSameCommonScripts(scripts) {
if SharedJSPool == nil {
createPool()
@@ -60,8 +61,9 @@ func init() {
goman.New(func() {
for range commonScriptsChangesChan {
if sharedNodeConfig != nil {
var scripts = sharedNodeConfig.CommonScripts // 拷贝为了安全操作
var cfg = nodeConfig()
if cfg != nil {
var scripts = cfg.CommonScripts // 拷贝为了安全操作
if js.IsSameCommonScripts(scripts) {
if SharedJSPool == nil {
createPool()

View File

@@ -43,15 +43,16 @@ func NewSystemServiceManager() *SystemServiceManager {
}
func (this *SystemServiceManager) Setup() error {
if sharedNodeConfig == nil || !sharedNodeConfig.IsOn {
var cfg = nodeConfig()
if cfg == nil || !cfg.IsOn {
return nil
}
if len(sharedNodeConfig.SystemServices) == 0 {
if len(cfg.SystemServices) == 0 {
return nil
}
systemdParams, ok := sharedNodeConfig.SystemServices[nodeconfigs.SystemServiceTypeSystemd]
systemdParams, ok := cfg.SystemServices[nodeconfigs.SystemServiceTypeSystemd]
if ok {
err := this.setupSystemd(systemdParams)
if err != nil {

View File

@@ -21,7 +21,7 @@ func init() {
}
events.On(events.EventLoaded, func() {
sharedOCSPTask.version = sharedNodeConfig.OCSPVersion
sharedOCSPTask.version = nodeConfig().OCSPVersion
goman.New(func() {
sharedOCSPTask.Start()
@@ -76,10 +76,11 @@ func (this *OCSPUpdateTask) Loop() error {
return err
}
var cfg = nodeConfig()
for _, ocsp := range resp.SslCertOCSP {
// 更新OCSP
if sharedNodeConfig != nil {
sharedNodeConfig.UpdateCertOCSP(ocsp.SslCertId, ocsp.Data, ocsp.ExpiresAt)
if cfg != nil {
cfg.UpdateCertOCSP(ocsp.SslCertId, ocsp.Data, ocsp.ExpiresAt)
}
// 修改版本

View File

@@ -62,7 +62,8 @@ func (this *SyncAPINodesTask) Stop() {
func (this *SyncAPINodesTask) Loop() error {
// 如果有节点定制的API节点地址
var hasCustomizedAPINodeAddrs = sharedNodeConfig != nil && len(sharedNodeConfig.APINodeAddrs) > 0
var cfg = nodeConfig()
var hasCustomizedAPINodeAddrs = cfg != nil && len(cfg.APINodeAddrs) > 0
config, err := configs.LoadAPIConfig()
if err != nil {

View File

@@ -47,11 +47,11 @@ func (this *TrimDisksTask) loop() error {
return nil
}
var nodeConfig = sharedNodeConfig
if nodeConfig == nil {
var cfg = nodeConfig()
if cfg == nil {
return nil
}
if !nodeConfig.AutoTrimDisks {
if !cfg.AutoTrimDisks {
return nil
}

View File

@@ -27,7 +27,7 @@ func init() {
}
events.On(events.EventLoaded, func() {
err := sharedTOAManager.Apply(sharedNodeConfig.TOA)
err := sharedTOAManager.Apply(nodeConfig().TOA)
if err != nil {
remotelogs.Error("TOA", err.Error())
}

View File

@@ -364,6 +364,9 @@ func (this *HTTPRequestStatManager) Upload() error {
// 城市
for k, stat := range cityMap {
var pieces = strings.SplitN(k, "@", 4)
if len(pieces) < 4 {
continue
}
var serverId = types.Int64(pieces[0])
pbCities = append(pbCities, &pb.UploadServerHTTPRequestStatRequest_RegionCity{
ServerId: serverId,
@@ -397,6 +400,9 @@ func (this *HTTPRequestStatManager) Upload() error {
// 运营商
for k, count := range providerMap {
var pieces = strings.SplitN(k, "@", 2)
if len(pieces) < 2 {
continue
}
var serverId = types.Int64(pieces[0])
pbProviders = append(pbProviders, &pb.UploadServerHTTPRequestStatRequest_RegionProvider{
ServerId: serverId,
@@ -425,6 +431,9 @@ func (this *HTTPRequestStatManager) Upload() error {
// 操作系统
for k, count := range systemMap {
var pieces = strings.SplitN(k, "@", 3)
if len(pieces) < 3 {
continue
}
var serverId = types.Int64(pieces[0])
pbSystems = append(pbSystems, &pb.UploadServerHTTPRequestStatRequest_System{
ServerId: serverId,
@@ -454,6 +463,9 @@ func (this *HTTPRequestStatManager) Upload() error {
// 浏览器
for k, count := range browserMap {
var pieces = strings.SplitN(k, "@", 3)
if len(pieces) < 3 {
continue
}
var serverId = types.Int64(pieces[0])
pbBrowsers = append(pbBrowsers, &pb.UploadServerHTTPRequestStatRequest_Browser{
ServerId: serverId,

View File

@@ -0,0 +1,85 @@
package agents
import (
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"os"
)
func MigrateSQLiteDBToKV(sqlitePath string) error {
_, err := os.Stat(sqlitePath)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
remotelogs.Println("AGENT_MANAGER", "migrating sqlite data to kvstore ...")
src := NewSQLiteDB(sqlitePath)
err = src.Init()
if err != nil {
return err
}
defer func() {
_ = src.Close()
}()
dst := NewKVDB()
err = dst.Init()
if err != nil {
return err
}
defer func() {
_ = dst.Close()
}()
err = dst.table.DB().Truncate()
if err != nil {
return err
}
var offset int64
const size int64 = 1000
for {
agentIPs, listErr := src.ListAgentIPs(offset, size)
if listErr != nil {
return listErr
}
if len(agentIPs) == 0 {
break
}
for _, agentIP := range agentIPs {
insertErr := dst.InsertAgentIP(agentIP.Id, agentIP.IP, agentIP.AgentCode)
if insertErr != nil {
return insertErr
}
}
offset += size
}
err = dst.Flush()
if err != nil {
return err
}
err = removeSQLiteDBFiles(sqlitePath)
if err != nil {
return err
}
remotelogs.Println("AGENT_MANAGER", "migrated sqlite data to kvstore")
return nil
}
func removeSQLiteDBFiles(path string) error {
for _, filename := range []string{path, path + "-shm", path + "-wal"} {
err := os.Remove(filename)
if err != nil && !os.IsNotExist(err) {
return err
}
}
return nil
}

View File

@@ -199,12 +199,20 @@ func (this *Manager) loadDB() error {
var sqlitePath = Tea.Root + "/data/agents.db"
_, sqliteErr := os.Stat(sqlitePath)
var db DB
if sqliteErr == nil || !teaconst.EnableKVCacheStore {
db = NewSQLiteDB(sqlitePath)
var err error
if sqliteErr == nil {
err = MigrateSQLiteDBToKV(sqlitePath)
if err != nil {
remotelogs.Error("AGENT_MANAGER", "migrate sqlite data failed: "+err.Error())
db = NewSQLiteDB(sqlitePath)
} else {
db = NewKVDB()
}
} else {
db = NewKVDB()
}
err := db.Init()
err = db.Init()
if err != nil {
return err
}

View File

@@ -73,6 +73,12 @@ func OpenFS(dir string, options *FSOptions) (*FS, error) {
func (this *FS) init() {
go func() {
defer func() {
if r := recover(); r != nil {
log.Println("[BFS]sync goroutine panic:", r)
}
}()
// sync in background
for range this.syncTicker.C {
this.syncLoop()
@@ -80,6 +86,12 @@ func (this *FS) init() {
}()
go func() {
defer func() {
if r := recover(); r != nil {
log.Println("[BFS]closing goroutine panic:", r)
}
}()
for {
this.processClosingBFiles()
}

File diff suppressed because one or more lines are too long

View File

View File

@@ -70,10 +70,10 @@ func (this *DB) Inspect(fn func(key []byte, value []byte)) error {
if valueErr != nil {
return valueErr
}
fn(it.Key(), value)
fn(append([]byte(nil), it.Key()...), append([]byte(nil), value...))
}
return nil
return it.Error()
}
// Truncate the database

View File

@@ -0,0 +1,10 @@
package kvstore
import "fmt"
func wrapRecoveredPanic(message string, panicErr any) error {
if resultErr, ok := panicErr.(error); ok {
return fmt.Errorf("%s: %w", message, resultErr)
}
return fmt.Errorf("%s: %v", message, panicErr)
}

View File

@@ -5,7 +5,6 @@ package kvstore
import (
"bytes"
"errors"
"fmt"
byteutils "github.com/TeaOSLab/EdgeNode/internal/utils/byte"
)
@@ -66,6 +65,9 @@ func (this *Query[T]) SetTable(table *Table[T]) *Query[T] {
func (this *Query[T]) SetTx(tx *Tx[T]) *Query[T] {
this.tx = tx
if tx != nil {
this.table = tx.table
}
return this
}
@@ -128,7 +130,12 @@ func (this *Query[T]) FieldPrefix(fieldName string, fieldPrefix string) *Query[T
}
func (this *Query[T]) FieldOffset(fieldOffset []byte) *Query[T] {
this.fieldOffsetKey = fieldOffset
if len(fieldOffset) == 0 {
this.fieldOffsetKey = nil
return this
}
this.fieldOffsetKey = append([]byte(nil), fieldOffset...)
return this
}
@@ -172,28 +179,12 @@ func (this *Query[T]) FindAll(fn IteratorFunc[T]) (err error) {
defer func() {
var panicErr = recover()
if panicErr != nil {
resultErr, ok := panicErr.(error)
if ok {
err = fmt.Errorf("execute query failed: %w", resultErr)
}
err = wrapRecoveredPanic("execute query failed", panicErr)
}
}()
if this.tx != nil {
defer func() {
_ = this.tx.Close()
}()
var itErr error
if len(this.fieldName) == 0 {
itErr = this.iterateKeys(fn)
} else {
itErr = this.iterateFields(fn)
}
if itErr != nil {
return itErr
}
return this.tx.Commit()
return this.findAllWithTx(this.tx, fn)
}
if this.table != nil {
@@ -205,19 +196,29 @@ func (this *Query[T]) FindAll(fn IteratorFunc[T]) (err error) {
}
return txFn(func(tx *Tx[T]) error {
this.tx = tx
if len(this.fieldName) == 0 {
return this.iterateKeys(fn)
}
return this.iterateFields(fn)
return this.findAllWithTx(tx, fn)
})
}
return errors.New("current query require 'table' or 'tx'")
}
func (this *Query[T]) iterateKeys(fn IteratorFunc[T]) error {
func (this *Query[T]) findAllWithTx(tx *Tx[T], fn IteratorFunc[T]) error {
if tx == nil {
return errors.New("current query require valid tx")
}
if this.table == nil {
this.table = tx.table
}
if len(this.fieldName) == 0 {
return this.iterateKeys(tx, fn)
}
return this.iterateFields(tx, fn)
}
func (this *Query[T]) iterateKeys(tx *Tx[T], fn IteratorFunc[T]) error {
if this.limit == 0 {
return nil
}
@@ -262,7 +263,7 @@ func (this *Query[T]) iterateKeys(fn IteratorFunc[T]) error {
var hasOffsetKey = len(this.offsetKey) > 0
it, itErr := this.tx.NewIterator(opt)
it, itErr := tx.NewIterator(opt)
if itErr != nil {
return itErr
}
@@ -297,7 +298,7 @@ func (this *Query[T]) iterateKeys(fn IteratorFunc[T]) error {
}
}
goNext, callbackErr := fn(this.tx, Item[T]{
goNext, callbackErr := fn(tx, Item[T]{
Key: string(keyBytes[prefixLen:]),
Value: value,
})
@@ -346,10 +347,10 @@ func (this *Query[T]) iterateKeys(fn IteratorFunc[T]) error {
}
}
return nil
return it.Error()
}
func (this *Query[T]) iterateFields(fn IteratorFunc[T]) error {
func (this *Query[T]) iterateFields(tx *Tx[T], fn IteratorFunc[T]) error {
if this.limit == 0 {
return nil
}
@@ -390,7 +391,7 @@ func (this *Query[T]) iterateFields(fn IteratorFunc[T]) error {
opt.UpperBound = byteutils.Append(prefix, 0xFF)
}
it, itErr := this.tx.NewIterator(opt)
it, itErr := tx.NewIterator(opt)
if itErr != nil {
return itErr
}
@@ -427,10 +428,10 @@ func (this *Query[T]) iterateFields(fn IteratorFunc[T]) error {
var resultItem = Item[T]{
Key: string(keyBytes),
FieldKey: fieldKeyBytes,
FieldKey: append([]byte(nil), fieldKeyBytes...),
}
if !this.keysOnly {
value, getErr := this.table.getWithKeyBytes(this.tx, this.table.FullKeyBytes(keyBytes))
value, getErr := this.table.getWithKeyBytes(tx, this.table.FullKeyBytes(keyBytes))
if getErr != nil {
if IsNotFound(getErr) {
return true, nil
@@ -441,7 +442,7 @@ func (this *Query[T]) iterateFields(fn IteratorFunc[T]) error {
resultItem.Value = value
}
goNextItem, err = fn(this.tx, resultItem)
goNextItem, err = fn(tx, resultItem)
if err != nil {
if IsSkipError(err) {
return true, nil
@@ -487,7 +488,7 @@ func (this *Query[T]) iterateFields(fn IteratorFunc[T]) error {
}
}
return nil
return it.Error()
}
func (this *Query[T]) matchOperators(fieldValueBytes []byte) bool {

View File

@@ -0,0 +1,225 @@
package kvstore_test
import (
"bytes"
"fmt"
"github.com/TeaOSLab/EdgeNode/internal/utils/kvstore"
"strings"
"testing"
"time"
)
func openIsolatedTable[T any](t *testing.T, tableName string, encoder kvstore.ValueEncoder[T]) *kvstore.Table[T] {
storeName := fmt.Sprintf("test-%d", time.Now().UnixNano())
store, err := kvstore.OpenStore(storeName)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
_ = store.Close()
_ = kvstore.RemoveStore(store.Path())
})
db, err := store.NewDB("db1")
if err != nil {
t.Fatal(err)
}
table, err := kvstore.NewTable[T](tableName, encoder)
if err != nil {
t.Fatal(err)
}
db.AddTable(table)
return table
}
func TestQuery_FieldKeyIsStableAcrossPaging(t *testing.T) {
table := openIsolatedTable[*testCachedItem](t, "cache_items", &testCacheItemEncoder[*testCachedItem]{})
err := table.AddFields("expiresAt")
if err != nil {
t.Fatal(err)
}
for i, key := range []string{"a1", "a2", "a3"} {
err = table.Set(key, &testCachedItem{
Hash: key,
URL: "https://example.com/" + key,
ExpiresAt: int64(i + 1),
})
if err != nil {
t.Fatal(err)
}
}
var firstFieldKey []byte
var firstFieldKeySnapshot []byte
var count int
err = table.Query().
FieldAsc("expiresAt").
Limit(2).
FindAll(func(tx *kvstore.Tx[*testCachedItem], item kvstore.Item[*testCachedItem]) (bool, error) {
switch count {
case 0:
firstFieldKey = item.FieldKey
firstFieldKeySnapshot = append([]byte(nil), item.FieldKey...)
case 1:
if !bytes.Equal(firstFieldKey, firstFieldKeySnapshot) {
t.Fatalf("field key mutated during iteration: got %q want %q", firstFieldKey, firstFieldKeySnapshot)
}
}
count++
return true, nil
})
if err != nil {
t.Fatal(err)
}
var keys []string
err = table.Query().
FieldAsc("expiresAt").
FieldOffset(firstFieldKey).
Limit(2).
FindAll(func(tx *kvstore.Tx[*testCachedItem], item kvstore.Item[*testCachedItem]) (bool, error) {
keys = append(keys, item.Key)
return true, nil
})
if err != nil {
t.Fatal(err)
}
if len(keys) != 2 || keys[0] != "a2" || keys[1] != "a3" {
t.Fatalf("unexpected paged keys: %v", keys)
}
}
func TestQuery_FindAll_StringPanicReturnsError(t *testing.T) {
table := openIsolatedTable[string](t, "users", kvstore.NewStringValueEncoder[string]())
err := table.Set("a1", "value-1")
if err != nil {
t.Fatal(err)
}
err = table.Query().
Limit(1).
FindAll(func(tx *kvstore.Tx[string], item kvstore.Item[string]) (bool, error) {
panic("boom")
})
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), "execute query failed: boom") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestQuery_ReusesFreshTableTransactionEachRun(t *testing.T) {
table := openIsolatedTable[string](t, "users", kvstore.NewStringValueEncoder[string]())
for _, key := range []string{"a1", "a2", "a3"} {
err := table.Set(key, "value-"+key)
if err != nil {
t.Fatal(err)
}
}
query := table.Query().Limit(1)
for i := 0; i < 2; i++ {
var keys []string
err := query.FindAll(func(tx *kvstore.Tx[string], item kvstore.Item[string]) (bool, error) {
keys = append(keys, item.Key)
return true, nil
})
if err != nil {
t.Fatal(err)
}
if len(keys) != 1 || keys[0] != "a1" {
t.Fatalf("unexpected result on run %d: %v", i, keys)
}
}
}
func TestQuery_UsesExistingTxWithoutClosingIt(t *testing.T) {
table := openIsolatedTable[string](t, "users", kvstore.NewStringValueEncoder[string]())
for _, key := range []string{"a1", "a2", "a3"} {
err := table.Set(key, "value-"+key)
if err != nil {
t.Fatal(err)
}
}
err := table.ReadTx(func(tx *kvstore.Tx[string]) error {
var keys []string
err := tx.Query().
Limit(2).
FindAll(func(queryTx *kvstore.Tx[string], item kvstore.Item[string]) (bool, error) {
if queryTx != tx {
return false, fmt.Errorf("query did not reuse the current tx")
}
keys = append(keys, item.Key)
return true, nil
})
if err != nil {
return err
}
if len(keys) != 2 {
return fmt.Errorf("unexpected query size: %d", len(keys))
}
_, err = tx.Get("a1")
return err
})
if err != nil {
t.Fatal(err)
}
}
func TestDB_InspectProvidesStableBuffers(t *testing.T) {
table := openIsolatedTable[string](t, "users", kvstore.NewStringValueEncoder[string]())
for _, pair := range []struct {
key string
value string
}{
{key: "a1", value: "value-1"},
{key: "a2", value: "value-2"},
} {
err := table.Set(pair.key, pair.value)
if err != nil {
t.Fatal(err)
}
}
var firstKey []byte
var firstValue []byte
var firstKeySnapshot []byte
var firstValueSnapshot []byte
var count int
err := table.DB().Inspect(func(key []byte, value []byte) {
switch count {
case 0:
firstKey = key
firstValue = value
firstKeySnapshot = append([]byte(nil), key...)
firstValueSnapshot = append([]byte(nil), value...)
case 1:
if !bytes.Equal(firstKey, firstKeySnapshot) {
t.Fatalf("inspect key mutated after next iteration: got %q want %q", firstKey, firstKeySnapshot)
}
if !bytes.Equal(firstValue, firstValueSnapshot) {
t.Fatalf("inspect value mutated after next iteration: got %q want %q", firstValue, firstValueSnapshot)
}
}
count++
})
if err != nil {
t.Fatal(err)
}
}

View File

@@ -116,11 +116,11 @@ func OpenStoreDir(dir string, storeName string) (*Store, error) {
}
var storeOnce = &sync.Once{}
var defaultSore *Store
var defaultStore *Store
func DefaultStore() (*Store, error) {
if defaultSore != nil {
return defaultSore, nil
if defaultStore != nil {
return defaultStore, nil
}
var resultErr error
@@ -137,10 +137,10 @@ func DefaultStore() (*Store, error) {
remotelogs.Error("KV", resultErr.Error())
return
}
defaultSore = store
defaultStore = store
})
return defaultSore, resultErr
return defaultStore, resultErr
}
func (this *Store) Path() string {

View File

@@ -298,7 +298,7 @@ func (this *Table[T]) Count() (int64, error) {
count++
}
return count, err
return count, it.Error()
}
func (this *Table[T]) FullKey(realKey string) []byte {
@@ -325,9 +325,14 @@ func (this *Table[T]) DecodeFieldKey(fieldName string, fieldKey []byte) (fieldVa
return
}
var fieldValueLen = binary.BigEndian.Uint16(fieldKey[l-2:])
var fieldValueLen = int(binary.BigEndian.Uint16(fieldKey[l-2:]))
var data = fieldKey[baseLen-4 : l-2]
if fieldValueLen+2 > len(data) {
err = errors.New("invalid field value length")
return
}
fieldValue = data[:fieldValueLen]
key = data[fieldValueLen+2: /** separator length **/]

View File

@@ -4,10 +4,13 @@ package kvstore
import (
"errors"
"fmt"
"github.com/cockroachdb/pebble"
)
var commitBatch = func(batch *pebble.Batch, opt *pebble.WriteOptions) error {
return batch.Commit(opt)
}
type Tx[T any] struct {
table *Table[T]
readOnly bool
@@ -129,12 +132,9 @@ func (this *Tx[T]) commit(opt *pebble.WriteOptions) (err error) {
defer func() {
var panicErr = recover()
if panicErr != nil {
resultErr, ok := panicErr.(error)
if ok {
err = fmt.Errorf("commit batch failed: %w", resultErr)
}
err = wrapRecoveredPanic("commit batch failed", panicErr)
}
}()
return this.batch.Commit(opt)
return commitBatch(this.batch, opt)
}

View File

@@ -0,0 +1,27 @@
package kvstore
import (
"strings"
"testing"
"github.com/cockroachdb/pebble"
)
func TestTx_commit_StringPanicReturnsError(t *testing.T) {
oldCommitBatch := commitBatch
commitBatch = func(batch *pebble.Batch, opt *pebble.WriteOptions) error {
panic("boom")
}
defer func() {
commitBatch = oldCommitBatch
}()
tx := &Tx[string]{}
err := tx.commit(DefaultWriteOptions)
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), "commit batch failed: boom") {
t.Fatalf("unexpected error: %v", err)
}
}

View File

@@ -814,6 +814,9 @@ func (this *Rule) ipToInt64(ip net.IP) int64 {
if len(ip) == 16 {
return int64(binary.BigEndian.Uint32(ip[12:16]))
}
if len(ip) < 4 {
return 0
}
return int64(binary.BigEndian.Uint32(ip))
}