Initial commit (code only without large binaries)

This commit is contained in:
robin
2026-02-15 18:58:44 +08:00
commit 35df75498f
9442 changed files with 1495866 additions and 0 deletions

View File

@@ -0,0 +1,309 @@
package apps
import (
"errors"
"fmt"
teaconst "github.com/TeaOSLab/EdgeUser/internal/const"
"github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
"github.com/iwind/gosock/pkg/gosock"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"
"time"
)
// AppCmd App命令帮助
type AppCmd struct {
product string
version string
usages []string
options []*CommandHelpOption
appendStrings []string
directives []*Directive
sock *gosock.Sock
}
func NewAppCmd() *AppCmd {
return &AppCmd{
sock: gosock.NewTmpSock(teaconst.ProcessName),
}
}
type CommandHelpOption struct {
Code string
Description string
}
// Product 产品
func (this *AppCmd) Product(product string) *AppCmd {
this.product = product
return this
}
// Version 版本
func (this *AppCmd) Version(version string) *AppCmd {
this.version = version
return this
}
// Usage 使用方法
func (this *AppCmd) Usage(usage string) *AppCmd {
this.usages = append(this.usages, usage)
return this
}
// Option 选项
func (this *AppCmd) Option(code string, description string) *AppCmd {
this.options = append(this.options, &CommandHelpOption{
Code: code,
Description: description,
})
return this
}
// Append 附加内容
func (this *AppCmd) Append(appendString string) *AppCmd {
this.appendStrings = append(this.appendStrings, appendString)
return this
}
// Print 打印
func (this *AppCmd) Print() {
fmt.Println(this.product + " v" + this.version)
fmt.Println("Usage:")
for _, usage := range this.usages {
fmt.Println(" " + usage)
}
if len(this.options) > 0 {
fmt.Println("")
fmt.Println("Options:")
spaces := 20
max := 40
for _, option := range this.options {
l := len(option.Code)
if l < max && l > spaces {
spaces = l + 4
}
}
for _, option := range this.options {
if len(option.Code) > max {
fmt.Println("")
fmt.Println(" " + option.Code)
option.Code = ""
}
fmt.Printf(" %-"+strconv.Itoa(spaces)+"s%s\n", option.Code, ": "+option.Description)
}
}
if len(this.appendStrings) > 0 {
fmt.Println("")
for _, s := range this.appendStrings {
fmt.Println(s)
}
}
}
// On 添加指令
func (this *AppCmd) On(arg string, callback func()) {
this.directives = append(this.directives, &Directive{
Arg: arg,
Callback: callback,
})
}
// Run 运行
func (this *AppCmd) Run(main func()) {
// 获取参数
args := os.Args[1:]
if len(args) > 0 {
switch args[0] {
case "-v", "version", "-version", "--version":
this.runVersion()
return
case "?", "help", "-help", "h", "-h":
this.runHelp()
return
case "start":
this.runStart()
return
case "stop":
this.runStop()
return
case "restart":
this.runRestart()
return
case "status":
this.runStatus()
return
}
// 查找指令
for _, directive := range this.directives {
if directive.Arg == args[0] {
directive.Callback()
return
}
}
fmt.Println("unknown command '" + args[0] + "'")
return
}
// 日志
writer := new(LogWriter)
writer.Init()
logs.SetWriter(writer)
// 运行主函数
main()
}
// 版本号
func (this *AppCmd) runVersion() {
fmt.Println(this.product+" v"+this.version, "(build: "+runtime.Version(), runtime.GOOS, runtime.GOARCH+")")
}
// 帮助
func (this *AppCmd) runHelp() {
this.Print()
}
// 启动
func (this *AppCmd) runStart() {
var pid = this.getPID()
if pid > 0 {
fmt.Println(this.product+" already started, pid:", pid)
return
}
var cmd = exec.Command(this.exe())
err := cmd.Start()
if err != nil {
fmt.Println(this.product+" start failed:", err.Error())
return
}
// create symbolic links
_ = this.createSymLinks()
fmt.Println(this.product+" started ok, pid:", cmd.Process.Pid)
}
// 停止
func (this *AppCmd) runStop() {
var pid = this.getPID()
if pid == 0 {
fmt.Println(this.product + " not started yet")
return
}
_, _ = this.sock.Send(&gosock.Command{Code: "stop"})
fmt.Println(this.product+" stopped ok, pid:", types.String(pid))
}
// 重启
func (this *AppCmd) runRestart() {
this.runStop()
time.Sleep(1 * time.Second)
this.runStart()
}
// RunRestart 重启
func (this *AppCmd) RunRestart() {
this.runStop()
time.Sleep(1 * time.Second)
this.runStart()
}
// 状态
func (this *AppCmd) runStatus() {
var pid = this.getPID()
if pid == 0 {
fmt.Println(this.product + " not started yet")
return
}
fmt.Println(this.product + " is running, pid: " + types.String(pid))
}
// 获取当前的PID
func (this *AppCmd) getPID() int {
if !this.sock.IsListening() {
return 0
}
reply, err := this.sock.Send(&gosock.Command{Code: "pid"})
if err != nil {
return 0
}
return maps.NewMap(reply.Params).GetInt("pid")
}
func (this *AppCmd) exe() string {
var exe, _ = os.Executable()
if len(exe) == 0 {
exe = os.Args[0]
}
return exe
}
// 创建软链接
func (this *AppCmd) createSymLinks() error {
if runtime.GOOS != "linux" {
return nil
}
var exe, _ = os.Executable()
if len(exe) == 0 {
return nil
}
var errorList = []string{}
// bin
{
var target = "/usr/bin/" + teaconst.ProcessName
old, _ := filepath.EvalSymlinks(target)
if old != exe {
_ = os.Remove(target)
err := os.Symlink(exe, target)
if err != nil {
errorList = append(errorList, err.Error())
}
}
}
// log
{
var realPath = filepath.Dir(filepath.Dir(exe)) + "/logs/run.log"
var target = "/var/log/" + teaconst.ProcessName + ".log"
old, _ := filepath.EvalSymlinks(target)
if old != realPath {
_ = os.Remove(target)
err := os.Symlink(realPath, target)
if err != nil {
errorList = append(errorList, err.Error())
}
}
}
if len(errorList) > 0 {
return errors.New(strings.Join(errorList, "\n"))
}
return nil
}

View File

@@ -0,0 +1,6 @@
package apps
type Directive struct {
Arg string
Callback func()
}

View File

@@ -0,0 +1,51 @@
package apps
import (
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/files"
"github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/utils/time"
"log"
)
type LogWriter struct {
fileAppender *files.Appender
}
func (this *LogWriter) Init() {
// 创建目录
dir := files.NewFile(Tea.LogDir())
if !dir.Exists() {
err := dir.Mkdir()
if err != nil {
log.Println("[error]" + err.Error())
}
}
logFile := files.NewFile(Tea.LogFile("run.log"))
// 打开要写入的日志文件
appender, err := logFile.Appender()
if err != nil {
logs.Error(err)
} else {
this.fileAppender = appender
}
}
func (this *LogWriter) Write(message string) {
log.Println(message)
if this.fileAppender != nil {
_, err := this.fileAppender.AppendString(timeutil.Format("Y/m/d H:i:s ") + message + "\n")
if err != nil {
log.Println("[error]" + err.Error())
}
}
}
func (this *LogWriter) Close() {
if this.fileAppender != nil {
_ = this.fileAppender.Close()
}
}

View File

@@ -0,0 +1,5 @@
package configloaders
import "sync"
var locker sync.Mutex

View File

@@ -0,0 +1,118 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package configloaders
import (
"encoding/json"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/systemconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/userconfigs"
"github.com/TeaOSLab/EdgeUser/internal/events"
"github.com/TeaOSLab/EdgeUser/internal/remotelogs"
"github.com/TeaOSLab/EdgeUser/internal/rpc"
"github.com/iwind/TeaGo/Tea"
"time"
)
var sharedRegisterConfig *userconfigs.UserRegisterConfig
func init() {
var ticker = time.NewTicker(1 * time.Minute)
if Tea.IsTesting() {
ticker = time.NewTicker(10 * time.Second)
}
events.On(events.EventStart, func() {
go func() {
for range ticker.C {
err := reloadRegisterConfig()
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("CONFIG_LOADER", "reload register config failed: "+err.Error())
} else {
remotelogs.Error("CONFIG_LOADER", "reload register config failed: "+err.Error())
}
}
}
}()
})
}
// LoadRegisterConfig 加载注册配置
func LoadRegisterConfig() (*userconfigs.UserRegisterConfig, error) {
locker.Lock()
if sharedRegisterConfig != nil {
locker.Unlock()
return sharedRegisterConfig, nil
}
locker.Unlock()
rpcClient, err := rpc.SharedRPC()
if err != nil {
return nil, err
}
resp, err := rpcClient.SysSettingRPC().ReadSysSetting(rpcClient.Context(0), &pb.ReadSysSettingRequest{Code: systemconfigs.SettingCodeUserRegisterConfig})
if err != nil {
return nil, err
}
var config = userconfigs.DefaultUserRegisterConfig()
if len(resp.ValueJSON) > 0 {
err = json.Unmarshal(resp.ValueJSON, config)
if err != nil {
return nil, err
}
locker.Lock()
sharedRegisterConfig = config
locker.Unlock()
}
return config, nil
}
func RequireVerification() bool {
locker.Lock()
defer locker.Unlock()
if sharedRegisterConfig == nil {
return false
}
return sharedRegisterConfig.RequireVerification
}
func RequireIdentity() bool {
locker.Lock()
defer locker.Unlock()
if sharedRegisterConfig == nil {
return false
}
return sharedRegisterConfig.RequireIdentity
}
// 刷新注册配置
func reloadRegisterConfig() error {
rpcClient, err := rpc.SharedRPC()
if err != nil {
return err
}
resp, err := rpcClient.SysSettingRPC().ReadSysSetting(rpcClient.Context(0), &pb.ReadSysSettingRequest{Code: systemconfigs.SettingCodeUserRegisterConfig})
if err != nil {
return err
}
var config = userconfigs.DefaultUserRegisterConfig()
if len(resp.ValueJSON) > 0 {
err = json.Unmarshal(resp.ValueJSON, config)
if err != nil {
return err
}
locker.Lock()
sharedRegisterConfig = config
locker.Unlock()
}
return nil
}

View File

@@ -0,0 +1,33 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package configloaders
import (
"encoding/json"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/systemconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/userconfigs"
"github.com/TeaOSLab/EdgeUser/internal/rpc"
)
// LoadServerConfig 加载服务配置
func LoadServerConfig() (*userconfigs.UserServerConfig, error) {
rpcClient, err := rpc.SharedRPC()
if err != nil {
return nil, err
}
resp, err := rpcClient.SysSettingRPC().ReadSysSetting(rpcClient.Context(0), &pb.ReadSysSettingRequest{Code: systemconfigs.SettingCodeUserServerConfig})
if err != nil {
return nil, err
}
var config = userconfigs.DefaultUserServerConfig()
if len(resp.ValueJSON) > 0 {
err = json.Unmarshal(resp.ValueJSON, config)
if err != nil {
return nil, err
}
}
return config, nil
}

View File

@@ -0,0 +1,115 @@
package configloaders
import (
"encoding/json"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/systemconfigs"
"github.com/TeaOSLab/EdgeUser/internal/rpc"
"github.com/iwind/TeaGo/logs"
"reflect"
"time"
)
var sharedUIConfig *systemconfigs.UserUIConfig = nil
func init() {
// 更新任务
// TODO 改成实时更新
var ticker = time.NewTicker(1 * time.Minute)
go func() {
for range ticker.C {
err := reloadUIConfig()
if err != nil {
logs.Println("[CONFIG_LOADERS]load ui config failed: " + err.Error())
}
}
}()
}
func LoadUIConfig() (*systemconfigs.UserUIConfig, error) {
locker.Lock()
defer locker.Unlock()
config, err := loadUIConfig()
if err != nil {
return nil, err
}
v := reflect.Indirect(reflect.ValueOf(config)).Interface().(systemconfigs.UserUIConfig)
return &v, nil
}
func loadUIConfig() (*systemconfigs.UserUIConfig, error) {
if sharedUIConfig != nil {
return sharedUIConfig, nil
}
var rpcClient, err = rpc.SharedRPC()
if err != nil {
return nil, err
}
resp, err := rpcClient.SysSettingRPC().ReadSysSetting(rpcClient.Context(0), &pb.ReadSysSettingRequest{
Code: systemconfigs.SettingCodeUserUIConfig,
})
if err != nil {
return nil, err
}
if len(resp.ValueJSON) == 0 {
sharedUIConfig = systemconfigs.NewUserUIConfig()
return sharedUIConfig, nil
}
var config = systemconfigs.NewUserUIConfig()
err = json.Unmarshal(resp.ValueJSON, config)
if err != nil {
logs.Println("[UI_MANAGER]" + err.Error())
sharedUIConfig = systemconfigs.NewUserUIConfig()
return sharedUIConfig, nil
}
sharedUIConfig = config
// 时区
updateTimeZone(config)
return sharedUIConfig, nil
}
func reloadUIConfig() error {
var rpcClient, err = rpc.SharedRPC()
if err != nil {
return err
}
resp, err := rpcClient.SysSettingRPC().ReadSysSetting(rpcClient.Context(0), &pb.ReadSysSettingRequest{
Code: systemconfigs.SettingCodeUserUIConfig,
})
if err != nil {
return err
}
if len(resp.ValueJSON) == 0 {
return nil
}
var config = systemconfigs.NewUserUIConfig()
err = json.Unmarshal(resp.ValueJSON, config)
if err != nil {
return err
}
var oldConfig = sharedUIConfig
sharedUIConfig = config
// 时区
if oldConfig == nil || oldConfig.TimeZone != config.TimeZone {
updateTimeZone(config)
}
return nil
}
// 修改时区
func updateTimeZone(config *systemconfigs.UserUIConfig) {
if len(config.TimeZone) > 0 {
location, err := time.LoadLocation(config.TimeZone)
if err == nil && time.Local != location {
time.Local = location
}
}
}

View File

@@ -0,0 +1,29 @@
package configloaders
import (
_ "github.com/iwind/TeaGo/bootstrap"
"testing"
"time"
)
func TestLoadUIConfig(t *testing.T) {
for i := 0; i < 10; i++ {
before := time.Now()
config, err := LoadUIConfig()
if err != nil {
t.Fatal(err)
}
t.Log(time.Since(before).Seconds()*1000, "ms")
t.Logf("%p", config)
}
}
func TestLoadUIConfig2(t *testing.T) {
for i := 0; i < 10; i++ {
config, err := LoadUIConfig()
if err != nil {
t.Fatal(err)
}
t.Log(config)
}
}

View File

@@ -0,0 +1,76 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package configloaders
import (
"bytes"
"encoding/json"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/systemconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/userconfigs"
"github.com/TeaOSLab/EdgeUser/internal/remotelogs"
"github.com/TeaOSLab/EdgeUser/internal/rpc"
"time"
)
var sharedUserPriceConfig *userconfigs.UserPriceConfig
var sharedUserPriceJSON []byte
func init() {
var ticker = time.NewTicker(1 * time.Minute)
go func() {
for range ticker.C {
_, err := LoadUserPriceConfig()
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("LoadUserPriceConfig", err.Error())
} else {
remotelogs.Error("LoadUserPriceConfig", err.Error())
}
}
}
}()
}
// LoadUserPriceConfig 加载用户计费设置
// 在没有error的情况下需要保证一定会返回一个不为空的配置
func LoadUserPriceConfig() (*userconfigs.UserPriceConfig, error) {
rpcClient, err := rpc.SharedRPC()
if err != nil {
return nil, err
}
resp, err := rpcClient.SysSettingRPC().ReadSysSetting(rpcClient.Context(0), &pb.ReadSysSettingRequest{Code: systemconfigs.SettingCodeUserPriceConfig})
if err != nil {
return nil, err
}
// 如果是相同,则不更新
if bytes.Equal(resp.ValueJSON, sharedUserPriceJSON) {
return sharedUserPriceConfig, nil
}
var config = userconfigs.DefaultUserPriceConfig()
if len(resp.ValueJSON) == 0 {
return config, nil
}
err = json.Unmarshal(resp.ValueJSON, config)
if err != nil {
return nil, err
}
sharedUserPriceConfig = config
sharedUserPriceJSON = resp.ValueJSON
return config, nil
}
func LoadCacheableUserPriceConfig() (*userconfigs.UserPriceConfig, error) {
if sharedUserPriceConfig != nil {
// clone是防止被修改
return sharedUserPriceConfig.Clone()
}
return LoadUserPriceConfig()
}

View File

@@ -0,0 +1,96 @@
package configs
import (
"errors"
"github.com/iwind/TeaGo/Tea"
"gopkg.in/yaml.v3"
"os"
)
const ConfigFileName = "api_user.yaml"
const oldConfigFileName = "api.yaml"
var SharedAPIConfig *APIConfig
// APIConfig API配置
type APIConfig struct {
OldRPC struct {
Endpoints []string `yaml:"endpoints"`
DisableUpdate bool `yaml:"disableUpdate"`
} `yaml:"rpc,omitempty"`
RPCEndpoints []string `yaml:"rpc.endpoints,flow" json:"rpc.endpoints"`
RPCDisableUpdate bool `yaml:"rpc.disableUpdate" json:"rpc.disableUpdate"`
NodeId string `yaml:"nodeId"`
Secret string `yaml:"secret"`
NumberId int64 `yaml:"numberId"`
}
// LoadAPIConfig 加载API配置
func LoadAPIConfig() (*APIConfig, error) {
if SharedAPIConfig != nil {
return SharedAPIConfig, nil
}
for _, filename := range []string{ConfigFileName, oldConfigFileName} {
data, err := os.ReadFile(Tea.ConfigFile(filename))
if err != nil {
if os.IsNotExist(err) {
continue
}
return nil, err
}
var config = &APIConfig{}
err = yaml.Unmarshal(data, config)
if err != nil {
return nil, err
}
err = config.Init()
if err != nil {
return nil, errors.New("init error: " + err.Error())
}
// 自动生成新的配置文件
if filename == oldConfigFileName {
config.OldRPC.Endpoints = nil
_ = config.WriteFile(Tea.ConfigFile(ConfigFileName))
}
SharedAPIConfig = config
return config, nil
}
return nil, errors.New("no config file '" + ConfigFileName + "' found")
}
func (this *APIConfig) Init() error {
// compatible with old
if len(this.RPCEndpoints) == 0 && len(this.OldRPC.Endpoints) > 0 {
this.RPCEndpoints = this.OldRPC.Endpoints
this.RPCDisableUpdate = this.OldRPC.DisableUpdate
}
if len(this.RPCEndpoints) == 0 {
return errors.New("no valid 'rpc.endpoints'")
}
if len(this.NodeId) == 0 {
return errors.New("'nodeId' required")
}
if len(this.Secret) == 0 {
return errors.New("'secret' required")
}
return nil
}
// WriteFile 写入API配置
func (this *APIConfig) WriteFile(path string) error {
data, err := yaml.Marshal(this)
if err != nil {
return err
}
return os.WriteFile(path, data, 0666)
}

View File

@@ -0,0 +1,21 @@
package configs
import (
_ "github.com/iwind/TeaGo/bootstrap"
"gopkg.in/yaml.v3"
"testing"
)
func TestLoadAPIConfig(t *testing.T) {
config, err := LoadAPIConfig()
if err != nil {
t.Fatal(err)
}
t.Logf("%+v", config)
configData, err := yaml.Marshal(config)
if err != nil {
t.Fatal(err)
}
t.Log(string(configData))
}

View File

@@ -0,0 +1,3 @@
package configs
var Secret = ""

View File

@@ -0,0 +1,24 @@
package teaconst
const (
Version = "1.4.7" //1.3.8.2
ProductName = "Edge User"
ProcessName = "edge-user"
ProductNameZH = "Edge"
Role = "user"
EncryptKey = "8f983f4d69b83aaa0d74b21a212f6967"
EncryptMethod = "aes-256-cfb"
ErrServer = "服务器出了点小问题,请联系技术人员处理。"
CookieSID = "edgeusid"
SessionUserId = "userId"
SystemdServiceName = "edge-user"
UpdatesURL = "https://goedge.cn/api/boot/versions?os=${os}&arch=${arch}&version=${version}"
IsPlus = true
)

View File

@@ -0,0 +1,25 @@
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package teaconst
import (
"os"
"strings"
)
var (
IsMain = checkMain()
IsDemoMode = false
)
// 检查是否为主程序
func checkMain() bool {
if len(os.Args) == 1 ||
(len(os.Args) >= 2 && os.Args[1] == "pprof") {
return true
}
exe, _ := os.Executable()
return strings.HasSuffix(exe, ".test") ||
strings.HasSuffix(exe, ".test.exe") ||
strings.Contains(exe, "___")
}

View File

@@ -0,0 +1,58 @@
package csrf
import (
"sync"
"time"
)
var sharedTokenManager = NewTokenManager()
func init() {
go func() {
ticker := time.NewTicker(1 * time.Hour)
for range ticker.C {
sharedTokenManager.Clean()
}
}()
}
type TokenManager struct {
tokenMap map[string]int64 // token => timestamp
locker sync.Mutex
}
func NewTokenManager() *TokenManager {
return &TokenManager{
tokenMap: map[string]int64{},
}
}
func (this *TokenManager) Put(token string) {
this.locker.Lock()
this.tokenMap[token] = time.Now().Unix()
this.locker.Unlock()
}
func (this *TokenManager) Exists(token string) bool {
this.locker.Lock()
_, ok := this.tokenMap[token]
this.locker.Unlock()
return ok
}
func (this *TokenManager) Delete(token string) {
this.locker.Lock()
delete(this.tokenMap, token)
this.locker.Unlock()
}
func (this *TokenManager) Clean() {
this.locker.Lock()
for token, timestamp := range this.tokenMap {
if time.Now().Unix()-timestamp > 3600 { // 删除一个小时前的
delete(this.tokenMap, token)
}
}
this.locker.Unlock()
}

View File

@@ -0,0 +1,66 @@
package csrf
import (
"crypto/sha256"
"encoding/base64"
"fmt"
"github.com/TeaOSLab/EdgeUser/internal/configs"
"github.com/iwind/TeaGo/types"
"strconv"
"time"
)
// 生成Token
func Generate() string {
timestamp := strconv.FormatInt(time.Now().Unix(), 10)
h := sha256.New()
h.Write([]byte(configs.Secret))
h.Write([]byte(timestamp))
s := h.Sum(nil)
token := base64.StdEncoding.EncodeToString([]byte(timestamp + fmt.Sprintf("%x", s)))
sharedTokenManager.Put(token)
return token
}
// 校验Token
func Validate(token string) (b bool) {
if len(token) == 0 {
return
}
if !sharedTokenManager.Exists(token) {
return
}
defer func() {
sharedTokenManager.Delete(token)
}()
data, err := base64.StdEncoding.DecodeString(token)
if err != nil {
return
}
hashString := string(data)
if len(hashString) < 10+32 {
return
}
timestampString := hashString[:10]
hashString = hashString[10:]
h := sha256.New()
h.Write([]byte(configs.Secret))
h.Write([]byte(timestampString))
hashData := h.Sum(nil)
if hashString != fmt.Sprintf("%x", hashData) {
return
}
timestamp := types.Int64(timestampString)
if timestamp < time.Now().Unix()-1800 { // 有效期半个小时
return
}
return true
}

View File

