Initial commit (code only without large binaries)

This commit is contained in:
robin
2026-02-15 18:58:44 +08:00
commit 35df75498f
9442 changed files with 1495866 additions and 0 deletions

View File

@@ -0,0 +1,7 @@
package setup
type Config struct {
APINodeProtocol string
APINodeHost string
APINodePort int
}

View File

@@ -0,0 +1,209 @@
package setup
import (
"encoding/json"
"fmt"
"github.com/TeaOSLab/EdgeAPI/internal/configs"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/cmd"
"github.com/iwind/TeaGo/types"
"os"
"strconv"
"strings"
)
type Setup struct {
config *Config
// 要返回的数据
AdminNodeId string
AdminNodeSecret string
logFp *os.File
}
func NewSetup(config *Config) *Setup {
return &Setup{
config: config,
}
}
func NewSetupFromCmd() *Setup {
var args = cmd.ParseArgs(strings.Join(os.Args[1:], " "))
var config = &Config{}
for _, arg := range args {
var index = strings.Index(arg, "=")
if index <= 0 {
continue
}
var value = arg[index+1:]
value = strings.Trim(value, "\"'")
switch arg[:index] {
case "-api-node-protocol":
config.APINodeProtocol = value
case "-api-node-host":
config.APINodeHost = value
case "-api-node-port":
config.APINodePort = types.Int(value)
}
}
var setup = NewSetup(config)
// log writer
var tmpDir = os.TempDir()
if len(tmpDir) > 0 {
fp, err := os.OpenFile(tmpDir+"/edge-install.log", os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0666)
if err == nil {
setup.logFp = fp
}
}
return setup
}
func (this *Setup) Run() error {
if this.config == nil {
return errors.New("config should not be nil")
}
if len(this.config.APINodeProtocol) == 0 {
return errors.New("api node protocol should not be empty")
}
if this.config.APINodeProtocol != "http" && this.config.APINodeProtocol != "https" {
return errors.New("invalid api node protocol: " + this.config.APINodeProtocol)
}
if len(this.config.APINodeHost) == 0 {
return errors.New("api node host should not be empty")
}
if this.config.APINodePort <= 0 {
return errors.New("api node port should not be less than 1")
}
// 执行SQL
config, err := configs.LoadDBConfig()
if err != nil {
return err
}
for _, db := range config.DBs {
// 可以同时运行多条语句
if !strings.Contains(db.Dsn, "multiStatements=") {
if strings.Contains(db.Dsn, "?") {
db.Dsn += "&multiStatements=true"
} else {
db.Dsn += "?multiStatements=true"
}
}
}
dbConfig, ok := config.DBs[Tea.Env]
if !ok {
return errors.New("can not find database config for env '" + Tea.Env + "'")
}
var executor = NewSQLExecutor(dbConfig)
if this.logFp != nil {
executor.SetLogWriter(this.logFp)
defer func() {
_ = this.logFp.Close()
_ = os.Remove(this.logFp.Name())
}()
}
err = executor.Run(false)
if err != nil {
return err
}
// Admin节点信息
var apiTokenDAO = models.NewApiTokenDAO()
token, err := apiTokenDAO.FindEnabledTokenWithRole(nil, "admin")
if err != nil {
return err
}
if token == nil {
return errors.New("can not find admin node token, please run the setup again")
}
this.AdminNodeId = token.NodeId
this.AdminNodeSecret = token.Secret
// 检查API节点
var dao = models.NewAPINodeDAO()
apiNodeId, err := dao.FindEnabledAPINodeIdWithAddr(nil, this.config.APINodeProtocol, this.config.APINodeHost, this.config.APINodePort)
if err != nil {
return err
}
if apiNodeId == 0 {
var addr = &serverconfigs.NetworkAddressConfig{
Protocol: serverconfigs.Protocol(this.config.APINodeProtocol),
Host: this.config.APINodeHost,
PortRange: strconv.Itoa(this.config.APINodePort),
}
addrsJSON, err := json.Marshal([]*serverconfigs.NetworkAddressConfig{addr})
if err != nil {
return fmt.Errorf("json encode api node addr failed: %w", err)
}
var httpJSON []byte = nil
var httpsJSON []byte = nil
if this.config.APINodeProtocol == "http" {
httpConfig := &serverconfigs.HTTPProtocolConfig{}
httpConfig.IsOn = true
httpConfig.Listen = []*serverconfigs.NetworkAddressConfig{
{
Protocol: "http",
PortRange: strconv.Itoa(this.config.APINodePort),
},
}
httpJSON, err = json.Marshal(httpConfig)
if err != nil {
return fmt.Errorf("json encode api node http config failed: %w", err)
}
}
if this.config.APINodeProtocol == "https" {
// TODO 如果在安装过程中开启了HTTPS需要同时上传SSL证书
var httpsConfig = &serverconfigs.HTTPSProtocolConfig{}
httpsConfig.IsOn = true
httpsConfig.Listen = []*serverconfigs.NetworkAddressConfig{
{
Protocol: "https",
PortRange: strconv.Itoa(this.config.APINodePort),
},
}
httpsJSON, err = json.Marshal(httpsConfig)
if err != nil {
return fmt.Errorf("json encode api node https config failed: %w", err)
}
}
// 创建API节点
nodeId, err := dao.CreateAPINode(nil, "默认API节点", "这是默认创建的第一个API节点", httpJSON, httpsJSON, false, nil, nil, addrsJSON, true)
if err != nil {
return fmt.Errorf("create api node in database failed: %w", err)
}
apiNodeId = nodeId
}
apiNode, err := dao.FindEnabledAPINode(nil, apiNodeId, nil)
if err != nil {
return err
}
if apiNode == nil {
return errors.New("apiNode should not be nil")
}
// 保存配置
var apiConfig = &configs.APIConfig{
NodeId: apiNode.UniqueId,
Secret: apiNode.Secret,
}
err = apiConfig.WriteFile(Tea.ConfigFile("api.yaml"))
if err != nil {
return fmt.Errorf("save config failed: %w", err)
}
return nil
}

View File

@@ -0,0 +1,19 @@
package setup
import (
_ "github.com/iwind/TeaGo/bootstrap"
"testing"
)
func TestSetup_Run(t *testing.T) {
setup := NewSetup(&Config{
APINodeProtocol: "http",
APINodeHost: "127.0.0.1",
APINodePort: 8003,
})
err := setup.Run()
if err != nil {
t.Fatal(err)
}
t.Log("OK")
}

View File

@@ -0,0 +1,8 @@
package setup
import (
_ "embed"
)
//go:embed sql.json
var sqlData []byte