@@ -0,0 +1,41 @@
package encrypt
import (
"github.com/iwind/TeaGo/logs"
)
const (
MagicKey = "f1c8eafb543f03023e97b7be864a4e9b"
)
// 加密特殊信息
func MagicKeyEncode(data []byte) []byte {
method, err := NewMethodInstance("aes-256-cfb", MagicKey, MagicKey[:16])
if err != nil {
logs.Println("[MagicKeyEncode]" + err.Error())
return data
}
dst, err := method.Encrypt(data)
if err != nil {
logs.Println("[MagicKeyEncode]" + err.Error())
return data
}
return dst
}
// 解密特殊信息
func MagicKeyDecode(data []byte) []byte {
method, err := NewMethodInstance("aes-256-cfb", MagicKey, MagicKey[:16])
if err != nil {
logs.Println("[MagicKeyEncode]" + err.Error())
return data
}
src, err := method.Decrypt(data)
if err != nil {
logs.Println("[MagicKeyEncode]" + err.Error())
return data
}
return src
}

View File

@@ -0,0 +1,11 @@
package encrypt
import "testing"
func TestMagicKeyEncode(t *testing.T) {
dst := MagicKeyEncode([]byte("Hello,World"))
t.Log("dst:", string(dst))
src := MagicKeyDecode(dst)
t.Log("src:", string(src))
}

View File

@@ -0,0 +1,12 @@
package encrypt
type MethodInterface interface {
// 初始化
Init(key []byte, iv []byte) error
// 加密
Encrypt(src []byte) (dst []byte, err error)
// 解密
Decrypt(dst []byte) (src []byte, err error)
}

View File

@@ -0,0 +1,73 @@
package encrypt
import (
"bytes"
"crypto/aes"
"crypto/cipher"
)
type AES128CFBMethod struct {
iv []byte
block cipher.Block
}
func (this *AES128CFBMethod) Init(key, iv []byte) error {
// 判断key是否为32长度
l := len(key)
if l > 16 {
key = key[:16]
} else if l < 16 {
key = append(key, bytes.Repeat([]byte{' '}, 16-l)...)
}
// 判断iv长度
l2 := len(iv)
if l2 > aes.BlockSize {
iv = iv[:aes.BlockSize]
} else if l2 < aes.BlockSize {
iv = append(iv, bytes.Repeat([]byte{' '}, aes.BlockSize-l2)...)
}
this.iv = iv
// block
block, err := aes.NewCipher(key)
if err != nil {
return err
}
this.block = block
return nil
}
func (this *AES128CFBMethod) Encrypt(src []byte) (dst []byte, err error) {
if len(src) == 0 {
return
}
defer func() {
err = RecoverMethodPanic(recover())
}()
dst = make([]byte, len(src))
encrypter := cipher.NewCFBEncrypter(this.block, this.iv)
encrypter.XORKeyStream(dst, src)
return
}
func (this *AES128CFBMethod) Decrypt(dst []byte) (src []byte, err error) {
if len(dst) == 0 {
return
}
defer func() {
err = RecoverMethodPanic(recover())
}()
src = make([]byte, len(dst))
encrypter := cipher.NewCFBDecrypter(this.block, this.iv)
encrypter.XORKeyStream(src, dst)
return
}

View File

@@ -0,0 +1,90 @@
package encrypt
import (
"runtime"
"strings"
"testing"
)
func TestAES128CFBMethod_Encrypt(t *testing.T) {
method, err := NewMethodInstance("aes-128-cfb", "abc", "123")
if err != nil {
t.Fatal(err)
}
src := []byte("Hello, World")
dst, err := method.Encrypt(src)
if err != nil {
t.Fatal(err)
}
dst = dst[:len(src)]
t.Log("dst:", string(dst))
src, err = method.Decrypt(dst)
if err != nil {
t.Fatal(err)
}
t.Log("src:", string(src))
}
func TestAES128CFBMethod_Encrypt2(t *testing.T) {
method, err := NewMethodInstance("aes-128-cfb", "abc", "123")
if err != nil {
t.Fatal(err)
}
sources := [][]byte{}
{
a := []byte{1}
_, err = method.Encrypt(a)
if err != nil {
t.Fatal(err)
}
}
for i := 0; i < 10; i++ {
src := []byte(strings.Repeat("Hello", 1))
dst, err := method.Encrypt(src)
if err != nil {
t.Fatal(err)
}
sources = append(sources, dst)
}
{
a := []byte{1}
_, err = method.Decrypt(a)
if err != nil {
t.Fatal(err)
}
}
for _, dst := range sources {
dst2 := append([]byte{}, dst...)
src2, err := method.Decrypt(dst2)
if err != nil {
t.Fatal(err)
}
t.Log(string(src2))
}
}
func BenchmarkAES128CFBMethod_Encrypt(b *testing.B) {
runtime.GOMAXPROCS(1)
method, err := NewMethodInstance("aes-128-cfb", "abc", "123")
if err != nil {
b.Fatal(err)
}
src := []byte(strings.Repeat("Hello", 1024))
for i := 0; i < b.N; i++ {
dst, err := method.Encrypt(src)
if err != nil {
b.Fatal(err)
}
_ = dst
}
}

View File

@@ -0,0 +1,74 @@
package encrypt
import (
"bytes"
"crypto/aes"
"crypto/cipher"
)
type AES192CFBMethod struct {
block cipher.Block
iv []byte
}
func (this *AES192CFBMethod) Init(key, iv []byte) error {
// 判断key是否为24长度
l := len(key)
if l > 24 {
key = key[:24]
} else if l < 24 {
key = append(key, bytes.Repeat([]byte{' '}, 24-l)...)
}
block, err := aes.NewCipher(key)
if err != nil {
return err
}
this.block = block
// 判断iv长度
l2 := len(iv)
if l2 > aes.BlockSize {
iv = iv[:aes.BlockSize]
} else if l2 < aes.BlockSize {
iv = append(iv, bytes.Repeat([]byte{' '}, aes.BlockSize-l2)...)
}
this.iv = iv
return nil
}
func (this *AES192CFBMethod) Encrypt(src []byte) (dst []byte, err error) {
if len(src) == 0 {
return
}
defer func() {
err = RecoverMethodPanic(recover())
}()
dst = make([]byte, len(src))
encrypter := cipher.NewCFBEncrypter(this.block, this.iv)
encrypter.XORKeyStream(dst, src)
return
}
func (this *AES192CFBMethod) Decrypt(dst []byte) (src []byte, err error) {
if len(dst) == 0 {
return
}
defer func() {
err = RecoverMethodPanic(recover())
}()
src = make([]byte, len(dst))
decrypter := cipher.NewCFBDecrypter(this.block, this.iv)
decrypter.XORKeyStream(src, dst)
return
}

View File

@@ -0,0 +1,45 @@
package encrypt
import (
"runtime"
"strings"
"testing"
)
func TestAES192CFBMethod_Encrypt(t *testing.T) {
method, err := NewMethodInstance("aes-192-cfb", "abc", "123")
if err != nil {
t.Fatal(err)
}
src := []byte("Hello, World")
dst, err := method.Encrypt(src)
if err != nil {
t.Fatal(err)
}
dst = dst[:len(src)]
t.Log("dst:", string(dst))
src, err = method.Decrypt(dst)
if err != nil {
t.Fatal(err)
}
t.Log("src:", string(src))
}
func BenchmarkAES192CFBMethod_Encrypt(b *testing.B) {
runtime.GOMAXPROCS(1)
method, err := NewMethodInstance("aes-192-cfb", "abc", "123")
if err != nil {
b.Fatal(err)
}
src := []byte(strings.Repeat("Hello", 1024))
for i := 0; i < b.N; i++ {
dst, err := method.Encrypt(src)
if err != nil {
b.Fatal(err)
}
_ = dst
}
}

View File

@@ -0,0 +1,72 @@
package encrypt
import (
"bytes"
"crypto/aes"
"crypto/cipher"
)
type AES256CFBMethod struct {
block cipher.Block
iv []byte
}
func (this *AES256CFBMethod) Init(key, iv []byte) error {
// 判断key是否为32长度
l := len(key)
if l > 32 {
key = key[:32]
} else if l < 32 {
key = append(key, bytes.Repeat([]byte{' '}, 32-l)...)
}
block, err := aes.NewCipher(key)
if err != nil {
return err
}
this.block = block
// 判断iv长度
l2 := len(iv)
if l2 > aes.BlockSize {
iv = iv[:aes.BlockSize]
} else if l2 < aes.BlockSize {
iv = append(iv, bytes.Repeat([]byte{' '}, aes.BlockSize-l2)...)
}
this.iv = iv
return nil
}
func (this *AES256CFBMethod) Encrypt(src []byte) (dst []byte, err error) {
if len(src) == 0 {
return
}
defer func() {
err = RecoverMethodPanic(recover())
}()
dst = make([]byte, len(src))
encrypter := cipher.NewCFBEncrypter(this.block, this.iv)
encrypter.XORKeyStream(dst, src)
return
}
func (this *AES256CFBMethod) Decrypt(dst []byte) (src []byte, err error) {
if len(dst) == 0 {
return
}
defer func() {
err = RecoverMethodPanic(recover())
}()
src = make([]byte, len(dst))
decrypter := cipher.NewCFBDecrypter(this.block, this.iv)
decrypter.XORKeyStream(src, dst)
return
}

View File

@@ -0,0 +1,42 @@
package encrypt
import "testing"
func TestAES256CFBMethod_Encrypt(t *testing.T) {
method, err := NewMethodInstance("aes-256-cfb", "abc", "123")
if err != nil {
t.Fatal(err)
}
src := []byte("Hello, World")
dst, err := method.Encrypt(src)
if err != nil {
t.Fatal(err)
}
dst = dst[:len(src)]
t.Log("dst:", string(dst))
src, err = method.Decrypt(dst)
if err != nil {
t.Fatal(err)
}
t.Log("src:", string(src))
}
func TestAES256CFBMethod_Encrypt2(t *testing.T) {
method, err := NewMethodInstance("aes-256-cfb", "abc", "123")
if err != nil {
t.Fatal(err)
}
src := []byte("Hello, World")
dst, err := method.Encrypt(src)
if err != nil {
t.Fatal(err)
}
t.Log("dst:", string(dst))
src, err = method.Decrypt(dst)
if err != nil {
t.Fatal(err)
}
t.Log("src:", string(src))
}

View File

@@ -0,0 +1,26 @@
package encrypt
type RawMethod struct {
}
func (this *RawMethod) Init(key, iv []byte) error {
return nil
}
func (this *RawMethod) Encrypt(src []byte) (dst []byte, err error) {
if len(src) == 0 {
return
}
dst = make([]byte, len(src))
copy(dst, src)
return
}
func (this *RawMethod) Decrypt(dst []byte) (src []byte, err error) {
if len(dst) == 0 {
return
}
src = make([]byte, len(dst))
copy(src, dst)
return
}

View File

@@ -0,0 +1,23 @@
package encrypt
import "testing"
func TestRawMethod_Encrypt(t *testing.T) {
method, err := NewMethodInstance("raw", "abc", "123")
if err != nil {
t.Fatal(err)
}
src := []byte("Hello, World")
dst, err := method.Encrypt(src)
if err != nil {
t.Fatal(err)
}
dst = dst[:len(src)]
t.Log("dst:", string(dst))
src, err = method.Decrypt(dst)
if err != nil {
t.Fatal(err)
}
t.Log("src:", string(src))
}

View File

@@ -0,0 +1,43 @@
package encrypt
import (
"errors"
"reflect"
)
var methods = map[string]reflect.Type{
"raw": reflect.TypeOf(new(RawMethod)).Elem(),
"aes-128-cfb": reflect.TypeOf(new(AES128CFBMethod)).Elem(),
"aes-192-cfb": reflect.TypeOf(new(AES192CFBMethod)).Elem(),
"aes-256-cfb": reflect.TypeOf(new(AES256CFBMethod)).Elem(),
}
func NewMethodInstance(method string, key string, iv string) (MethodInterface, error) {
valueType, ok := methods[method]
if !ok {
return nil, errors.New("method '" + method + "' not found")
}
instance, ok := reflect.New(valueType).Interface().(MethodInterface)
if !ok {
return nil, errors.New("method '" + method + "' must implement MethodInterface")
}
err := instance.Init([]byte(key), []byte(iv))
return instance, err
}
func RecoverMethodPanic(err interface{}) error {
if err != nil {
s, ok := err.(string)
if ok {
return errors.New(s)
}
e, ok := err.(error)
if ok {
return e
}
return errors.New("unknown error")
}
return nil
}

View File

@@ -0,0 +1,8 @@
package encrypt
import "testing"
func TestFindMethodInstance(t *testing.T) {
t.Log(NewMethodInstance("a", "b", ""))
t.Log(NewMethodInstance("aes-256-cfb", "123456", ""))
}

View File

@@ -0,0 +1,56 @@
package errors
import (
"errors"
"path/filepath"
"runtime"
"strconv"
)
type errorObj struct {
err error
file string
line int
funcName string
}
func (this *errorObj) Error() string {
s := this.err.Error() + "\n " + this.file
if len(this.funcName) > 0 {
s += ":" + this.funcName + "()"
}
s += ":" + strconv.Itoa(this.line)
return s
}
// 新错误
func New(errText string) error {
ptr, file, line, ok := runtime.Caller(1)
funcName := ""
if ok {
frame, _ := runtime.CallersFrames([]uintptr{ptr}).Next()
funcName = filepath.Base(frame.Function)
}
return &errorObj{
err: errors.New(errText),
file: file,
line: line,
funcName: funcName,
}
}
// 包装已有错误
func Wrap(err error) error {
ptr, file, line, ok := runtime.Caller(1)
funcName := ""
if ok {
frame, _ := runtime.CallersFrames([]uintptr{ptr}).Next()
funcName = filepath.Base(frame.Function)
}
return &errorObj{
err: err,
file: file,
line: line,
funcName: funcName,
}
}

View File

@@ -0,0 +1,22 @@
package errors
import (
"errors"
"testing"
)
func TestNew(t *testing.T) {
t.Log(New("hello"))
t.Log(Wrap(errors.New("hello")))
t.Log(testError1())
t.Log(Wrap(testError1()))
t.Log(Wrap(testError2()))
}
func testError1() error {
return New("test error1")
}
func testError2() error {
return Wrap(testError1())
}

View File

@@ -0,0 +1,14 @@
package errors
import "strings"
// IsResourceNotFound 判断是否为资源无法查看错误
func IsResourceNotFound(err error) bool {
if err == nil {
return false
}
if strings.Contains(err.Error(), "resource not found") {
return true
}
return false
}

View File

@@ -0,0 +1,10 @@
package events
type Event = string
const (
EventStart Event = "start" // start loading
EventQuit Event = "quit" // quit node gracefully
EventSecurityConfigChanged Event = "securityConfigChanged" // 安全设置变更
)

View File

@@ -0,0 +1,27 @@
package events
import "sync"
var eventsMap = map[string][]func(){} // event => []callbacks
var locker = sync.Mutex{}
// 增加事件回调
func On(event string, callback func()) {
locker.Lock()
defer locker.Unlock()
var callbacks = eventsMap[event]
callbacks = append(callbacks, callback)
eventsMap[event] = callbacks
}
// 通知事件
func Notify(event string) {
locker.Lock()
var callbacks = eventsMap[event]
locker.Unlock()
for _, callback := range callbacks {
callback()
}
}

View File

@@ -0,0 +1,16 @@
package events
import "testing"
func TestOn(t *testing.T) {
On("hello", func() {
t.Log("world")
})
On("hello", func() {
t.Log("world2")
})
On("hello2", func() {
t.Log("world2")
})
Notify("hello")
}

View File

@@ -0,0 +1,149 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package gen
import (
"bytes"
"encoding/json"
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
"github.com/TeaOSLab/EdgeUser/internal/web/actions/default/servers/server/settings/conds/condutils"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/files"
"github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/maps"
"io"
"os"
"path/filepath"
)
func Generate() error {
err := generateComponentsJSFile()
if err != nil {
return fmt.Errorf("generate 'components.src.js' failed: %w", err)
}
return nil
}
// 生成Javascript文件
func generateComponentsJSFile() error {
var buffer = bytes.NewBuffer([]byte{})
var webRoot string
if Tea.IsTesting() {
webRoot = Tea.Root + "/../web/public/js/components/"
} else {
webRoot = Tea.Root + "/web/public/js/components/"
}
f := files.NewFile(webRoot)
f.Range(func(file *files.File) {
if !file.IsFile() {
return
}
if file.Ext() != ".js" {
return
}
data, err := file.ReadAll()
if err != nil {
logs.Error(err)
return
}
buffer.Write(data)
buffer.Write([]byte{'\n', '\n'})
})
// 条件组件
typesJSON, err := json.Marshal(condutils.ReadAllAvailableCondTypes())
if err != nil {
logs.Println("ComponentsAction marshal request cond types failed: " + err.Error())
} else {
buffer.WriteString("window.REQUEST_COND_COMPONENTS = ")
buffer.Write(typesJSON)
buffer.Write([]byte{';', '\n', '\n'})
}
// 条件操作符
requestOperatorsJSON, err := json.Marshal(shared.AllRequestOperators())
if err != nil {
logs.Println("ComponentsAction marshal request operators failed: " + err.Error())
} else {
buffer.WriteString("window.REQUEST_COND_OPERATORS = ")
buffer.Write(requestOperatorsJSON)
buffer.Write([]byte{';', '\n', '\n'})
}
// 请求变量
requestVariablesJSON, err := json.Marshal(shared.DefaultRequestVariables())
if err != nil {
logs.Println("ComponentsAction marshal request variables failed: " + err.Error())
} else {
buffer.WriteString("window.REQUEST_VARIABLES = ")
buffer.Write(requestVariablesJSON)
buffer.Write([]byte{';', '\n', '\n'})
}
// 指标
metricHTTPKeysJSON, err := json.Marshal(serverconfigs.FindAllMetricKeyDefinitions(serverconfigs.MetricItemCategoryHTTP))
if err != nil {
logs.Println("ComponentsAction marshal metric http keys failed: " + err.Error())
} else {
buffer.WriteString("window.METRIC_HTTP_KEYS = ")
buffer.Write(metricHTTPKeysJSON)
buffer.Write([]byte{';', '\n', '\n'})
}
// WAF checkpoints
var wafCheckpointsMaps = []maps.Map{}
for _, checkpoint := range firewallconfigs.AllCheckpoints {
wafCheckpointsMaps = append(wafCheckpointsMaps, maps.Map{
"name": checkpoint.Name,
"prefix": checkpoint.Prefix,
"description": checkpoint.Description,
})
}
wafCheckpointsJSON, err := json.Marshal(wafCheckpointsMaps)
if err != nil {
logs.Println("ComponentsAction marshal waf rule checkpoints failed: " + err.Error())
} else {
buffer.WriteString("window.WAF_RULE_CHECKPOINTS = ")
buffer.Write(wafCheckpointsJSON)
buffer.Write([]byte{';', '\n', '\n'})
}
// WAF操作符
wafOperatorsJSON, err := json.Marshal(firewallconfigs.AllRuleOperators)
if err != nil {
logs.Println("ComponentsAction marshal waf rule operators failed: " + err.Error())
} else {
buffer.WriteString("window.WAF_RULE_OPERATORS = ")
buffer.Write(wafOperatorsJSON)
buffer.Write([]byte{';', '\n', '\n'})
}
// WAF验证码类型
captchaTypesJSON, err := json.Marshal(firewallconfigs.FindAllCaptchaTypes())
if err != nil {
logs.Println("ComponentsAction marshal captcha types failed: " + err.Error())
} else {
buffer.WriteString("window.WAF_CAPTCHA_TYPES = ")
buffer.Write(captchaTypesJSON)
buffer.Write([]byte{';', '\n', '\n'})
}
fp, err := os.OpenFile(filepath.Clean(Tea.PublicFile("/js/components.src.js")), os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0777)
if err != nil {
return err
}
_, err = io.Copy(fp, buffer)
if err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,13 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package gen
import "testing"
func TestGenerate(t *testing.T) {
err := Generate()
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}

View File

@@ -0,0 +1,10 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package monitor
// ItemValue 数据值定义
type ItemValue struct {
Item string
ValueJSON []byte
CreatedAt int64
}

View File

@@ -0,0 +1,84 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package monitor
import (
"encoding/json"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeUser/internal/events"
"github.com/TeaOSLab/EdgeUser/internal/remotelogs"
"github.com/TeaOSLab/EdgeUser/internal/rpc"
"github.com/iwind/TeaGo/maps"
"time"
)
var SharedValueQueue = NewValueQueue()
func init() {
events.On(events.EventStart, func() {
go SharedValueQueue.Start()
})
}
// ValueQueue 数据记录队列
type ValueQueue struct {
valuesChan chan *ItemValue
}
func NewValueQueue() *ValueQueue {
return &ValueQueue{
valuesChan: make(chan *ItemValue, 1024),
}
}
// Start 启动队列
func (this *ValueQueue) Start() {
// 这里单次循环就行因为Loop里已经使用了Range通道
err := this.Loop()
if err != nil {
remotelogs.Error("MONITOR_QUEUE", err.Error())
}
}
// Add 添加数据
func (this *ValueQueue) Add(item string, value maps.Map) {
valueJSON, err := json.Marshal(value)
if err != nil {
remotelogs.Error("MONITOR_QUEUE", "marshal value error: "+err.Error())
return
}
select {
case this.valuesChan <- &ItemValue{
Item: item,
ValueJSON: valueJSON,
CreatedAt: time.Now().Unix(),
}:
default:
}
}
// Loop 单次循环
func (this *ValueQueue) Loop() error {
rpcClient, err := rpc.SharedRPC()
if err != nil {
return err
}
for value := range this.valuesChan {
_, err = rpcClient.NodeValueRPC().CreateNodeValue(rpcClient.Context(0), &pb.CreateNodeValueRequest{
Item: value.Item,
ValueJSON: value.ValueJSON,
CreatedAt: value.CreatedAt,
})
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("MONITOR", err.Error())
} else {
remotelogs.Error("MONITOR", err.Error())
}
continue
}
}
return nil
}

View File

@@ -0,0 +1,217 @@
package nodes
import (
"encoding/json"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeUser/internal/configs"
teaconst "github.com/TeaOSLab/EdgeUser/internal/const"
"github.com/TeaOSLab/EdgeUser/internal/events"
"github.com/TeaOSLab/EdgeUser/internal/monitor"
"github.com/TeaOSLab/EdgeUser/internal/remotelogs"
"github.com/TeaOSLab/EdgeUser/internal/rpc"
"github.com/TeaOSLab/EdgeUser/internal/utils"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/maps"
"github.com/shirou/gopsutil/v3/cpu"
"github.com/shirou/gopsutil/v3/disk"
"os"
"runtime"
"strings"
"time"
)
type NodeStatusExecutor struct {
isFirstTime bool
cpuUpdatedTime time.Time
cpuLogicalCount int
cpuPhysicalCount int
}
func NewNodeStatusExecutor() *NodeStatusExecutor {
return &NodeStatusExecutor{}
}
func (this *NodeStatusExecutor) Listen() {
this.isFirstTime = true
this.cpuUpdatedTime = time.Now()
this.update()
// TODO 这个时间间隔可以配置
ticker := time.NewTicker(30 * time.Second)
events.On(events.EventQuit, func() {
remotelogs.Println("NODE_STATUS", "quit executor")
ticker.Stop()
})
for range ticker.C {
this.isFirstTime = false
this.update()
}
}
func (this *NodeStatusExecutor) update() {
status := &nodeconfigs.NodeStatus{}
status.BuildVersion = teaconst.Version
status.BuildVersionCode = utils.VersionToLong(teaconst.Version)
status.OS = runtime.GOOS
status.Arch = runtime.GOARCH
status.ConfigVersion = 0
status.IsActive = true
status.ConnectionCount = 0 // TODO 将来显示连接数
hostname, _ := os.Hostname()
status.Hostname = hostname
this.updateCPU(status)
this.updateMem(status)
this.updateLoad(status)
this.updateDisk(status)
status.UpdatedAt = time.Now().Unix()
status.Timestamp = status.UpdatedAt
// 发送数据
jsonData, err := json.Marshal(status)
if err != nil {
remotelogs.Error("NODE_STATUS", "serial NodeStatus fail: "+err.Error())
return
}
rpcClient, err := rpc.SharedRPC()
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("NODE_STATUS", "failed to open rpc: "+err.Error())
} else {
remotelogs.Error("NODE_STATUS", "failed to open rpc: "+err.Error())
}
return
}
nodeId := int64(0)
if configs.SharedAPIConfig != nil {
nodeId = configs.SharedAPIConfig.NumberId
}
_, err = rpcClient.UserNodeRPC().UpdateUserNodeStatus(rpcClient.Context(0), &pb.UpdateUserNodeStatusRequest{
UserNodeId: nodeId,
StatusJSON: jsonData,
})
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("NODE_STATUS", "rpc UpdateUserNodeStatus() failed: "+err.Error())
} else {
remotelogs.Error("NODE_STATUS", "rpc UpdateUserNodeStatus() failed: "+err.Error())
}
return
}
}
// 更新CPU
func (this *NodeStatusExecutor) updateCPU(status *nodeconfigs.NodeStatus) {
duration := time.Duration(0)
if this.isFirstTime {
duration = 100 * time.Millisecond
}
percents, err := cpu.Percent(duration, false)
if err != nil {
status.Error = "cpu.Percent(): " + err.Error()
return
}
if len(percents) == 0 {
return
}
status.CPUUsage = percents[0] / 100
if time.Since(this.cpuUpdatedTime) > 300*time.Second { // 每隔5分钟才会更新一次
this.cpuUpdatedTime = time.Now()
status.CPULogicalCount, err = cpu.Counts(true)
if err != nil {
status.Error = "cpu.Counts(): " + err.Error()
return
}
status.CPUPhysicalCount, err = cpu.Counts(false)
if err != nil {
status.Error = "cpu.Counts(): " + err.Error()
return
}
this.cpuLogicalCount = status.CPULogicalCount
this.cpuPhysicalCount = status.CPUPhysicalCount
} else {
status.CPULogicalCount = this.cpuLogicalCount
status.CPUPhysicalCount = this.cpuPhysicalCount
}
// 记录监控数据
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemCPU, maps.Map{
"usage": status.CPUUsage,
"cores": runtime.NumCPU(),
})
}
// 更新硬盘
func (this *NodeStatusExecutor) updateDisk(status *nodeconfigs.NodeStatus) {
partitions, err := disk.Partitions(false)
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("NODE_STATUS", err.Error())
} else {
remotelogs.Error("NODE_STATUS", err.Error())
}
return
}
lists.Sort(partitions, func(i int, j int) bool {
p1 := partitions[i]
p2 := partitions[j]
return p1.Mountpoint > p2.Mountpoint
})
// 当前TeaWeb所在的fs
var rootFS = ""
var rootTotal = uint64(0)
if lists.ContainsString([]string{"darwin", "linux", "freebsd"}, runtime.GOOS) {
for _, p := range partitions {
if p.Mountpoint == "/" {
rootFS = p.Fstype
usage, _ := disk.Usage(p.Mountpoint)
if usage != nil {
rootTotal = usage.Total
}
break
}
}
}
var total = rootTotal
var totalUsage = uint64(0)
maxUsage := float64(0)
for _, partition := range partitions {
if runtime.GOOS != "windows" && !strings.Contains(partition.Device, "/") && !strings.Contains(partition.Device, "\\") {
continue
}
// 跳过不同fs的
if len(rootFS) > 0 && rootFS != partition.Fstype {
continue
}
usage, err := disk.Usage(partition.Mountpoint)
if err != nil {
continue
}
if partition.Mountpoint != "/" && (usage.Total != rootTotal || total == 0) {
total += usage.Total
}
totalUsage += usage.Used
if usage.UsedPercent >= maxUsage {
maxUsage = usage.UsedPercent
status.DiskMaxUsagePartition = partition.Mountpoint
}
}
status.DiskTotal = total
if total > 0 {
status.DiskUsage = float64(totalUsage) / float64(total)
}
status.DiskMaxUsage = maxUsage / 100
}

View File

@@ -0,0 +1,27 @@
package nodes
import (
"github.com/shirou/gopsutil/v3/cpu"
"testing"
"time"
)
func TestNodeStatusExecutor_CPU(t *testing.T) {
countLogicCPU, err := cpu.Counts(true)
if err != nil {
t.Fatal(err)
}
t.Log("logic count:", countLogicCPU)
countPhysicalCPU, err := cpu.Counts(false)
if err != nil {
t.Fatal(err)
}
t.Log("physical count:", countPhysicalCPU)
percents, err := cpu.Percent(100*time.Millisecond, false)
if err != nil {
t.Fatal(err)
}
t.Log(percents)
}

View File

@@ -0,0 +1,58 @@
//go:build !windows
// +build !windows
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeUser/internal/monitor"
"github.com/iwind/TeaGo/maps"
"github.com/shirou/gopsutil/v3/load"
"github.com/shirou/gopsutil/v3/mem"
)
// 更新内存
func (this *NodeStatusExecutor) updateMem(status *nodeconfigs.NodeStatus) {
stat, err := mem.VirtualMemory()
if err != nil {
return
}
// 重新计算内存
if stat.Total > 0 {
stat.Used = stat.Total - stat.Free - stat.Buffers - stat.Cached
status.MemoryUsage = float64(stat.Used) / float64(stat.Total)
}
status.MemoryTotal = stat.Total
// 记录监控数据
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemMemory, maps.Map{
"usage": status.MemoryUsage,
"total": status.MemoryTotal,
"used": stat.Used,
})
}
// 更新负载
func (this *NodeStatusExecutor) updateLoad(status *nodeconfigs.NodeStatus) {
stat, err := load.Avg()
if err != nil {
status.Error = err.Error()
return
}
if stat == nil {
status.Error = "load is nil"
return
}
status.Load1m = stat.Load1
status.Load5m = stat.Load5
status.Load15m = stat.Load15
// 记录监控数据
monitor.SharedValueQueue.Add(nodeconfigs.NodeValueItemLoad, maps.Map{
"load1m": status.Load1m,
"load5m": status.Load5m,
"load15m": status.Load15m,
})
}

View File

@@ -0,0 +1,102 @@
//go:build windows
// +build windows
package nodes
import (
"context"
"github.com/shirou/gopsutil/v3/cpu"
"github.com/shirou/gopsutil/v3/mem"
"math"
"sync"
"time"
)
type WindowsLoadValue struct {
Timestamp int64
Value int
}
var windowsLoadValues = []*WindowsLoadValue{}
var windowsLoadLocker = &sync.Mutex{}
// 更新内存
func (this *NodeStatusExecutor) updateMem(status *NodeStatus) {
stat, err := mem.VirtualMemory()
if err != nil {
status.Error = err.Error()
return
}
status.MemoryUsage = stat.UsedPercent
status.MemoryTotal = stat.Total
}
// 更新负载
func (this *NodeStatusExecutor) updateLoad(status *NodeStatus) {
timestamp := time.Now().Unix()
currentLoad := 0
info, err := cpu.ProcInfo()
if err == nil && len(info) > 0 && info[0].ProcessorQueueLength < 1000 {
currentLoad = int(info[0].ProcessorQueueLength)
}
// 删除15分钟之前的数据
windowsLoadLocker.Lock()
result := []*WindowsLoadValue{}
for _, v := range windowsLoadValues {
if timestamp-v.Timestamp > 15*60 {
continue
}
result = append(result, v)
}
result = append(result, &WindowsLoadValue{
Timestamp: timestamp,
Value: currentLoad,
})
windowsLoadValues = result
total1 := 0
count1 := 0
total5 := 0
count5 := 0
total15 := 0
count15 := 0
for _, v := range result {
if timestamp-v.Timestamp <= 60 {
total1 += v.Value
count1++
}
if timestamp-v.Timestamp <= 300 {
total5 += v.Value
count5++
}
total15 += v.Value
count15++
}
load1 := float64(0)
load5 := float64(0)
load15 := float64(0)
if count1 > 0 {
load1 = math.Round(float64(total1*100)/float64(count1)) / 100
}
if count5 > 0 {
load5 = math.Round(float64(total5*100)/float64(count5)) / 100
}
if count15 > 0 {
load15 = math.Round(float64(total15*100)/float64(count15)) / 100
}
windowsLoadLocker.Unlock()
// 在老Windows上不显示错误
if err == context.DeadlineExceeded {
err = nil
}
status.Load1m = load1
status.Load5m = load5
status.Load15m = load15
}

View File

@@ -0,0 +1,70 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package serverutils
import (
"errors"
"github.com/iwind/TeaGo"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/types"
"gopkg.in/yaml.v3"
"net"
"os"
"time"
)
const configFilename = "server.yaml"
// LoadServerConfig 读取当前服务配置
func LoadServerConfig() (*TeaGo.ServerConfig, error) {
var configFile = Tea.ConfigFile(configFilename)
data, err := os.ReadFile(configFile)
if err != nil {
return nil, err
}
var serverConfig = &TeaGo.ServerConfig{}
err = yaml.Unmarshal(data, serverConfig)
if err != nil {
return nil, err
}
return serverConfig, nil
}
// ReadServerHTTPS 检查HTTPS地址
func ReadServerHTTPS() (port int, err error) {
config, err := LoadServerConfig()
if err != nil {
return 0, err
}
if config == nil {
return 0, errors.New("could not load server config")
}
if config.Https.On && len(config.Https.Listen) > 0 {
for _, listen := range config.Https.Listen {
_, portString, splitErr := net.SplitHostPort(listen)
if splitErr == nil {
var portInt = types.Int(portString)
if portInt > 0 {
// 是否已经启动
checkErr := func() error {
conn, connErr := net.DialTimeout("tcp", ":"+portString, 1*time.Second)
if connErr != nil {
return connErr
}
_ = conn.Close()
return nil
}()
if checkErr != nil {
continue
}
port = portInt
err = nil
break
}
}
}
}
return
}

View File

@@ -0,0 +1,118 @@
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes
import (
"encoding/json"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeUser/internal/rpc"
"github.com/TeaOSLab/EdgeUser/internal/ttlcache"
"github.com/iwind/TeaGo/actions"
"github.com/iwind/TeaGo/logs"
"strings"
"time"
)
// SessionManager SESSION管理
type SessionManager struct {
life uint
}
func NewSessionManager() (*SessionManager, error) {
return &SessionManager{}, nil
}
func (this *SessionManager) Init(config *actions.SessionConfig) {
this.life = config.Life
}
func (this *SessionManager) Read(sid string) map[string]string {
// 忽略OTP
if strings.HasSuffix(sid, "_otp") {
return map[string]string{}
}
var result = map[string]string{}
var cacheKey = "SESSION@" + sid
var item = ttlcache.DefaultCache.Read(cacheKey)
if item != nil && item.Value != nil {
itemMap, ok := item.Value.(map[string]string)
if ok {
return itemMap
}
}
rpcClient, err := rpc.SharedRPC()
if err != nil {
return map[string]string{}
}
resp, err := rpcClient.LoginSessionRPC().FindLoginSession(rpcClient.Context(0), &pb.FindLoginSessionRequest{Sid: sid})
if err != nil {
logs.Println("SESSION", "read '"+sid+"' failed: "+err.Error())
result["@error"] = err.Error()
return result
}
var session = resp.LoginSession
if session == nil || len(session.ValuesJSON) == 0 {
return result
}
err = json.Unmarshal(session.ValuesJSON, &result)
if err != nil {
logs.Println("SESSION", "decode '"+sid+"' values failed: "+err.Error())
}
// Write to cache
ttlcache.DefaultCache.Write(cacheKey, result, time.Now().Unix()+300 /** must not be too long **/)
return result
}
func (this *SessionManager) WriteItem(sid string, key string, value string) bool {
// 删除缓存
defer ttlcache.DefaultCache.Delete("SESSION@" + sid)
// 忽略OTP
if strings.HasSuffix(sid, "_otp") {
return false
}
rpcClient, err := rpc.SharedRPC()
if err != nil {
return false
}
_, err = rpcClient.LoginSessionRPC().WriteLoginSessionValue(rpcClient.Context(0), &pb.WriteLoginSessionValueRequest{
Sid: sid,
Key: key,
Value: value,
})
if err != nil {
logs.Println("SESSION", "write sid:'"+sid+"' key:'"+key+"' failed: "+err.Error())
}
return true
}
func (this *SessionManager) Delete(sid string) bool {
// 删除缓存
defer ttlcache.DefaultCache.Delete("SESSION@" + sid)
// 忽略OTP
if strings.HasSuffix(sid, "_otp") {
return false
}
rpcClient, err := rpc.SharedRPC()
if err != nil {
return false
}
_, err = rpcClient.LoginSessionRPC().DeleteLoginSession(rpcClient.Context(0), &pb.DeleteLoginSessionRequest{Sid: sid})
if err != nil {
logs.Println("SESSION", "delete '"+sid+"' failed: "+err.Error())
}
return true
}

View File

@@ -0,0 +1,486 @@
package nodes
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/iplibrary"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs"
"github.com/TeaOSLab/EdgeUser/internal/configloaders"
"github.com/TeaOSLab/EdgeUser/internal/configs"
teaconst "github.com/TeaOSLab/EdgeUser/internal/const"
"github.com/TeaOSLab/EdgeUser/internal/events"
"github.com/TeaOSLab/EdgeUser/internal/rpc"
_ "github.com/TeaOSLab/EdgeUser/internal/tasks"
"github.com/TeaOSLab/EdgeUser/internal/utils"
_ "github.com/TeaOSLab/EdgeUser/internal/web"
"github.com/iwind/TeaGo"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types"
"github.com/iwind/gosock/pkg/gosock"
"gopkg.in/yaml.v3"
"log"
"net"
"os"
"os/exec"
"time"
)
type UserNode struct {
sock *gosock.Sock
}
func NewUserNode() *UserNode {
return &UserNode{
sock: gosock.NewTmpSock(teaconst.ProcessName),
}
}
func (this *UserNode) Run() {
// 启动用户界面
var secret = this.genSecret()
configs.Secret = secret
// 本地Sock
err := this.listenSock()
if err != nil {
logs.Println("[USER_NODE]" + err.Error())
return
}
// 检查server配置
err = this.checkServer()
if err != nil {
logs.Println("[USER_NODE]" + err.Error())
return
}
// 触发事件
events.Notify(events.EventStart)
// 拉取配置
err = this.pullConfig()
if err != nil {
logs.Println("[USER_NODE]pull config failed: " + err.Error())
return
}
// 设置DNS
this.setupDNS()
// 监控状态
go NewNodeStatusExecutor().Listen()
logs.Println("[USER_NODE]initializing ip library ...")
err = iplibrary.InitPlus()
if err != nil {
logs.Println("[USER_NODE]initialize ip library failed: " + err.Error())
}
// 启动Web服务
sessionManager, err := NewSessionManager()
if err != nil {
log.Fatal("start session failed: " + err.Error())
return
}
TeaGo.NewServer(false).
AccessLog(false).
EndAll().
Session(sessionManager, teaconst.CookieSID).
ReadHeaderTimeout(3*time.Second).
ReadTimeout(600*time.Second).
Static("/www", Tea.Root+"/www").
Start()
}
// Daemon 实现守护进程
func (this *UserNode) Daemon() {
var isDebug = lists.ContainsString(os.Args, "debug")
for {
conn, err := this.sock.Dial()
if err != nil {
if isDebug {
log.Println("[DAEMON]starting ...")
}
// 尝试启动
err = func() error {
exe, err := os.Executable()
if err != nil {
return err
}
cmd := exec.Command(exe)
err = cmd.Start()
if err != nil {
return err
}
err = cmd.Wait()
if err != nil {
return err
}
return nil
}()
if err != nil {
if isDebug {
log.Println("[DAEMON]", err)
}
time.Sleep(1 * time.Second)
} else {
time.Sleep(5 * time.Second)
}
} else {
_ = conn.Close()
time.Sleep(5 * time.Second)
}
}
}
// InstallSystemService 安装系统服务
func (this *UserNode) InstallSystemService() error {
var shortName = teaconst.SystemdServiceName
exe, err := os.Executable()
if err != nil {
return err
}
manager := utils.NewServiceManager(shortName, teaconst.ProductName)
err = manager.Install(exe, []string{})
if err != nil {
return err
}
return nil
}
// 检查Server配置
func (this *UserNode) checkServer() error {
var configFile = Tea.ConfigFile("server.yaml")
_, err := os.Stat(configFile)
if err == nil {
return nil
}
if os.IsNotExist(err) {
// 创建文件
var templateFile = Tea.ConfigFile("server.template.yaml")
data, err := os.ReadFile(templateFile)
if err == nil {
err = os.WriteFile(configFile, data, 0666)
if err != nil {
return fmt.Errorf("create config file failed: %w", err)
}
} else {
templateYAML := `# environment code
env: prod
# http
http:
"on": true
listen: [ "0.0.0.0:7789" ]
# https
https:
"on": false
listen: [ "0.0.0.0:443"]
cert: ""
key: ""
`
err = os.WriteFile(configFile, []byte(templateYAML), 0666)
if err != nil {
return fmt.Errorf("create config file failed: %w", err)
}
}
} else {
return fmt.Errorf("can not read config from 'configs/server.yaml': %w", err)
}
return nil
}
// 生成Secret
func (this *UserNode) genSecret() string {
var tmpFile = os.TempDir() + "/edge-user-secret.tmp"
data, err := os.ReadFile(tmpFile)
if err == nil && len(data) == 32 {
return string(data)
}
secret := rands.String(32)
_ = os.WriteFile(tmpFile, []byte(secret), 0666)
return secret
}
// 拉取配置
func (this *UserNode) pullConfig() error {
rpcClient, err := rpc.SharedRPC()
if err != nil {
return err
}
var nodeResp *pb.FindCurrentUserNodeResponse
for i := 0; i < 10; i++ { // may retry many times
nodeResp, err = rpcClient.UserNodeRPC().FindCurrentUserNode(rpcClient.Context(0), &pb.FindCurrentUserNodeRequest{})
if err != nil {
time.Sleep(1 * time.Second)
continue
}
break
}
if err != nil {
return err
}
var node = nodeResp.UserNode
if node == nil {
return errors.New("invalid 'nodeId' or 'secret'")
}
if configs.SharedAPIConfig != nil {
configs.SharedAPIConfig.NumberId = node.Id
}
// 读取Web服务配置
var serverConfig = &TeaGo.ServerConfig{
Env: Tea.EnvProd,
}
if Tea.IsTesting() {
serverConfig.Env = Tea.EnvDev
}
// HTTP
httpConfig, err := this.decodeHTTP(node)
if err != nil {
return fmt.Errorf("decode http config failed: %w", err)
}
if httpConfig != nil && httpConfig.IsOn && len(httpConfig.Listen) > 0 {
serverConfig.Http.On = true
var listens = []string{}
for _, listen := range httpConfig.Listen {
listens = append(listens, listen.Addresses()...)
}
serverConfig.Http.Listen = listens
}
// HTTPS
httpsConfig, err := this.DecodeHTTPS(node)
if err != nil {
return fmt.Errorf("decode https config failed: %w", err)
}
if httpsConfig != nil && httpsConfig.IsOn && len(httpsConfig.Listen) > 0 {
serverConfig.Https.On = true
serverConfig.Https.Cert = "configs/https.cert.pem"
serverConfig.Https.Key = "configs/https.key.pem"
var listens = []string{}
for _, listen := range httpsConfig.Listen {
listens = append(listens, listen.Addresses()...)
}
serverConfig.Https.Listen = listens
}
// 保存到文件
serverYAML, err := yaml.Marshal(serverConfig)
if err != nil {
return err
}
err = os.WriteFile(Tea.ConfigFile("server.yaml"), serverYAML, 0666)
if err != nil {
return err
}
// add to local firewall
var ports = []int{}
for _, listens := range [][]string{serverConfig.Http.Listen, serverConfig.Https.Listen} {
for _, listen := range listens {
_, portString, err := net.SplitHostPort(listen)
if err == nil {
var port = types.Int(portString)
if port > 0 && !lists.ContainsInt(ports, port) {
ports = append(ports, port)
}
}
}
}
if len(ports) > 0 {
go utils.AddPortsToFirewall(ports)
}
return nil
}
// 解析HTTP配置
func (this *UserNode) decodeHTTP(node *pb.UserNode) (*serverconfigs.HTTPProtocolConfig, error) {
if len(node.HttpJSON) == 0 {
return nil, nil
}
config := &serverconfigs.HTTPProtocolConfig{}
err := json.Unmarshal(node.HttpJSON, config)
if err != nil {
return nil, err
}
err = config.Init()
if err != nil {
return nil, err
}
return config, nil
}
// DecodeHTTPS 解析HTTPS配置
func (this *UserNode) DecodeHTTPS(node *pb.UserNode) (*serverconfigs.HTTPSProtocolConfig, error) {
if len(node.HttpsJSON) == 0 {
return nil, nil
}
var config = &serverconfigs.HTTPSProtocolConfig{}
err := json.Unmarshal(node.HttpsJSON, config)
if err != nil {
return nil, err
}
err = config.Init(context.TODO())
if err != nil {
return nil, err
}
if config.SSLPolicyRef != nil {
policyId := config.SSLPolicyRef.SSLPolicyId
if policyId > 0 {
rpcClient, err := rpc.SharedRPC()
if err != nil {
return nil, err
}
policyConfigResp, err := rpcClient.SSLPolicyRPC().FindEnabledSSLPolicyConfig(rpcClient.Context(0), &pb.FindEnabledSSLPolicyConfigRequest{SslPolicyId: policyId})
if err != nil {
return nil, err
}
if len(policyConfigResp.SslPolicyJSON) > 0 {
policyConfig := &sslconfigs.SSLPolicy{}
err = json.Unmarshal(policyConfigResp.SslPolicyJSON, policyConfig)
if err != nil {
return nil, err
}
if len(policyConfig.Certs) > 0 {
err = os.WriteFile(Tea.ConfigFile("https.cert.pem"), policyConfig.Certs[0].CertData, 0666)
if err != nil {
return nil, err
}
err = os.WriteFile(Tea.ConfigFile("https.key.pem"), policyConfig.Certs[0].KeyData, 0666)
if err != nil {
return nil, err
}
}
}
}
}
err = config.Init(context.TODO())
if err != nil {
return nil, err
}
return config, nil
}
// 监听本地sock
func (this *UserNode) listenSock() error {
// 检查是否在运行
if this.sock.IsListening() {
reply, err := this.sock.Send(&gosock.Command{Code: "pid"})
if err == nil {
return errors.New("error: the process is already running, pid: " + maps.NewMap(reply.Params).GetString("pid"))
} else {
return errors.New("error: the process is already running")
}
}
// 启动监听
go func() {
this.sock.OnCommand(func(cmd *gosock.Command) {
switch cmd.Code {
case "pid":
_ = cmd.Reply(&gosock.Command{
Code: "pid",
Params: map[string]interface{}{
"pid": os.Getpid(),
},
})
case "info":
exePath, _ := os.Executable()
_ = cmd.Reply(&gosock.Command{
Code: "info",
Params: map[string]interface{}{
"pid": os.Getpid(),
"version": teaconst.Version,
"path": exePath,
},
})
case "stop":
_ = cmd.ReplyOk()
// 退出主进程
events.Notify(events.EventQuit)
os.Exit(0)
case "dev": // 切换到dev
Tea.Env = Tea.EnvDev
_ = cmd.ReplyOk()
case "prod": // 切换到prod
Tea.Env = Tea.EnvProd
_ = cmd.ReplyOk()
case "demo":
teaconst.IsDemoMode = !teaconst.IsDemoMode
_ = cmd.Reply(&gosock.Command{
Params: map[string]interface{}{"isDemo": teaconst.IsDemoMode},
})
}
})
err := this.sock.Listen()
if err != nil {
logs.Println("NODE", err.Error())
}
}()
events.On(events.EventQuit, func() {
logs.Println("NODE", "quit unix sock")
_ = this.sock.Close()
})
return nil
}
// 设置DNS相关
func (this *UserNode) setupDNS() {
config, loadErr := configloaders.LoadUIConfig()
if loadErr != nil {
// 默认使用go原生
err := os.Setenv("GODEBUG", "netdns=go")
if err != nil {
logs.Println("[DNS_RESOLVER]set env failed: " + err.Error())
}
return
}
var err error
switch config.DNSResolver.Type {
case nodeconfigs.DNSResolverTypeGoNative:
err = os.Setenv("GODEBUG", "netdns=go")
case nodeconfigs.DNSResolverTypeCGO:
err = os.Setenv("GODEBUG", "netdns=cgo")
default:
// 默认使用go原生
err = os.Setenv("GODEBUG", "netdns=go")
}
if err != nil {
logs.Println("[DNS_RESOLVER]set env failed: " + err.Error())
}
}