259850
EdgeAPI/internal/setup/sql.json Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,581 @@
package setup
import (
"errors"
"fmt"
"github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
"io"
"regexp"
"runtime"
"sort"
"strings"
"sync"
)
var recordsTables = []*SQLRecordsTable{
{
TableName: "edgeRegionCountries",
UniqueFields: []string{"name"},
ExceptFields: []string{"customName", "customCodes"},
},
{
TableName: "edgeRegionProvinces",
UniqueFields: []string{"name", "countryId"},
ExceptFields: []string{"customName", "customCodes"},
},
{
TableName: "edgeRegionCities",
UniqueFields: []string{"name", "provinceId"},
ExceptFields: []string{"customName", "customCodes"},
},
{
TableName: "edgeRegionTowns",
UniqueFields: []string{"name", "cityId"},
ExceptFields: []string{"customName", "customCodes"},
},
{
TableName: "edgeRegionProviders",
UniqueFields: []string{"name"},
ExceptFields: []string{"customName", "customCodes"},
},
{
TableName: "edgeFormalClientSystems",
UniqueFields: []string{"dataId"},
},
{
TableName: "edgeFormalClientBrowsers",
UniqueFields: []string{"dataId"},
},
{
TableName: "edgeClientAgents",
UniqueFields: []string{"code"},
ExceptFields: []string{"countIPs"},
},
{
TableName: "edgeClientAgentIPs",
UniqueFields: []string{"agentId", "ip"},
IgnoreId: true,
},
{
TableName: "edgeMessageMedias",
UniqueFields: []string{"type"},
IgnoreId: true,
},
}
type sqlItem struct {
sqlString string
args []any
}
type SQLDump struct {
logWriter io.Writer
}
func NewSQLDump() *SQLDump {
return &SQLDump{}
}
func (this *SQLDump) SetLogWriter(logWriter io.Writer) {
this.logWriter = logWriter
}
// Dump 导出数据
func (this *SQLDump) Dump(db *dbs.DB, includingRecords bool) (result *SQLDumpResult, err error) {
result = &SQLDumpResult{}
tableNames, err := db.TableNames()
if err != nil {
return result, err
}
fullTableMap, err := this.findFullTables(db, tableNames)
if err != nil {
return nil, err
}
var autoIncrementReg = regexp.MustCompile(` AUTO_INCREMENT=\d+`)
for _, table := range fullTableMap {
var tableName = table.Name
// 忽略一些分表
if strings.HasPrefix(strings.ToLower(tableName), strings.ToLower("edgeHTTPAccessLogs_")) {
continue
}
if strings.HasPrefix(strings.ToLower(tableName), strings.ToLower("edgeNSAccessLogs_")) {
continue
}
var sqlTable = &SQLTable{
Name: table.Name,
Engine: table.Engine,
Charset: table.Collation,
Definition: autoIncrementReg.ReplaceAllString(table.Code, ""),
}
// 字段
var fields = []*SQLField{}
for _, field := range table.Fields {
fields = append(fields, &SQLField{
Name: field.Name,
Definition: field.Definition(),
})
}
sqlTable.Fields = fields
// 索引
var indexes = []*SQLIndex{}
for _, index := range table.Indexes {
indexes = append(indexes, &SQLIndex{
Name: index.Name,
Definition: index.Definition(),
})
}
sqlTable.Indexes = indexes
// Records
var records = []*SQLRecord{}
if includingRecords {
recordsTable := this.findRecordsTable(tableName)
if recordsTable != nil {
ones, _, err := db.FindOnes("SELECT * FROM " + tableName + " ORDER BY id ASC")
if err != nil {
return result, err
}
for _, one := range ones {
record := &SQLRecord{
Id: one.GetInt64("id"),
Values: map[string]string{},
UniqueFields: recordsTable.UniqueFields,
ExceptFields: recordsTable.ExceptFields,
}
for k, v := range one {
// 需要排除的字段
if lists.ContainsString(record.ExceptFields, k) {
continue
}
record.Values[k] = types.String(v)
}
records = append(records, record)
}
}
}
sqlTable.Records = records
result.Tables = append(result.Tables, sqlTable)
}
return
}
// Apply 应用数据
func (this *SQLDump) Apply(db *dbs.DB, newResult *SQLDumpResult, showLog bool) (ops []string, err error) {
// 设置Innodb事务提交模式
{
// 检查是否为root用户
config, _ := db.Config()
if config == nil {
return nil, nil
}
dsnConfig, err := mysql.ParseDSN(config.Dsn)
if err != nil || dsnConfig == nil {
return nil, err
}
if dsnConfig.User == "root" {
result, err := db.FindOne("SHOW VARIABLES WHERE variable_name='innodb_flush_log_at_trx_commit'")
if err == nil && result != nil {
var oldValue = result.GetInt("Value")
if oldValue == 1 {
_, _ = db.Exec("SET GLOBAL innodb_flush_log_at_trx_commit=2")
}
}
}
}
// 执行队列
var execQueue = make(chan *sqlItem, 256)
var threads = 32
var wg = sync.WaitGroup{}
wg.Add(threads + 1 /** applyQueue **/)
var applyOps []string
var applyErr error
go func() {
defer wg.Done()
defer close(execQueue)
applyOps, applyErr = this.applyQueue(db, newResult, showLog, execQueue)
}()
var sqlErrors = []error{}
var sqlErrLocker = &sync.Mutex{}
for i := 0; i < threads; i++ {
go func() {
defer wg.Done()
for item := range execQueue {
_, err := db.Exec(item.sqlString, item.args...)
if err != nil {
sqlErrLocker.Lock()
sqlErrors = append(sqlErrors, errors.New(item.sqlString+": "+err.Error()))
sqlErrLocker.Unlock()
break
}
}
}()
}
wg.Wait()
if applyErr != nil {
return nil, applyErr
}
if len(sqlErrors) == 0 {
// 升级数据
err = UpgradeSQLData(db)
if err != nil {
return nil, errors.New("upgrade data failed: " + err.Error())
}
return applyOps, nil
}
return nil, sqlErrors[0]
}
func (this *SQLDump) applyQueue(db *dbs.DB, newResult *SQLDumpResult, showLog bool, queue chan *sqlItem) (ops []string, err error) {
var execSQL = func(sqlString string, args ...any) {
queue <- &sqlItem{
sqlString: sqlString,
args: args,
}
}
currentResult, err := this.Dump(db, false)
if err != nil {
return nil, err
}
// 新增表格
for _, newTable := range newResult.Tables {
var oldTable = currentResult.FindTable(newTable.Name)
if oldTable == nil {
var op = "+ table " + newTable.Name
ops = append(ops, op)
if showLog {
this.log(op)
}
if len(newTable.Records) == 0 {
execSQL(newTable.Definition)
} else {
_, err = db.Exec(newTable.Definition)
if err != nil {
return nil, errors.New("'" + op + "' failed: " + err.Error())
}
}
} else if oldTable.Definition != newTable.Definition {
// 对比字段
// +
for _, newField := range newTable.Fields {
var oldField = oldTable.FindField(newField.Name)
if oldField == nil {
var op = "+ " + newTable.Name + " " + newField.Name
ops = append(ops, op)
if showLog {
this.log(op)
}
_, err = db.Exec("ALTER TABLE " + newTable.Name + " ADD `" + newField.Name + "` " + newField.Definition)
if err != nil {
return nil, errors.New("'" + op + "' failed: " + err.Error())
}
} else if !newField.EqualDefinition(oldField.Definition) {
var op = "* " + newTable.Name + " " + newField.Name
ops = append(ops, op)
if showLog {
this.log(op)
}
_, err = db.Exec("ALTER TABLE " + newTable.Name + " MODIFY `" + newField.Name + "` " + newField.Definition)
if err != nil {
return nil, errors.New("'" + op + "' failed: " + err.Error())
}
}
}
// 对比索引
// +
for _, newIndex := range newTable.Indexes {
var oldIndex = oldTable.FindIndex(newIndex.Name)
if oldIndex == nil {
var op = "+ index " + newTable.Name + " " + newIndex.Name
ops = append(ops, op)
if showLog {
this.log(op)
}
_, err = db.Exec("ALTER TABLE " + newTable.Name + " ADD " + newIndex.Definition)
if err != nil {
err = this.tryCreateIndex(err, db, newTable.Name, newIndex.Definition)
if err != nil {
return nil, errors.New("'" + op + "' failed: " + err.Error())
}
}
} else if oldIndex.Definition != newIndex.Definition {
var op = "* index " + newTable.Name + " " + newIndex.Name
ops = append(ops, op)
if showLog {
this.log(op)
}
_, err = db.Exec("ALTER TABLE " + newTable.Name + " DROP KEY " + newIndex.Name)
if err != nil {
return nil, errors.New("'" + op + "' drop old key failed: " + err.Error())
}
_, err = db.Exec("ALTER TABLE " + newTable.Name + " ADD " + newIndex.Definition)
if err != nil {
err = this.tryCreateIndex(err, db, newTable.Name, newIndex.Definition)
if err != nil {
return nil, errors.New("'" + op + "' failed: " + err.Error())
}
}
}
}
// -
for _, oldIndex := range oldTable.Indexes {
var newIndex = newTable.FindIndex(oldIndex.Name)
if newIndex == nil {
var op = "- index " + oldTable.Name + " " + oldIndex.Name
ops = append(ops, op)
if showLog {
this.log(op)
}
_, err = db.Exec("ALTER TABLE " + oldTable.Name + " DROP KEY " + oldIndex.Name)
if err != nil {
return nil, errors.New("'" + op + "' failed: " + err.Error())
}
}
}
// 对比字段
// -
for _, oldField := range oldTable.Fields {
var newField = newTable.FindField(oldField.Name)
if newField == nil {
var op = "- field " + oldTable.Name + " " + oldField.Name
ops = append(ops, op)
if showLog {
this.log(op)
}
_, err = db.Exec("ALTER TABLE " + oldTable.Name + " DROP COLUMN `" + oldField.Name + "`")
if err != nil {
return nil, errors.New("'" + op + "' failed: " + err.Error())
}
}
}
}
// 对比记录
// +
var newRecordsTable = this.findRecordsTable(newTable.Name)
for _, record := range newTable.Records {
var queryArgs = []string{}
var queryValues = []any{}
var valueStrings = []string{}
for _, field := range record.UniqueFields {
queryArgs = append(queryArgs, field+"=?")
queryValues = append(queryValues, record.Values[field])
valueStrings = append(valueStrings, record.Values[field])
}
var recordId int64
for field, recordValue := range record.Values {
if field == "id" {
recordId = types.Int64(recordValue)
break
}
}
var one maps.Map
if newRecordsTable != nil && newRecordsTable.IgnoreId {
one, err = db.FindOne("SELECT * FROM "+newTable.Name+" WHERE (("+strings.Join(queryArgs, " AND ")+"))", queryValues...)
} else {
queryValues = append(queryValues, recordId)
one, err = db.FindOne("SELECT * FROM "+newTable.Name+" WHERE (("+strings.Join(queryArgs, " AND ")+") OR id=?)", queryValues...)
}
if err != nil {
return nil, err
}
if one == nil {
ops = append(ops, "+ record "+newTable.Name+" "+strings.Join(valueStrings, ", "))
if showLog {
// 不记录详细日志,防止小白用户误解日志内容
// this.log("+ record " + newTable.Name + " " + strings.Join(valueStrings, ", "))
}
var params = []string{}
var args = []string{}
var values = []any{}
for k, v := range record.Values {
// 需要排除的字段
if lists.ContainsString(record.ExceptFields, k) {
continue
}
if newRecordsTable != nil && newRecordsTable.IgnoreId && k == "id" {
continue
}
params = append(params, "`"+k+"`")
args = append(args, "?")
values = append(values, v)
}
execSQL("INSERT INTO "+newTable.Name+" ("+strings.Join(params, ", ")+") VALUES ("+strings.Join(args, ", ")+")", values...)
} else if !record.ValuesEquals(one) {
ops = append(ops, "* record "+newTable.Name+" "+strings.Join(valueStrings, ", "))
if showLog {
// 不记录详细日志,防止小白用户误解日志内容
// this.log("* record " + newTable.Name + " " + strings.Join(valueStrings, ", "))
}
var args = []string{}
var values = []any{}
for k, v := range record.Values {
if k == "id" {
continue
}
// 需要排除的字段
if lists.ContainsString(record.ExceptFields, k) {
continue
}
args = append(args, "`"+k+"`"+"=?")
values = append(values, v)
}
values = append(values, one.GetInt("id"))
execSQL("UPDATE "+newTable.Name+" SET "+strings.Join(args, ", ")+" WHERE id=?", values...)
}
}
}
// 减少表格
// 由于我们不删除任何表格,所以这里什么都不做
return
}
// 查找所有表的完整信息
func (this *SQLDump) findFullTables(db *dbs.DB, tableNames []string) ([]*dbs.Table, error) {
var fullTables = []*dbs.Table{}
if len(tableNames) == 0 {
return fullTables, nil
}
var locker = &sync.Mutex{}
var queue = make(chan string, len(tableNames))
for _, tableName := range tableNames {
queue <- tableName
}
var wg = &sync.WaitGroup{}
var concurrent = 8
if runtime.NumCPU() > 4 {
concurrent = 32
}
wg.Add(concurrent)
var lastErr error
for i := 0; i < concurrent; i++ {
go func() {
defer wg.Done()
for {
select {
case tableName := <-queue:
table, err := db.FindFullTable(tableName)
if err != nil {
locker.Lock()
lastErr = err
locker.Unlock()
return
}
locker.Lock()
table.Name = tableName
fullTables = append(fullTables, table)
locker.Unlock()
default:
return
}
}
}()
}
wg.Wait()
if lastErr != nil {
return nil, lastErr
}
// 排序
sort.Slice(fullTables, func(i, j int) bool {
return fullTables[i].Name < fullTables[j].Name
})
return fullTables, nil
}
// 查找有记录的表
func (this *SQLDump) findRecordsTable(tableName string) *SQLRecordsTable {
for _, table := range recordsTables {
if table.TableName == tableName {
return table
}
}
return nil
}
// 创建索引
func (this *SQLDump) tryCreateIndex(err error, db *dbs.DB, tableName string, indexDefinition string) error {
if err == nil {
return nil
}
// 处理Duplicate entry
if strings.Contains(err.Error(), "Error 1062: Duplicate entry") && (strings.HasSuffix(tableName, "Stats") || strings.HasSuffix(tableName, "Values")) {
var tries = 5 // 尝试次数
for i := 0; i < tries; i++ {
_, err = db.Exec("TRUNCATE TABLE " + tableName)
if err != nil {
if i == tries-1 {
return err
}
continue
}
_, err = db.Exec("ALTER TABLE " + tableName + " ADD " + indexDefinition)
if err != nil {
if i == tries-1 {
return err
}
} else {
return nil
}
}
}
return err
}
// 打印操作日志
func (this *SQLDump) log(message string) {
if this.logWriter != nil {
_, _ = this.logWriter.Write([]byte(message + "\n"))
} else {
fmt.Println(message)
}
}