View File

@@ -0,0 +1,10 @@
package oplogs
const (
LevelNone = "none"
LevelInfo = "info"
LevelDebug = "debug"
LevelWarn = "warn"
LevelError = "error"
LevelFatal = "fatal"
)

View File

@@ -0,0 +1,122 @@
package remotelogs
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeUser/internal/configs"
teaconst "github.com/TeaOSLab/EdgeUser/internal/const"
"github.com/TeaOSLab/EdgeUser/internal/rpc"
"github.com/iwind/TeaGo/logs"
"time"
)
var logChan = make(chan *pb.NodeLog, 64) // 队列数量不需要太长,因为日志通常仅仅为调试用
func init() {
// 定期上传日志
ticker := time.NewTicker(60 * time.Second)
go func() {
for range ticker.C {
err := uploadLogs()
if err != nil {
logs.Println("[LOG]" + err.Error())
}
}
}()
}
// Debug 打印调试信息
func Debug(tag string, description string) {
logs.Println("[" + tag + "]" + description)
}
// Println 打印普通信息
func Println(tag string, description string) {
logs.Println("[" + tag + "]" + description)
nodeConfig, _ := configs.LoadAPIConfig()
if nodeConfig == nil {
return
}
select {
case logChan <- &pb.NodeLog{
Role: teaconst.Role,
Tag: tag,
Description: description,
Level: "info",
NodeId: nodeConfig.NumberId,
CreatedAt: time.Now().Unix(),
}:
default:
}
}
// Warn 打印警告信息
func Warn(tag string, description string) {
logs.Println("[" + tag + "]" + description)
nodeConfig, _ := configs.LoadAPIConfig()
if nodeConfig == nil {
return
}
select {
case logChan <- &pb.NodeLog{
Role: teaconst.Role,
Tag: tag,
Description: description,
Level: "warning",
NodeId: nodeConfig.NumberId,
CreatedAt: time.Now().Unix(),
}:
default:
}
}
// Error 打印错误信息
func Error(tag string, description string) {
logs.Println("[" + tag + "]" + description)
nodeConfig, _ := configs.LoadAPIConfig()
if nodeConfig == nil {
return
}
select {
case logChan <- &pb.NodeLog{
Role: teaconst.Role,
Tag: tag,
Description: description,
Level: "error",
NodeId: nodeConfig.NumberId,
CreatedAt: time.Now().Unix(),
}:
default:
}
}
// 上传日志
func uploadLogs() error {
logList := []*pb.NodeLog{}
Loop:
for {
select {
case log := <-logChan:
logList = append(logList, log)
default:
break Loop
}
}
if len(logList) == 0 {
return nil
}
rpcClient, err := rpc.SharedRPC()
if err != nil {
return err
}
_, err = rpcClient.NodeLogRPC().CreateNodeLogs(rpcClient.Context(0), &pb.CreateNodeLogsRequest{NodeLogs: logList})
return err
}

View File

@@ -0,0 +1,15 @@
package remotelogs
import (
_ "github.com/iwind/TeaGo/bootstrap"
"testing"
)
func TestPrintln(t *testing.T) {
Println("test", "123")
err := uploadLogs()
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}

View File

@@ -0,0 +1,633 @@
package rpc
import (
"context"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/dao"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeUser/internal/configs"
teaconst "github.com/TeaOSLab/EdgeUser/internal/const"
"github.com/TeaOSLab/EdgeUser/internal/encrypt"
"github.com/TeaOSLab/EdgeUser/internal/utils"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/rands"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/encoding/gzip"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
"net/url"
"sync"
"time"
)
// RPCClient RPC客户端
type RPCClient struct {
apiConfig *configs.APIConfig
conns []*grpc.ClientConn
locker sync.Mutex
}
// NewRPCClient 构造新的RPC客户端
func NewRPCClient(apiConfig *configs.APIConfig) (*RPCClient, error) {
if apiConfig == nil {
return nil, errors.New("api config should not be nil")
}
client := &RPCClient{
apiConfig: apiConfig,
}
err := client.init()
if err != nil {
return nil, err
}
// 设置DAO的RPC
dao.SetRPC(client)
return client, nil
}
func (this *RPCClient) SysSettingRPC() pb.SysSettingServiceClient {
return pb.NewSysSettingServiceClient(this.pickConn())
}
func (this *RPCClient) SysLockerRPC() pb.SysLockerServiceClient {
return pb.NewSysLockerServiceClient(this.pickConn())
}
func (this *RPCClient) NodeRPC() pb.NodeServiceClient {
return pb.NewNodeServiceClient(this.pickConn())
}
func (this *RPCClient) NodeClusterRPC() pb.NodeClusterServiceClient {
return pb.NewNodeClusterServiceClient(this.pickConn())
}
func (this *RPCClient) NodeRegionRPC() pb.NodeRegionServiceClient {
return pb.NewNodeRegionServiceClient(this.pickConn())
}
func (this *RPCClient) NodePriceItemRPC() pb.NodePriceItemServiceClient {
return pb.NewNodePriceItemServiceClient(this.pickConn())
}
func (this *RPCClient) NodeLogRPC() pb.NodeLogServiceClient {
return pb.NewNodeLogServiceClient(this.pickConn())
}
func (this *RPCClient) NodeValueRPC() pb.NodeValueServiceClient {
return pb.NewNodeValueServiceClient(this.pickConn())
}
func (this *RPCClient) ServerRPC() pb.ServerServiceClient {
return pb.NewServerServiceClient(this.pickConn())
}
func (this *RPCClient) ServerClientSystemMonthlyStatRPC() pb.ServerClientSystemMonthlyStatServiceClient {
return pb.NewServerClientSystemMonthlyStatServiceClient(this.pickConn())
}
func (this *RPCClient) ServerClientBrowserMonthlyStatRPC() pb.ServerClientBrowserMonthlyStatServiceClient {
return pb.NewServerClientBrowserMonthlyStatServiceClient(this.pickConn())
}
func (this *RPCClient) ServerRegionCountryMonthlyStatRPC() pb.ServerRegionCountryMonthlyStatServiceClient {
return pb.NewServerRegionCountryMonthlyStatServiceClient(this.pickConn())
}
func (this *RPCClient) ServerRegionProvinceMonthlyStatRPC() pb.ServerRegionProvinceMonthlyStatServiceClient {
return pb.NewServerRegionProvinceMonthlyStatServiceClient(this.pickConn())
}
func (this *RPCClient) ServerRegionCityMonthlyStatRPC() pb.ServerRegionCityMonthlyStatServiceClient {
return pb.NewServerRegionCityMonthlyStatServiceClient(this.pickConn())
}
func (this *RPCClient) ServerRegionProviderMonthlyStatRPC() pb.ServerRegionProviderMonthlyStatServiceClient {
return pb.NewServerRegionProviderMonthlyStatServiceClient(this.pickConn())
}
func (this *RPCClient) ServerHTTPFirewallDailyStatRPC() pb.ServerHTTPFirewallDailyStatServiceClient {
return pb.NewServerHTTPFirewallDailyStatServiceClient(this.pickConn())
}
func (this *RPCClient) ServerGroupRPC() pb.ServerGroupServiceClient {
return pb.NewServerGroupServiceClient(this.pickConn())
}
func (this *RPCClient) ServerDailyRPC() pb.ServerDailyStatServiceClient {
return pb.NewServerDailyStatServiceClient(this.pickConn())
}
func (this *RPCClient) OriginRPC() pb.OriginServiceClient {
return pb.NewOriginServiceClient(this.pickConn())
}
func (this *RPCClient) HTTPWebRPC() pb.HTTPWebServiceClient {
return pb.NewHTTPWebServiceClient(this.pickConn())
}
func (this *RPCClient) HTTPAuthPolicyRPC() pb.HTTPAuthPolicyServiceClient {
return pb.NewHTTPAuthPolicyServiceClient(this.pickConn())
}
func (this *RPCClient) ReverseProxyRPC() pb.ReverseProxyServiceClient {
return pb.NewReverseProxyServiceClient(this.pickConn())
}
func (this *RPCClient) HTTPGzipRPC() pb.HTTPGzipServiceClient {
return pb.NewHTTPGzipServiceClient(this.pickConn())
}
func (this *RPCClient) HTTPHeaderRPC() pb.HTTPHeaderServiceClient {
return pb.NewHTTPHeaderServiceClient(this.pickConn())
}
func (this *RPCClient) HTTPHeaderPolicyRPC() pb.HTTPHeaderPolicyServiceClient {
return pb.NewHTTPHeaderPolicyServiceClient(this.pickConn())
}
func (this *RPCClient) HTTPPageRPC() pb.HTTPPageServiceClient {
return pb.NewHTTPPageServiceClient(this.pickConn())
}
func (this *RPCClient) HTTPAccessLogPolicyRPC() pb.HTTPAccessLogPolicyServiceClient {
return pb.NewHTTPAccessLogPolicyServiceClient(this.pickConn())
}
func (this *RPCClient) HTTPCachePolicyRPC() pb.HTTPCachePolicyServiceClient {
return pb.NewHTTPCachePolicyServiceClient(this.pickConn())
}
func (this *RPCClient) HTTPCacheTaskRPC() pb.HTTPCacheTaskServiceClient {
return pb.NewHTTPCacheTaskServiceClient(this.pickConn())
}
func (this *RPCClient) HTTPCacheTaskKeyRPC() pb.HTTPCacheTaskKeyServiceClient {
return pb.NewHTTPCacheTaskKeyServiceClient(this.pickConn())
}
func (this *RPCClient) HTTPFirewallPolicyRPC() pb.HTTPFirewallPolicyServiceClient {
return pb.NewHTTPFirewallPolicyServiceClient(this.pickConn())
}
func (this *RPCClient) HTTPFirewallRuleGroupRPC() pb.HTTPFirewallRuleGroupServiceClient {
return pb.NewHTTPFirewallRuleGroupServiceClient(this.pickConn())
}
func (this *RPCClient) HTTPFirewallRuleSetRPC() pb.HTTPFirewallRuleSetServiceClient {
return pb.NewHTTPFirewallRuleSetServiceClient(this.pickConn())
}
func (this *RPCClient) HTTPLocationRPC() pb.HTTPLocationServiceClient {
return pb.NewHTTPLocationServiceClient(this.pickConn())
}
func (this *RPCClient) HTTPWebsocketRPC() pb.HTTPWebsocketServiceClient {
return pb.NewHTTPWebsocketServiceClient(this.pickConn())
}
func (this *RPCClient) HTTPRewriteRuleRPC() pb.HTTPRewriteRuleServiceClient {
return pb.NewHTTPRewriteRuleServiceClient(this.pickConn())
}
// HTTPAccessLogRPC 访问日志
func (this *RPCClient) HTTPAccessLogRPC() pb.HTTPAccessLogServiceClient {
return pb.NewHTTPAccessLogServiceClient(this.pickConn())
}
func (this *RPCClient) SSLCertRPC() pb.SSLCertServiceClient {
return pb.NewSSLCertServiceClient(this.pickConn())
}
func (this *RPCClient) SSLPolicyRPC() pb.SSLPolicyServiceClient {
return pb.NewSSLPolicyServiceClient(this.pickConn())
}
func (this *RPCClient) MessageRPC() pb.MessageServiceClient {
return pb.NewMessageServiceClient(this.pickConn())
}
func (this *RPCClient) IPListRPC() pb.IPListServiceClient {
return pb.NewIPListServiceClient(this.pickConn())
}
func (this *RPCClient) IPItemRPC() pb.IPItemServiceClient {
return pb.NewIPItemServiceClient(this.pickConn())
}
func (this *RPCClient) FileRPC() pb.FileServiceClient {
return pb.NewFileServiceClient(this.pickConn())
}
func (this *RPCClient) FileChunkRPC() pb.FileChunkServiceClient {
return pb.NewFileChunkServiceClient(this.pickConn())
}
func (this *RPCClient) RegionCountryRPC() pb.RegionCountryServiceClient {
return pb.NewRegionCountryServiceClient(this.pickConn())
}
func (this *RPCClient) RegionProvinceRPC() pb.RegionProvinceServiceClient {
return pb.NewRegionProvinceServiceClient(this.pickConn())
}
func (this *RPCClient) RegionCityRPC() pb.RegionCityServiceClient {
return pb.NewRegionCityServiceClient(this.pickConn())
}
func (this *RPCClient) RegionProviderRPC() pb.RegionProviderServiceClient {
return pb.NewRegionProviderServiceClient(this.pickConn())
}
func (this *RPCClient) LogRPC() pb.LogServiceClient {
return pb.NewLogServiceClient(this.pickConn())
}
func (this *RPCClient) DNSProviderRPC() pb.DNSProviderServiceClient {
return pb.NewDNSProviderServiceClient(this.pickConn())
}
func (this *RPCClient) DNSDomainRPC() pb.DNSDomainServiceClient {
return pb.NewDNSDomainServiceClient(this.pickConn())
}
func (this *RPCClient) DNSRPC() pb.DNSServiceClient {
return pb.NewDNSServiceClient(this.pickConn())
}
func (this *RPCClient) NSClusterRPC() pb.NSClusterServiceClient {
return pb.NewNSClusterServiceClient(this.pickConn())
}
func (this *RPCClient) ACMEUserRPC() pb.ACMEUserServiceClient {
return pb.NewACMEUserServiceClient(this.pickConn())
}
func (this *RPCClient) ACMETaskRPC() pb.ACMETaskServiceClient {
return pb.NewACMETaskServiceClient(this.pickConn())
}
func (this *RPCClient) ACMEProviderRPC() pb.ACMEProviderServiceClient {
return pb.NewACMEProviderServiceClient(this.pickConn())
}
func (this *RPCClient) ACMEProviderAccountRPC() pb.ACMEProviderAccountServiceClient {
return pb.NewACMEProviderAccountServiceClient(this.pickConn())
}
func (this *RPCClient) UserRPC() pb.UserServiceClient {
return pb.NewUserServiceClient(this.pickConn())
}
func (this *RPCClient) UserBillRPC() pb.UserBillServiceClient {
return pb.NewUserBillServiceClient(this.pickConn())
}
func (this *RPCClient) UserIdentityRPC() pb.UserIdentityServiceClient {
return pb.NewUserIdentityServiceClient(this.pickConn())
}
func (this *RPCClient) ServerBillRPC() pb.ServerBillServiceClient {
return pb.NewServerBillServiceClient(this.pickConn())
}
func (this *RPCClient) UserTrafficBillRPC() pb.UserTrafficBillServiceClient {
return pb.NewUserTrafficBillServiceClient(this.pickConn())
}
func (this *RPCClient) UserNodeRPC() pb.UserNodeServiceClient {
return pb.NewUserNodeServiceClient(this.pickConn())
}
func (this *RPCClient) APINodeRPC() pb.APINodeServiceClient {
return pb.NewAPINodeServiceClient(this.pickConn())
}
func (this *RPCClient) UserAccessKeyRPC() pb.UserAccessKeyServiceClient {
return pb.NewUserAccessKeyServiceClient(this.pickConn())
}
func (this *RPCClient) IPLibraryRPC() pb.IPLibraryServiceClient {
return pb.NewIPLibraryServiceClient(this.pickConn())
}
func (this *RPCClient) PlanRPC() pb.PlanServiceClient {
return pb.NewPlanServiceClient(this.pickConn())
}
func (this *RPCClient) UserPlanRPC() pb.UserPlanServiceClient {
return pb.NewUserPlanServiceClient(this.pickConn())
}
func (this *RPCClient) UserAccountRPC() pb.UserAccountServiceClient {
return pb.NewUserAccountServiceClient(this.pickConn())
}
func (this *RPCClient) UserAccountLogRPC() pb.UserAccountLogServiceClient {
return pb.NewUserAccountLogServiceClient(this.pickConn())
}
func (this *RPCClient) UserOrderRPC() pb.UserOrderServiceClient {
return pb.NewUserOrderServiceClient(this.pickConn())
}
func (this *RPCClient) UserEmailVerificationRPC() pb.UserEmailVerificationServiceClient {
return pb.NewUserEmailVerificationServiceClient(this.pickConn())
}
func (this *RPCClient) UserMobileVerificationRPC() pb.UserMobileVerificationServiceClient {
return pb.NewUserMobileVerificationServiceClient(this.pickConn())
}
func (this *RPCClient) UserVerifyCodeRPC() pb.UserVerifyCodeServiceClient {
return pb.NewUserVerifyCodeServiceClient(this.pickConn())
}
func (this *RPCClient) OrderMethodRPC() pb.OrderMethodServiceClient {
return pb.NewOrderMethodServiceClient(this.pickConn())
}
func (this *RPCClient) UserTicketRPC() pb.UserTicketServiceClient {
return pb.NewUserTicketServiceClient(this.pickConn())
}
func (this *RPCClient) UserTicketLogRPC() pb.UserTicketLogServiceClient {
return pb.NewUserTicketLogServiceClient(this.pickConn())
}
func (this *RPCClient) UserTicketCategoryRPC() pb.UserTicketCategoryServiceClient {
return pb.NewUserTicketCategoryServiceClient(this.pickConn())
}
func (this *RPCClient) LoginRPC() pb.LoginServiceClient {
return pb.NewLoginServiceClient(this.pickConn())
}
func (this *RPCClient) NSDomainRPC() pb.NSDomainServiceClient {
return pb.NewNSDomainServiceClient(this.pickConn())
}
func (this *RPCClient) NSDomainGroupRPC() pb.NSDomainGroupServiceClient {
return pb.NewNSDomainGroupServiceClient(this.pickConn())
}
func (this *RPCClient) NSRecordRPC() pb.NSRecordServiceClient {
return pb.NewNSRecordServiceClient(this.pickConn())
}
func (this *RPCClient) NSKeyRPC() pb.NSKeyServiceClient {
return pb.NewNSKeyServiceClient(this.pickConn())
}
func (this *RPCClient) NSRouteRPC() pb.NSRouteServiceClient {
return pb.NewNSRouteServiceClient(this.pickConn())
}
func (this *RPCClient) NSAccessLogRPC() pb.NSAccessLogServiceClient {
return pb.NewNSAccessLogServiceClient(this.pickConn())
}
func (this *RPCClient) NSRPC() pb.NSServiceClient {
return pb.NewNSServiceClient(this.pickConn())
}
func (this *RPCClient) NSQuestionOptionRPC() pb.NSQuestionOptionServiceClient {
return pb.NewNSQuestionOptionServiceClient(this.pickConn())
}
func (this *RPCClient) NSNodeRPC() pb.NSNodeServiceClient {
return pb.NewNSNodeServiceClient(this.pickConn())
}
func (this *RPCClient) NSRecordHourlyStatRPC() pb.NSRecordHourlyStatServiceClient {
return pb.NewNSRecordHourlyStatServiceClient(this.pickConn())
}
func (this *RPCClient) NSUserPlanRPC() pb.NSUserPlanServiceClient {
return pb.NewNSUserPlanServiceClient(this.pickConn())
}
func (this *RPCClient) NSPlanRPC() pb.NSPlanServiceClient {
return pb.NewNSPlanServiceClient(this.pickConn())
}
func (this *RPCClient) ServerBandwidthStatRPC() pb.ServerBandwidthStatServiceClient {
return pb.NewServerBandwidthStatServiceClient(this.pickConn())
}
func (this *RPCClient) ServerDailyStatRPC() pb.ServerDailyStatServiceClient {
return pb.NewServerDailyStatServiceClient(this.pickConn())
}
func (this *RPCClient) PriceRPC() pb.PriceServiceClient {
return pb.NewPriceServiceClient(this.pickConn())
}
func (this *RPCClient) TrafficPackageRPC() pb.TrafficPackageServiceClient {
return pb.NewTrafficPackageServiceClient(this.pickConn())
}
func (this *RPCClient) TrafficPackagePeriodRPC() pb.TrafficPackagePeriodServiceClient {
return pb.NewTrafficPackagePeriodServiceClient(this.pickConn())
}
func (this *RPCClient) TrafficPackagePriceRPC() pb.TrafficPackagePriceServiceClient {
return pb.NewTrafficPackagePriceServiceClient(this.pickConn())
}
func (this *RPCClient) UserTrafficPackageRPC() pb.UserTrafficPackageServiceClient {
return pb.NewUserTrafficPackageServiceClient(this.pickConn())
}
func (this *RPCClient) ClientAgentRPC() pb.ClientAgentServiceClient {
return pb.NewClientAgentServiceClient(this.pickConn())
}
func (this *RPCClient) ADPackageRPC() pb.ADPackageServiceClient {
return pb.NewADPackageServiceClient(this.pickConn())
}
func (this *RPCClient) ADPackageInstanceRPC() pb.ADPackageInstanceServiceClient {
return pb.NewADPackageInstanceServiceClient(this.pickConn())
}
func (this *RPCClient) ADPackagePeriodRPC() pb.ADPackagePeriodServiceClient {
return pb.NewADPackagePeriodServiceClient(this.pickConn())
}
func (this *RPCClient) ADPackagePriceRPC() pb.ADPackagePriceServiceClient {
return pb.NewADPackagePriceServiceClient(this.pickConn())
}
func (this *RPCClient) ADNetworkRPC() pb.ADNetworkServiceClient {
return pb.NewADNetworkServiceClient(this.pickConn())
}
func (this *RPCClient) UserADInstanceRPC() pb.UserADInstanceServiceClient {
return pb.NewUserADInstanceServiceClient(this.pickConn())
}
func (this *RPCClient) UserScriptRPC() pb.UserScriptServiceClient {
return pb.NewUserScriptServiceClient(this.pickConn())
}
func (this *RPCClient) PostRPC() pb.PostServiceClient {
return pb.NewPostServiceClient(this.pickConn())
}
func (this *RPCClient) LoginSessionRPC() pb.LoginSessionServiceClient {
return pb.NewLoginSessionServiceClient(this.pickConn())
}
func (this *RPCClient) LoginTicketRPC() pb.LoginTicketServiceClient {
return pb.NewLoginTicketServiceClient(this.pickConn())
}
// Context 构造用户上下文
func (this *RPCClient) Context(userId int64) context.Context {
ctx := context.Background()
m := maps.Map{
"timestamp": time.Now().Unix(),
"type": "user",
"userId": userId,
}
method, err := encrypt.NewMethodInstance(teaconst.EncryptMethod, this.apiConfig.Secret, this.apiConfig.NodeId)
if err != nil {
utils.PrintError(err)
return context.Background()
}
data, err := method.Encrypt(m.AsJSON())
if err != nil {
utils.PrintError(err)
return context.Background()
}
token := base64.StdEncoding.EncodeToString(data)
ctx = metadata.AppendToOutgoingContext(ctx, "nodeId", this.apiConfig.NodeId, "token", token)
return ctx
}
// APIContext 构造API上下文
func (this *RPCClient) APIContext(apiNodeId int64) context.Context {
ctx := context.Background()
m := maps.Map{
"timestamp": time.Now().Unix(),
"type": "api",
"userId": apiNodeId,
}
method, err := encrypt.NewMethodInstance(teaconst.EncryptMethod, this.apiConfig.Secret, this.apiConfig.NodeId)
if err != nil {
utils.PrintError(err)
return context.Background()
}
data, err := method.Encrypt(m.AsJSON())
if err != nil {
utils.PrintError(err)
return context.Background()
}
token := base64.StdEncoding.EncodeToString(data)
ctx = metadata.AppendToOutgoingContext(ctx, "nodeId", this.apiConfig.NodeId, "token", token)
return ctx
}
// UpdateConfig 修改配置
func (this *RPCClient) UpdateConfig(config *configs.APIConfig) error {
this.apiConfig = config
this.locker.Lock()
err := this.init()
this.locker.Unlock()
return err
}
// 初始化
func (this *RPCClient) init() error {
// 重新连接
var conns = []*grpc.ClientConn{}
for _, endpoint := range this.apiConfig.RPCEndpoints {
u, err := url.Parse(endpoint)
if err != nil {
return fmt.Errorf("parse endpoint failed: %w", err)
}
var conn *grpc.ClientConn
var callOptions = grpc.WithDefaultCallOptions(
grpc.MaxCallRecvMsgSize(128<<20),
grpc.MaxCallSendMsgSize(128<<20),
grpc.UseCompressor(gzip.Name),
)
var keepaliveParams = grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,
})
if u.Scheme == "http" {
conn, err = grpc.Dial(u.Host, grpc.WithTransportCredentials(insecure.NewCredentials()), callOptions, keepaliveParams)
} else if u.Scheme == "https" {
conn, err = grpc.Dial(u.Host, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
InsecureSkipVerify: true,
})), callOptions, keepaliveParams)
} else {
return errors.New("parse endpoint failed: invalid scheme '" + u.Scheme + "'")
}
if err != nil {
return err
}
conns = append(conns, conn)
}
if len(conns) == 0 {
return errors.New("[RPC]no available endpoints")
}
// 这里不需要加锁因为会和pickConn冲突
this.conns = conns
return nil
}
// 随机选择一个连接
func (this *RPCClient) pickConn() *grpc.ClientConn {
this.locker.Lock()
defer this.locker.Unlock()
// 检查连接状态
var countConns = len(this.conns)
if countConns > 0 {
if countConns == 1 {
return this.conns[0]
}
for _, state := range []connectivity.State{
connectivity.Ready,
connectivity.Idle,
connectivity.Connecting,
connectivity.TransientFailure,
} {
var availableConns = []*grpc.ClientConn{}
for _, conn := range this.conns {
if conn.GetState() == state {
availableConns = append(availableConns, conn)
}
}
if len(availableConns) > 0 {
return this.randConn(availableConns)
}
}
}
return this.randConn(this.conns)
}
func (this *RPCClient) randConn(conns []*grpc.ClientConn) *grpc.ClientConn {
var l = len(conns)
if l == 0 {
return nil
}
if l == 1 {
return conns[0]
}
return conns[rands.Int(0, l-1)]
}

View File

@@ -0,0 +1,5 @@
package rpc
import (
_ "github.com/iwind/TeaGo/bootstrap"
)

View File

@@ -0,0 +1,54 @@
package rpc
import (
"github.com/TeaOSLab/EdgeUser/internal/configs"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"strings"
"sync"
)
var sharedRPC *RPCClient = nil
var locker = &sync.Mutex{}
// SharedRPC 获取GRPC客户端
func SharedRPC() (*RPCClient, error) {
locker.Lock()
defer locker.Unlock()
if sharedRPC != nil {
return sharedRPC, nil
}
config, err := configs.LoadAPIConfig()
if err != nil {
return nil, err
}
client, err := NewRPCClient(config)
if err != nil {
return nil, err
}
sharedRPC = client
return sharedRPC, nil
}
// IsConnError 是否为连接错误
func IsConnError(err error) bool {
if err == nil {
return false
}
// 检查是否为连接错误
statusErr, ok := status.FromError(err)
if ok {
var errorCode = statusErr.Code()
return errorCode == codes.Unavailable || errorCode == codes.Canceled
}
if strings.Contains(err.Error(), "code = Canceled") {
return true
}
return false
}

View File

@@ -0,0 +1,156 @@
package tasks
import (
"context"
"crypto/tls"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeUser/internal/configs"
"github.com/TeaOSLab/EdgeUser/internal/events"
"github.com/TeaOSLab/EdgeUser/internal/rpc"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/logs"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"net/url"
"sort"
"strings"
"sync"
"time"
)
func init() {
events.On(events.EventStart, func() {
task := NewSyncAPINodesTask()
go task.Start()
})
}
// SyncAPINodesTask API节点同步任务
type SyncAPINodesTask struct {
ticker *time.Ticker
}
func NewSyncAPINodesTask() *SyncAPINodesTask {
return &SyncAPINodesTask{}
}
func (this *SyncAPINodesTask) Start() {
this.ticker = time.NewTicker(5 * time.Minute)
if Tea.IsTesting() {
// 快速测试
this.ticker = time.NewTicker(1 * time.Minute)
}
for range this.ticker.C {
err := this.Loop()
if err != nil {
logs.Println("[TASK][SYNC_API_NODES]" + err.Error())
}
}
}
func (this *SyncAPINodesTask) Loop() error {
config, err := configs.LoadAPIConfig()
if err != nil {
return err
}
// 是否禁止自动升级
if config.RPCDisableUpdate {
return nil
}
// 获取所有可用的节点
rpcClient, err := rpc.SharedRPC()
if err != nil {
return err
}
resp, err := rpcClient.APINodeRPC().FindAllEnabledAPINodes(rpcClient.Context(0), &pb.FindAllEnabledAPINodesRequest{})
if err != nil {
return err
}
var newEndpoints = []string{}
for _, node := range resp.ApiNodes {
if !node.IsOn {
continue
}
newEndpoints = append(newEndpoints, node.AccessAddrs...)
}
// 和现有的对比
if this.isSame(newEndpoints, config.RPCEndpoints) {
return nil
}
// 测试是否有API节点可用
var hasOk = this.testEndpoints(newEndpoints)
if !hasOk {
return nil
}
// 修改RPC对象配置
config.RPCEndpoints = newEndpoints
err = rpcClient.UpdateConfig(config)
if err != nil {
return err
}
// 保存到文件
err = config.WriteFile(Tea.ConfigFile(configs.ConfigFileName))
if err != nil {
return err
}
return nil
}
func (this *SyncAPINodesTask) isSame(endpoints1 []string, endpoints2 []string) bool {
sort.Strings(endpoints1)
sort.Strings(endpoints2)
return strings.Join(endpoints1, "&") == strings.Join(endpoints2, "&")
}
func (this *SyncAPINodesTask) testEndpoints(endpoints []string) bool {
if len(endpoints) == 0 {
return false
}
var wg = sync.WaitGroup{}
wg.Add(len(endpoints))
var ok = false
for _, endpoint := range endpoints {
go func(endpoint string) {
defer wg.Done()
u, err := url.Parse(endpoint)
if err != nil {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer func() {
cancel()
}()
var conn *grpc.ClientConn
if u.Scheme == "http" {
conn, err = grpc.DialContext(ctx, u.Host, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock())
} else if u.Scheme == "https" {
conn, err = grpc.DialContext(ctx, u.Host, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
InsecureSkipVerify: true,
})), grpc.WithBlock())
}
if err != nil {
return
}
_ = conn.Close()
ok = true
}(endpoint)
}
wg.Wait()
return ok
}

View File

@@ -0,0 +1,145 @@
package ttlcache
import (
"github.com/TeaOSLab/EdgeUser/internal/utils"
"time"
)
var DefaultCache = NewCache()
// TTL缓存
// 最大的缓存时间为30 * 86400
// Piece数据结构
//
// Piece1 | Piece2 | Piece3 | ...
// [ Item1, Item2, ... | ...
//
// KeyMap列表数据结构
// { timestamp1 => [key1, key2, ...] }, ...
type Cache struct {
isDestroyed bool
pieces []*Piece
countPieces uint64
maxItems int
gcPieceIndex int
ticker *utils.Ticker
}
func NewCache(opt ...OptionInterface) *Cache {
countPieces := 128
maxItems := 1_000_000
for _, option := range opt {
if option == nil {
continue
}
switch o := option.(type) {
case *PiecesOption:
if o.Count > 0 {
countPieces = o.Count
}
case *MaxItemsOption:
if o.Count > 0 {
maxItems = o.Count
}
}
}
cache := &Cache{
countPieces: uint64(countPieces),
maxItems: maxItems,
}
for i := 0; i < countPieces; i++ {
cache.pieces = append(cache.pieces, NewPiece(maxItems/countPieces))
}
// start timer
go func() {
cache.ticker = utils.NewTicker(5 * time.Second)
for cache.ticker.Next() {
cache.GC()
}
}()
return cache
}
func (this *Cache) Write(key string, value interface{}, expiredAt int64) {
if this.isDestroyed {
return
}
currentTimestamp := time.Now().Unix()
if expiredAt <= currentTimestamp {
return
}
maxExpiredAt := currentTimestamp + 30*86400
if expiredAt > maxExpiredAt {
expiredAt = maxExpiredAt
}
uint64Key := HashKey([]byte(key))
pieceIndex := uint64Key % this.countPieces
this.pieces[pieceIndex].Add(uint64Key, &Item{
Value: value,
expiredAt: expiredAt,
})
}
func (this *Cache) IncreaseInt64(key string, delta int64, expiredAt int64) int64 {
if this.isDestroyed {
return 0
}
currentTimestamp := time.Now().Unix()
if expiredAt <= currentTimestamp {
return 0
}
maxExpiredAt := currentTimestamp + 30*86400
if expiredAt > maxExpiredAt {
expiredAt = maxExpiredAt
}
uint64Key := HashKey([]byte(key))
pieceIndex := uint64Key % this.countPieces
return this.pieces[pieceIndex].IncreaseInt64(uint64Key, delta, expiredAt)
}
func (this *Cache) Read(key string) (item *Item) {
uint64Key := HashKey([]byte(key))
return this.pieces[uint64Key%this.countPieces].Read(uint64Key)
}
func (this *Cache) Delete(key string) {
uint64Key := HashKey([]byte(key))
this.pieces[uint64Key%this.countPieces].Delete(uint64Key)
}
func (this *Cache) Count() (count int) {
for _, piece := range this.pieces {
count += piece.Count()
}
return
}
func (this *Cache) GC() {
this.pieces[this.gcPieceIndex].GC()
newIndex := this.gcPieceIndex + 1
if newIndex >= int(this.countPieces) {
newIndex = 0
}
this.gcPieceIndex = newIndex
}
func (this *Cache) Destroy() {
this.isDestroyed = true
if this.ticker != nil {
this.ticker.Stop()
this.ticker = nil
}
for _, piece := range this.pieces {
piece.Destroy()
}
}

View File

@@ -0,0 +1,124 @@
package ttlcache
import (
"github.com/iwind/TeaGo/rands"
"runtime"
"strconv"
"testing"
"time"
)
func TestNewCache(t *testing.T) {
cache := NewCache()
cache.Write("a", 1, time.Now().Unix()+3600)
cache.Write("b", 2, time.Now().Unix()+3601)
cache.Write("a", 1, time.Now().Unix()+3602)
cache.Write("d", 1, time.Now().Unix()+1)
for _, piece := range cache.pieces {
if len(piece.m) > 0 {
for k, item := range piece.m {
t.Log(k, "=>", item.Value, item.expiredAt)
}
}
}
t.Log(cache.Read("a"))
time.Sleep(2 * time.Second)
t.Log(cache.Read("d"))
}
func BenchmarkCache_Add(b *testing.B) {
runtime.GOMAXPROCS(1)
cache := NewCache()
for i := 0; i < b.N; i++ {
cache.Write(strconv.Itoa(i), i, time.Now().Unix()+int64(i%1024))
}
}
func TestCache_IncreaseInt64(t *testing.T) {
var cache = NewCache()
{
cache.IncreaseInt64("a", 1, time.Now().Unix()+3600)
t.Log(cache.Read("a"))
}
{
cache.IncreaseInt64("a", 1, time.Now().Unix()+3600+1)
t.Log(cache.Read("a"))
}
{
cache.Write("b", 1, time.Now().Unix()+3600+2)
t.Log(cache.Read("b"))
}
{
cache.IncreaseInt64("b", 1, time.Now().Unix()+3600+3)
t.Log(cache.Read("b"))
}
}
func TestCache_Read(t *testing.T) {
runtime.GOMAXPROCS(1)
var cache = NewCache(PiecesOption{Count: 32})
for i := 0; i < 10_000_000; i++ {
cache.Write("HELLO_WORLD_"+strconv.Itoa(i), i, time.Now().Unix()+int64(i%10240)+1)
}
total := 0
for _, piece := range cache.pieces {
//t.Log(len(piece.m), "keys")
total += len(piece.m)
}
t.Log(total, "total keys")
before := time.Now()
for i := 0; i < 10_240; i++ {
_ = cache.Read("HELLO_WORLD_" + strconv.Itoa(i))
}
t.Log(time.Since(before).Seconds()*1000, "ms")
}
func TestCache_GC(t *testing.T) {
var cache = NewCache(&PiecesOption{Count: 5})
cache.Write("a", 1, time.Now().Unix()+1)
cache.Write("b", 2, time.Now().Unix()+2)
cache.Write("c", 3, time.Now().Unix()+3)
cache.Write("d", 4, time.Now().Unix()+4)
cache.Write("e", 5, time.Now().Unix()+10)
go func() {
for i := 0; i < 1000; i++ {
cache.Write("f", 1, time.Now().Unix()+1)
time.Sleep(10 * time.Millisecond)
}
}()
for i := 0; i < 20; i++ {
cache.GC()
t.Log("items:", cache.Count())
time.Sleep(1 * time.Second)
}
t.Log("now:", time.Now().Unix())
for _, p := range cache.pieces {
for k, v := range p.m {
t.Log(k, v.Value, v.expiredAt)
}
}
}
func TestCache_GC2(t *testing.T) {
runtime.GOMAXPROCS(1)
cache := NewCache()
for i := 0; i < 1_000_000; i++ {
cache.Write(strconv.Itoa(i), i, time.Now().Unix()+int64(rands.Int(0, 100)))
}
for i := 0; i < 100; i++ {
t.Log(cache.Count(), "items")
time.Sleep(1 * time.Second)
}
}

View File

@@ -0,0 +1,6 @@
package ttlcache
type Item struct {
Value interface{}
expiredAt int64
}

View File

@@ -0,0 +1,20 @@
package ttlcache
type OptionInterface interface {
}
type PiecesOption struct {
Count int
}
func NewPiecesOption(count int) *PiecesOption {
return &PiecesOption{Count: count}
}
type MaxItemsOption struct {
Count int
}
func NewMaxItemsOption(count int) *MaxItemsOption {
return &MaxItemsOption{Count: count}
}

View File

@@ -0,0 +1,88 @@
package ttlcache
import (
"github.com/iwind/TeaGo/types"
"sync"
"time"
)
type Piece struct {
m map[uint64]*Item
maxItems int
locker sync.RWMutex
}
func NewPiece(maxItems int) *Piece {
return &Piece{m: map[uint64]*Item{}, maxItems: maxItems}
}
func (this *Piece) Add(key uint64, item *Item) {
this.locker.Lock()
if len(this.m) >= this.maxItems {
this.locker.Unlock()
return
}
this.m[key] = item
this.locker.Unlock()
}
func (this *Piece) IncreaseInt64(key uint64, delta int64, expiredAt int64) (result int64) {
this.locker.Lock()
item, ok := this.m[key]
if ok {
result := types.Int64(item.Value) + delta
item.Value = result
item.expiredAt = expiredAt
} else {
if len(this.m) < this.maxItems {
result = delta
this.m[key] = &Item{
Value: delta,
expiredAt: expiredAt,
}
}
}
this.locker.Unlock()
return
}
func (this *Piece) Delete(key uint64) {
this.locker.Lock()
delete(this.m, key)
this.locker.Unlock()
}
func (this *Piece) Read(key uint64) (item *Item) {
this.locker.RLock()
item = this.m[key]
if item != nil && item.expiredAt < time.Now().Unix() {
item = nil
}
this.locker.RUnlock()
return
}
func (this *Piece) Count() (count int) {
this.locker.RLock()
count = len(this.m)
this.locker.RUnlock()
return
}
func (this *Piece) GC() {
this.locker.Lock()
timestamp := time.Now().Unix()
for k, item := range this.m {
if item.expiredAt <= timestamp {
delete(this.m, k)
}
}
this.locker.Unlock()
}
func (this *Piece) Destroy() {
this.locker.Lock()
this.m = nil
this.locker.Unlock()
}

View File

@@ -0,0 +1,60 @@
package ttlcache
import (
"github.com/iwind/TeaGo/rands"
"testing"
"time"
)
func TestPiece_Add(t *testing.T) {
piece := NewPiece(10)
piece.Add(1, &Item{expiredAt: time.Now().Unix() + 3600})
piece.Add(2, &Item{})
piece.Add(3, &Item{})
piece.Delete(3)
for key, item := range piece.m {
t.Log(key, item.Value)
}
t.Log(piece.Read(1))
}
func TestPiece_MaxItems(t *testing.T) {
piece := NewPiece(10)
for i := 0; i < 1000; i++ {
piece.Add(uint64(i), &Item{expiredAt: time.Now().Unix() + 3600})
}
t.Log(len(piece.m))
}
func TestPiece_GC(t *testing.T) {
piece := NewPiece(10)
piece.Add(1, &Item{Value: 1, expiredAt: time.Now().Unix() + 1})
piece.Add(2, &Item{Value: 2, expiredAt: time.Now().Unix() + 1})
piece.Add(3, &Item{Value: 3, expiredAt: time.Now().Unix() + 1})
t.Log("before gc ===")
for key, item := range piece.m {
t.Log(key, item.Value)
}
time.Sleep(1 * time.Second)
piece.GC()
t.Log("after gc ===")
for key, item := range piece.m {
t.Log(key, item.Value)
}
}
func TestPiece_GC2(t *testing.T) {
piece := NewPiece(10)
for i := 0; i < 10_000; i++ {
piece.Add(uint64(i), &Item{Value: 1, expiredAt: time.Now().Unix() + int64(rands.Int(1, 10))})
}
time.Sleep(1 * time.Second)
before := time.Now()
piece.GC()
t.Log(time.Since(before).Seconds()*1000, "ms")
t.Log(piece.Count())
}

View File

@@ -0,0 +1,7 @@
package ttlcache
import "github.com/cespare/xxhash/v2"
func HashKey(key []byte) uint64 {
return xxhash.Sum64(key)
}

View File

@@ -0,0 +1,13 @@
package ttlcache
import (
"runtime"
"testing"
)
func BenchmarkHashKey(b *testing.B) {
runtime.GOMAXPROCS(1)
for i := 0; i < b.N; i++ {
HashKey([]byte("HELLO,WORLDHELLO,WORLDHELLO,WORLDHELLO,WORLDHELLO,WORLDHELLO,WORLD"))
}
}

View File

@@ -0,0 +1,12 @@
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package dateutils
// SplitYmd 分隔Ymd格式的日期
// Ymd => Y-m-d
func SplitYmd(day string) string {
if len(day) != 8 {
return day
}
return day[:4] + "-" + day[4:6] + "-" + day[6:]
}

View File

@@ -0,0 +1,19 @@
package domainutils
import (
"regexp"
"strings"
)
// ValidateDomainFormat 校验域名格式
func ValidateDomainFormat(domain string) bool {
pieces := strings.Split(domain, ".")
for _, piece := range pieces {
// \p{Han} 中文unicode字符集
if !regexp.MustCompile(`^[\p{Han}a-z0-9-]+$`).MatchString(piece) {
return false
}
}
return true
}

View File

@@ -0,0 +1,18 @@
package utils
import (
"github.com/iwind/TeaGo/logs"
"strings"
)
func PrintError(err error) {
// TODO 记录调用的文件名、行数
logs.Println("[ERROR]" + err.Error())
}
func IsNotFound(err error) bool {
if err == nil {
return false
}
return strings.Contains(strings.ToUpper(err.Error()), "NOT FOUND")
}

View File

@@ -0,0 +1,28 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package utils
import (
"github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/types"
"os/exec"
"runtime"
)
func AddPortsToFirewall(ports []int) {
for _, port := range ports {
// Linux
if runtime.GOOS == "linux" {
// firewalld
firewallCmd, _ := exec.LookPath("firewall-cmd")
if len(firewallCmd) > 0 {
err := exec.Command(firewallCmd, "--add-port="+types.String(port)+"/tcp").Run()
if err == nil {
logs.Println("API_NODE", "add port '"+types.String(port)+"' to firewalld")
_ = exec.Command(firewallCmd, "--add-port="+types.String(port)+"/tcp", "--permanent").Run()
}
}
}
}
}

View File

@@ -0,0 +1,62 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package utils
import (
"github.com/TeaOSLab/EdgeCommon/pkg/iputils"
"strings"
)
// ParseIPValue 解析IP值
func ParseIPValue(value string) (newValue string, ipFrom string, ipTo string, ok bool) {
if len(value) == 0 {
return
}
newValue = value
// ip1-ip2
if strings.Contains(value, "-") {
var pieces = strings.Split(value, "-")
if len(pieces) != 2 {
return
}
ipFrom = strings.TrimSpace(pieces[0])
ipTo = strings.TrimSpace(pieces[1])
if !iputils.IsValid(ipFrom) || !iputils.IsValid(ipTo) {
return
}
if !iputils.IsSameVersion(ipFrom, ipTo) {
return
}
if iputils.CompareIP(ipFrom, ipTo) > 0 {
ipFrom, ipTo = ipTo, ipFrom
newValue = ipFrom + "-" + ipTo
}
ok = true
return
}
// ip/mask
if strings.Contains(value, "/") {
cidr, err := iputils.ParseCIDR(value)
if err != nil {
return
}
return newValue, cidr.From().String(), cidr.To().String(), true
}
// single value
if iputils.IsValid(value) {
ipFrom = value
ok = true
return
}
return
}

View File

@@ -0,0 +1,60 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package utils
import (
"bytes"
"encoding/json"
"errors"
"reflect"
)
// JSONIsNull 判断JSON数据是否为null
func JSONIsNull(jsonData []byte) bool {
return len(jsonData) == 0 || bytes.Equal(jsonData, []byte("null"))
}
// JSONClone 使用JSON克隆对象
func JSONClone(v interface{}) (interface{}, error) {
data, err := json.Marshal(v)
if err != nil {
return nil, err
}
var nv = reflect.New(reflect.TypeOf(v).Elem()).Interface()
err = json.Unmarshal(data, nv)
if err != nil {
return nil, err
}
return nv, nil
}
// JSONDecodeConfig 解码并重新编码
// 是为了去除原有JSON中不需要的数据
func JSONDecodeConfig(data []byte, ptr any) (encodeJSON []byte, err error) {
err = json.Unmarshal(data, ptr)
if err != nil {
return
}
encodeJSON, err = json.Marshal(ptr)
if err != nil {
return
}
// validate config
if ptr != nil {
config, ok := ptr.(interface {
Init() error
})
if ok {
initErr := config.Init()
if initErr != nil {
err = errors.New("validate config failed: " + initErr.Error())
}
}
}
return
}