View File

@@ -0,0 +1,16 @@
package setup
import "strings"
type SQLDumpResult struct {
Tables []*SQLTable `json:"tables"`
}
func (this *SQLDumpResult) FindTable(tableName string) *SQLTable {
for _, table := range this.Tables {
if strings.EqualFold(table.Name, tableName) {
return table
}
}
return nil
}

View File

@@ -0,0 +1,99 @@
package setup
import (
"encoding/json"
"github.com/iwind/TeaGo/dbs"
"testing"
"time"
)
func TestSQLDump_Dump(t *testing.T) {
db, err := dbs.NewInstanceFromConfig(&dbs.DBConfig{
Driver: "mysql",
Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge?charset=utf8mb4&timeout=30s",
Prefix: "edge",
})
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
dump := NewSQLDump()
result, err := dump.Dump(db, true)
if err != nil {
t.Fatal(err)
}
// Table
for _, table := range result.Tables {
_ = table
//t.Log(table.Name, table.Engine, table.Charset)
/**for _, field := range table.Fields {
t.Log("===", field.Name, ":", field.Definition)
}**/
/**for _, index := range table.Indexes {
t.Log("===", index.Name, ":", index.Definition)
}**/
/**for _, record := range table.Records {
t.Log(record.Id, record.Values)
}**/
}
data, err := json.Marshal(result)
if err != nil {
t.Fatal(err)
}
t.Log(len(data), "bytes")
}
func TestSQLDump_Apply(t *testing.T) {
db, err := dbs.NewInstanceFromConfig(&dbs.DBConfig{
Driver: "mysql",
Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge?charset=utf8mb4&timeout=30s",
Prefix: "edge",
})
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
var dump = NewSQLDump()
result, err := dump.Dump(db, true)
if err != nil {
t.Fatal(err)
}
var before = time.Now()
defer func() {
t.Log("cost:", time.Since(before))
}()
db2, err := dbs.NewInstanceFromConfig(&dbs.DBConfig{
Driver: "mysql",
Dsn: "edge:123456@tcp(192.168.2.60:3306)/db_edge_new?charset=utf8mb4&timeout=30s",
Prefix: "edge",
})
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db2.Close()
}()
ops, err := dump.Apply(db2, result, false)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
/**if len(ops) > 0 {
for _, op := range ops {
t.Log("", op)
}
}**/
_ = ops
}

View File