View File

@@ -0,0 +1,25 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package utils_test
import (
"github.com/TeaOSLab/EdgeUser/internal/utils"
"testing"
)
func TestJSONClone(t *testing.T) {
type A struct {
B int `json:"b"`
C string `json:"c"`
}
var a = &A{B: 123, C: "456"}
for i := 0; i < 5; i++ {
c, err := utils.JSONClone(a)
if err != nil {
t.Fatal(err)
}
t.Logf("%p, %#v", c, c)
}
}

View File

@@ -0,0 +1,174 @@
package utils
import (
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
teaconst "github.com/TeaOSLab/EdgeUser/internal/const"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/logs"
"github.com/miekg/dns"
"sync"
)
var sharedDNSClient *dns.Client
var sharedDNSConfig *dns.ClientConfig
func init() {
if !teaconst.IsMain {
return
}
config, err := dns.ClientConfigFromFile("/etc/resolv.conf")
if err != nil {
logs.Println("ERROR: configure dns client failed: " + err.Error())
return
}
sharedDNSConfig = config
sharedDNSClient = &dns.Client{}
}
// LookupCNAME 查询CNAME记录
// TODO 可以设置使用的DNS主机地址
func LookupCNAME(host string) (string, error) {
if sharedDNSClient == nil {
return "", errors.New("could not find dns client")
}
var m = new(dns.Msg)
m.SetQuestion(host+".", dns.TypeCNAME)
m.RecursionDesired = true
var lastErr error
var success = false
var result = ""
var serverAddrs = sharedDNSConfig.Servers
{
var publicDNSHosts = []string{"8.8.8.8" /** Google **/, "8.8.4.4" /** Google **/}
for _, publicDNSHost := range publicDNSHosts {
if !lists.ContainsString(serverAddrs, publicDNSHost) {
serverAddrs = append(serverAddrs, publicDNSHost)
}
}
}
var wg = &sync.WaitGroup{}
for _, serverAddr := range serverAddrs {
wg.Add(1)
go func(serverAddr string) {
defer wg.Done()
r, _, err := sharedDNSClient.Exchange(m, configutils.QuoteIP(serverAddr)+":"+sharedDNSConfig.Port)
if err != nil {
lastErr = err
return
}
success = true
if len(r.Answer) == 0 {
return
}
result = r.Answer[0].(*dns.CNAME).Target
}(serverAddr)
}
wg.Wait()
if success {
return result, nil
}
return "", lastErr
}
// LookupNS 查询NS记录
// TODO 可以设置使用的DNS主机地址
func LookupNS(host string) ([]string, error) {
config, err := dns.ClientConfigFromFile("/etc/resolv.conf")
if err != nil {
return nil, err
}
var c = new(dns.Client)
var m = new(dns.Msg)
m.SetQuestion(host+".", dns.TypeNS)
m.RecursionDesired = true
var result = []string{}
var lastErr error
var hasValidServer = false
for _, serverAddr := range config.Servers {
r, _, err := c.Exchange(m, configutils.QuoteIP(serverAddr)+":"+config.Port)
if err != nil {
lastErr = err
continue
}
hasValidServer = true
if len(r.Answer) == 0 {
continue
}
for _, answer := range r.Answer {
result = append(result, answer.(*dns.NS).Ns)
}
break
}
if hasValidServer {
return result, nil
}
return nil, lastErr
}
// LookupTXT 获取CNAME
// TODO 可以设置使用的DNS主机地址
func LookupTXT(host string) ([]string, error) {
config, err := dns.ClientConfigFromFile("/etc/resolv.conf")
if err != nil {
return nil, err
}
var c = new(dns.Client)
var m = new(dns.Msg)
m.SetQuestion(host+".", dns.TypeTXT)
m.RecursionDesired = true
var lastErr error
var result = []string{}
var hasValidServer = false
for _, serverAddr := range config.Servers {
r, _, err := c.Exchange(m, configutils.QuoteIP(serverAddr)+":"+config.Port)
if err != nil {
lastErr = err
continue
}
hasValidServer = true
if len(r.Answer) == 0 {
continue
}
for _, answer := range r.Answer {
result = append(result, answer.(*dns.TXT).Txt...)
}
break
}
if hasValidServer {
return result, nil
}
return nil, lastErr
}

View File

@@ -0,0 +1,23 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package utils_test
import (
"github.com/TeaOSLab/EdgeUser/internal/utils"
"testing"
)
func TestLookupCNAME(t *testing.T) {
for _, domain := range []string{"www.yun4s.cn", "example.com", "goedge.cn"} {
result, err := utils.LookupCNAME(domain)
t.Log(domain, "=>", result, err)
}
}
func TestLookupNS(t *testing.T) {
t.Log(utils.LookupNS("goedge.cn"))
}
func TestLookupTXT(t *testing.T) {
t.Log(utils.LookupTXT("_acme-challenge.dl"))
}

View File

@@ -0,0 +1,12 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package utils
import "regexp"
var mobileRegex = regexp.MustCompile(`^1\d{10}$`)
// IsValidMobile validate mobile number
func IsValidMobile(mobile string) bool {
return mobileRegex.MatchString(mobile)
}

View File

@@ -0,0 +1,17 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package utils_test
import (
"github.com/TeaOSLab/EdgeUser/internal/utils"
"github.com/iwind/TeaGo/assert"
"testing"
)
func TestIsValidMobile(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsFalse(utils.IsValidMobile("138"))
a.IsFalse(utils.IsValidMobile("1382222"))
a.IsFalse(utils.IsValidMobile("1381234567890"))
a.IsTrue(utils.IsValidMobile("13812345678"))
}

View File

@@ -0,0 +1,143 @@
package numberutils
import (
"fmt"
"github.com/iwind/TeaGo/types"
"regexp"
"strconv"
"strings"
)
func FormatInt64(value int64) string {
return strconv.FormatInt(value, 10)
}
func FormatInt(value int) string {
return strconv.Itoa(value)
}
func Pow1024(n int) int64 {
if n <= 0 {
return 1
}
if n == 1 {
return 1024
}
return Pow1024(n-1) * 1024
}
func FormatBytes(bytes int64) string {
if bytes < Pow1024(1) {
return FormatInt64(bytes) + "B"
} else if bytes < Pow1024(2) {
return TrimZeroSuffix(fmt.Sprintf("%.2fKiB", float64(bytes)/float64(Pow1024(1))))
} else if bytes < Pow1024(3) {
return TrimZeroSuffix(fmt.Sprintf("%.2fMiB", float64(bytes)/float64(Pow1024(2))))
} else if bytes < Pow1024(4) {
return TrimZeroSuffix(fmt.Sprintf("%.2fGiB", float64(bytes)/float64(Pow1024(3))))
} else if bytes < Pow1024(5) {
return TrimZeroSuffix(fmt.Sprintf("%.2fTiB", float64(bytes)/float64(Pow1024(4))))
} else if bytes < Pow1024(6) {
return TrimZeroSuffix(fmt.Sprintf("%.2fPiB", float64(bytes)/float64(Pow1024(5))))
} else {
return TrimZeroSuffix(fmt.Sprintf("%.2fEiB", float64(bytes)/float64(Pow1024(6))))
}
}
func FormatBits(bits int64) string {
if bits < Pow1024(1) {
return FormatInt64(bits) + "bps"
} else if bits < Pow1024(2) {
return TrimZeroSuffix(fmt.Sprintf("%.4fKbps", float64(bits)/float64(Pow1024(1))))
} else if bits < Pow1024(3) {
return TrimZeroSuffix(fmt.Sprintf("%.4fMbps", float64(bits)/float64(Pow1024(2))))
} else if bits < Pow1024(4) {
return TrimZeroSuffix(fmt.Sprintf("%.4fGbps", float64(bits)/float64(Pow1024(3))))
} else if bits < Pow1024(5) {
return TrimZeroSuffix(fmt.Sprintf("%.4fTbps", float64(bits)/float64(Pow1024(4))))
} else if bits < Pow1024(6) {
return TrimZeroSuffix(fmt.Sprintf("%.4fPbps", float64(bits)/float64(Pow1024(5))))
} else {
return TrimZeroSuffix(fmt.Sprintf("%.4fEbps", float64(bits)/float64(Pow1024(6))))
}
}
func FormatCount(count int64) string {
if count < 1000 {
return types.String(count)
}
if count < 1000*1000 {
return fmt.Sprintf("%.1fK", float32(count)/1000)
}
if count < 1000*1000*1000 {
return fmt.Sprintf("%.1fM", float32(count)/1000/1000)
}
return fmt.Sprintf("%.1fB", float32(count)/1000/1000/1000)
}
func FormatFloat(f interface{}, decimal int) string {
if f == nil {
return ""
}
switch x := f.(type) {
case float32, float64:
var s = fmt.Sprintf("%."+types.String(decimal)+"f", x)
// 分隔
var dotIndex = strings.Index(s, ".")
if dotIndex > 0 {
var d = s[:dotIndex]
var f2 = s[dotIndex:]
f2 = strings.TrimRight(strings.TrimRight(f2, "0"), ".")
return formatDigit(d) + f2
}
return s
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
return formatDigit(types.String(x))
case string:
return x
}
return ""
}
var decimalReg = regexp.MustCompile(`^(\d+\.\d+)([a-zA-Z]+)?$`)
// TrimZeroSuffix 去除小数数字尾部多余的0
func TrimZeroSuffix(s string) string {
var matches = decimalReg.FindStringSubmatch(s)
if len(matches) < 3 {
return s
}
return strings.TrimRight(strings.TrimRight(matches[1], "0"), ".") + matches[2]
}
func formatDigit(d string) string {
if len(d) == 0 {
return d
}
var prefix = ""
if d[0] < '0' || d[0] > '9' {
prefix = d[:1]
d = d[1:]
}
var l = len(d)
if l > 3 {
var pieces = l / 3
var commIndex = l - pieces*3
var d2 = ""
if commIndex > 0 {
d2 = d[:commIndex] + ", "
}
for i := 0; i < pieces; i++ {
d2 += d[commIndex+i*3 : commIndex+i*3+3]
if i != pieces-1 {
d2 += ", "
}
}
return prefix + d2
}
return prefix + d
}

View File

@@ -0,0 +1,31 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package otputils
import (
"net/url"
)
// FixIssuer fix issuer in otp provisioning url
func FixIssuer(urlString string) string {
u, err := url.Parse(urlString)
if err != nil {
return urlString
}
var query = u.Query()
if query != nil {
var issuerName = query.Get("issuer")
if len(issuerName) > 0 {
unescapedIssuerName, unescapeErr := url.QueryUnescape(issuerName)
if unescapeErr == nil {
query.Set("issuer", unescapedIssuerName)
u.RawQuery = query.Encode()
}
}
return u.String()
}
return urlString
}

View File

@@ -0,0 +1,18 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package otputils_test
import (
"github.com/TeaOSLab/EdgeUser/internal/utils/otputils"
"testing"
)
func TestFixIssuer(t *testing.T) {
var beforeURL = "otpauth://totp/GoEdge%25E7%25AE%25A1%25E7%2590%2586%25E5%2591%2598%25E7%25B3%25BB%25E7%25BB%259F:admin?issuer=GoEdge%25E7%25AE%25A1%25E7%2590%2586%25E5%2591%2598%25E7%25B3%25BB%25E7%25BB%259F&secret=Q3J4WNOWBRFLP3HI"
var afterURL = otputils.FixIssuer(beforeURL)
t.Log(afterURL)
if beforeURL == afterURL {
t.Fatal("'afterURL' should not be equal to 'beforeURL'")
}
}

View File

@@ -0,0 +1,50 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package portalutils
import (
"github.com/TeaOSLab/EdgeUser/internal/remotelogs"
"github.com/iwind/TeaGo/Tea"
"io"
"os"
)
func HasPortalIndex() bool {
return len(checkPortalIndex()) > 0
}
func ReadPortalIndex(writer io.Writer) {
var indexPath = checkPortalIndex()
if len(indexPath) == 0 {
return
}
fp, err := os.Open(indexPath)
if err != nil {
remotelogs.Error("PORTAL", "read portal index failed: "+err.Error())
return
}
defer func() {
_ = fp.Close()
}()
_, _ = io.Copy(writer, fp)
}
func checkPortalIndex() string {
var indexes = []string{
"index.html",
}
for _, index := range indexes {
var path = Tea.Root + "/www/" + index
stat, err := os.Stat(path)
if err != nil {
continue
}
if stat.IsDir() {
continue
}
return path
}
return ""
}

View File

@@ -0,0 +1,12 @@
package utils
import (
"runtime/debug"
)
func Recover() {
e := recover()
if e != nil {
debug.PrintStack()
}
}

View File

@@ -0,0 +1,111 @@
package utils
import (
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/files"
"github.com/iwind/TeaGo/logs"
"log"
"os"
"path/filepath"
"runtime"
"sync"
)
// 服务管理器
type ServiceManager struct {
Name string
Description string
fp *os.File
logger *log.Logger
onceLocker sync.Once
}
// 获取对象
func NewServiceManager(name, description string) *ServiceManager {
manager := &ServiceManager{
Name: name,
Description: description,
}
// root
manager.resetRoot()
return manager
}
// 设置服务
func (this *ServiceManager) setup() {
this.onceLocker.Do(func() {
logFile := files.NewFile(Tea.Root + "/logs/service.log")
if logFile.Exists() {
_ = logFile.Delete()
}
//logger
fp, err := os.OpenFile(Tea.Root+"/logs/service.log", os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0666)
if err != nil {
logs.Error(err)
return
}
this.fp = fp
this.logger = log.New(fp, "", log.LstdFlags)
})
}
// 记录普通日志
func (this *ServiceManager) Log(msg string) {
this.setup()
if this.logger == nil {
return
}
this.logger.Println("[info]" + msg)
}
// 记录错误日志
func (this *ServiceManager) LogError(msg string) {
this.setup()
if this.logger == nil {
return
}
this.logger.Println("[error]" + msg)
}
// 关闭
func (this *ServiceManager) Close() error {
if this.fp != nil {
return this.fp.Close()
}
return nil
}
// 重置Root
func (this *ServiceManager) resetRoot() {
if !Tea.IsTesting() {
exePath, err := os.Executable()
if err != nil {
exePath = os.Args[0]
}
link, err := filepath.EvalSymlinks(exePath)
if err == nil {
exePath = link
}
fullPath, err := filepath.Abs(exePath)
if err == nil {
Tea.UpdateRoot(filepath.Dir(filepath.Dir(fullPath)))
}
}
Tea.SetPublicDir(Tea.Root + Tea.DS + "web" + Tea.DS + "public")
Tea.SetViewsDir(Tea.Root + Tea.DS + "web" + Tea.DS + "views")
Tea.SetTmpDir(Tea.Root + Tea.DS + "web" + Tea.DS + "tmp")
}
// 保持命令行窗口是打开的
func (this *ServiceManager) PauseWindow() {
if runtime.GOOS != "windows" {
return
}
b := make([]byte, 1)
_, _ = os.Stdin.Read(b)
}

View File

@@ -0,0 +1,160 @@
//go:build linux
// +build linux
package utils
import (
"errors"
teaconst "github.com/TeaOSLab/EdgeUser/internal/const"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/files"
"os"
"os/exec"
"regexp"
)
var systemdServiceFile = "/etc/systemd/system/edge-user.service"
var initServiceFile = "/etc/init.d/" + teaconst.SystemdServiceName
// 安装服务
func (this *ServiceManager) Install(exePath string, args []string) error {
if os.Getgid() != 0 {
return errors.New("only root users can install the service")
}
systemd, err := exec.LookPath("systemctl")
if err != nil {
return this.installInitService(exePath, args)
}
return this.installSystemdService(systemd, exePath, args)
}
// 启动服务
func (this *ServiceManager) Start() error {
if os.Getgid() != 0 {
return errors.New("only root users can start the service")
}
if files.NewFile(systemdServiceFile).Exists() {
systemd, err := exec.LookPath("systemctl")
if err != nil {
return err
}
return exec.Command(systemd, "start", teaconst.SystemdServiceName+".service").Start()
}
return exec.Command("service", teaconst.ProcessName, "start").Start()
}
// 删除服务
func (this *ServiceManager) Uninstall() error {
if os.Getgid() != 0 {
return errors.New("only root users can uninstall the service")
}
if files.NewFile(systemdServiceFile).Exists() {
systemd, err := exec.LookPath("systemctl")
if err != nil {
return err
}
// disable service
_ = exec.Command(systemd, "disable", teaconst.SystemdServiceName+".service").Start()
// reload
_ = exec.Command(systemd, "daemon-reload").Start()
return files.NewFile(systemdServiceFile).Delete()
}
f := files.NewFile(initServiceFile)
if f.Exists() {
return f.Delete()
}
return nil
}
// install init service
func (this *ServiceManager) installInitService(exePath string, args []string) error {
shortName := teaconst.SystemdServiceName
scriptFile := Tea.Root + "/scripts/" + shortName
if !files.NewFile(scriptFile).Exists() {
return errors.New("'scripts/" + shortName + "' file not exists")
}
data, err := os.ReadFile(scriptFile)
if err != nil {
return err
}
data = regexp.MustCompile("INSTALL_DIR=.+").ReplaceAll(data, []byte("INSTALL_DIR="+Tea.Root))
err = os.WriteFile(initServiceFile, data, 0777)
if err != nil {
return err
}
chkCmd, err := exec.LookPath("chkconfig")
if err != nil {
return err
}
err = exec.Command(chkCmd, "--add", teaconst.ProcessName).Start()
if err != nil {
return err
}
return nil
}
// install systemd service
func (this *ServiceManager) installSystemdService(systemd, exePath string, args []string) error {
shortName := teaconst.SystemdServiceName
longName := "GoEdge User" // TODO 将来可以修改
var startCmd = exePath + " daemon"
bashPath, _ := exec.LookPath("bash")
if len(bashPath) > 0 {
startCmd = bashPath + " -c \"" + startCmd + "\""
}
desc := `# Provides: ` + shortName + `
# Required-Start: $all
# Required-Stop:
# Default-Start: 2 3 4 5
# Default-Stop:
# Short-Description: ` + longName + ` Service
### END INIT INFO
[Unit]
Description=` + longName + ` Service
Before=shutdown.target
After=network-online.target
[Service]
Type=simple
Restart=always
RestartSec=1s
ExecStart=` + startCmd + `
ExecStop=` + exePath + ` stop
ExecReload=` + exePath + ` reload
[Install]
WantedBy=multi-user.target`
// write file
err := os.WriteFile(systemdServiceFile, []byte(desc), 0777)
if err != nil {
return err
}
// stop current systemd service if running
_ = exec.Command(systemd, "stop", shortName+".service").Start()
// reload
_ = exec.Command(systemd, "daemon-reload").Start()
// enable
cmd := exec.Command(systemd, "enable", shortName+".service")
return cmd.Run()
}

View File

@@ -0,0 +1,19 @@
//go:build !linux && !windows
// +build !linux,!windows
package utils
// 安装服务
func (this *ServiceManager) Install(exePath string, args []string) error {
return nil
}
// 启动服务
func (this *ServiceManager) Start() error {
return nil
}
// 删除服务
func (this *ServiceManager) Uninstall() error {
return nil
}

View File

@@ -0,0 +1,12 @@
package utils
import (
teaconst "github.com/TeaOSLab/EdgeUser/internal/const"
"testing"
)
func TestServiceManager_Log(t *testing.T) {
manager := NewServiceManager(teaconst.ProductName, teaconst.ProductName+" Server")
manager.Log("Hello, World")
manager.LogError("Hello, World")
}

View File

@@ -0,0 +1,175 @@
//go:build windows
// +build windows
package utils
import (
"fmt"
teaconst "github.com/TeaOSLab/EdgeUser/internal/const"
"github.com/iwind/TeaGo/Tea"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc"
"golang.org/x/sys/windows/svc/mgr"
"os/exec"
)
// 安装服务
func (this *ServiceManager) Install(exePath string, args []string) error {
m, err := mgr.Connect()
if err != nil {
return fmt.Errorf("connecting: %w please 'Run as administrator' again", err)
}
defer m.Disconnect()
s, err := m.OpenService(this.Name)
if err == nil {
s.Close()
return fmt.Errorf("service %s already exists", this.Name)
}
s, err = m.CreateService(this.Name, exePath, mgr.Config{
DisplayName: this.Name,
Description: this.Description,
StartType: windows.SERVICE_AUTO_START,
}, args...)
if err != nil {
return fmt.Errorf("creating: %w", err)
}
defer s.Close()
return nil
}
// 启动服务
func (this *ServiceManager) Start() error {
m, err := mgr.Connect()
if err != nil {
return err
}
defer m.Disconnect()
s, err := m.OpenService(this.Name)
if err != nil {
return fmt.Errorf("could not access service: %w", err)
}
defer s.Close()
err = s.Start("service")
if err != nil {
return fmt.Errorf("could not start service: %w", err)
}
return nil
}
// 删除服务
func (this *ServiceManager) Uninstall() error {
m, err := mgr.Connect()
if err != nil {
return fmt.Errorf("connecting: %w please 'Run as administrator' again", err)
}
defer m.Disconnect()
s, err := m.OpenService(this.Name)
if err != nil {
return fmt.Errorf("open service: %w", err)
}
// shutdown service
_, err = s.Control(svc.Stop)
if err != nil {
fmt.Printf("shutdown service: %s\n", err.Error())
}
defer s.Close()
err = s.Delete()
if err != nil {
return fmt.Errorf("deleting: %w", err)
}
return nil
}
// 运行
func (this *ServiceManager) Run() {
err := svc.Run(this.Name, this)
if err != nil {
this.LogError(err.Error())
}
}
// 同服务管理器的交互
func (this *ServiceManager) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (ssec bool, errno uint32) {
const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown | svc.AcceptPauseAndContinue
changes <- svc.Status{
State: svc.StartPending,
}
changes <- svc.Status{
State: svc.Running,
Accepts: cmdsAccepted,
}
// start service
this.Log("start")
this.cmdStart()
loop:
for {
select {
case c := <-r:
switch c.Cmd {
case svc.Interrogate:
this.Log("cmd: Interrogate")
changes <- c.CurrentStatus
case svc.Stop, svc.Shutdown:
this.Log("cmd: Stop|Shutdown")
// stop service
this.cmdStop()
break loop
case svc.Pause:
this.Log("cmd: Pause")
// stop service
this.cmdStop()
changes <- svc.Status{
State: svc.Paused,
Accepts: cmdsAccepted,
}
case svc.Continue:
this.Log("cmd: Continue")
// start service
this.cmdStart()
changes <- svc.Status{
State: svc.Running,
Accepts: cmdsAccepted,
}
default:
this.LogError(fmt.Sprintf("unexpected control request #%d\r\n", c))
}
}
}
changes <- svc.Status{
State: svc.StopPending,
}
return
}
// 启动Web服务
func (this *ServiceManager) cmdStart() {
cmd := exec.Command(Tea.Root+Tea.DS+"bin"+Tea.DS+teaconst.SystemdServiceName+".exe", "start")
err := cmd.Start()
if err != nil {
this.LogError(err.Error())
}
}
// 停止Web服务
func (this *ServiceManager) cmdStop() {
cmd := exec.Command(Tea.Root+Tea.DS+"bin"+Tea.DS+teaconst.SystemdServiceName+".exe", "stop")
err := cmd.Start()
if err != nil {
this.LogError(err.Error())
}
}

View File

@@ -0,0 +1,10 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package sizes
const (
K int64 = 1024
M = 1024 * K
G = 1024 * M
T = 1024 * G
)

View File

@@ -0,0 +1,17 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package sizes_test
import (
"github.com/TeaOSLab/EdgeUser/internal/utils/sizes"
"github.com/iwind/TeaGo/assert"
"testing"
)
func TestSizes(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsTrue(sizes.K == 1024)
a.IsTrue(sizes.M == 1024*1024)
a.IsTrue(sizes.G == 1024*1024*1024)
a.IsTrue(sizes.T == 1024*1024*1024*1024)
}

View File

@@ -0,0 +1,31 @@
package utils
import (
"github.com/iwind/TeaGo/types"
"strings"
)
// format address
func FormatAddress(addr string) string {
if strings.HasSuffix(addr, "unix:") {
return addr
}
addr = strings.Replace(addr, " ", "", -1)
addr = strings.Replace(addr, "\t", "", -1)
addr = strings.Replace(addr, "", ":", -1)
addr = strings.TrimSpace(addr)
return addr
}
// 分割数字
func SplitNumbers(numbers string) (result []int64) {
if len(numbers) == 0 {
return
}
pieces := strings.Split(numbers, ",")
for _, piece := range pieces {
number := types.Int64(strings.TrimSpace(piece))
result = append(result, number)
}
return
}

View File

@@ -0,0 +1,67 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package utils
import (
"github.com/iwind/TeaGo/lists"
"strings"
)
func FilterNotEmpty(item string) bool {
return len(item) > 0
}
func MapAddPrefixFunc(prefix string) func(item string) string {
return func(item string) string {
if !strings.HasPrefix(item, prefix) {
return prefix + item
}
return item
}
}
type StringsStream struct {
s []string
}
func NewStringsStream(s []string) *StringsStream {
return &StringsStream{s: s}
}
func (this *StringsStream) Map(f ...func(item string) string) *StringsStream {
for index, item := range this.s {
for _, f1 := range f {
item = f1(item)
}
this.s[index] = item
}
return this
}
func (this *StringsStream) Filter(f ...func(item string) bool) *StringsStream {
for _, f1 := range f {
var newStrings = []string{}
for _, item := range this.s {
if f1(item) {
newStrings = append(newStrings, item)
}
}
this.s = newStrings
}
return this
}
func (this *StringsStream) Unique() *StringsStream {
var newStrings = []string{}
for _, item := range this.s {
if !lists.ContainsString(newStrings, item) {
newStrings = append(newStrings, item)
}
}
this.s = newStrings
return this
}
func (this *StringsStream) Result() []string {
return this.s
}

View File

@@ -0,0 +1,25 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package utils_test
import (
"github.com/TeaOSLab/EdgeUser/internal/utils"
"strings"
"testing"
)
func TestStringsStream_Filter(t *testing.T) {
var stream = utils.NewStringsStream([]string{"a", "b", "1", "2", "", "png", "a"})
stream.Filter(func(item string) bool {
return len(item) > 0
})
t.Log(stream.Result())
stream.Map(func(item string) string {
return "." + item
})
t.Log(stream.Result())
stream.Unique()
t.Log(stream.Result())
stream.Map(strings.ToUpper, strings.ToLower)
t.Log(stream.Result())
}

View File

@@ -0,0 +1,47 @@
package utils
import (
"time"
)
// 类似于time.Ticker但能够真正地停止
type Ticker struct {
raw *time.Ticker
S chan bool
C <-chan time.Time
isStopped bool
}
// 创建新Ticker
func NewTicker(duration time.Duration) *Ticker {
raw := time.NewTicker(duration)
return &Ticker{
raw: raw,
C: raw.C,
S: make(chan bool, 1),
}
}
// 查找下一个Tick
func (this *Ticker) Next() bool {
select {
case <-this.raw.C:
return true
case <-this.S:
return false
}
}
// 停止
func (this *Ticker) Stop() {
if this.isStopped {
return
}
this.isStopped = true
this.raw.Stop()
this.S <- true
}

View File

@@ -0,0 +1,55 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package utils
import (
"errors"
"fmt"
"github.com/iwind/TeaGo/types"
"regexp"
)
// RangeTimes 计算时间点
func RangeTimes(timeFrom string, timeTo string, everyMinutes int32) (result []string, err error) {
if everyMinutes <= 0 {
return nil, errors.New("invalid 'everyMinutes'")
}
var reg = regexp.MustCompile(`^\d{4}$`)
if !reg.MatchString(timeFrom) {
return nil, errors.New("invalid timeFrom '" + timeFrom + "'")
}
if !reg.MatchString(timeTo) {
return nil, errors.New("invalid timeTo '" + timeTo + "'")
}
if timeFrom > timeTo {
// swap
timeFrom, timeTo = timeTo, timeFrom
}
var everyMinutesInt = int(everyMinutes)
var fromHour = types.Int(timeFrom[:2])
var fromMinute = types.Int(timeFrom[2:])
var toHour = types.Int(timeTo[:2])
var toMinute = types.Int(timeTo[2:])
if fromMinute%everyMinutesInt == 0 {
result = append(result, timeFrom)
}
for {
fromMinute += everyMinutesInt
if fromMinute > 59 {
fromHour += fromMinute / 60
fromMinute = fromMinute % 60
}
if fromHour > toHour || (fromHour == toHour && fromMinute > toMinute) {
break
}
result = append(result, fmt.Sprintf("%02d%02d", fromHour, fromMinute))
}
return
}

View File

@@ -0,0 +1,95 @@
package utils
import (
"archive/zip"
"errors"
"io"
"os"
)
type Unzip struct {
zipFile string
targetDir string
}
func NewUnzip(zipFile string, targetDir string) *Unzip {
return &Unzip{
zipFile: zipFile,
targetDir: targetDir,
}
}
func (this *Unzip) Run() error {
if len(this.zipFile) == 0 {
return errors.New("zip file should not be empty")
}
if len(this.targetDir) == 0 {
return errors.New("target dir should not be empty")
}
reader, err := zip.OpenReader(this.zipFile)
if err != nil {
return err
}
defer func() {
_ = reader.Close()
}()
for _, file := range reader.File {
var info = file.FileInfo()
var target = this.targetDir + "/" + file.Name
// 目录
if info.IsDir() {
stat, err := os.Stat(target)
if err != nil {
if !os.IsNotExist(err) {
return err
} else {
err = os.MkdirAll(target, info.Mode())
if err != nil {
return err
}
}
} else if !stat.IsDir() {
err = os.MkdirAll(target, info.Mode())
if err != nil {
return err
}
}
continue
}
// 文件
err = func(file *zip.File, target string) error {
fileReader, err := file.Open()
if err != nil {
return err
}
defer func() {
_ = fileReader.Close()
}()
// remove old
_ = os.Remove(target)
// create new
fileWriter, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, file.FileInfo().Mode())
if err != nil {
return err
}
defer func() {
_ = fileWriter.Close()
}()
_, err = io.Copy(fileWriter, fileReader)
return err
}(file, target)
if err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,307 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package utils
import (
"bytes"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
teaconst "github.com/TeaOSLab/EdgeUser/internal/const"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
stringutil "github.com/iwind/TeaGo/utils/string"
"io"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
)
type UpgradeFileWriter struct {
rawWriter io.Writer
written int64
}
func NewUpgradeFileWriter(rawWriter io.Writer) *UpgradeFileWriter {
return &UpgradeFileWriter{rawWriter: rawWriter}
}
func (this *UpgradeFileWriter) Write(p []byte) (n int, err error) {
n, err = this.rawWriter.Write(p)
this.written += int64(n)
return
}
func (this *UpgradeFileWriter) TotalWritten() int64 {
return this.written
}
type UpgradeManager struct {
client *http.Client
component string
newVersion string
contentLength int64
isDownloading bool
writer *UpgradeFileWriter
body io.ReadCloser
isCancelled bool
downloadURL string
}
func NewUpgradeManager(component string, downloadURL string) *UpgradeManager {
return &UpgradeManager{
component: component,
downloadURL: downloadURL,
client: &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
CheckRedirect: nil,
Jar: nil,
Timeout: 30 * time.Minute,
},
}
}
func (this *UpgradeManager) Start() error {
if this.isDownloading {
return errors.New("another process is running")
}
this.isDownloading = true
defer func() {
this.client.CloseIdleConnections()
this.isDownloading = false
}()
// 检查unzip
unzipExe, _ := exec.LookPath("unzip")
// 检查cp
cpExe, _ := exec.LookPath("cp")
if len(cpExe) == 0 {
return errors.New("can not find 'cp' command")
}
// 检查新版本
var downloadURL = this.downloadURL
if len(downloadURL) == 0 {
var url = teaconst.UpdatesURL
var osName = runtime.GOOS
if Tea.IsTesting() && osName == "darwin" {
osName = "linux"
}
url = strings.ReplaceAll(url, "${os}", osName)
url = strings.ReplaceAll(url, "${arch}", runtime.GOARCH)
url = strings.ReplaceAll(url, "${version}", teaconst.Version)
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return fmt.Errorf("create url request failed: %w", err)
}
req.Header.Set("User-Agent", "Edge-User/"+teaconst.Version)
resp, err := this.client.Do(req)
if err != nil {
return fmt.Errorf("read latest version failed: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
return errors.New("read latest version failed: invalid response code '" + types.String(resp.StatusCode) + "'")
}
data, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("read latest version failed: %w", err)
}
var m = maps.Map{}
err = json.Unmarshal(data, &m)
if err != nil {
return fmt.Errorf("invalid response data: %w, origin data: %s", err, string(data))
}
var code = m.GetInt("code")
if code != 200 {
return errors.New(m.GetString("message"))
}
var dataMap = m.GetMap("data")
var downloadHost = dataMap.GetString("host")
var versions = dataMap.GetSlice("versions")
var downloadPath = ""
for _, component := range versions {
var componentMap = maps.NewMap(component)
if componentMap.Has("version") {
if componentMap.GetString("code") == this.component {
var version = componentMap.GetString("version")
if stringutil.VersionCompare(version, teaconst.Version) > 0 {
this.newVersion = version
downloadPath = componentMap.GetString("url")
break
}
}
}
}
if len(downloadPath) == 0 {
return errors.New("no latest version to download")
}
downloadURL = downloadHost + downloadPath
}
{
req, err := http.NewRequest(http.MethodGet, downloadURL, nil)
if err != nil {
return fmt.Errorf("create download request failed: %w", err)
}
req.Header.Set("User-Agent", "Edge-User/"+teaconst.Version)
resp, err := this.client.Do(req)
if err != nil {
return fmt.Errorf("download failed: '%s': %w", downloadURL, err)
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
return errors.New("download failed: " + downloadURL + ": invalid response code '" + types.String(resp.StatusCode) + "'")
}
this.contentLength = resp.ContentLength
this.body = resp.Body
// download to tmp
var tmpDir = os.TempDir()
var filename = filepath.Base(downloadURL)
var destFile = tmpDir + "/" + filename
_ = os.Remove(destFile)
fp, err := os.Create(destFile)
if err != nil {
return fmt.Errorf("create file failed: %w", err)
}
defer func() {
// 删除安装文件
_ = os.Remove(destFile)
}()
this.writer = NewUpgradeFileWriter(fp)
_, err = io.Copy(this.writer, resp.Body)
if err != nil {
_ = fp.Close()
if this.isCancelled {
return nil
}
return fmt.Errorf("download failed: %w", err)
}
_ = fp.Close()
// unzip
var unzipDir = tmpDir + "/edge-" + this.component + "-tmp"
stat, err := os.Stat(unzipDir)
if err == nil && stat.IsDir() {
err = os.RemoveAll(unzipDir)
if err != nil {
return fmt.Errorf("remove old dir '%s' failed: %w", unzipDir, err)
}
}
if len(unzipExe) > 0 {
var unzipCmd = exec.Command(unzipExe, "-q", "-o", destFile, "-d", unzipDir)
var unzipStderr = &bytes.Buffer{}
unzipCmd.Stderr = unzipStderr
err = unzipCmd.Run()
if err != nil {
return fmt.Errorf("unzip installation file failed: %w: %s", err, unzipStderr.String())
}
} else {
var unzipCmd = &Unzip{
zipFile: destFile,
targetDir: unzipDir,
}
err = unzipCmd.Run()
if err != nil {
return fmt.Errorf("unzip installation file failed: %w", err)
}
}
installationFiles, err := filepath.Glob(unzipDir + "/edge-" + this.component + "/*")
if err != nil {
return fmt.Errorf("lookup installation files failed: %w", err)
}
// cp to target dir
currentExe, err := os.Executable()
if err != nil {
return fmt.Errorf("reveal current executable file path failed: %w", err)
}
var targetDir = filepath.Dir(filepath.Dir(currentExe))
if !Tea.IsTesting() {
for _, installationFile := range installationFiles {
var cpCmd = exec.Command(cpExe, "-R", "-f", installationFile, targetDir)
var cpStderr = &bytes.Buffer{}
cpCmd.Stderr = cpStderr
err = cpCmd.Run()
if err != nil {
return errors.New("overwrite installation files failed: '" + cpCmd.String() + "': " + cpStderr.String())
}
}
}
// remove tmp
_ = os.RemoveAll(unzipDir)
}
return nil
}
func (this *UpgradeManager) IsDownloading() bool {
return this.isDownloading
}
func (this *UpgradeManager) Progress() float32 {
if this.contentLength <= 0 {
return -1
}
if this.writer == nil {
return -1
}
return float32(this.writer.TotalWritten()) / float32(this.contentLength)
}
func (this *UpgradeManager) NewVersion() string {
return this.newVersion
}
func (this *UpgradeManager) Cancel() error {
this.isCancelled = true
this.isDownloading = false
if this.body != nil {
_ = this.body.Close()
}
return nil
}

View File

@@ -0,0 +1,35 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package utils_test
import (
"github.com/TeaOSLab/EdgeUser/internal/utils"
"testing"
"time"
)
func TestNewUpgradeManager(t *testing.T) {
var manager = utils.NewUpgradeManager("user", "")
var ticker = time.NewTicker(2 * time.Second)
go func() {
for range ticker.C {
if manager.IsDownloading() {
t.Logf("%.2f%%", manager.Progress()*100)
}
}
}()
/**go func() {
time.Sleep(5 * time.Second)
if manager.IsDownloading() {
t.Log("cancel downloading")
_ = manager.Cancel()
}
}()**/
err := manager.Start()
if err != nil {
t.Fatal(err)
}
}

View File

@@ -0,0 +1,24 @@
package utils
import (
"encoding/binary"
"net"
"strings"
)
// VersionToLong 计算版本代号
func VersionToLong(version string) uint32 {
var countDots = strings.Count(version, ".")
if countDots == 2 {
version += ".0"
} else if countDots == 1 {
version += ".0.0"
} else if countDots == 0 {
version += ".0.0.0"
}
var ip = net.ParseIP(version)
if ip == nil || ip.To4() == nil {
return 0
}
return binary.BigEndian.Uint32(ip.To4())
}

View File

@@ -0,0 +1,32 @@
Copyright (c) 2012-2016, Nick Galbreath
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
https://github.com/client9/libinjection
http://opensource.org/licenses/BSD-3-Clause

View File

@@ -0,0 +1 @@
copy from https://github.com/libinjection/libinjection

View File

@@ -0,0 +1,65 @@
/**
* Copyright 2012-2016 Nick Galbreath
* nickg@client9.com
* BSD License -- see COPYING.txt for details
*
* https://libinjection.client9.com/
*
*/
#ifndef LIBINJECTION_H
#define LIBINJECTION_H
#ifdef __cplusplus
# define LIBINJECTION_BEGIN_DECLS extern "C" {
# define LIBINJECTION_END_DECLS }
#else
# define LIBINJECTION_BEGIN_DECLS
# define LIBINJECTION_END_DECLS
#endif
LIBINJECTION_BEGIN_DECLS
/*
* Pull in size_t
*/
#include <string.h>
/*
* Version info.
*
* This is moved into a function to allow SWIG and other auto-generated
* binding to not be modified during minor release changes. We change
* change the version number in the c source file, and not regenerated
* the binding
*
* See python's normalized version
* http://www.python.org/dev/peps/pep-0386/#normalizedversion
*/
const char* libinjection_version(void);
/**
* Simple API for SQLi detection - returns a SQLi fingerprint or NULL
* is benign input
*
* \param[in] s input string, may contain nulls, does not need to be null-terminated
* \param[in] slen input string length
* \param[out] fingerprint buffer of 8+ characters. c-string,
* \return 1 if SQLi, 0 if benign. fingerprint will be set or set to empty string.
*/
int libinjection_sqli(const char* s, size_t slen, char fingerprint[]);
/** ALPHA version of xss detector.
*
* NOT DONE.
*
* \param[in] s input string, may contain nulls, does not need to be null-terminated
* \param[in] slen input string length
* \return 1 if XSS found, 0 if benign
*
*/
int libinjection_xss(const char* s, size_t slen, int strictMode);
LIBINJECTION_END_DECLS
#endif /* LIBINJECTION_H */

View File