@@ -0,0 +1,582 @@
package setup
import (
"crypto/rand"
"encoding/json"
"fmt"
"io"
"time"
"github.com/TeaOSLab/EdgeAPI/internal/configs"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types"
stringutil "github.com/iwind/TeaGo/utils/string"
)
// SQLExecutor 安装或升级SQL执行器
type SQLExecutor struct {
dbConfig *dbs.DBConfig
logWriter io.Writer
}
func NewSQLExecutor(dbConfig *dbs.DBConfig) *SQLExecutor {
return &SQLExecutor{
dbConfig: dbConfig,
}
}
func NewSQLExecutorFromCmd() (*SQLExecutor, error) {
// 执行SQL
config, err := configs.LoadDBConfig()
if err != nil {
return nil, err
}
return NewSQLExecutor(config.DBs[Tea.Env]), nil
}
func (this *SQLExecutor) SetLogWriter(logWriter io.Writer) {
this.logWriter = logWriter
}
func (this *SQLExecutor) Run(showLog bool) error {
db, err := dbs.NewInstanceFromConfig(this.dbConfig)
if err != nil {
return err
}
// prevent default configure loading
var globalConfig = dbs.GlobalConfig()
if globalConfig != nil && len(globalConfig.DBs) == 0 {
// 同时设置 dev 和 prod 环境,确保 DAO 初始化时能找到配置
globalConfig.DBs = map[string]*dbs.DBConfig{
"dev": this.dbConfig,
"prod": this.dbConfig,
}
globalConfig.Default.DB = "prod"
}
defer func() {
_ = db.Close()
}()
var sqlDump = NewSQLDump()
sqlDump.SetLogWriter(this.logWriter)
if this.logWriter != nil {
showLog = true
}
var sqlResult = &SQLDumpResult{}
err = json.Unmarshal(sqlData, sqlResult)
if err != nil {
return fmt.Errorf("decode sql data failed: %w", err)
}
_, err = sqlDump.Apply(db, sqlResult, showLog)
if err != nil {
return err
}
// 检查数据
err = this.checkData(db)
if err != nil {
return err
}
return nil
}
// 检查数据
func (this *SQLExecutor) checkData(db *dbs.DB) error {
// 检查管理员平台节点
err := this.checkAdminNode(db)
if err != nil {
return fmt.Errorf("check admin node failed: %w", err)
}
// 检查用户平台节点
err = this.checkUserNode(db)
if err != nil {
return fmt.Errorf("check user node failed: %w", err)
}
// 检查集群配置
err = this.checkCluster(db)
if err != nil {
return fmt.Errorf("check cluster failed: %w", err)
}
// 检查初始化用户
// 需要放在检查集群后面
err = this.checkUser(db)
if err != nil {
return fmt.Errorf("check user failed: %w", err)
}
// 检查IP名单
err = this.checkIPList(db)
if err != nil {
return fmt.Errorf("check ip list failed: %w", err)
}
// 检查指标设置
err = this.checkMetricItems(db)
if err != nil {
return fmt.Errorf("check metric items failed: %w", err)
}
// 检查自建DNS全局设置
err = this.checkNS(db)
if err != nil {
return fmt.Errorf("check ns failed: %w", err)
}
// 更新Agents
err = this.checkClientAgents(db)
if err != nil {
return fmt.Errorf("check client agents failed: %w", err)
}
// 更新版本号
err = this.updateVersion(db, ComposeSQLVersion())
if err != nil {
return fmt.Errorf("update version failed: %w", err)
}
return nil
}
// 创建初始用户
func (this *SQLExecutor) checkUser(db *dbs.DB) error {
one, err := db.FindOne("SELECT id FROM edgeUsers LIMIT 1")
if err != nil {
return err
}
if len(one) > 0 {
return nil
}
// 读取默认集群ID
// Read default cluster id
clusterId, err := db.FindCol(0, "SELECT id FROM edgeNodeClusters WHERE state=1 ORDER BY id ASC LIMIT 1")
if err != nil {
return err
}
_, err = db.Exec("INSERT INTO edgeUsers (`username`, `password`, `fullname`, `isOn`, `state`, `createdAt`, `clusterId`) VALUES (?, ?, ?, ?, ?, ?, ?)", "USER_"+rands.HexString(10), stringutil.Md5(rands.HexString(32)), "默认用户", 1, 1, time.Now().Unix(), clusterId)
return err
}
// 检查管理员平台节点
func (this *SQLExecutor) checkAdminNode(db *dbs.DB) error {
stmt, err := db.Prepare("SELECT COUNT(*) FROM edgeAPITokens WHERE role='admin'")
if err != nil {
return err
}
defer func() {
_ = stmt.Close()
}()
col, err := stmt.FindCol(0)
if err != nil {
return err
}
var count = types.Int(col)
if count > 0 {
return nil
}
var nodeId = rands.HexString(32)
var secret = rands.String(32)
_, err = db.Exec("INSERT INTO edgeAPITokens (nodeId, secret, role) VALUES (?, ?, ?)", nodeId, secret, "admin")
if err != nil {
return err
}
return nil
}
// 检查用户平台节点
func (this *SQLExecutor) checkUserNode(db *dbs.DB) error {
stmt, err := db.Prepare("SELECT COUNT(*) FROM edgeAPITokens WHERE role='user'")
if err != nil {
return err
}
defer func() {
_ = stmt.Close()
}()
col, err := stmt.FindCol(0)
if err != nil {
return err
}
var count = types.Int(col)
if count > 0 {
return nil
}
var nodeId = rands.HexString(32)
var secret = rands.String(32)
_, err = db.Exec("INSERT INTO edgeAPITokens (nodeId, secret, role) VALUES (?, ?, ?)", nodeId, secret, "user")
if err != nil {
return err
}
return nil
}
// 检查集群配置
func (this *SQLExecutor) checkCluster(db *dbs.DB) error {
/// 检查是否有集群数据
stmt, err := db.Prepare("SELECT COUNT(*) FROM edgeNodeClusters")
if err != nil {
return fmt.Errorf("query clusters failed: %w", err)
}
defer func() {
_ = stmt.Close()
}()
col, err := stmt.FindCol(0)
if err != nil {
return fmt.Errorf("query clusters failed: %w", err)
}
var count = types.Int(col)
if count > 0 {
return nil
}
// 创建默认集群
var uniqueId = rands.HexString(32)
var secret = rands.String(32)
var clusterDNSConfig = &dnsconfigs.ClusterDNSConfig{
NodesAutoSync: true,
ServersAutoSync: true,
CNAMERecords: []string{},
CNAMEAsDomain: true,
TTL: 0,
IncludingLnNodes: true,
}
clusterDNSConfigJSON, err := json.Marshal(clusterDNSConfig)
if err != nil {
return err
}
var defaultDNSName = "g" + rands.HexString(6) + ".cdn"
{
var b = make([]byte, 3)
_, err = rand.Read(b)
if err == nil {
defaultDNSName = fmt.Sprintf("g%x.cdn", b)
}
}
_, err = db.Exec("INSERT INTO edgeNodeClusters (name, useAllAPINodes, state, uniqueId, secret, dns, dnsName) VALUES (?, ?, ?, ?, ?, ?, ?)", "默认集群", 1, 1, uniqueId, secret, string(clusterDNSConfigJSON), defaultDNSName)
if err != nil {
return err
}
// 创建APIToken
_, err = db.Exec("INSERT INTO edgeAPITokens (nodeId, secret, role, state) VALUES (?, ?, 'cluster', 1)", uniqueId, secret)
if err != nil {
return err
}
// 默认缓存策略
models.SharedHTTPCachePolicyDAO = models.NewHTTPCachePolicyDAO()
models.SharedHTTPCachePolicyDAO.Instance = db
tx, err := db.Begin()
if err != nil {
return err
}
policyId, err := models.SharedHTTPCachePolicyDAO.CreateDefaultCachePolicy(tx, "默认集群")
if err != nil {
_ = tx.Rollback()
return err
}
err = tx.Commit()
if err != nil {
return err
}
_, err = db.Exec("UPDATE edgeNodeClusters SET cachePolicyId=?", policyId)
if err != nil {
return err
}
// 默认WAf策略
models.SharedHTTPFirewallPolicyDAO = models.NewHTTPFirewallPolicyDAO()
models.SharedHTTPFirewallPolicyDAO.Instance = db
models.SharedHTTPFirewallRuleGroupDAO = models.NewHTTPFirewallRuleGroupDAO()
models.SharedHTTPFirewallRuleGroupDAO.Instance = db
models.SharedHTTPFirewallRuleSetDAO = models.NewHTTPFirewallRuleSetDAO()
models.SharedHTTPFirewallRuleSetDAO.Instance = db
models.SharedHTTPFirewallRuleDAO = models.NewHTTPFirewallRuleDAO()
models.SharedHTTPFirewallRuleDAO.Instance = db
models.SharedHTTPWebDAO = models.NewHTTPWebDAO()
models.SharedHTTPWebDAO.Instance = db
models.SharedServerDAO = models.NewServerDAO()
models.SharedServerDAO.Instance = db
models.SharedNodeClusterDAO = models.NewNodeClusterDAO()
models.SharedNodeClusterDAO.Instance = db
models.SharedIPListDAO = models.NewIPListDAO()
models.SharedIPListDAO.Instance = db
tx, err = db.Begin()
if err != nil {
return err
}
policyId, err = models.SharedHTTPFirewallPolicyDAO.CreateDefaultFirewallPolicy(tx, "默认集群")
if err != nil {
_ = tx.Rollback()
return err
}
err = tx.Commit()
if err != nil {
return err
}
_, err = db.Exec("UPDATE edgeNodeClusters SET httpFirewallPolicyId=?", policyId)
if err != nil {
return err
}
return nil
}
// 检查IP名单
func (this *SQLExecutor) checkIPList(db *dbs.DB) error {
stmt, err := db.Prepare("SELECT COUNT(*) FROM edgeIPLists")
if err != nil {
return fmt.Errorf("query ip lists failed: %w", err)
}
defer func() {
_ = stmt.Close()
}()
col, err := stmt.FindCol(0)
if err != nil {
return fmt.Errorf("query ip lists failed: %w", err)
}
var count = types.Int(col)
if count > 0 {
return nil
}
// 创建名单
_, err = db.Exec("INSERT INTO edgeIPLists(name, type, code, isPublic, isGlobal, createdAt) VALUES (?, ?, ?, ?, ?, ?)", "公共黑名单", "black", "black", 1, 1, time.Now().Unix())
if err != nil {
return err
}
_, err = db.Exec("INSERT INTO edgeIPLists(name, type, code, isPublic, isGlobal, createdAt) VALUES (?, ?, ?, ?, ?, ?)", "公共白名单", "white", "white", 1, 1, time.Now().Unix())
if err != nil {
return err
}
return nil
}
// 检查统计指标
func (this *SQLExecutor) checkMetricItems(db *dbs.DB) error {
var createMetricItem = func(code string,
category string,
name string,
keys []string,
period int,
periodUnit string,
value string,
chartMaps []maps.Map,
) error {
// 检查是否已创建
itemMap, err := db.FindOne("SELECT id FROM edgeMetricItems WHERE code=? LIMIT 1", code)
if err != nil {
return err
}
if len(itemMap) == 0 {
keysJSON, err := json.Marshal(keys)
if err != nil {
return err
}
_, err = db.Exec("INSERT INTO edgeMetricItems (isOn, code, category, name, `keys`, period, periodUnit, value, state, isPublic) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", 1, code, category, name, keysJSON, period, periodUnit, value, 1, 1)
if err != nil {
return err
}
// 再次查询
itemMap, err = db.FindOne("SELECT id FROM edgeMetricItems WHERE code=? LIMIT 1", code)
if err != nil {
return err
}
}
var itemId = itemMap.GetInt64("id")
// chart
for _, chartMap := range chartMaps {
var chartCode = chartMap.GetString("code")
one, err := db.FindOne("SELECT id FROM edgeMetricCharts WHERE itemId=? AND code=? LIMIT 1", itemId, chartCode)
if err != nil {
return err
}
if len(one) == 0 {
_, err = db.Exec("INSERT INTO edgeMetricCharts (itemId, name, code, type, widthDiv, params, isOn, state) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", itemId, chartMap.GetString("name"), chartCode, chartMap.GetString("type"), chartMap.GetInt("widthDiv"), "{}", 1, 1)
if err != nil {
return err
}
}
}
return nil
}
{
err := createMetricItem("ip_requests", serverconfigs.MetricItemCategoryHTTP, "独立IP请求数", []string{"${remoteAddr}"}, 1, "day", "${countRequest}", []maps.Map{
{
"name": "独立IP排行",
"type": "bar",
"widthDiv": 0,
"code": "ip_requests_bar",
},
})
if err != nil {
return err
}
}
{
err := createMetricItem("ip_traffic_out", serverconfigs.MetricItemCategoryHTTP, "独立IP下行流量", []string{"${remoteAddr}"}, 1, "day", "${countTrafficOut}", []maps.Map{
{
"name": "独立IP排行",
"type": "bar",
"widthDiv": 0,
"code": "ip_traffic_out_bar",
},
})
if err != nil {
return err
}
}
{
err := createMetricItem("request_path", serverconfigs.MetricItemCategoryHTTP, "请求路径统计", []string{"${requestPath}"}, 1, "day", "${countRequest}", []maps.Map{
{
"name": "请求路径排行",
"type": "bar",
"widthDiv": 0,
"code": "request_path_bar",
},
})
if err != nil {
return err
}
}
{
err := createMetricItem("request_method", serverconfigs.MetricItemCategoryHTTP, "请求方法统计", []string{"${requestMethod}"}, 1, "day", "${countRequest}", []maps.Map{
{
"name": "请求方法分布",
"type": "pie",
"widthDiv": 2,
"code": "request_method_pie",
},
})
if err != nil {
return err
}
}
{
err := createMetricItem("status", serverconfigs.MetricItemCategoryHTTP, "状态码统计", []string{"${status}"}, 1, "day", "${countRequest}", []maps.Map{
{
"name": "状态码分布",
"type": "pie",
"widthDiv": 2,
"code": "status_pie",
},
})
if err != nil {
return err
}
}
{
err := createMetricItem("request_referer_host", serverconfigs.MetricItemCategoryHTTP, "请求来源统计", []string{"${referer.host}"}, 1, "day", "${countRequest}", []maps.Map{
{
"name": "请求来源排行",
"type": "bar",
"widthDiv": 0,
"code": "request_referer_host_bar",
},
})
if err != nil {
return err
}
}
return nil
}
// 更新Agents表
func (this *SQLExecutor) checkClientAgents(db *dbs.DB) error {
ones, _, err := db.FindOnes("SELECT id FROM edgeClientAgents")
if err != nil {
return err
}
for _, one := range ones {
var agentId = one.GetInt64("id")
countIPs, err := db.FindCol(0, "SELECT COUNT(*) FROM edgeClientAgentIPs WHERE agentId=?", agentId)
if err != nil {
return err
}
_, err = db.Exec("UPDATE edgeClientAgents SET countIPs=? WHERE id=?", countIPs, agentId)
if err != nil {
return err
}
}
return nil
}
// 更新版本号
func (this *SQLExecutor) updateVersion(db *dbs.DB, version string) error {
stmt, err := db.Prepare("SELECT COUNT(*) FROM edgeVersions")
if err != nil {
return fmt.Errorf("query version failed: %w", err)
}
defer func() {
_ = stmt.Close()
}()
col, err := stmt.FindCol(0)
if err != nil {
return fmt.Errorf("query version failed: %w", err)
}
var count = types.Int(col)
if count > 0 {
_, err = db.Exec("UPDATE edgeVersions SET version=?", version)
if err != nil {
return fmt.Errorf("update version failed: %w", err)
}
return nil
}
_, err = db.Exec("INSERT edgeVersions (version) VALUES (?)", version)
if err != nil {
return fmt.Errorf("create version failed: %w", err)
}
return nil
}