@@ -0,0 +1,868 @@
#include "libinjection_html5.h"
#include <string.h>
#include <assert.h>
#ifdef DEBUG
#include <stdio.h>
#define TRACE() printf("%s:%d\n", __FUNCTION__, __LINE__)
#else
#define TRACE()
#endif
#define CHAR_EOF -1
#define CHAR_NULL 0
#define CHAR_BANG 33
#define CHAR_DOUBLE 34
#define CHAR_PERCENT 37
#define CHAR_SINGLE 39
#define CHAR_DASH 45
#define CHAR_SLASH 47
#define CHAR_LT 60
#define CHAR_EQUALS 61
#define CHAR_GT 62
#define CHAR_QUESTION 63
#define CHAR_RIGHTB 93
#define CHAR_TICK 96
/* prototypes */
static int h5_skip_white(h5_state_t* hs);
static int h5_is_white(char ch);
static int h5_state_eof(h5_state_t* hs);
static int h5_state_data(h5_state_t* hs);
static int h5_state_tag_open(h5_state_t* hs);
static int h5_state_tag_name(h5_state_t* hs);
static int h5_state_tag_name_close(h5_state_t* hs);
static int h5_state_end_tag_open(h5_state_t* hs);
static int h5_state_self_closing_start_tag(h5_state_t* hs);
static int h5_state_attribute_name(h5_state_t* hs);
static int h5_state_after_attribute_name(h5_state_t* hs);
static int h5_state_before_attribute_name(h5_state_t* hs);
static int h5_state_before_attribute_value(h5_state_t* hs);
static int h5_state_attribute_value_double_quote(h5_state_t* hs);
static int h5_state_attribute_value_single_quote(h5_state_t* hs);
static int h5_state_attribute_value_back_quote(h5_state_t* hs);
static int h5_state_attribute_value_no_quote(h5_state_t* hs);
static int h5_state_after_attribute_value_quoted_state(h5_state_t* hs);
static int h5_state_comment(h5_state_t* hs);
static int h5_state_cdata(h5_state_t* hs);
/* 12.2.4.44 */
static int h5_state_bogus_comment(h5_state_t* hs);
static int h5_state_bogus_comment2(h5_state_t* hs);
/* 12.2.4.45 */
static int h5_state_markup_declaration_open(h5_state_t* hs);
/* 8.2.4.52 */
static int h5_state_doctype(h5_state_t* hs);
/**
* public function
*/
void libinjection_h5_init(h5_state_t* hs, const char* s, size_t len, enum html5_flags flags)
{
memset(hs, 0, sizeof(h5_state_t));
hs->s = s;
hs->len = len;
switch (flags) {
case DATA_STATE:
hs->state = h5_state_data;
break;
case VALUE_NO_QUOTE:
hs->state = h5_state_before_attribute_name;
break;
case VALUE_SINGLE_QUOTE:
hs->state = h5_state_attribute_value_single_quote;
break;
case VALUE_DOUBLE_QUOTE:
hs->state = h5_state_attribute_value_double_quote;
break;
case VALUE_BACK_QUOTE:
hs->state = h5_state_attribute_value_back_quote;
break;
}
}
/**
* public function
*/
int libinjection_h5_next(h5_state_t* hs)
{
assert(hs->state != NULL);
return (*hs->state)(hs);
}
/**
* Everything below here is private
*
*/
static int h5_is_white(char ch)
{
/*
* \t = horizontal tab = 0x09
* \n = newline = 0x0A
* \v = vertical tab = 0x0B
* \f = form feed = 0x0C
* \r = cr = 0x0D
*/
return strchr(" \t\n\v\f\r", ch) != NULL;
}
static int h5_skip_white(h5_state_t* hs)
{
char ch;
while (hs->pos < hs->len) {
ch = hs->s[hs->pos];
switch (ch) {
case 0x00: /* IE only */
case 0x20:
case 0x09:
case 0x0A:
case 0x0B: /* IE only */
case 0x0C:
case 0x0D: /* IE only */
hs->pos += 1;
break;
default:
return ch;
}
}
return CHAR_EOF;
}
static int h5_state_eof(h5_state_t* hs)
{
/* eliminate unused function argument warning */
(void)hs;
return 0;
}
static int h5_state_data(h5_state_t* hs)
{
const char* idx;
TRACE();
assert(hs->len >= hs->pos);
idx = (const char*) memchr(hs->s + hs->pos, CHAR_LT, hs->len - hs->pos);
if (idx == NULL) {
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->token_type = DATA_TEXT;
hs->state = h5_state_eof;
if (hs->token_len == 0) {
return 0;
}
} else {
hs->token_start = hs->s + hs->pos;
hs->token_type = DATA_TEXT;
hs->token_len = (size_t)(idx - hs->s) - hs->pos;
hs->pos = (size_t)(idx - hs->s) + 1;
hs->state = h5_state_tag_open;
if (hs->token_len == 0) {
return h5_state_tag_open(hs);
}
}
return 1;
}
/**
* 12 2.4.8
*/
static int h5_state_tag_open(h5_state_t* hs)
{
char ch;
TRACE();
if (hs->pos >= hs->len) {
return 0;
}
ch = hs->s[hs->pos];
if (ch == CHAR_BANG) {
hs->pos += 1;
return h5_state_markup_declaration_open(hs);
} else if (ch == CHAR_SLASH) {
hs->pos += 1;
hs->is_close = 1;
return h5_state_end_tag_open(hs);
} else if (ch == CHAR_QUESTION) {
hs->pos += 1;
return h5_state_bogus_comment(hs);
} else if (ch == CHAR_PERCENT) {
/* this is not in spec.. alternative comment format used
by IE <= 9 and Safari < 4.0.3 */
hs->pos += 1;
return h5_state_bogus_comment2(hs);
} else if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z')) {
return h5_state_tag_name(hs);
} else if (ch == CHAR_NULL) {
/* IE-ism NULL characters are ignored */
return h5_state_tag_name(hs);
} else {
/* user input mistake in configuring state */
if (hs->pos == 0) {
return h5_state_data(hs);
}
hs->token_start = hs->s + hs->pos - 1;
hs->token_len = 1;
hs->token_type = DATA_TEXT;
hs->state = h5_state_data;
return 1;
}
}
/**
* 12.2.4.9
*/
static int h5_state_end_tag_open(h5_state_t* hs)
{
char ch;
TRACE();
if (hs->pos >= hs->len) {
return 0;
}
ch = hs->s[hs->pos];
if (ch == CHAR_GT) {
return h5_state_data(hs);
} else if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z')) {
return h5_state_tag_name(hs);
}
hs->is_close = 0;
return h5_state_bogus_comment(hs);
}
/*
*
*/
static int h5_state_tag_name_close(h5_state_t* hs)
{
TRACE();
hs->is_close = 0;
hs->token_start = hs->s + hs->pos;
hs->token_len = 1;
hs->token_type = TAG_NAME_CLOSE;
hs->pos += 1;
if (hs->pos < hs->len) {
hs->state = h5_state_data;
} else {
hs->state = h5_state_eof;
}
return 1;
}
/**
* 12.2.4.10
*/
static int h5_state_tag_name(h5_state_t* hs)
{
char ch;
size_t pos;
TRACE();
pos = hs->pos;
while (pos < hs->len) {
ch = hs->s[pos];
if (ch == 0) {
/* special non-standard case */
/* allow nulls in tag name */
/* some old browsers apparently allow and ignore them */
pos += 1;
} else if (h5_is_white(ch)) {
hs->token_start = hs->s + hs->pos;
hs->token_len = pos - hs->pos;
hs->token_type = TAG_NAME_OPEN;
hs->pos = pos + 1;
hs->state = h5_state_before_attribute_name;
return 1;
} else if (ch == CHAR_SLASH) {
hs->token_start = hs->s + hs->pos;
hs->token_len = pos - hs->pos;
hs->token_type = TAG_NAME_OPEN;
hs->pos = pos + 1;
hs->state = h5_state_self_closing_start_tag;
return 1;
} else if (ch == CHAR_GT) {
hs->token_start = hs->s + hs->pos;
hs->token_len = pos - hs->pos;
if (hs->is_close) {
hs->pos = pos + 1;
hs->is_close = 0;
hs->token_type = TAG_CLOSE;
hs->state = h5_state_data;
} else {
hs->pos = pos;
hs->token_type = TAG_NAME_OPEN;
hs->state = h5_state_tag_name_close;
}
return 1;
} else {
pos += 1;
}
}
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->token_type = TAG_NAME_OPEN;
hs->state = h5_state_eof;
return 1;
}
/**
* 12.2.4.34
*/
static int h5_state_before_attribute_name(h5_state_t* hs)
{
int ch;
TRACE();
/* for manual tail call optimization, see comment below */
tail_call:;
ch = h5_skip_white(hs);
switch (ch) {
case CHAR_EOF: {
return 0;
}
case CHAR_SLASH: {
hs->pos += 1;
/* Logically, We want to call h5_state_self_closing_start_tag(hs) here.
As this function may call us back and the compiler
might not implement automatic tail call optimization,
this might result in a deep recursion.
We detect this case here and start over with the current state.
*/
if (hs->pos < hs->len && hs->s[hs->pos] != CHAR_GT) {
goto tail_call;
}
return h5_state_self_closing_start_tag(hs);
}
case CHAR_GT: {
hs->state = h5_state_data;
hs->token_start = hs->s + hs->pos;
hs->token_len = 1;
hs->token_type = TAG_NAME_CLOSE;
hs->pos += 1;
return 1;
}
default: {
return h5_state_attribute_name(hs);
}
}
}
static int h5_state_attribute_name(h5_state_t* hs)
{
char ch;
size_t pos;
TRACE();
pos = hs->pos + 1;
while (pos < hs->len) {
ch = hs->s[pos];
if (h5_is_white(ch)) {
hs->token_start = hs->s + hs->pos;
hs->token_len = pos - hs->pos;
hs->token_type = ATTR_NAME;
hs->state = h5_state_after_attribute_name;
hs->pos = pos + 1;
return 1;
} else if (ch == CHAR_SLASH) {
hs->token_start = hs->s + hs->pos;
hs->token_len = pos - hs->pos;
hs->token_type = ATTR_NAME;
hs->state = h5_state_self_closing_start_tag;
hs->pos = pos + 1;
return 1;
} else if (ch == CHAR_EQUALS) {
hs->token_start = hs->s + hs->pos;
hs->token_len = pos - hs->pos;
hs->token_type = ATTR_NAME;
hs->state = h5_state_before_attribute_value;
hs->pos = pos + 1;
return 1;
} else if (ch == CHAR_GT) {
hs->token_start = hs->s + hs->pos;
hs->token_len = pos - hs->pos;
hs->token_type = ATTR_NAME;
hs->state = h5_state_tag_name_close;
hs->pos = pos;
return 1;
} else {
pos += 1;
}
}
/* EOF */
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->token_type = ATTR_NAME;
hs->state = h5_state_eof;
hs->pos = hs->len;
return 1;
}
/**
* 12.2.4.36
*/
static int h5_state_after_attribute_name(h5_state_t* hs)
{
int c;
TRACE();
c = h5_skip_white(hs);
switch (c) {
case CHAR_EOF: {
return 0;
}
case CHAR_SLASH: {
hs->pos += 1;
return h5_state_self_closing_start_tag(hs);
}
case CHAR_EQUALS: {
hs->pos += 1;
return h5_state_before_attribute_value(hs);
}
case CHAR_GT: {
return h5_state_tag_name_close(hs);
}
default: {
return h5_state_attribute_name(hs);
}
}
}
/**
* 12.2.4.37
*/
static int h5_state_before_attribute_value(h5_state_t* hs)
{
int c;
TRACE();
c = h5_skip_white(hs);
if (c == CHAR_EOF) {
hs->state = h5_state_eof;
return 0;
}
if (c == CHAR_DOUBLE) {
return h5_state_attribute_value_double_quote(hs);
} else if (c == CHAR_SINGLE) {
return h5_state_attribute_value_single_quote(hs);
} else if (c == CHAR_TICK) {
/* NON STANDARD IE */
return h5_state_attribute_value_back_quote(hs);
} else {
return h5_state_attribute_value_no_quote(hs);
}
}
static int h5_state_attribute_value_quote(h5_state_t* hs, char qchar)
{
const char* idx;
TRACE();
/* skip initial quote in normal case.
* don't do this "if (pos == 0)" since it means we have started
* in a non-data state. given an input of '><foo
* we want to make 0-length attribute name
*/
if (hs->pos > 0) {
hs->pos += 1;
}
idx = (const char*) memchr(hs->s + hs->pos, qchar, hs->len - hs->pos);
if (idx == NULL) {
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->token_type = ATTR_VALUE;
hs->state = h5_state_eof;
} else {
hs->token_start = hs->s + hs->pos;
hs->token_len = (size_t)(idx - hs->s) - hs->pos;
hs->token_type = ATTR_VALUE;
hs->state = h5_state_after_attribute_value_quoted_state;
hs->pos += hs->token_len + 1;
}
return 1;
}
static
int h5_state_attribute_value_double_quote(h5_state_t* hs)
{
TRACE();
return h5_state_attribute_value_quote(hs, CHAR_DOUBLE);
}
static
int h5_state_attribute_value_single_quote(h5_state_t* hs)
{
TRACE();
return h5_state_attribute_value_quote(hs, CHAR_SINGLE);
}
static
int h5_state_attribute_value_back_quote(h5_state_t* hs)
{
TRACE();
return h5_state_attribute_value_quote(hs, CHAR_TICK);
}
static int h5_state_attribute_value_no_quote(h5_state_t* hs)
{
char ch;
size_t pos;
TRACE();
pos = hs->pos;
while (pos < hs->len) {
ch = hs->s[pos];
if (h5_is_white(ch)) {
hs->token_type = ATTR_VALUE;
hs->token_start = hs->s + hs->pos;
hs->token_len = pos - hs->pos;
hs->pos = pos + 1;
hs->state = h5_state_before_attribute_name;
return 1;
} else if (ch == CHAR_GT) {
hs->token_type = ATTR_VALUE;
hs->token_start = hs->s + hs->pos;
hs->token_len = pos - hs->pos;
hs->pos = pos;
hs->state = h5_state_tag_name_close;
return 1;
}
pos += 1;
}
TRACE();
/* EOF */
hs->state = h5_state_eof;
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->token_type = ATTR_VALUE;
return 1;
}
/**
* 12.2.4.41
*/
static int h5_state_after_attribute_value_quoted_state(h5_state_t* hs)
{
char ch;
TRACE();
if (hs->pos >= hs->len) {
return 0;
}
ch = hs->s[hs->pos];
if (h5_is_white(ch)) {
hs->pos += 1;
return h5_state_before_attribute_name(hs);
} else if (ch == CHAR_SLASH) {
hs->pos += 1;
return h5_state_self_closing_start_tag(hs);
} else if (ch == CHAR_GT) {
hs->token_start = hs->s + hs->pos;
hs->token_len = 1;
hs->token_type = TAG_NAME_CLOSE;
hs->pos += 1;
hs->state = h5_state_data;
return 1;
} else {
return h5_state_before_attribute_name(hs);
}
}
/**
* 12.2.4.43
*
* WARNING: This function is partially inlined into h5_state_before_attribute_name()
*/
static int h5_state_self_closing_start_tag(h5_state_t* hs)
{
char ch;
TRACE();
if (hs->pos >= hs->len) {
return 0;
}
ch = hs->s[hs->pos];
if (ch == CHAR_GT) {
assert(hs->pos > 0);
hs->token_start = hs->s + hs->pos -1;
hs->token_len = 2;
hs->token_type = TAG_NAME_SELFCLOSE;
hs->state = h5_state_data;
hs->pos += 1;
return 1;
} else {
return h5_state_before_attribute_name(hs);
}
}
/**
* 12.2.4.44
*/
static int h5_state_bogus_comment(h5_state_t* hs)
{
const char* idx;
TRACE();
idx = (const char*) memchr(hs->s + hs->pos, CHAR_GT, hs->len - hs->pos);
if (idx == NULL) {
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->pos = hs->len;
hs->state = h5_state_eof;
} else {
hs->token_start = hs->s + hs->pos;
hs->token_len = (size_t)(idx - hs->s) - hs->pos;
hs->pos = (size_t)(idx - hs->s) + 1;
hs->state = h5_state_data;
}
hs->token_type = TAG_COMMENT;
return 1;
}
/**
* 12.2.4.44 ALT
*/
static int h5_state_bogus_comment2(h5_state_t* hs)
{
const char* idx;
size_t pos;
TRACE();
pos = hs->pos;
while (1) {
idx = (const char*) memchr(hs->s + pos, CHAR_PERCENT, hs->len - pos);
if (idx == NULL || (idx + 1 >= hs->s + hs->len)) {
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->pos = hs->len;
hs->token_type = TAG_COMMENT;
hs->state = h5_state_eof;
return 1;
}
if (*(idx +1) != CHAR_GT) {
pos = (size_t)(idx - hs->s) + 1;
continue;
}
/* ends in %> */
hs->token_start = hs->s + hs->pos;
hs->token_len = (size_t)(idx - hs->s) - hs->pos;
hs->pos = (size_t)(idx - hs->s) + 2;
hs->state = h5_state_data;
hs->token_type = TAG_COMMENT;
return 1;
}
}
/**
* 8.2.4.45
*/
static int h5_state_markup_declaration_open(h5_state_t* hs)
{
size_t remaining;
TRACE();
remaining = hs->len - hs->pos;
if (remaining >= 7 &&
/* case insensitive */
(hs->s[hs->pos + 0] == 'D' || hs->s[hs->pos + 0] == 'd') &&
(hs->s[hs->pos + 1] == 'O' || hs->s[hs->pos + 1] == 'o') &&
(hs->s[hs->pos + 2] == 'C' || hs->s[hs->pos + 2] == 'c') &&
(hs->s[hs->pos + 3] == 'T' || hs->s[hs->pos + 3] == 't') &&
(hs->s[hs->pos + 4] == 'Y' || hs->s[hs->pos + 4] == 'y') &&
(hs->s[hs->pos + 5] == 'P' || hs->s[hs->pos + 5] == 'p') &&
(hs->s[hs->pos + 6] == 'E' || hs->s[hs->pos + 6] == 'e')
) {
return h5_state_doctype(hs);
} else if (remaining >= 7 &&
/* upper case required */
hs->s[hs->pos + 0] == '[' &&
hs->s[hs->pos + 1] == 'C' &&
hs->s[hs->pos + 2] == 'D' &&
hs->s[hs->pos + 3] == 'A' &&
hs->s[hs->pos + 4] == 'T' &&
hs->s[hs->pos + 5] == 'A' &&
hs->s[hs->pos + 6] == '['
) {
hs->pos += 7;
return h5_state_cdata(hs);
} else if (remaining >= 2 &&
hs->s[hs->pos + 0] == '-' &&
hs->s[hs->pos + 1] == '-') {
hs->pos += 2;
return h5_state_comment(hs);
}
return h5_state_bogus_comment(hs);
}
/**
* 12.2.4.48
* 12.2.4.49
* 12.2.4.50
* 12.2.4.51
* state machine spec is confusing since it can only look
* at one character at a time but simply it's comments end by:
* 1) EOF
* 2) ending in -->
* 3) ending in -!>
*/
static int h5_state_comment(h5_state_t* hs)
{
char ch;
const char* idx;
size_t pos;
size_t offset;
const char* end = hs->s + hs->len;
TRACE();
pos = hs->pos;
while (1) {
idx = (const char*) memchr(hs->s + pos, CHAR_DASH, hs->len - pos);
/* did not find anything or has less than 3 chars left */
if (idx == NULL || idx > hs->s + hs->len - 3) {
hs->state = h5_state_eof;
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->token_type = TAG_COMMENT;
return 1;
}
offset = 1;
/* skip all nulls */
while (idx + offset < end && *(idx + offset) == 0) {
offset += 1;
}
if (idx + offset == end) {
hs->state = h5_state_eof;
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->token_type = TAG_COMMENT;
return 1;
}
ch = *(idx + offset);
if (ch != CHAR_DASH && ch != CHAR_BANG) {
pos = (size_t)(idx - hs->s) + 1;
continue;
}
/* need to test */
#if 0
/* skip all nulls */
while (idx + offset < end && *(idx + offset) == 0) {
offset += 1;
}
if (idx + offset == end) {
hs->state = h5_state_eof;
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->token_type = TAG_COMMENT;
return 1;
}
#endif
offset += 1;
if (idx + offset == end) {
hs->state = h5_state_eof;
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->token_type = TAG_COMMENT;
return 1;
}
ch = *(idx + offset);
if (ch != CHAR_GT) {
pos = (size_t)(idx - hs->s) + 1;
continue;
}
offset += 1;
/* ends in --> or -!> */
hs->token_start = hs->s + hs->pos;
hs->token_len = (size_t)(idx - hs->s) - hs->pos;
hs->pos = (size_t)(idx + offset - hs->s);
hs->state = h5_state_data;
hs->token_type = TAG_COMMENT;
return 1;
}
}
static int h5_state_cdata(h5_state_t* hs)
{
const char* idx;
size_t pos;
TRACE();
pos = hs->pos;
while (1) {
idx = (const char*) memchr(hs->s + pos, CHAR_RIGHTB, hs->len - pos);
/* did not find anything or has less than 3 chars left */
if (idx == NULL || idx > hs->s + hs->len - 3) {
hs->state = h5_state_eof;
hs->token_start = hs->s + hs->pos;
hs->token_len = hs->len - hs->pos;
hs->token_type = DATA_TEXT;
return 1;
} else if ( *(idx+1) == CHAR_RIGHTB && *(idx+2) == CHAR_GT) {
hs->state = h5_state_data;
hs->token_start = hs->s + hs->pos;
hs->token_len = (size_t)(idx - hs->s) - hs->pos;
hs->pos = (size_t)(idx - hs->s) + 3;
hs->token_type = DATA_TEXT;
return 1;
} else {
pos = (size_t)(idx - hs->s) + 1;
}
}
}
/**
* 8.2.4.52
* http://www.w3.org/html/wg/drafts/html/master/syntax.html#doctype-state
*/
static int h5_state_doctype(h5_state_t* hs)
{
const char* idx;
TRACE();
hs->token_start = hs->s + hs->pos;
hs->token_type = DOCTYPE;
idx = (const char*) memchr(hs->s + hs->pos, CHAR_GT, hs->len - hs->pos);
if (idx == NULL) {
hs->state = h5_state_eof;
hs->token_len = hs->len - hs->pos;
} else {
hs->state = h5_state_data;
hs->token_len = (size_t)(idx - hs->s) - hs->pos;
hs->pos = (size_t)(idx - hs->s) + 1;
}
return 1;
}

View File

@@ -0,0 +1,54 @@
#ifndef LIBINJECTION_HTML5
#define LIBINJECTION_HTML5
#ifdef __cplusplus
extern "C" {
#endif
/* pull in size_t */
#include <stddef.h>
enum html5_type {
DATA_TEXT
, TAG_NAME_OPEN
, TAG_NAME_CLOSE
, TAG_NAME_SELFCLOSE
, TAG_DATA
, TAG_CLOSE
, ATTR_NAME
, ATTR_VALUE
, TAG_COMMENT
, DOCTYPE
};
enum html5_flags {
DATA_STATE
, VALUE_NO_QUOTE
, VALUE_SINGLE_QUOTE
, VALUE_DOUBLE_QUOTE
, VALUE_BACK_QUOTE
};
struct h5_state;
typedef int (*ptr_html5_state)(struct h5_state*);
typedef struct h5_state {
const char* s;
size_t len;
size_t pos;
int is_close;
ptr_html5_state state;
const char* token_start;
size_t token_len;
enum html5_type token_type;
} h5_state_t;
void libinjection_h5_init(h5_state_t* hs, const char* s, size_t len, enum html5_flags);
int libinjection_h5_next(h5_state_t* hs);
#ifdef __cplusplus
}
#endif
#endif

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,294 @@
/**
* Copyright 2012-2016 Nick Galbreath
* nickg@client9.com
* BSD License -- see `COPYING.txt` for details
*
* https://libinjection.client9.com/
*
*/
#ifndef LIBINJECTION_SQLI_H
#define LIBINJECTION_SQLI_H
#ifdef __cplusplus
extern "C" {
#endif
/*
* Pull in size_t
*/
#include <string.h>
enum sqli_flags {
FLAG_NONE = 0
, FLAG_QUOTE_NONE = 1 /* 1 << 0 */
, FLAG_QUOTE_SINGLE = 2 /* 1 << 1 */
, FLAG_QUOTE_DOUBLE = 4 /* 1 << 2 */
, FLAG_SQL_ANSI = 8 /* 1 << 3 */
, FLAG_SQL_MYSQL = 16 /* 1 << 4 */
};
enum lookup_type {
LOOKUP_WORD = 1
, LOOKUP_TYPE = 2
, LOOKUP_OPERATOR = 3
, LOOKUP_FINGERPRINT = 4
};
struct libinjection_sqli_token {
#ifdef SWIG
%immutable;
#endif
/*
* position and length of token
* in original string
*/
size_t pos;
size_t len;
/* count:
* in type 'v', used for number of opening '@'
* but maybe used in other contexts
*/
int count;
char type;
char str_open;
char str_close;
char val[32];
};
typedef struct libinjection_sqli_token stoken_t;
/**
* Pointer to function, takes c-string input,
* returns '\0' for no match, else a char
*/
struct libinjection_sqli_state;
typedef char (*ptr_lookup_fn)(struct libinjection_sqli_state*, int lookuptype, const char* word, size_t len);
struct libinjection_sqli_state {
#ifdef SWIG
%immutable;
#endif
/*
* input, does not need to be null terminated.
* it is also not modified.
*/
const char *s;
/*
* input length
*/
size_t slen;
/*
* How to lookup a word or fingerprint
*/
ptr_lookup_fn lookup;
void* userdata;
/*
*
*/
int flags;
/*
* pos is the index in the string during tokenization
*/
size_t pos;
#ifndef SWIG
/* for SWIG.. don't use this.. use functional API instead */
/* MAX TOKENS + 1 since we use one extra token
* to determine the type of the previous token
*/
struct libinjection_sqli_token tokenvec[8];
#endif
/*
* Pointer to token position in tokenvec, above
*/
struct libinjection_sqli_token *current;
/*
* fingerprint pattern c-string
* +1 for ending null
* Minimum of 8 bytes to add gcc's -fstack-protector to work
*/
char fingerprint[8];
/*
* Line number of code that said decided if the input was SQLi or
* not. Most of the time it's line that said "it's not a matching
* fingerprint" but there is other logic that sometimes approves
* an input. This is only useful for debugging.
*
*/
int reason;
/* Number of ddw (dash-dash-white) comments
* These comments are in the form of
* '--[whitespace]' or '--[EOF]'
*
* All databases treat this as a comment.
*/
int stats_comment_ddw;
/* Number of ddx (dash-dash-[notwhite]) comments
*
* ANSI SQL treats these are comments, MySQL treats this as
* two unary operators '-' '-'
*
* If you are parsing result returns FALSE and
* stats_comment_dd > 0, you should reparse with
* COMMENT_MYSQL
*
*/
int stats_comment_ddx;
/*
* c-style comments found /x .. x/
*/
int stats_comment_c;
/* '#' operators or MySQL EOL comments found
*
*/
int stats_comment_hash;
/*
* number of tokens folded away
*/
int stats_folds;
/*
* total tokens processed
*/
int stats_tokens;
};
typedef struct libinjection_sqli_state sfilter;
struct libinjection_sqli_token* libinjection_sqli_get_token(
struct libinjection_sqli_state* sql_state, int i);
/*
* Version info.
*
* This is moved into a function to allow SWIG and other auto-generated
* binding to not be modified during minor release changes. We change
* change the version number in the c source file, and not regenerated
* the binding
*
* See python's normalized version
* http://www.python.org/dev/peps/pep-0386/#normalizedversion
*/
const char* libinjection_version(void);
/**
*
*/
void libinjection_sqli_init(struct libinjection_sqli_state *sf,
const char* s, size_t len,
int flags);
/**
* Main API: tests for SQLi in three possible contexts, no quotes,
* single quote and double quote
*
* \param sql_state core data structure
*
* \return 1 (true) if SQLi, 0 (false) if benign
*/
int libinjection_is_sqli(struct libinjection_sqli_state* sql_state);
/* FOR HACKERS ONLY
* provides deep hooks into the decision making process
*/
void libinjection_sqli_callback(struct libinjection_sqli_state *sf,
ptr_lookup_fn fn,
void* userdata);
/*
* Resets state, but keeps initial string and callbacks
*/
void libinjection_sqli_reset(struct libinjection_sqli_state *sf,
int flags);
/**
*
*/
/**
* This detects SQLi in a single context, mostly useful for custom
* logic and debugging.
*
* \param sql_state Main data structure
* \param flags flags to adjust parsing
*
* \returns a pointer to sfilter.fingerprint as convenience
* do not free!
*
*/
const char* libinjection_sqli_fingerprint(struct libinjection_sqli_state *sql_state,
int flags);
/**
* The default "word" to token-type or fingerprint function. This
* uses a ASCII case-insensitive binary tree.
*/
char libinjection_sqli_lookup_word(struct libinjection_sqli_state *sql_state,
int lookup_type,
const char* str,
size_t len);
/* Streaming tokenization interface.
*
* sql_state->current is updated with the current token.
*
* \returns 1, has a token, keep going, or 0 no tokens
*
*/
int libinjection_sqli_tokenize(struct libinjection_sqli_state *sf);
/**
* parses and folds input, up to 5 tokens
*
*/
int libinjection_sqli_fold(struct libinjection_sqli_state *sf);
/** The built-in default function to match fingerprints
* and do false negative/positive analysis. This calls the following
* two functions. With this, you over-ride one part or the other.
*
* return libinjection_sqli_blacklist(sql_state) &&
* libinjection_sqli_not_whitelist(sql_state);
*
* \param sql_state should be filled out after libinjection_sqli_fingerprint is called
*/
int libinjection_sqli_check_fingerprint(struct libinjection_sqli_state * sql_state);
/* Given a pattern determine if it's a SQLi pattern.
*
* \return TRUE if sqli, false otherwise
*/
int libinjection_sqli_blacklist(struct libinjection_sqli_state* sql_state);
/* Given a positive match for a pattern (i.e. pattern is SQLi), this function
* does additional analysis to reduce false positives.
*
* \return TRUE if SQLi, false otherwise
*/
int libinjection_sqli_not_whitelist(struct libinjection_sqli_state * sql_state);
#ifdef __cplusplus
}
#endif
#endif /* LIBINJECTION_SQLI_H */

Some files were not shown because too many files have changed in this diff Show More