View File

@@ -0,0 +1,13 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build !plus
package setup
import (
"github.com/iwind/TeaGo/dbs"
)
// 检查自建DNS全局设置
func (this *SQLExecutor) checkNS(db *dbs.DB) error {
return nil
}

View File

@@ -0,0 +1,38 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package setup
import (
"encoding/json"
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/systemconfigs"
"github.com/iwind/TeaGo/dbs"
)
// 检查自建DNS全局设置
func (this *SQLExecutor) checkNS(db *dbs.DB) error {
// 访问日志
{
one, err := db.FindOne("SELECT id FROM edgeSysSettings WHERE code=? LIMIT 1", systemconfigs.SettingCodeNSAccessLogSetting)
if err != nil {
return err
}
if len(one) == 0 {
ref := &dnsconfigs.NSAccessLogRef{
IsPrior: false,
IsOn: true,
LogMissingDomains: false,
}
refJSON, err := json.Marshal(ref)
if err != nil {
return err
}
_, err = db.Exec("INSERT edgeSysSettings (code, value) VALUES (?, ?)", systemconfigs.SettingCodeNSAccessLogSetting, refJSON)
if err != nil {
return err
}
}
}
return nil
}

View File

@@ -0,0 +1,103 @@
package setup
import (
"github.com/iwind/TeaGo/dbs"
"testing"
)
func TestSQLExecutor_Run(t *testing.T) {
var executor = NewSQLExecutor(&dbs.DBConfig{
Driver: "mysql",
Prefix: "edge",
Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge_new?charset=utf8mb4&multiStatements=true",
})
err := executor.Run(false)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestSQLExecutor_checkCluster(t *testing.T) {
var executor = NewSQLExecutor(&dbs.DBConfig{
Driver: "mysql",
Prefix: "edge",
Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge_new?charset=utf8mb4&multiStatements=true",
})
db, err := dbs.NewInstanceFromConfig(executor.dbConfig)
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
err = executor.checkCluster(db)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestSQLExecutor_checkMetricItems(t *testing.T) {
var executor = NewSQLExecutor(&dbs.DBConfig{
Driver: "mysql",
Prefix: "edge",
Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge_new?charset=utf8mb4&multiStatements=true",
})
db, err := dbs.NewInstanceFromConfig(executor.dbConfig)
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
err = executor.checkMetricItems(db)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestSQLExecutor_checkNS(t *testing.T) {
var executor = NewSQLExecutor(&dbs.DBConfig{
Driver: "mysql",
Prefix: "edge",
Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge_new?charset=utf8mb4&multiStatements=true",
})
db, err := dbs.NewInstanceFromConfig(executor.dbConfig)
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
err = executor.checkNS(db)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestSQLExecutor_checkClientAgents(t *testing.T) {
var executor = NewSQLExecutor(&dbs.DBConfig{
Driver: "mysql",
Prefix: "edge",
Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge?charset=utf8mb4&multiStatements=true",
})
db, err := dbs.NewInstanceFromConfig(executor.dbConfig)
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
err = executor.checkClientAgents(db)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}

View File

@@ -0,0 +1,19 @@
package setup
import "regexp"
type SQLField struct {
Name string `json:"name"`
Definition string `json:"definition"`
}
func (this *SQLField) EqualDefinition(definition2 string) bool {
if this.Definition == definition2 {
return true
}
// 针对MySQL v8.0.17以后
def1 := regexp.MustCompile(`(?)(tinyint|smallint|mediumint|int|bigint)\(\d+\)`).
ReplaceAllString(this.Definition, "${1}")
return def1 == definition2
}

View File

@@ -0,0 +1,6 @@
package setup
type SQLIndex struct {
Name string `json:"name"`
Definition string `json:"definition"`
}

View File

@@ -0,0 +1,34 @@
package setup
import (
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
)
type SQLRecord struct {
Id int64 `json:"id"`
Values map[string]string `json:"values"`
UniqueFields []string `json:"uniqueFields"`
ExceptFields []string `json:"exceptFields"`
}
func (this *SQLRecord) ValuesEquals(values maps.Map) bool {
for k, v := range values {
// 跳过ID
if k == "id" {
continue
}
// 需要排除的字段
if lists.ContainsString(this.ExceptFields, k) {
continue
}
var vString = types.String(v)
if this.Values[k] != vString {
return false
}
}
return true
}

View File

@@ -0,0 +1,8 @@
package setup
type SQLRecordsTable struct {
TableName string
UniqueFields []string
ExceptFields []string
IgnoreId bool // 是否可以排除ID
}

View File

@@ -0,0 +1,38 @@
package setup
type SQLTable struct {
Name string `json:"name"`
Engine string `json:"engine"`
Charset string `json:"charset"`
Definition string `json:"definition"`
Fields []*SQLField `json:"fields"`
Indexes []*SQLIndex `json:"indexes"`
Records []*SQLRecord `json:"records"`
}
func (this *SQLTable) FindField(fieldName string) *SQLField {
for _, field := range this.Fields {
if field.Name == fieldName {
return field
}
}
return nil
}
func (this *SQLTable) FindIndex(indexName string) *SQLIndex {
for _, index := range this.Indexes {
if index.Name == indexName {
return index
}
}
return nil
}
func (this *SQLTable) FindRecord(id int64) *SQLRecord {
for _, record := range this.Records {
if record.Id == id {
return record
}
}
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,228 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build !plus
package setup
import (
"encoding/json"
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/systemconfigs"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/types"
timeutil "github.com/iwind/TeaGo/utils/time"
"regexp"
)
// v0.2.8.1
func upgradeV0_2_8_1(db *dbs.DB) error {
// 升级EdgeDNS线路
ones, _, err := db.FindOnes("SELECT id, dnsRoutes FROM edgeNodes WHERE dnsRoutes IS NOT NULL")
if err != nil {
return err
}
for _, one := range ones {
var nodeId = one.GetInt64("id")
var dnsRoutes = one.GetString("dnsRoutes")
if len(dnsRoutes) == 0 {
continue
}
var m = map[string][]string{}
err = json.Unmarshal([]byte(dnsRoutes), &m)
if err != nil {
continue
}
var isChanged = false
var reg = regexp.MustCompile(`^\d+$`)
for k, routes := range m {
for index, route := range routes {
if reg.MatchString(route) {
route = "id:" + route
isChanged = true
}
routes[index] = route
}
m[k] = routes
}
if isChanged {
mJSON, err := json.Marshal(m)
if err != nil {
return err
}
_, err = db.Exec("UPDATE edgeNodes SET dnsRoutes=? WHERE id=? LIMIT 1", string(mJSON), nodeId)
if err != nil {
return err
}
}
}
return nil
}
// v0.4.9
func upgradeV0_4_9(db *dbs.DB) error {
// 升级管理配置
{
one, err := db.FindOne("SELECT value FROM edgeSysSettings WHERE code=?", systemconfigs.SettingCodeAdminSecurityConfig)
if err != nil {
return err
}
if one != nil {
var valueJSON = one.GetBytes("value")
if len(valueJSON) > 0 {
var config = &systemconfigs.SecurityConfig{}
err = json.Unmarshal(valueJSON, config)
if err == nil {
config.DenySearchEngines = true
config.DenySpiders = true
configJSON, err := json.Marshal(config)
if err != nil {
return fmt.Errorf("encode SecurityConfig failed: %w", err)
} else {
_, err := db.Exec("UPDATE edgeSysSettings SET value=? WHERE code=?", configJSON, systemconfigs.SettingCodeAdminSecurityConfig)
if err != nil {
return err
}
}
}
}
}
}
return nil
}
// v0.5.3
func upgradeV0_5_3(db *dbs.DB) error {
// 升级集群服务配置
{
type oldGlobalConfig struct {
// HTTP & HTTPS相关配置
HTTPAll struct {
MatchDomainStrictly bool `yaml:"matchDomainStrictly" json:"matchDomainStrictly"` // 是否严格匹配域名
AllowMismatchDomains []string `yaml:"allowMismatchDomains" json:"allowMismatchDomains"` // 允许的不匹配的域名
DefaultDomain string `yaml:"defaultDomain" json:"defaultDomain"` // 默认的域名
DomainMismatchAction *serverconfigs.DomainMismatchAction `yaml:"domainMismatchAction" json:"domainMismatchAction"` // 不匹配时采取的动作
} `yaml:"httpAll" json:"httpAll"`
}
value, err := db.FindCol(0, "SELECT value FROM edgeSysSettings WHERE code='serverGlobalConfig'")
if err != nil {
return err
}
if value != nil {
var valueJSON = []byte(types.String(value))
var oldConfig = &oldGlobalConfig{}
err = json.Unmarshal(valueJSON, oldConfig)
if err == nil {
var newConfig = &serverconfigs.GlobalServerConfig{}
newConfig.HTTPAll.MatchDomainStrictly = oldConfig.HTTPAll.MatchDomainStrictly
newConfig.HTTPAll.AllowMismatchDomains = oldConfig.HTTPAll.AllowMismatchDomains
newConfig.HTTPAll.DefaultDomain = oldConfig.HTTPAll.DefaultDomain
if oldConfig.HTTPAll.DomainMismatchAction != nil {
newConfig.HTTPAll.DomainMismatchAction = oldConfig.HTTPAll.DomainMismatchAction
}
newConfig.HTTPAll.AllowNodeIP = true
newConfig.Log.RecordServerError = false
newConfigJSON, err := json.Marshal(newConfig)
if err == nil {
_, err = db.Exec("UPDATE edgeNodeClusters SET globalServerConfig=?", newConfigJSON)
if err != nil {
return err
}
}
}
}
}
return nil
}
// v0.5.6
func upgradeV0_5_6(db *dbs.DB) error {
// 修复默认集群的DNS设置
{
var id = 1
clusterMap, err := db.FindOne("SELECT dns FROM edgeNodeClusters WHERE id=? AND state=1", id)
if err != nil {
return err
}
if len(clusterMap) > 0 {
var dnsString = clusterMap.GetString("dns")
if len(dnsString) > 0 && dnsString != "null" {
var dnsData = []byte(dnsString)
var dnsConfig = &dnsconfigs.ClusterDNSConfig{
CNAMEAsDomain: true,
IncludingLnNodes: true,
}
err = json.Unmarshal(dnsData, dnsConfig)
if err == nil && !dnsConfig.NodesAutoSync && !dnsConfig.ServersAutoSync {
dnsConfig.NodesAutoSync = true
dnsConfig.ServersAutoSync = true
dnsConfigJSON, err := json.Marshal(dnsConfig)
if err != nil {
return err
}
_, err = db.Exec("UPDATE edgeNodeClusters SET dns=? WHERE id=?", dnsConfigJSON, id)
if err != nil {
return err
}
}
}
}
}
return nil
}
// v0.5.7
func upgradeV0_5_8(db *dbs.DB) error {
// node task versions
{
_, err := db.Exec("UPDATE edgeNodeTasks SET version=0 WHERE LENGTH(version)=19")
if err != nil {
return err
}
}
// 删除操作系统和浏览器相关统计
// 只删除当前月,避免因为数据过多阻塞
{
_, err := db.Exec("DELETE FROM edgeServerClientSystemMonthlyStats WHERE month=?", timeutil.Format("Ym"))
if err != nil {
return err
}
}
{
_, err := db.Exec("DELETE FROM edgeServerClientBrowserMonthlyStats WHERE month=?", timeutil.Format("Ym"))
if err != nil {
return err
}
}
// 修复默认黑白名单不是全局的问题
{
_, err := db.Exec("UPDATE edgeIPLists SET isGlobal=1 WHERE id IN (1, 2)")
if err != nil {
return err
}
}
return nil
}
// v1.2.9
func upgradeV1_2_9(db *dbs.DB) error {
// 升级WAF规则
{
_, err := db.Exec("UPDATE edgeHTTPFirewallRules SET value=? WHERE value=? AND param='${userAgent}'", "python|pycurl|http-client|httpclient|apachebench|nethttp|http_request|java|perl|ruby|scrapy|php\\b|rust", "python|pycurl|http-client|httpclient|apachebench|nethttp|http_request|java|perl|ruby|scrapy|php|rust")
if err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,484 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package setup
import (
"encoding/json"
"fmt"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/systemconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/userconfigs"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/types"
timeutil "github.com/iwind/TeaGo/utils/time"
"regexp"
"strings"
"time"
)
// v0.2.8.1
func upgradeV0_2_8_1(db *dbs.DB) error {
// 访问日志设置
{
one, err := db.FindOne("SELECT id FROM edgeSysSettings WHERE code=? LIMIT 1", systemconfigs.SettingCodeNSAccessLogSetting)
if err != nil {
return err
}
if len(one) == 0 {
var ref = &dnsconfigs.NSAccessLogRef{
IsPrior: false,
IsOn: true,
LogMissingDomains: false,
}
refJSON, err := json.Marshal(ref)
if err != nil {
return err
}
_, err = db.Exec("INSERT edgeSysSettings (code, value) VALUES (?, ?)", systemconfigs.SettingCodeNSAccessLogSetting, refJSON)
if err != nil {
return err
}
}
}
// 升级EdgeDNS线路
ones, _, err := db.FindOnes("SELECT id, dnsRoutes FROM edgeNodes WHERE dnsRoutes IS NOT NULL")
if err != nil {
return err
}
for _, one := range ones {
var nodeId = one.GetInt64("id")
var dnsRoutes = one.GetString("dnsRoutes")
if len(dnsRoutes) == 0 {
continue
}
var m = map[string][]string{}
err = json.Unmarshal([]byte(dnsRoutes), &m)
if err != nil {
continue
}
var isChanged = false
var reg = regexp.MustCompile(`^\d+$`)
for k, routes := range m {
for index, route := range routes {
if reg.MatchString(route) {
route = "id:" + route
isChanged = true
}
routes[index] = route
}
m[k] = routes
}
if isChanged {
mJSON, err := json.Marshal(m)
if err != nil {
return err
}
_, err = db.Exec("UPDATE edgeNodes SET dnsRoutes=? WHERE id=? LIMIT 1", string(mJSON), nodeId)
if err != nil {
return err
}
}
}
return nil
}
// v0.4.9
func upgradeV0_4_9(db *dbs.DB) error {
// 升级用户UI配置
{
one, err := db.FindOne("SELECT value FROM edgeSysSettings WHERE code=?", systemconfigs.SettingCodeUserUIConfig)
if err != nil {
return err
}
if one != nil {
var valueJSON = one.GetBytes("value")
if len(valueJSON) > 0 {
var config = &systemconfigs.UserUIConfig{}
err = json.Unmarshal(valueJSON, config)
if err == nil {
config.ShowTrafficCharts = true
config.ShowBandwidthCharts = true
config.BandwidthUnit = systemconfigs.BandwidthUnitBit
configJSON, err := json.Marshal(config)
if err != nil {
return fmt.Errorf("encode UserUIConfig failed: %w", err)
} else {
_, err := db.Exec("UPDATE edgeSysSettings SET value=? WHERE code=?", configJSON, systemconfigs.SettingCodeUserUIConfig)
if err != nil {
return err
}
}
}
}
}
}
// 升级管理配置
{
one, err := db.FindOne("SELECT value FROM edgeSysSettings WHERE code=?", systemconfigs.SettingCodeAdminSecurityConfig)
if err != nil {
return err
}
if one != nil {
var valueJSON = one.GetBytes("value")
if len(valueJSON) > 0 {
var config = &systemconfigs.SecurityConfig{}
err = json.Unmarshal(valueJSON, config)
if err == nil {
config.DenySearchEngines = true
config.DenySpiders = true
configJSON, err := json.Marshal(config)
if err != nil {
return fmt.Errorf("encode SecurityConfig failed: %w", err)
} else {
_, err := db.Exec("UPDATE edgeSysSettings SET value=? WHERE code=?", configJSON, systemconfigs.SettingCodeAdminSecurityConfig)
if err != nil {
return err
}
}
}
}
}
}
return nil
}
// v0.5.3
func upgradeV0_5_3(db *dbs.DB) error {
// 升级ns domains中的status字段
{
_, err := db.Exec("UPDATE edgeNSDomains SET status='" + dnsconfigs.NSDomainStatusVerified + "'")
if err != nil {
return err
}
}
// 升级集群服务配置
{
type oldGlobalConfig struct {
// HTTP & HTTPS相关配置
HTTPAll struct {
MatchDomainStrictly bool `yaml:"matchDomainStrictly" json:"matchDomainStrictly"` // 是否严格匹配域名
AllowMismatchDomains []string `yaml:"allowMismatchDomains" json:"allowMismatchDomains"` // 允许的不匹配的域名
DefaultDomain string `yaml:"defaultDomain" json:"defaultDomain"` // 默认的域名
DomainMismatchAction *serverconfigs.DomainMismatchAction `yaml:"domainMismatchAction" json:"domainMismatchAction"` // 不匹配时采取的动作
} `yaml:"httpAll" json:"httpAll"`
}
value, err := db.FindCol(0, "SELECT value FROM edgeSysSettings WHERE code='serverGlobalConfig'")
if err != nil {
return err
}
if value != nil {
var valueJSON = []byte(types.String(value))
var oldConfig = &oldGlobalConfig{}
err = json.Unmarshal(valueJSON, oldConfig)
if err == nil {
var newConfig = &serverconfigs.GlobalServerConfig{}
newConfig.HTTPAll.MatchDomainStrictly = oldConfig.HTTPAll.MatchDomainStrictly
newConfig.HTTPAll.AllowMismatchDomains = oldConfig.HTTPAll.AllowMismatchDomains
newConfig.HTTPAll.DefaultDomain = oldConfig.HTTPAll.DefaultDomain
if oldConfig.HTTPAll.DomainMismatchAction != nil {
newConfig.HTTPAll.DomainMismatchAction = oldConfig.HTTPAll.DomainMismatchAction
}
newConfig.HTTPAll.AllowNodeIP = true
newConfig.Log.RecordServerError = false
newConfigJSON, err := json.Marshal(newConfig)
if err == nil {
_, err = db.Exec("UPDATE edgeNodeClusters SET globalServerConfig=?", newConfigJSON)
if err != nil {
return err
}
}
}
}
}
return nil
}
func upgradeV0_5_6(db *dbs.DB) error {
// 升级PriceConfig enablePlans
err := func() error {
countPlans, err := db.FindCol(0, "SELECT COUNT(*) FROM edgePlans WHERE state=1")
if err != nil {
return err
}
var countPlansInt = types.Int64(countPlans)
if countPlansInt > 0 {
countUserPlans, err := db.FindCol(0, "SELECT COUNT(*) FROM edgeUserPlans WHERE state=1")
if err != nil {
return err
}
var countUserPlansInt = types.Int64(countUserPlans)
if countUserPlansInt > 0 {
countServers, err := db.FindCol(0, "SELECT COUNT(*) FROM edgeServers WHERE state=1 AND userPlanId>0")
if err != nil {
return err
}
var countServersInt = types.Int64(countServers)
if countServersInt > 0 {
var config = userconfigs.DefaultUserPriceConfig()
configValue, err := db.FindCol(0, "SELECT value FROM edgeSysSettings WHERE code='"+systemconfigs.SettingCodeUserPriceConfig+"'")
if err != nil {
return err
}
var configValueString = types.String(configValue)
if len(configValueString) > 0 {
err = json.Unmarshal([]byte(configValueString), config)
if err == nil && config.IsOn { // 如果已经设置了,则不重复设置
return nil
}
}
if err == nil {
config.IsOn = true
config.EnablePlans = true
configJSON, err := json.Marshal(config)
if err != nil {
return fmt.Errorf("encode price config failed: %w", err)
}
if len(configValueString) > 0 { // update
_, err = db.Exec("UPDATE edgeSysSettings SET value=? WHERE code=?", configJSON, systemconfigs.SettingCodeUserPriceConfig)
} else { // insert
_, err = db.Exec("INSERT edgeSysSettings (code, value) VALUES (?, ?)", systemconfigs.SettingCodeUserPriceConfig, configJSON)
}
if err != nil {
return err
}
}
}
}
}
return nil
}()
if err != nil {
return err
}
// 修复默认集群的DNS设置
{
var id = 1
clusterMap, err := db.FindOne("SELECT dns FROM edgeNodeClusters WHERE id=? AND state=1", id)
if err != nil {
return err
}
if len(clusterMap) > 0 {
var dnsString = clusterMap.GetString("dns")
if len(dnsString) > 0 && dnsString != "null" {
var dnsData = []byte(dnsString)
var dnsConfig = &dnsconfigs.ClusterDNSConfig{
CNAMEAsDomain: true,
IncludingLnNodes: true,
}
err = json.Unmarshal(dnsData, dnsConfig)
if err == nil && !dnsConfig.NodesAutoSync && !dnsConfig.ServersAutoSync {
dnsConfig.NodesAutoSync = true
dnsConfig.ServersAutoSync = true
dnsConfigJSON, err := json.Marshal(dnsConfig)
if err != nil {
return err
}
_, err = db.Exec("UPDATE edgeNodeClusters SET dns=? WHERE id=?", dnsConfigJSON, id)
if err != nil {
return err
}
}
}
}
}
return nil
}
// v0.5.8
func upgradeV0_5_8(db *dbs.DB) error {
// node task versions
{
_, err := db.Exec("UPDATE edgeNodeTasks SET version=0 WHERE LENGTH(version)=19")
if err != nil {
return err
}
}
// ns mx records
{
_, err := db.Exec("UPDATE edgeNSRecords SET mxPriority=10 WHERE type='MX' AND mxPriority=0")
if err != nil {
return err
}
}
// 删除操作系统和浏览器相关统计
// 只删除当前月,避免因为数据过多阻塞
{
_, err := db.Exec("DELETE FROM edgeServerClientSystemMonthlyStats WHERE month=?", timeutil.Format("Ym"))
if err != nil {
return err
}
}
{
_, err := db.Exec("DELETE FROM edgeServerClientBrowserMonthlyStats WHERE month=?", timeutil.Format("Ym"))
if err != nil {
return err
}
}
// 修复默认黑白名单不是全局的问题
{
_, err := db.Exec("UPDATE edgeIPLists SET isGlobal=1 WHERE id IN (1, 2)")
if err != nil {
return err
}
}
return nil
}
// v1.2.9
func upgradeV1_2_9(db *dbs.DB) error {
// 升级WAF规则
{
_, err := db.Exec("UPDATE edgeHTTPFirewallRules SET value=? WHERE value=? AND param='${userAgent}'", "python|pycurl|http-client|httpclient|apachebench|nethttp|http_request|java|perl|ruby|scrapy|php\\b|rust", "python|pycurl|http-client|httpclient|apachebench|nethttp|http_request|java|perl|ruby|scrapy|php|rust")
if err != nil {
return err
}
}
// 升级套餐网站数限制
{
_, err := db.Exec("UPDATE edgePlans SET totalServers=1 WHERE totalServers=0")
if err != nil {
return err
}
}
// 升级网站流量限制状态
{
_, err := db.Exec("UPDATE edgeServers SET trafficLimitStatus=NULL WHERE trafficLimitStatus IS NOT NULL")
if err != nil {
return err
}
}
// 升级套餐按日/按月统计数据
{
// 检查是否已升级
countUserPlanStatsValue, err := db.FindCol(0, "SELECT COUNT(*) FROM edgeUserPlanStats")
if err != nil {
return err
}
var countUserPlanStats = 0
if countUserPlanStatsValue != nil {
countUserPlanStats = types.Int(countUserPlanStatsValue)
}
if countUserPlanStats == 0 {
var upgradeFunc = func(userPlanId int64, serverId int64, month string) error {
var bandwidthTable = "edgeServerBandwidthStats_" + types.String(serverId%models.ServerBandwidthStatTablePartitions)
sumMap, err := db.FindOne("SELECT SUM(totalBytes) AS totalBytes,SUM(countRequests) AS countRequests FROM "+bandwidthTable+" WHERE serverId=? AND day BETWEEN ? AND ?", serverId, month+"01", month+"31")
if err != nil {
return err
}
if sumMap != nil && len(sumMap) >= 2 {
var totalBytes = sumMap.GetInt64("totalBytes")
var countRequests = sumMap.GetInt64("countRequests")
if totalBytes > 0 || countRequests > 0 {
_, err = db.Exec("INSERT INTO edgeUserPlanStats (userPlanId, date, dateType, trafficBytes, countRequests) VALUES (?, ?, ?, ?, ?)", userPlanId, month, "month", totalBytes, countRequests)
if err != nil {
var errMessage = strings.ToLower(err.Error())
if !strings.Contains(errMessage, "duplicate") && !strings.Contains(errMessage, "1062") {
return err
}
err = nil
}
// daily
for i := 1; i <= 31; i++ {
var day = month + fmt.Sprintf("%02d", i)
dailySumMap, err := db.FindOne("SELECT SUM(totalBytes) AS totalBytes,SUM(countRequests) AS countRequests FROM "+bandwidthTable+" WHERE serverId=? AND day=?", serverId, day)
if err != nil {
return err
}
var dailyTotalBytes = dailySumMap.GetInt64("totalBytes")
var dailyCountRequests = dailySumMap.GetInt64("countRequests")
if dailyTotalBytes > 0 || dailyCountRequests > 0 {
_, err = db.Exec("INSERT INTO edgeUserPlanStats (userPlanId, date, dateType, trafficBytes, countRequests) VALUES (?, ?, ?, ?, ?)", userPlanId, day, "day", dailyTotalBytes, dailyCountRequests)
if err != nil {
var errMessage = strings.ToLower(err.Error())
if !strings.Contains(errMessage, "duplicate") && !strings.Contains(errMessage, "1062") {
return err
}
err = nil
}
}
}
// userPlanId
_, err = db.Exec("UPDATE "+bandwidthTable+" SET userPlanId=? WHERE serverId=? AND day BETWEEN ? AND ?", userPlanId, serverId, month+"01", month+"31")
if err != nil {
return err
}
// userPlanBandwidth
{
ones, _, err := db.FindOnes("SELECT * FROM "+bandwidthTable+" WHERE serverId=? AND userPlanId=?", serverId, userPlanId)
if err != nil {
return err
}
for _, one := range ones {
_, err = db.Exec("INSERT INTO edgeUserPlanBandwidthStats_"+types.String(userPlanId%models.UserPlanBandwidthStatTablePartitions)+" (userId, userPlanId, day, timeAt, bytes, regionId, totalBytes, avgBytes, cachedBytes, attackBytes, countRequests, countCachedRequests, countAttackRequests) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", one.GetInt64("userId"), userPlanId, one.GetString("day"), one.GetString("timeAt"), one.GetInt64("bytes"), one.GetInt64("regionId"), one.GetInt64("totalBytes"), one.GetInt64("avgBytes"), one.GetInt64("cachedBytes"), one.GetInt64("attackBytes"), one.GetInt64("countRequests"), one.GetInt64("countCachedRequests"), one.GetInt64("countAttackRequests"))
if err != nil {
var errMessage = err.Error()
if !strings.Contains(errMessage, "duplicate") && !strings.Contains(errMessage, "1062") {
return err
}
}
}
}
}
}
return nil
}
userPlans, _, err := db.FindOnes("SELECT id FROM edgeUserPlans WHERE state=1")
if err != nil {
return err
}
for _, userPlan := range userPlans {
var userPlanId = userPlan.GetInt64("id")
servers, _, err := db.FindOnes("SELECT id FROM edgeServers WHERE userPlanId=?", userPlanId)
if err != nil {
return err
}
for _, server := range servers {
var serverId = server.GetInt64("id")
{
var lastMonth = timeutil.Format("Ym", time.Now().AddDate(0, -1, 0))
err = upgradeFunc(userPlanId, serverId, lastMonth)
if err != nil {
return err
}
}
{
var month = timeutil.Format("Ym")
err = upgradeFunc(userPlanId, serverId, month)
if err != nil {
return err
}
}
}
}
}
}
return nil
}

View File

@@ -0,0 +1,49 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build !plus
package setup
import (
"github.com/iwind/TeaGo/dbs"
"testing"
)
func TestUpgradeSQLData_v0_5_6(t *testing.T) {
db, err := dbs.NewInstanceFromConfig(&dbs.DBConfig{
Driver: "mysql",
Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge?charset=utf8mb4&timeout=30s",
Prefix: "edge",
})
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
err = upgradeV0_5_6(db)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestUpgradeSQLData_v1_3_4(t *testing.T) {
db, err := dbs.NewInstanceFromConfig(&dbs.DBConfig{
Driver: "mysql",
Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge?charset=utf8mb4&timeout=30s",
Prefix: "edge",
})
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
err = upgradeV1_3_4(db)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}

View File

@@ -0,0 +1,65 @@
//go:build plus
package setup
import (
"github.com/iwind/TeaGo/dbs"
"testing"
)
func TestUpgradeSQLData_v0_5_6(t *testing.T) {
db, err := dbs.NewInstanceFromConfig(&dbs.DBConfig{
Driver: "mysql",
Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge?charset=utf8mb4&timeout=30s",
Prefix: "edge",
})
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
err = upgradeV0_5_6(db)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestUpgradeSQLData_v0_5_8(t *testing.T) {
db, err := dbs.NewInstanceFromConfig(&dbs.DBConfig{
Driver: "mysql",
Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge?charset=utf8mb4&timeout=30s",
Prefix: "edge",
})
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
err = upgradeV0_5_8(db)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestUpgradeSQLData_v1_2_9(t *testing.T) {
db, err := dbs.NewInstanceFromConfig(&dbs.DBConfig{
Driver: "mysql",
Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge?charset=utf8mb4&timeout=30s",
Prefix: "edge",
})
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
err = upgradeV1_2_9(db)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}

View File

@@ -0,0 +1,310 @@
package setup
import (
"github.com/iwind/TeaGo/dbs"
"testing"
)
func TestUpgradeSQLData(t *testing.T) {
db, err := dbs.NewInstanceFromConfig(&dbs.DBConfig{
Driver: "mysql",
Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge_new?charset=utf8mb4&timeout=30s",
Prefix: "edge",
})
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
err = UpgradeSQLData(db)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestUpgradeSQLData_v0_3_1(t *testing.T) {
db, err := dbs.NewInstanceFromConfig(&dbs.DBConfig{
Driver: "mysql",
Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge_new?charset=utf8mb4&timeout=30s",
Prefix: "edge",
})
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
err = upgradeV0_3_1(db)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestUpgradeSQLData_v0_3_2(t *testing.T) {
db, err := dbs.NewInstanceFromConfig(&dbs.DBConfig{
Driver: "mysql",
Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge?charset=utf8mb4&timeout=30s",
Prefix: "edge",
})
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
err = upgradeV0_3_2(db)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestUpgradeSQLData_v0_3_3(t *testing.T) {
db, err := dbs.NewInstanceFromConfig(&dbs.DBConfig{
Driver: "mysql",
Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge?charset=utf8mb4&timeout=30s",
Prefix: "edge",
})
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
err = upgradeV0_3_3(db)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestUpgradeSQLData_v0_3_7(t *testing.T) {
db, err := dbs.NewInstanceFromConfig(&dbs.DBConfig{
Driver: "mysql",
Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge?charset=utf8mb4&timeout=30s",
Prefix: "edge",
})
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
err = upgradeV0_3_7(db)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestUpgradeSQLData_v0_4_0(t *testing.T) {
db, err := dbs.NewInstanceFromConfig(&dbs.DBConfig{
Driver: "mysql",
Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge?charset=utf8mb4&timeout=30s",
Prefix: "edge",
})
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
err = upgradeV0_4_0(db)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestUpgradeSQLData_v0_4_1(t *testing.T) {
db, err := dbs.NewInstanceFromConfig(&dbs.DBConfig{
Driver: "mysql",
Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge?charset=utf8mb4&timeout=30s",
Prefix: "edge",
})
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
err = upgradeV0_4_1(db)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestUpgradeSQLData_v0_4_5(t *testing.T) {
db, err := dbs.NewInstanceFromConfig(&dbs.DBConfig{
Driver: "mysql",
Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge?charset=utf8mb4&timeout=30s",
Prefix: "edge",
})
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
err = upgradeV0_4_5(db)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestUpgradeSQLData_v0_4_7(t *testing.T) {
db, err := dbs.NewInstanceFromConfig(&dbs.DBConfig{
Driver: "mysql",
Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge?charset=utf8mb4&timeout=30s",
Prefix: "edge",
})
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
err = upgradeV0_4_7(db)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestUpgradeSQLData_v0_4_8(t *testing.T) {
db, err := dbs.NewInstanceFromConfig(&dbs.DBConfig{
Driver: "mysql",
Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge?charset=utf8mb4&timeout=30s",
Prefix: "edge",
})
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
err = upgradeV0_4_8(db)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestUpgradeSQLData_v0_4_9(t *testing.T) {
db, err := dbs.NewInstanceFromConfig(&dbs.DBConfig{
Driver: "mysql",
Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge?charset=utf8mb4&timeout=30s",
Prefix: "edge",
})
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
err = upgradeV0_4_9(db)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestUpgradeSQLData_v0_4_11(t *testing.T) {
db, err := dbs.NewInstanceFromConfig(&dbs.DBConfig{
Driver: "mysql",
Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge?charset=utf8mb4&timeout=30s",
Prefix: "edge",
})
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
err = upgradeV0_4_11(db)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestUpgradeSQLData_v0_5_3(t *testing.T) {
db, err := dbs.NewInstanceFromConfig(&dbs.DBConfig{
Driver: "mysql",
Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge?charset=utf8mb4&timeout=30s",
Prefix: "edge",
})
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
err = upgradeV0_5_3(db)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestUpgradeSQLData_v1_2_1(t *testing.T) {
db, err := dbs.NewInstanceFromConfig(&dbs.DBConfig{
Driver: "mysql",
Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge?charset=utf8mb4&timeout=30s",
Prefix: "edge",
})
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
err = upgradeV1_2_1(db)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestUpgradeSQLData_v1_2_10(t *testing.T) {
db, err := dbs.NewInstanceFromConfig(&dbs.DBConfig{
Driver: "mysql",
Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge?charset=utf8mb4&timeout=30s",
Prefix: "edge",
})
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
err = upgradeV1_2_10(db)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestUpgradeSQLData_v1_3_2(t *testing.T) {
db, err := dbs.NewInstanceFromConfig(&dbs.DBConfig{
Driver: "mysql",
Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge?charset=utf8mb4&timeout=30s",
Prefix: "edge",
})
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
err = upgradeV1_3_2(db)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}

View File

@@ -0,0 +1,34 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package setup
import (
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
"github.com/iwind/TeaGo/types"
stringutil "github.com/iwind/TeaGo/utils/string"
"strings"
)
// ComposeSQLVersion 组合SQL的版本号
func ComposeSQLVersion() string {
return teaconst.Version
}
// CompareVersion 对比版本
func CompareVersion(version1 string, version2 string) int8 {
if len(version1) == 0 || len(version2) == 0 {
return 0
}
return stringutil.VersionCompare(fixVersion(version1), fixVersion(version2))
}
func fixVersion(version string) string {
var pieces = strings.Split(version, ".")
var lastPiece = types.Int(pieces[len(pieces)-1])
if lastPiece > 10 {
// 这个是以前使用的SQL版本号我们给去掉
version = strings.Join(pieces[:len(pieces)-1], ".")
}
return version
}

View File

@@ -0,0 +1,23 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package setup_test
import (
"github.com/TeaOSLab/EdgeAPI/internal/setup"
"github.com/iwind/TeaGo/assert"
"testing"
)
func TestComposeSQLVersion(t *testing.T) {
t.Log(setup.ComposeSQLVersion())
}
func TestCompareVersion(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsTrue(setup.CompareVersion("1.3.4", "1.3.4") == 0)
a.IsTrue(setup.CompareVersion("1.3.4", "1.3.3") > 0)
a.IsTrue(setup.CompareVersion("1.3.4", "1.3.5") < 0)
a.IsTrue(setup.CompareVersion("1.3.4.3", "1.3.4.12") > 0) // because 12 > 10
a.IsTrue(setup.CompareVersion("1.3.4.3", "1.3.4.2") > 0)
a.IsTrue(setup.CompareVersion("1.3.4.3", "1.3.4.4") < 0)
}