This commit is contained in:
unknown
2026-02-04 20:27:13 +08:00
commit 3b042d1dad
9410 changed files with 1488147 additions and 0 deletions

View File

@@ -0,0 +1,257 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package accounts
import (
"context"
"errors"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/accounts"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/userconfigs"
)
// OrderMethodService 订单支付方式相关服务
type OrderMethodService struct {
services.BaseService
}
// CreateOrderMethod 创建支付方式
func (this *OrderMethodService) CreateOrderMethod(ctx context.Context, req *pb.CreateOrderMethodRequest) (*pb.CreateOrderMethodResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查代号是否相同
exists, err := accounts.SharedOrderMethodDAO.ExistOrderMethodWithCode(tx, req.Code, 0)
if err != nil {
return nil, err
}
if exists {
return nil, errors.New("pay method code already exists")
}
var params any
if len(req.ParentCode) > 0 {
params, err = userconfigs.DecodePayMethodParams(req.ParentCode, req.ParamsJSON)
if err != nil {
return nil, errors.New("invalid params: " + err.Error())
}
}
methodId, err := accounts.SharedOrderMethodDAO.CreateMethod(tx, req.Name, req.Code, req.Url, req.Description, req.ParentCode, params, req.ClientType, req.QrcodeTitle)
if err != nil {
return nil, err
}
return &pb.CreateOrderMethodResponse{
OrderMethodId: methodId,
}, nil
}
// UpdateOrderMethod 修改支付方式
func (this *OrderMethodService) UpdateOrderMethod(ctx context.Context, req *pb.UpdateOrderMethodRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查代号是否相同
exists, err := accounts.SharedOrderMethodDAO.ExistOrderMethodWithCode(tx, req.Code, req.OrderMethodId)
if err != nil {
return nil, err
}
if exists {
return nil, errors.New("pay method code already exists")
}
method, err := accounts.SharedOrderMethodDAO.FindEnabledBasicOrderMethod(tx, req.OrderMethodId)
if err != nil {
return nil, err
}
if method == nil {
return nil, errors.New("could not find method")
}
var params any
if len(method.ParentCode) > 0 {
params, err = userconfigs.DecodePayMethodParams(method.ParentCode, req.ParamsJSON)
if err != nil {
return nil, errors.New("invalid params: " + err.Error())
}
}
err = accounts.SharedOrderMethodDAO.UpdateMethod(tx, req.OrderMethodId, req.Name, req.Code, req.Url, req.Description, params, req.ClientType, req.QrcodeTitle, req.IsOn)
if err != nil {
return nil, err
}
return this.Success()
}
// DeleteOrderMethod 删除支付方式
func (this *OrderMethodService) DeleteOrderMethod(ctx context.Context, req *pb.DeleteOrderMethodRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = accounts.SharedOrderMethodDAO.DisableOrderMethod(tx, req.OrderMethodId)
if err != nil {
return nil, err
}
return this.Success()
}
// FindEnabledOrderMethod 查找单个支付方式
func (this *OrderMethodService) FindEnabledOrderMethod(ctx context.Context, req *pb.FindEnabledOrderMethodRequest) (*pb.FindEnabledOrderMethodResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
method, err := accounts.SharedOrderMethodDAO.FindEnabledOrderMethod(tx, req.OrderMethodId)
if err != nil {
return nil, err
}
if method == nil {
return &pb.FindEnabledOrderMethodResponse{
OrderMethod: nil,
}, nil
}
return &pb.FindEnabledOrderMethodResponse{
OrderMethod: &pb.OrderMethod{
Id: int64(method.Id),
Name: method.Name,
Code: method.Code,
Description: method.Description,
Url: method.Url,
Secret: method.Secret,
IsOn: method.IsOn,
ParentCode: method.ParentCode,
Params: method.Params, // 注意参数不能通过接口泄露给平台用户
ClientType: method.ClientType,
QrcodeTitle: method.QrcodeTitle,
},
}, nil
}
// FindEnabledOrderMethodWithCode 根据代号查找支付方式
func (this *OrderMethodService) FindEnabledOrderMethodWithCode(ctx context.Context, req *pb.FindEnabledOrderMethodWithCodeRequest) (*pb.FindEnabledOrderMethodWithCodeResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
method, err := accounts.SharedOrderMethodDAO.FindEnabledOrderMethodWithCode(tx, req.Code)
if err != nil {
return nil, err
}
if method == nil {
return &pb.FindEnabledOrderMethodWithCodeResponse{
OrderMethod: nil,
}, nil
}
// 保护数据
if userId > 0 {
method.Secret = ""
}
return &pb.FindEnabledOrderMethodWithCodeResponse{
OrderMethod: &pb.OrderMethod{
Id: int64(method.Id),
Name: method.Name,
Code: method.Code,
ParentCode: method.ParentCode,
Description: method.Description,
Url: method.Url,
Secret: method.Secret,
IsOn: method.IsOn,
ClientType: method.ClientType,
QrcodeTitle: method.QrcodeTitle,
},
}, nil
}
// FindAllEnabledOrderMethods 查找所有支付方式
func (this *OrderMethodService) FindAllEnabledOrderMethods(ctx context.Context, req *pb.FindAllEnabledOrderMethodsRequest) (*pb.FindAllEnabledOrderMethodsResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
methods, err := accounts.SharedOrderMethodDAO.FindAllEnabledMethodOrders(tx)
if err != nil {
return nil, err
}
var pbMethods = []*pb.OrderMethod{}
for _, method := range methods {
// 防止secret泄露
if userId > 0 {
method.Secret = ""
}
pbMethods = append(pbMethods, &pb.OrderMethod{
Id: int64(method.Id),
Name: method.Name,
Code: method.Code,
Description: method.Description,
Url: method.Url,
Secret: method.Secret,
IsOn: method.IsOn,
ParentCode: method.ParentCode, // 不要返回params以防止泄露
ClientType: method.ClientType,
})
}
return &pb.FindAllEnabledOrderMethodsResponse{
OrderMethods: pbMethods,
}, nil
}
// FindAllAvailableOrderMethods 查找所有已启用的支付方式
func (this *OrderMethodService) FindAllAvailableOrderMethods(ctx context.Context, req *pb.FindAllAvailableOrderMethodsRequest) (*pb.FindAllAvailableOrderMethodsResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
methods, err := accounts.SharedOrderMethodDAO.FindAllEnabledAndOnMethodOrders(tx)
if err != nil {
return nil, err
}
var pbMethods = []*pb.OrderMethod{}
for _, method := range methods {
// 防止secret泄露
if userId > 0 {
method.Secret = ""
}
pbMethods = append(pbMethods, &pb.OrderMethod{
Id: int64(method.Id),
Name: method.Name,
Code: method.Code,
Description: method.Description,
Url: method.Url,
Secret: method.Secret,
IsOn: method.IsOn,
ParentCode: method.ParentCode, // 不返回params防止泄露
ClientType: method.ClientType,
})
}
return &pb.FindAllAvailableOrderMethodsResponse{
OrderMethods: pbMethods,
}, nil
}

View File

@@ -0,0 +1,195 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus
package accounts
import (
"context"
"encoding/json"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/accounts"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/maps"
)
// UserAccountService 用户账户相关服务
type UserAccountService struct {
services.BaseService
}
// CountUserAccounts 计算账户数量
func (this *UserAccountService) CountUserAccounts(ctx context.Context, req *pb.CountUserAccountsRequest) (*pb.RPCCountResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
count, err := accounts.SharedUserAccountDAO.CountAllAccounts(tx, req.Keyword)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListUserAccounts 列出单页账户
func (this *UserAccountService) ListUserAccounts(ctx context.Context, req *pb.ListUserAccountsRequest) (*pb.ListUserAccountsResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
userAccounts, err := accounts.SharedUserAccountDAO.ListAccounts(tx, req.Keyword, req.Offset, req.Size)
if err != nil {
return nil, err
}
var pbAccounts = []*pb.UserAccount{}
for _, account := range userAccounts {
// 用户
user, err := models.SharedUserDAO.FindEnabledUser(tx, int64(account.UserId), nil)
if err != nil {
return nil, err
}
var pbUser = &pb.User{}
if user != nil {
pbUser = &pb.User{
Id: int64(user.Id),
Username: user.Username,
Fullname: user.Fullname,
}
}
pbAccounts = append(pbAccounts, &pb.UserAccount{
Id: int64(account.Id),
UserId: int64(account.UserId),
Total: account.Total,
TotalFrozen: account.TotalFrozen,
User: pbUser,
})
}
return &pb.ListUserAccountsResponse{UserAccounts: pbAccounts}, nil
}
// FindEnabledUserAccountWithUserId 根据用户ID查找单个账户
func (this *UserAccountService) FindEnabledUserAccountWithUserId(ctx context.Context, req *pb.FindEnabledUserAccountWithUserIdRequest) (*pb.FindEnabledUserAccountWithUserIdResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if userId > 0 {
req.UserId = userId
}
var tx = this.NullTx()
account, err := accounts.SharedUserAccountDAO.FindUserAccountWithUserId(tx, req.UserId)
if err != nil {
return nil, err
}
if account == nil {
return &pb.FindEnabledUserAccountWithUserIdResponse{UserAccount: nil}, nil
}
// 用户
user, err := models.SharedUserDAO.FindEnabledUser(tx, int64(account.UserId), nil)
if err != nil {
return nil, err
}
var pbUser = &pb.User{}
if user != nil {
pbUser = &pb.User{
Id: int64(user.Id),
Username: user.Username,
Fullname: user.Fullname,
}
}
return &pb.FindEnabledUserAccountWithUserIdResponse{
UserAccount: &pb.UserAccount{
Id: int64(account.Id),
UserId: int64(account.UserId),
Total: account.Total,
TotalFrozen: account.TotalFrozen,
User: pbUser,
},
}, nil
}
// FindEnabledUserAccount 查找单个账户
func (this *UserAccountService) FindEnabledUserAccount(ctx context.Context, req *pb.FindEnabledUserAccountRequest) (*pb.FindEnabledUserAccountResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
err = accounts.SharedUserAccountDAO.CheckUserAccount(tx, userId, req.UserAccountId)
if err != nil {
return nil, err
}
}
account, err := accounts.SharedUserAccountDAO.FindUserAccountWithAccountId(tx, req.UserAccountId)
if err != nil {
return nil, err
}
if account == nil {
return &pb.FindEnabledUserAccountResponse{UserAccount: nil}, nil
}
// 用户
user, err := models.SharedUserDAO.FindEnabledUser(tx, int64(account.UserId), nil)
if err != nil {
return nil, err
}
var pbUser = &pb.User{}
if user != nil {
pbUser = &pb.User{
Id: int64(user.Id),
Username: user.Username,
Fullname: user.Fullname,
}
}
return &pb.FindEnabledUserAccountResponse{
UserAccount: &pb.UserAccount{
Id: int64(account.Id),
UserId: int64(account.UserId),
Total: account.Total,
TotalFrozen: account.TotalFrozen,
User: pbUser,
},
}, nil
}
// UpdateUserAccount 修改用户账户
func (this *UserAccountService) UpdateUserAccount(ctx context.Context, req *pb.UpdateUserAccountRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var params = maps.Map{}
if len(req.ParamsJSON) > 0 {
err = json.Unmarshal(req.ParamsJSON, &params)
if err != nil {
return nil, err
}
}
err = this.RunTx(func(tx *dbs.Tx) error {
err := accounts.SharedUserAccountDAO.UpdateUserAccount(tx, req.UserAccountId, req.Delta, req.EventType, req.Description, params)
if err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
return this.Success()
}

View File

@@ -0,0 +1,132 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus
package accounts
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/accounts"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/lists"
"strings"
)
// UserAccountDailyStatService 用户账户统计服务
type UserAccountDailyStatService struct {
services.BaseService
}
// ListUserAccountDailyStats 列出按天统计
func (this *UserAccountDailyStatService) ListUserAccountDailyStats(ctx context.Context, req *pb.ListUserAccountDailyStatsRequest) (*pb.ListUserAccountDailyStatsResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var dayFrom = req.DayFrom
var dayTo = req.DayTo
dayFrom = strings.ReplaceAll(dayFrom, "-", "")
dayTo = strings.ReplaceAll(dayTo, "-", "")
days, err := utils.RangeDays(dayFrom, dayTo)
if err != nil {
return nil, err
}
if len(days) == 0 {
return &pb.ListUserAccountDailyStatsResponse{Stats: nil}, nil
}
var tx = this.NullTx()
stats, err := accounts.SharedUserAccountDailyStatDAO.FindDailyStats(tx, dayFrom, dayTo)
if err != nil {
return nil, err
}
var statMap = map[string]*accounts.UserAccountDailyStat{} // day => Stat
for _, stat := range stats {
statMap[stat.Day] = stat
}
var pbStats = []*pb.ListUserAccountDailyStatsResponse_Stat{}
for _, day := range days {
stat, ok := statMap[day]
if ok {
pbStats = append(pbStats, &pb.ListUserAccountDailyStatsResponse_Stat{
Day: day,
Income: float32(stat.Income),
Expense: float32(stat.Expense),
})
} else {
pbStats = append(pbStats, &pb.ListUserAccountDailyStatsResponse_Stat{
Day: day,
Income: 0,
Expense: 0,
})
}
}
// 反向排序
lists.Reverse(pbStats)
return &pb.ListUserAccountDailyStatsResponse{Stats: pbStats}, nil
}
// ListUserAccountMonthlyStats 列出按月统计
func (this *UserAccountDailyStatService) ListUserAccountMonthlyStats(ctx context.Context, req *pb.ListUserAccountMonthlyStatsRequest) (*pb.ListUserAccountMonthlyStatsResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var dayFrom = req.DayFrom
var dayTo = req.DayTo
dayFrom = strings.ReplaceAll(dayFrom, "-", "")
dayTo = strings.ReplaceAll(dayTo, "-", "")
months, err := utils.RangeMonths(dayFrom, dayTo)
if err != nil {
return nil, err
}
if len(months) == 0 {
return &pb.ListUserAccountMonthlyStatsResponse{Stats: nil}, nil
}
var tx = this.NullTx()
stats, err := accounts.SharedUserAccountDailyStatDAO.FindMonthlyStats(tx, dayFrom, dayTo)
if err != nil {
return nil, err
}
var statMap = map[string]*accounts.UserAccountDailyStat{} // month => Stat
for _, stat := range stats {
statMap[stat.Month] = stat
}
var pbStats = []*pb.ListUserAccountMonthlyStatsResponse_Stat{}
for _, month := range months {
stat, ok := statMap[month]
if ok {
pbStats = append(pbStats, &pb.ListUserAccountMonthlyStatsResponse_Stat{
Month: month,
Income: float32(stat.Income),
Expense: float32(stat.Expense),
})
} else {
pbStats = append(pbStats, &pb.ListUserAccountMonthlyStatsResponse_Stat{
Month: month,
Income: 0,
Expense: 0,
})
}
}
// 反向排序
lists.Reverse(pbStats)
return &pb.ListUserAccountMonthlyStatsResponse{Stats: pbStats}, nil
}

View File

@@ -0,0 +1,81 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus
package accounts
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/accounts"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
// UserAccountLogService 用户账户日志服务
type UserAccountLogService struct {
services.BaseService
}
// CountUserAccountLogs 计算日志数量
func (this *UserAccountLogService) CountUserAccountLogs(ctx context.Context, req *pb.CountUserAccountLogsRequest) (*pb.RPCCountResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
count, err := accounts.SharedUserAccountLogDAO.CountAccountLogs(tx, userId, req.UserAccountId, req.Keyword, req.EventType)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListUserAccountLogs 列出单页日志
func (this *UserAccountLogService) ListUserAccountLogs(ctx context.Context, req *pb.ListUserAccountLogsRequest) (*pb.ListUserAccountLogsResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
accountLogs, err := accounts.SharedUserAccountLogDAO.ListAccountLogs(tx, userId, req.UserAccountId, req.Keyword, req.EventType, req.Offset, req.Size)
if err != nil {
return nil, err
}
var pbLogs = []*pb.UserAccountLog{}
var cacheMap = utils.NewCacheMap()
for _, log := range accountLogs {
// 用户
var pbUser = &pb.User{Id: int64(log.UserId)}
user, err := models.SharedUserDAO.FindEnabledUser(tx, int64(log.UserId), cacheMap)
if err != nil {
return nil, err
}
if user != nil {
pbUser = &pb.User{Id: int64(user.Id), Fullname: user.Fullname, Username: user.Username}
}
// 账户
var pbAccount = &pb.UserAccount{Id: int64(log.AccountId)}
pbLogs = append(pbLogs, &pb.UserAccountLog{
Id: int64(log.Id),
UserId: int64(log.UserId),
UserAccountId: int64(log.AccountId),
Delta: log.Delta,
DeltaFrozen: log.DeltaFrozen,
Total: log.Total,
TotalFrozen: log.TotalFrozen,
EventType: log.EventType,
Description: log.Description,
CreatedAt: int64(log.CreatedAt),
ParamsJSON: log.Params,
User: pbUser,
UserAccount: pbAccount,
})
}
return &pb.ListUserAccountLogsResponse{UserAccountLogs: pbLogs}, nil
}

View File

@@ -0,0 +1,342 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package accounts
import (
"context"
"encoding/json"
"errors"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/accounts"
"github.com/TeaOSLab/EdgeAPI/internal/payments"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/userconfigs"
"github.com/iwind/TeaGo/dbs"
"net/url"
)
// UserOrderService 用户订单相关服务
type UserOrderService struct {
services.BaseService
}
// CreateUserOrder 创建订单
func (this *UserOrderService) CreateUserOrder(ctx context.Context, req *pb.CreateUserOrderRequest) (*pb.CreateUserOrderResponse, error) {
userId, err := this.ValidateUserNode(ctx, false)
if err != nil {
return nil, err
}
if !userconfigs.IsValidOrderType(req.Type) {
return nil, errors.New("invalid order type '" + req.Type + "'")
}
var tx = this.NullTx()
method, err := accounts.SharedOrderMethodDAO.FindEnabledOrderMethodWithCode(tx, req.OrderMethodCode)
if err != nil {
return nil, err
}
if method == nil {
return nil, errors.New("can not find order method with code '" + req.OrderMethodCode + "'")
}
if !method.IsOn {
return nil, errors.New("method is not enabled")
}
var methodId = int64(method.Id)
if req.Amount <= 0 {
return nil, errors.New("'amount' should be greater than 0")
}
var orderCode = ""
err = this.RunTx(func(tx *dbs.Tx) error {
_, code, err := accounts.SharedUserOrderDAO.CreateOrder(tx, 0, userId, req.Type, methodId, req.Amount, req.ParamsJSON)
if err != nil {
return err
}
orderCode = code
return nil
})
if err != nil {
return nil, err
}
order, err := accounts.SharedUserOrderDAO.FindUserOrderWithCode(tx, orderCode)
if err != nil {
return nil, err
}
if order == nil {
return nil, errors.New("can not find order with generated code '" + orderCode + "'")
}
payURL, err := payments.GeneratePayURL(order, method)
if err != nil {
return nil, errors.New("generate pay url failed: " + err.Error())
}
return &pb.CreateUserOrderResponse{
Code: orderCode,
PayURL: payURL,
}, nil
}
// FindEnabledUserOrder 查看订单
func (this *UserOrderService) FindEnabledUserOrder(ctx context.Context, req *pb.FindEnabledUserOrderRequest) (*pb.FindEnabledUserOrderResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
order, err := accounts.SharedUserOrderDAO.FindUserOrderWithCode(tx, req.Code)
if err != nil {
return nil, err
}
if order == nil {
return &pb.FindEnabledUserOrderResponse{UserOrder: nil}, nil
}
// 检查用户权限
if userId > 0 {
if int64(order.UserId) != userId {
return &pb.FindEnabledUserOrderResponse{UserOrder: nil}, nil
}
}
// 用户
var cacheMap = utils.NewCacheMap()
user, err := models.SharedUserDAO.FindEnabledUser(tx, int64(order.UserId), cacheMap)
if err != nil {
return nil, err
}
var pbUser *pb.User
if user != nil {
pbUser = &pb.User{
Id: int64(user.Id),
Username: user.Username,
Fullname: user.Fullname,
}
}
// 支付方式
method, err := accounts.SharedOrderMethodDAO.FindEnabledOrderMethod(tx, int64(order.MethodId))
if err != nil {
return nil, err
}
var pbMethod *pb.OrderMethod
if method != nil {
pbMethod = &pb.OrderMethod{
Id: int64(method.Id),
Name: method.Name,
Code: method.Code,
IsOn: method.IsOn,
ParentCode: method.ParentCode,
ClientType: method.ClientType,
QrcodeTitle: method.QrcodeTitle,
}
}
// 支付URL
payURL, err := payments.GeneratePayURL(order, method)
if err != nil {
return nil, err
}
return &pb.FindEnabledUserOrderResponse{UserOrder: &pb.UserOrder{
UserId: int64(order.UserId),
Code: order.Code,
Type: order.Type,
OrderMethodId: int64(order.MethodId),
Status: order.Status,
Amount: float32(order.Amount),
ParamsJSON: order.Params,
CreatedAt: int64(order.CreatedAt),
CancelledAt: int64(order.CancelledAt),
FinishedAt: int64(order.FinishedAt),
IsExpired: order.IsExpired(),
User: pbUser,
OrderMethod: pbMethod,
CanPay: !order.IsExpired() && order.Status == userconfigs.OrderStatusNone,
PayURL: payURL,
}}, nil
}
// CancelUserOrder 取消订单
func (this *UserOrderService) CancelUserOrder(ctx context.Context, req *pb.CancelUserOrderRequest) (*pb.RPCSuccess, error) {
adminId, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
order, err := accounts.SharedUserOrderDAO.FindUserOrderWithCode(tx, req.Code)
if err != nil {
return nil, err
}
if order == nil {
return nil, errors.New("can not find order")
}
if userId > 0 {
if int64(order.UserId) != userId {
return nil, errors.New("can not find order")
}
}
err = this.RunTx(func(tx *dbs.Tx) error {
return accounts.SharedUserOrderDAO.CancelOrder(tx, adminId, userId, int64(order.Id))
})
if err != nil {
return nil, err
}
return this.Success()
}
// FinishUserOrder 完成订单
func (this *UserOrderService) FinishUserOrder(ctx context.Context, req *pb.FinishUserOrderRequest) (*pb.RPCSuccess, error) {
adminId, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
order, err := accounts.SharedUserOrderDAO.FindUserOrderWithCode(tx, req.Code)
if err != nil {
return nil, err
}
if order == nil {
return nil, errors.New("can not find order")
}
err = this.RunTx(func(tx *dbs.Tx) error {
return accounts.SharedUserOrderDAO.FinishOrder(tx, adminId, 0, int64(order.Id))
})
if err != nil {
return nil, err
}
return this.Success()
}
// CountEnabledUserOrders 计算订单数量
func (this *UserOrderService) CountEnabledUserOrders(ctx context.Context, req *pb.CountEnabledUserOrdersRequest) (*pb.RPCCountResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if userId > 0 {
req.UserId = userId
}
var tx = this.NullTx()
count, err := accounts.SharedUserOrderDAO.CountEnabledUserOrders(tx, req.UserId, req.Status, req.Keyword)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListEnabledUserOrders 列出单页订单
func (this *UserOrderService) ListEnabledUserOrders(ctx context.Context, req *pb.ListEnabledUserOrdersRequest) (*pb.ListEnabledUserOrdersResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if userId > 0 {
req.UserId = userId
}
var tx = this.NullTx()
orders, err := accounts.SharedUserOrderDAO.ListEnabledUserOrders(tx, req.UserId, req.Status, req.Keyword, req.Offset, req.Size)
if err != nil {
return nil, err
}
var pbOrders = []*pb.UserOrder{}
var cacheMap = utils.NewCacheMap()
for _, order := range orders {
// 用户
user, err := models.SharedUserDAO.FindEnabledUser(tx, int64(order.UserId), cacheMap)
if err != nil {
return nil, err
}
var pbUser *pb.User
if user != nil {
pbUser = &pb.User{
Id: int64(user.Id),
Username: user.Username,
Fullname: user.Fullname,
}
}
// 支付方式
method, err := accounts.SharedOrderMethodDAO.FindEnabledOrderMethod(tx, int64(order.MethodId))
if err != nil {
return nil, err
}
var pbMethod *pb.OrderMethod
if method != nil {
pbMethod = &pb.OrderMethod{
Id: int64(method.Id),
Name: method.Name,
Code: method.Code,
IsOn: method.IsOn,
}
}
pbOrders = append(pbOrders, &pb.UserOrder{
UserId: int64(order.UserId),
Code: order.Code,
Type: order.Type,
OrderMethodId: int64(order.MethodId),
Status: order.Status,
Amount: float32(order.Amount),
ParamsJSON: order.Params,
CreatedAt: int64(order.CreatedAt),
CancelledAt: int64(order.CancelledAt),
FinishedAt: int64(order.FinishedAt),
IsExpired: order.IsExpired(),
User: pbUser,
OrderMethod: pbMethod,
})
}
return &pb.ListEnabledUserOrdersResponse{
UserOrders: pbOrders,
}, nil
}
// NotifyUserOrderPayment 订单支付通知
func (this *UserOrderService) NotifyUserOrderPayment(ctx context.Context, req *pb.NotifyUserOrderPaymentRequest) (*pb.RPCSuccess, error) {
userId, err := this.ValidateUserNode(ctx, false)
if err != nil {
return nil, err
}
var formValues = url.Values{}
err = json.Unmarshal(req.FormData, &formValues)
if err != nil {
return nil, errors.New("decode form values failed: " + err.Error())
}
err = this.RunTx(func(tx *dbs.Tx) error {
// 校验检查订单
orderId, err := payments.Verify(req.PayMethod, formValues)
if err != nil {
return err
}
return accounts.SharedUserOrderDAO.FinishOrder(tx, 0, userId, orderId)
})
if err != nil {
return nil, err
}
return this.Success()
}

View File

@@ -0,0 +1,140 @@
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package antiddos
import (
"context"
"errors"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
type ADNetworkService struct {
services.BaseService
}
// CreateADNetwork 创建线路
func (this *ADNetworkService) CreateADNetwork(ctx context.Context, req *pb.CreateADNetworkRequest) (*pb.CreateADNetworkResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
networkId, err := models.SharedADNetworkDAO.CreateNetwork(tx, req.Name, req.Description)
if err != nil {
return nil, err
}
return &pb.CreateADNetworkResponse{AdNetworkId: networkId}, nil
}
// UpdateADNetwork 修改线路
func (this *ADNetworkService) UpdateADNetwork(ctx context.Context, req *pb.UpdateADNetworkRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if req.AdNetworkId <= 0 {
return nil, errors.New("invalid adNetworkId")
}
err = models.SharedADNetworkDAO.UpdateNetwork(tx, req.AdNetworkId, req.IsOn, req.Name, req.Description)
if err != nil {
return nil, err
}
return this.Success()
}
// FindADNetwork 查找单个线路
func (this *ADNetworkService) FindADNetwork(ctx context.Context, req *pb.FindADNetworkRequest) (*pb.FindADNetworkResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
network, err := models.SharedADNetworkDAO.FindEnabledADNetwork(tx, req.AdNetworkId)
if err != nil {
return nil, err
}
if network == nil {
return &pb.FindADNetworkResponse{AdNetwork: nil}, nil
}
return &pb.FindADNetworkResponse{AdNetwork: &pb.ADNetwork{
Id: int64(network.Id),
IsOn: network.IsOn,
Name: network.Name,
Description: network.Description,
}}, nil
}
// FindAllADNetworks 列出所有线路
func (this *ADNetworkService) FindAllADNetworks(ctx context.Context, req *pb.FindAllADNetworkRequest) (*pb.FindAllADNetworkResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
networks, err := models.SharedADNetworkDAO.FindAllNetworks(tx)
if err != nil {
return nil, err
}
var pbNetworks = []*pb.ADNetwork{}
for _, network := range networks {
pbNetworks = append(pbNetworks, &pb.ADNetwork{
Id: int64(network.Id),
IsOn: network.IsOn,
Name: network.Name,
Description: network.Description,
})
}
return &pb.FindAllADNetworkResponse{
AdNetworks: pbNetworks,
}, nil
}
// FindAllAvailableADNetworks 列出所有可用的线路
func (this *ADNetworkService) FindAllAvailableADNetworks(ctx context.Context, req *pb.FindAllAvailableADNetworksRequest) (*pb.FindAllAvailableADNetworksResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
networks, err := models.SharedADNetworkDAO.FindAllAvailableNetworks(tx)
if err != nil {
return nil, err
}
var pbNetworks = []*pb.ADNetwork{}
for _, network := range networks {
pbNetworks = append(pbNetworks, &pb.ADNetwork{
Id: int64(network.Id),
IsOn: network.IsOn,
Name: network.Name,
Description: network.Description,
})
}
return &pb.FindAllAvailableADNetworksResponse{
AdNetworks: pbNetworks,
}, nil
}
// DeleteADNetwork 删除线路
func (this *ADNetworkService) DeleteADNetwork(ctx context.Context, req *pb.DeleteADNetworkRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedADNetworkDAO.DisableADNetwork(tx, req.AdNetworkId)
if err != nil {
return nil, err
}
return this.Success()
}

View File

@@ -0,0 +1,273 @@
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package antiddos
import (
"context"
"errors"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/types"
)
// ADPackageService 高防产品服务
type ADPackageService struct {
services.BaseService
}
// CreateADPackage 创建高防产品
func (this *ADPackageService) CreateADPackage(ctx context.Context, req *pb.CreateADPackageRequest) (*pb.CreateADPackageResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查线路
if req.AdNetworkId <= 0 {
return nil, errors.New("invalid adNetworkId")
}
network, err := models.SharedADNetworkDAO.FindEnabledADNetwork(tx, req.AdNetworkId)
if err != nil {
return nil, err
}
if network == nil {
return nil, errors.New("invalid network")
}
packageId, err := models.SharedADPackageDAO.CreatePackage(tx, req.AdNetworkId, req.ProtectionBandwidthSize, req.ProtectionBandwidthUnit, req.ServerBandwidthSize, req.ServerBandwidthUnit)
if err != nil {
return nil, err
}
return &pb.CreateADPackageResponse{AdPackageId: packageId}, nil
}
// UpdateADPackage 修改高防产品
func (this *ADPackageService) UpdateADPackage(ctx context.Context, req *pb.UpdateADPackageRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查线路
if req.AdNetworkId <= 0 {
return nil, errors.New("invalid adNetworkId")
}
network, err := models.SharedADNetworkDAO.FindEnabledADNetwork(tx, req.AdNetworkId)
if err != nil {
return nil, err
}
if network == nil {
return nil, errors.New("invalid network")
}
err = models.SharedADPackageDAO.UpdatePackage(tx, req.AdPackageId, req.IsOn, req.AdNetworkId, req.ProtectionBandwidthSize, req.ProtectionBandwidthUnit, req.ServerBandwidthSize, req.ServerBandwidthUnit)
if err != nil {
return nil, err
}
return this.Success()
}
// FindADPackage 查找单个高防产品
func (this *ADPackageService) FindADPackage(ctx context.Context, req *pb.FindADPackageRequest) (*pb.FindADPackageResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
adPackage, err := models.SharedADPackageDAO.FindEnabledADPackage(tx, req.AdPackageId)
if err != nil {
return nil, err
}
if adPackage == nil {
return &pb.FindADPackageResponse{AdPackage: nil}, nil
}
// 线路
var pbNetwork *pb.ADNetwork
var network *models.ADNetwork
if adPackage.NetworkId > 0 {
network, err = models.SharedADNetworkDAO.FindEnabledADNetwork(tx, int64(adPackage.NetworkId))
if err != nil {
return nil, err
}
if network != nil {
pbNetwork = &pb.ADNetwork{
Id: int64(network.Id),
IsOn: network.IsOn,
Name: network.Name,
Description: network.Description,
}
}
}
return &pb.FindADPackageResponse{AdPackage: &pb.ADPackage{
Id: int64(adPackage.Id),
IsOn: adPackage.IsOn,
AdNetworkId: int64(adPackage.NetworkId),
ProtectionBandwidthSize: types.Int32(adPackage.ProtectionBandwidthSize),
ProtectionBandwidthUnit: adPackage.ProtectionBandwidthUnit,
ServerBandwidthSize: types.Int32(adPackage.ServerBandwidthSize),
ServerBandwidthUnit: adPackage.ServerBandwidthUnit,
Summary: adPackage.Summary(network),
AdNetwork: pbNetwork,
}}, nil
}
// CountADPackages 查询高防产品数量
func (this *ADPackageService) CountADPackages(ctx context.Context, req *pb.CountADPackagesRequest) (*pb.RPCCountResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
count, err := models.SharedADPackageDAO.CountAllPackages(tx, req.AdNetworkId)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// CountAllIdleADPackages 查询可用的产品数量
func (this *ADPackageService) CountAllIdleADPackages(ctx context.Context, req *pb.CountAllIdleADPackages) (*pb.RPCCountResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
count, err := models.SharedADPackageDAO.CountAllIdlePackages(tx)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListADPackages 列出单页高防产品
func (this *ADPackageService) ListADPackages(ctx context.Context, req *pb.ListADPackagesRequest) (*pb.ListADPackagesResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
packages, err := models.SharedADPackageDAO.ListPackages(tx, req.AdNetworkId, req.Offset, req.Size)
if err != nil {
return nil, err
}
var pbPackages = []*pb.ADPackage{}
for _, p := range packages {
// 线路
var pbNetwork *pb.ADNetwork
var network *models.ADNetwork
if p.NetworkId > 0 {
network, err = models.SharedADNetworkDAO.FindEnabledADNetwork(tx, int64(p.NetworkId))
if err != nil {
return nil, err
}
if network != nil {
pbNetwork = &pb.ADNetwork{
Id: int64(network.Id),
IsOn: network.IsOn,
Name: network.Name,
Description: network.Description,
}
}
}
pbPackages = append(pbPackages, &pb.ADPackage{
Id: int64(p.Id),
IsOn: p.IsOn,
AdNetworkId: int64(p.NetworkId),
ProtectionBandwidthSize: int32(p.ProtectionBandwidthSize),
ProtectionBandwidthUnit: p.ProtectionBandwidthUnit,
ServerBandwidthSize: int32(p.ServerBandwidthSize),
ServerBandwidthUnit: p.ServerBandwidthUnit,
Summary: p.Summary(network),
AdNetwork: pbNetwork,
})
}
return &pb.ListADPackagesResponse{AdPackages: pbPackages}, nil
}
// FindAllIdleADPackages 列出所有可用的高防产品
func (this *ADPackageService) FindAllIdleADPackages(ctx context.Context, req *pb.FindAllIdleADPackagesRequest) (*pb.FindAllIdleADPackagesResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
packages, err := models.SharedADPackageDAO.FindAllIdlePackages(tx)
if err != nil {
return nil, err
}
var pbPackages = []*pb.ADPackage{}
for _, p := range packages {
// 线路
var pbNetwork *pb.ADNetwork
var network *models.ADNetwork
if p.NetworkId > 0 {
network, err = models.SharedADNetworkDAO.FindEnabledADNetwork(tx, int64(p.NetworkId))
if err != nil {
return nil, err
}
if network != nil {
pbNetwork = &pb.ADNetwork{
Id: int64(network.Id),
IsOn: network.IsOn,
Name: network.Name,
Description: network.Description,
}
}
}
// 可用实例
countIdleInstances, err := models.SharedADPackageInstanceDAO.CountIdleInstances(tx, int64(p.Id))
if err != nil {
return nil, err
}
pbPackages = append(pbPackages, &pb.ADPackage{
Id: int64(p.Id),
IsOn: p.IsOn,
AdNetworkId: int64(p.NetworkId),
ProtectionBandwidthSize: int32(p.ProtectionBandwidthSize),
ProtectionBandwidthUnit: p.ProtectionBandwidthUnit,
ServerBandwidthSize: int32(p.ServerBandwidthSize),
ServerBandwidthUnit: p.ServerBandwidthUnit,
Summary: p.Summary(network),
AdNetwork: pbNetwork,
CountIdleADPackageInstances: countIdleInstances,
})
}
return &pb.FindAllIdleADPackagesResponse{AdPackages: pbPackages}, nil
}
// DeleteADPackage 删除高防产品
func (this *ADPackageService) DeleteADPackage(ctx context.Context, req *pb.DeleteADPackageRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedADPackageDAO.DisableADPackage(tx, req.AdPackageId)
if err != nil {
return nil, err
}
return this.Success()
}

View File

@@ -0,0 +1,378 @@
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package antiddos
import (
"context"
"errors"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/types"
)
// ADPackageInstanceService 高防实例服务
type ADPackageInstanceService struct {
services.BaseService
}
// CreateADPackageInstance 创建实例
func (this *ADPackageInstanceService) CreateADPackageInstance(ctx context.Context, req *pb.CreateADPackageInstanceRequest) (*pb.CreateADPackageInstanceResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// validate
if req.AdPackageId <= 0 {
return nil, errors.New("invalid 'adPackageId'")
}
if req.NodeClusterId <= 0 {
return nil, errors.New("invalid 'nodeClusterId'")
}
instanceId, err := models.SharedADPackageInstanceDAO.CreateInstance(tx, req.AdPackageId, req.NodeClusterId, req.NodeIds, req.IpAddresses)
if err != nil {
return nil, err
}
return &pb.CreateADPackageInstanceResponse{AdPackageInstanceId: instanceId}, nil
}
// UpdateADPackageInstance 修改实例
func (this *ADPackageInstanceService) UpdateADPackageInstance(ctx context.Context, req *pb.UpdateADPackageInstanceRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// validate
if req.AdPackageInstanceId <= 0 {
return nil, errors.New("invalid 'adPackageInstanceId'")
}
if req.NodeClusterId <= 0 {
return nil, errors.New("invalid 'nodeClusterId'")
}
err = models.SharedADPackageInstanceDAO.UpdateInstance(tx, req.AdPackageInstanceId, req.NodeClusterId, req.NodeIds, req.IpAddresses, req.IsOn)
if err != nil {
return nil, err
}
return this.Success()
}
// FindADPackageInstance 查找单个实例
func (this *ADPackageInstanceService) FindADPackageInstance(ctx context.Context, req *pb.FindADPackageInstanceRequest) (*pb.FindADPackageInstanceResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
instance, err := models.SharedADPackageInstanceDAO.FindEnabledADPackageInstance(tx, req.AdPackageInstanceId)
if err != nil {
return nil, err
}
if instance == nil {
return &pb.FindADPackageInstanceResponse{
AdPackageInstance: nil,
}, nil
}
// package
adPackage, err := models.SharedADPackageDAO.FindEnabledADPackage(tx, int64(instance.PackageId))
if err != nil {
return nil, err
}
if adPackage == nil {
return &pb.FindADPackageInstanceResponse{
AdPackageInstance: nil,
}, nil
}
var pbPackage *pb.ADPackage
// network
var pbNetwork *pb.ADNetwork
network, err := models.SharedADNetworkDAO.FindEnabledADNetwork(tx, int64(adPackage.NetworkId))
if err != nil {
return nil, err
}
if network == nil {
return &pb.FindADPackageInstanceResponse{
AdPackageInstance: nil,
}, nil
}
pbNetwork = &pb.ADNetwork{
Id: int64(network.Id),
IsOn: network.IsOn,
Name: network.Name,
Description: network.Description,
}
pbPackage = &pb.ADPackage{
Id: int64(adPackage.Id),
ProtectionBandwidthSize: types.Int32(adPackage.ProtectionBandwidthSize),
ProtectionBandwidthUnit: adPackage.ProtectionBandwidthUnit,
ServerBandwidthSize: types.Int32(adPackage.ServerBandwidthSize),
ServerBandwidthUnit: adPackage.ServerBandwidthUnit,
AdNetwork: pbNetwork,
IsOn: adPackage.IsOn,
Summary: adPackage.Summary(network),
}
return &pb.FindADPackageInstanceResponse{AdPackageInstance: &pb.ADPackageInstance{
Id: int64(instance.Id),
IsOn: instance.IsOn,
AdPackageId: int64(instance.PackageId),
NodeClusterId: int64(instance.ClusterId),
NodeIds: instance.DecodeNodeIds(),
IpAddresses: instance.DecodeIPAddresses(),
AdPackage: pbPackage,
}}, nil
}
// FindAllADPackageInstances 列出单个高防产品所有实例
func (this *ADPackageInstanceService) FindAllADPackageInstances(ctx context.Context, req *pb.FindAllADPackageInstancesRequest) (*pb.FindAllADPackageInstancesResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
var pbInstances = []*pb.ADPackageInstance{}
instances, err := models.SharedADPackageInstanceDAO.FindAllPackageInstances(tx, req.AdPackageId)
if err != nil {
return nil, err
}
// package
adPackage, err := models.SharedADPackageDAO.FindEnabledADPackage(tx, req.AdPackageId)
if err != nil {
return nil, err
}
if adPackage == nil {
return &pb.FindAllADPackageInstancesResponse{
AdPackageInstances: nil,
}, nil
}
var pbPackage *pb.ADPackage
// network
var pbNetwork *pb.ADNetwork
network, err := models.SharedADNetworkDAO.FindEnabledADNetwork(tx, int64(adPackage.NetworkId))
if err != nil {
return nil, err
}
if network != nil {
pbNetwork = &pb.ADNetwork{
Id: int64(network.Id),
IsOn: network.IsOn,
Name: network.Name,
Description: network.Description,
}
}
pbPackage = &pb.ADPackage{
Id: int64(adPackage.Id),
ProtectionBandwidthSize: types.Int32(adPackage.ProtectionBandwidthSize),
ProtectionBandwidthUnit: adPackage.ProtectionBandwidthUnit,
ServerBandwidthSize: types.Int32(adPackage.ServerBandwidthSize),
ServerBandwidthUnit: adPackage.ServerBandwidthUnit,
AdNetwork: pbNetwork,
IsOn: adPackage.IsOn,
Summary: adPackage.Summary(network),
}
for _, instance := range instances {
// 集群
var pbCluster *pb.NodeCluster
cluster, err := models.SharedNodeClusterDAO.FindClusterBasicInfo(tx, int64(instance.ClusterId), nil)
if err != nil {
return nil, err
}
if cluster != nil {
pbCluster = &pb.NodeCluster{
Id: int64(cluster.Id),
Name: cluster.Name,
IsOn: cluster.IsOn,
}
}
pbInstances = append(pbInstances, &pb.ADPackageInstance{
Id: int64(instance.Id),
IsOn: instance.IsOn,
AdPackageId: int64(instance.PackageId),
NodeClusterId: int64(instance.ClusterId),
NodeIds: instance.DecodeNodeIds(),
IpAddresses: instance.DecodeIPAddresses(),
NodeCluster: pbCluster,
AdPackage: pbPackage,
})
}
return &pb.FindAllADPackageInstancesResponse{AdPackageInstances: pbInstances}, nil
}
// DeleteADPackageInstance 删除实例
func (this *ADPackageInstanceService) DeleteADPackageInstance(ctx context.Context, req *pb.DeleteADPackageInstanceRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedADPackageInstanceDAO.DisableADPackageInstance(tx, req.AdPackageInstanceId)
if err != nil {
return nil, err
}
return this.Success()
}
// CountIdleADPackageInstances 计算可购的实例数量
func (this *ADPackageInstanceService) CountIdleADPackageInstances(ctx context.Context, req *pb.CountIdleADPackageInstancesRequest) (*pb.RPCCountResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
count, err := models.SharedADPackageInstanceDAO.CountIdleInstances(tx, req.AdPackageId)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// CountADPackageInstances 计算实例数量
func (this *ADPackageInstanceService) CountADPackageInstances(ctx context.Context, req *pb.CountADPackageInstancesRequest) (*pb.RPCCountResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
if userId > 0 {
req.UserId = userId
}
var tx = this.NullTx()
count, err := models.SharedADPackageInstanceDAO.CountInstances(tx, req.UserId, req.AdNetworkId, req.AdPackageId, req.Ip)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListADPackageInstances 列出单页实例
func (this *ADPackageInstanceService) ListADPackageInstances(ctx context.Context, req *pb.ListADPackageInstancesRequest) (*pb.ListADPackageInstancesResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
if userId > 0 {
req.UserId = userId
}
var tx = this.NullTx()
var pbInstances = []*pb.ADPackageInstance{}
instances, err := models.SharedADPackageInstanceDAO.ListInstances(tx, req.UserId, req.AdNetworkId, req.AdPackageId, req.Ip, req.Offset, req.Size)
if err != nil {
return nil, err
}
for _, instance := range instances {
// 集群
var pbCluster *pb.NodeCluster
if instance.ClusterId > 0 {
cluster, err := models.SharedNodeClusterDAO.FindClusterBasicInfo(tx, int64(instance.ClusterId), nil)
if err != nil {
return nil, err
}
if cluster != nil {
pbCluster = &pb.NodeCluster{
Id: int64(cluster.Id),
Name: cluster.Name,
IsOn: cluster.IsOn,
}
}
}
// package
adPackage, err := models.SharedADPackageDAO.FindEnabledADPackage(tx, int64(instance.PackageId))
if err != nil {
return nil, err
}
if adPackage == nil {
continue
}
var pbPackage *pb.ADPackage
// network
var pbNetwork *pb.ADNetwork
network, err := models.SharedADNetworkDAO.FindEnabledADNetwork(tx, int64(adPackage.NetworkId))
if err != nil {
return nil, err
}
if network == nil {
continue
}
pbNetwork = &pb.ADNetwork{
Id: int64(network.Id),
IsOn: network.IsOn,
Name: network.Name,
Description: network.Description,
}
pbPackage = &pb.ADPackage{
Id: int64(adPackage.Id),
ProtectionBandwidthSize: types.Int32(adPackage.ProtectionBandwidthSize),
ProtectionBandwidthUnit: adPackage.ProtectionBandwidthUnit,
ServerBandwidthSize: types.Int32(adPackage.ServerBandwidthSize),
ServerBandwidthUnit: adPackage.ServerBandwidthUnit,
AdNetwork: pbNetwork,
IsOn: adPackage.IsOn,
Summary: adPackage.Summary(network),
}
// user
var pbUser *pb.User
user, err := instance.CurrentUser()
if err != nil {
return nil, err
}
if user != nil {
pbUser = &pb.User{
Id: int64(user.Id),
Fullname: user.Fullname,
Username: user.Username,
}
}
pbInstances = append(pbInstances, &pb.ADPackageInstance{
Id: int64(instance.Id),
IsOn: instance.IsOn,
AdPackageId: int64(instance.PackageId),
NodeClusterId: int64(instance.ClusterId),
NodeIds: instance.DecodeNodeIds(),
IpAddresses: instance.DecodeIPAddresses(),
UserId: int64(instance.UserId),
UserDayTo: instance.UserDayTo,
NodeCluster: pbCluster,
AdPackage: pbPackage,
User: pbUser,
})
}
return &pb.ListADPackageInstancesResponse{AdPackageInstances: pbInstances}, nil
}

View File

@@ -0,0 +1,147 @@
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package antiddos
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
// ADPackagePeriodService 高防实例有效期服务
type ADPackagePeriodService struct {
services.BaseService
}
// CreateADPackagePeriod 创建有效期
func (this *ADPackagePeriodService) CreateADPackagePeriod(ctx context.Context, req *pb.CreateADPackagePeriodRequest) (*pb.CreateADPackagePeriodResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
periodId, err := models.SharedADPackagePeriodDAO.CreatePeriod(tx, req.Count, req.Unit)
if err != nil {
return nil, err
}
return &pb.CreateADPackagePeriodResponse{
AdPackagePeriodId: periodId,
}, nil
}
// UpdateADPackagePeriod 修改有效期
func (this *ADPackagePeriodService) UpdateADPackagePeriod(ctx context.Context, req *pb.UpdateADPackagePeriodRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedADPackagePeriodDAO.UpdatePeriod(tx, req.AdPackagePeriodId, req.IsOn)
if err != nil {
return nil, err
}
return this.Success()
}
// DeleteADPackagePeriod 删除有效期
func (this *ADPackagePeriodService) DeleteADPackagePeriod(ctx context.Context, req *pb.DeleteADPackagePeriodRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedADPackagePeriodDAO.DisableADPackagePeriod(tx, req.AdPackagePeriodId)
if err != nil {
return nil, err
}
return this.Success()
}
// FindADPackagePeriod 查找有效期
func (this *ADPackagePeriodService) FindADPackagePeriod(ctx context.Context, req *pb.FindADPackagePeriodRequest) (*pb.FindADPackagePeriodResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
period, err := models.SharedADPackagePeriodDAO.FindEnabledADPackagePeriod(tx, req.AdPackagePeriodId)
if err != nil {
return nil, err
}
if period == nil {
return &pb.FindADPackagePeriodResponse{AdPackagePeriod: nil}, nil
}
return &pb.FindADPackagePeriodResponse{AdPackagePeriod: &pb.ADPackagePeriod{
Id: int64(period.Id),
IsOn: period.IsOn,
Count: int32(period.Count),
Unit: period.Unit,
Months: int32(period.Months),
}}, nil
}
// FindAllADPackagePeriods 列出所有有效期
func (this *ADPackagePeriodService) FindAllADPackagePeriods(ctx context.Context, req *pb.FindAllADPackagePeriodsRequest) (*pb.FindAllADPackagePeriodsResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
periods, err := models.SharedADPackagePeriodDAO.FindAllPeriods(tx)
if err != nil {
return nil, err
}
var pbPeriods = []*pb.ADPackagePeriod{}
for _, period := range periods {
pbPeriods = append(pbPeriods, &pb.ADPackagePeriod{
Id: int64(period.Id),
IsOn: period.IsOn,
Count: int32(period.Count),
Unit: period.Unit,
Months: int32(period.Months),
})
}
return &pb.FindAllADPackagePeriodsResponse{
AdPackagePeriods: pbPeriods,
}, nil
}
// FindAllAvailableADPackagePeriods 列出所有可用有效期
func (this *ADPackagePeriodService) FindAllAvailableADPackagePeriods(ctx context.Context, req *pb.FindAllAvailableADPackagePeriodsRequest) (*pb.FindAllAvailableADPackagePeriodsResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
periods, err := models.SharedADPackagePeriodDAO.FindAllAvailablePeriods(tx)
if err != nil {
return nil, err
}
var pbPeriods = []*pb.ADPackagePeriod{}
for _, period := range periods {
pbPeriods = append(pbPeriods, &pb.ADPackagePeriod{
Id: int64(period.Id),
IsOn: period.IsOn,
Count: int32(period.Count),
Unit: period.Unit,
Months: int32(period.Months),
})
}
return &pb.FindAllAvailableADPackagePeriodsResponse{
AdPackagePeriods: pbPeriods,
}, nil
}

View File

@@ -0,0 +1,120 @@
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package antiddos
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
type ADPackagePriceService struct {
services.BaseService
}
// UpdateADPackagePrice 设置高防产品价格
func (this *ADPackagePriceService) UpdateADPackagePrice(ctx context.Context, req *pb.UpdateADPackagePriceRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// TODO 检查各项参数的有效性
err = models.SharedADPackagePriceDAO.UpdatePackagePrice(tx, req.AdPackageId, req.AdPackagePeriodId, req.Price)
if err != nil {
return nil, err
}
return this.Success()
}
// FindADPackagePrice 获取单个高防产品具体价格
func (this *ADPackagePriceService) FindADPackagePrice(ctx context.Context, req *pb.FindADPackagePriceRequest) (*pb.FindADPackagePriceResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
price, err := models.SharedADPackagePriceDAO.FindPackagePrice(tx, req.AdPackageId, req.AdPackagePeriodId)
if err != nil {
return nil, err
}
return &pb.FindADPackagePriceResponse{
Price: price,
Amount: float64(req.Count) * price,
}, nil
}
// CountADPackagePrices 计算高防产品价格项数量
func (this *ADPackagePriceService) CountADPackagePrices(ctx context.Context, req *pb.CountADPackagePricesRequest) (*pb.RPCCountResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
count, err := models.SharedADPackagePriceDAO.CountPackagePrices(tx, req.AdPackageId)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// FindADPackagePrices 查找高防产品价格
func (this *ADPackagePriceService) FindADPackagePrices(ctx context.Context, req *pb.FindADPackagePricesRequest) (*pb.FindADPackagePricesResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
prices, err := models.SharedADPackagePriceDAO.FindPackagePrices(tx, req.AdPackageId)
if err != nil {
return nil, err
}
var pbPrices = []*pb.ADPackagePrice{}
for _, price := range prices {
pbPrices = append(pbPrices, &pb.ADPackagePrice{
AdPackageId: int64(price.PackageId),
AdPackagePeriodId: int64(price.PeriodId),
Price: price.Price,
})
}
return &pb.FindADPackagePricesResponse{
AdPackagePrices: pbPrices,
}, nil
}
// FindAllADPackagePrices 查找所有高防产品价格
func (this *ADPackagePriceService) FindAllADPackagePrices(ctx context.Context, req *pb.FindAllADPackagePricesRequest) (*pb.FindAllADPackagePricesResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
prices, err := models.SharedADPackagePriceDAO.FindAllPackagePrices(tx)
if err != nil {
return nil, err
}
var pbPrices = []*pb.ADPackagePrice{}
for _, price := range prices {
pbPrices = append(pbPrices, &pb.ADPackagePrice{
AdPackageId: int64(price.PackageId),
AdPackagePeriodId: int64(price.PeriodId),
Price: price.Price,
})
}
return &pb.FindAllADPackagePricesResponse{
AdPackagePrices: pbPrices,
}, nil
}

View File

@@ -0,0 +1,732 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package antiddos
import (
"context"
"encoding/json"
"errors"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/accounts"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/userconfigs"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
"strings"
)
// UserADInstanceService 用户高防实例服务
type UserADInstanceService struct {
services.BaseService
}
// CreateUserADInstance 创建用户高防实例
func (this *UserADInstanceService) CreateUserADInstance(ctx context.Context, req *pb.CreateUserADInstanceRequest) (*pb.CreateUserADInstanceResponse, error) {
adminId, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 高防产品
if req.AdPackageId <= 0 {
return nil, errors.New("invalid 'adPackageId'")
}
adPackage, err := models.SharedADPackageDAO.FindEnabledADPackage(tx, req.AdPackageId)
if err != nil {
return nil, err
}
if adPackage == nil {
return nil, errors.New("invalid 'adPackage'")
}
// 有效期选项
if req.AdPackagePeriodId <= 0 {
return nil, errors.New("invalid 'adPackagePeriodId'")
}
period, err := models.SharedADPackagePeriodDAO.FindEnabledADPackagePeriod(tx, req.AdPackagePeriodId)
if err != nil {
return nil, err
}
if period == nil || !period.IsOn {
return nil, errors.New("could not find instance period with id '" + types.String(req.AdPackagePeriodId) + "'")
}
_, dayTo := period.DayPeriod()
// 数量
if req.Count <= 0 {
return nil, errors.New("invalid 'count'")
}
instances, err := models.SharedADPackageInstanceDAO.FindIdlePackageInstances(tx, req.AdPackageId, req.Count)
if err != nil {
return nil, err
}
var countInstances = int32(len(instances))
if countInstances < req.Count {
return nil, errors.New("no enough instances")
}
if countInstances > req.Count {
instances = instances[:req.Count]
}
var userInstanceIds = []int64{}
err = this.RunTx(func(tx *dbs.Tx) error {
for _, instance := range instances {
var instanceId = int64(instance.Id)
userInstanceId, err := models.SharedUserADInstanceDAO.CreateUserInstance(tx, req.UserId, adminId, instanceId, req.AdPackagePeriodId)
if err != nil {
return err
}
err = models.SharedADPackageInstanceDAO.UpdateInstanceUser(tx, instanceId, req.UserId, dayTo, userInstanceId)
if err != nil {
return err
}
userInstanceIds = append(userInstanceIds, userInstanceId)
}
return nil
})
if err != nil {
return nil, err
}
return &pb.CreateUserADInstanceResponse{
UserADInstanceIds: userInstanceIds,
}, nil
}
// BuyUserADInstance 购买用户高防实例
func (this *UserADInstanceService) BuyUserADInstance(ctx context.Context, req *pb.BuyUserADInstanceRequest) (*pb.BuyUserADInstanceResponse, error) {
adminId, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
if userId > 0 {
req.UserId = userId
}
userId = req.UserId
var tx = this.NullTx()
// 高防产品
if req.AdPackageId <= 0 {
return nil, errors.New("invalid 'adPackageId'")
}
adPackage, err := models.SharedADPackageDAO.FindEnabledADPackage(tx, req.AdPackageId)
if err != nil {
return nil, err
}
if adPackage == nil {
return nil, errors.New("invalid 'adPackage'")
}
// 线路
network, err := models.SharedADNetworkDAO.FindEnabledADNetwork(tx, int64(adPackage.NetworkId))
if err != nil {
return nil, err
}
if network == nil {
return nil, errors.New("invalid 'network'")
}
// 有效期选项
if req.AdPackagePeriodId <= 0 {
return nil, errors.New("invalid 'adPackagePeriodId'")
}
period, err := models.SharedADPackagePeriodDAO.FindEnabledADPackagePeriod(tx, req.AdPackagePeriodId)
if err != nil {
return nil, err
}
if period == nil || !period.IsOn {
return nil, errors.New("could not find instance period with id '" + types.String(req.AdPackagePeriodId) + "'")
}
_, dayTo := period.DayPeriod()
// 数量
if req.Count <= 0 {
return nil, errors.New("invalid 'count'")
}
var userInstanceIds = []int64{}
err = this.RunTx(func(tx *dbs.Tx) error {
var packageId = int64(adPackage.Id)
instances, err := models.SharedADPackageInstanceDAO.FindIdlePackageInstances(tx, packageId, req.Count)
if err != nil {
return err
}
var countInstances = int32(len(instances))
if countInstances < req.Count {
return errors.New("no enough instances")
}
if countInstances > req.Count {
instances = instances[:req.Count]
}
// 获取价格
price, err := models.SharedADPackagePriceDAO.FindPackagePrice(tx, packageId, req.AdPackagePeriodId)
if err != nil {
return err
}
if price <= 0 {
return errors.New("invalid package price, id:" + types.String(packageId))
}
var amount = price * float64(req.Count)
// 先减少余额
account, err := accounts.SharedUserAccountDAO.FindUserAccountWithUserId(tx, userId)
if err != nil {
return err
}
if account == nil || account.Total < amount {
return errors.New("no enough balance to buy the package")
}
err = accounts.SharedUserAccountDAO.UpdateUserAccount(tx, int64(account.Id), -amount, userconfigs.AccountEventTypeBuyAntiDDoSPackage, "购买DDoS高防实例线路"+network.Name+" / 防护带宽:"+types.String(adPackage.ProtectionBandwidthSize)+userconfigs.ADPackageSizeFullUnit(adPackage.ProtectionBandwidthUnit)+" / 业务带宽:"+types.String(adPackage.ServerBandwidthSize)+userconfigs.ADPackageSizeFullUnit(adPackage.ServerBandwidthUnit)+" / "+types.String(period.Count)+userconfigs.ADPackagePeriodUnitName(period.Unit)+"\" x "+types.String(req.Count), maps.Map{
"adNetworkId": network.Id,
"adPackageId": packageId,
"protectionBandwidthSize": adPackage.ProtectionBandwidthSize,
"protectionBandwidthUnit": adPackage.ProtectionBandwidthUnit,
"serverBandwidthSize": adPackage.ServerBandwidthSize,
"serverBandwidthUnit": adPackage.ServerBandwidthUnit,
"adPackagePeriodId": req.AdPackagePeriodId,
"adPackagePeriodCount": period.Count,
"adPackagePeriodUnit": period.Unit,
"count": req.Count,
})
if err != nil {
return err
}
for _, instance := range instances {
var instanceId = int64(instance.Id)
userInstanceId, err := models.SharedUserADInstanceDAO.CreateUserInstance(tx, req.UserId, adminId, instanceId, req.AdPackagePeriodId)
if err != nil {
return err
}
err = models.SharedADPackageInstanceDAO.UpdateInstanceUser(tx, instanceId, req.UserId, dayTo, userInstanceId)
if err != nil {
return err
}
userInstanceIds = append(userInstanceIds, userInstanceId)
}
return nil
})
if err != nil {
return nil, err
}
return &pb.BuyUserADInstanceResponse{
UserADInstanceIds: userInstanceIds,
}, nil
}
// CountUserADInstances 查询当前高防实例数量
func (this *UserADInstanceService) CountUserADInstances(ctx context.Context, req *pb.CountUserADInstancesRequest) (*pb.RPCCountResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
if userId > 0 {
req.UserId = userId
}
var tx = this.NullTx()
count, err := models.SharedUserADInstanceDAO.CountUserInstances(tx, req.AdNetworkId, req.UserId, req.AdPackagePeriodId, req.ExpiresDay, req.AvailableOnly)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListUserADInstances 列出单页高防实例
func (this *UserADInstanceService) ListUserADInstances(ctx context.Context, req *pb.ListUserADInstancesRequest) (*pb.ListUserADInstancesResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var fromUser = false
if userId > 0 {
fromUser = true
req.UserId = userId
}
var tx = this.NullTx()
userInstances, err := models.SharedUserADInstanceDAO.ListUserInstances(tx, req.AdNetworkId, req.UserId, req.AdPackagePeriodId, req.ExpiresDay, req.AvailableOnly, req.Offset, req.Size)
if err != nil {
return nil, err
}
var pbUserInstances = []*pb.UserADInstance{}
for _, userInstance := range userInstances {
instanceIsAvailable, err := userInstance.CheckAvailable(tx)
if err != nil {
return nil, err
}
// instance
instance, err := models.SharedADPackageInstanceDAO.FindEnabledADPackageInstance(tx, int64(userInstance.InstanceId))
if err != nil {
return nil, err
}
if instance == nil {
continue
}
// package
p, err := models.SharedADPackageDAO.FindEnabledADPackage(tx, int64(instance.PackageId))
if err != nil {
return nil, err
}
var pbPackage *pb.ADPackage
if p != nil {
// network
var pbNetwork *pb.ADNetwork
network, err := models.SharedADNetworkDAO.FindEnabledADNetwork(tx, int64(p.NetworkId))
if err != nil {
return nil, err
}
if network != nil {
pbNetwork = &pb.ADNetwork{
Id: int64(network.Id),
IsOn: network.IsOn,
Name: network.Name,
Description: network.Description,
}
}
pbPackage = &pb.ADPackage{
Id: int64(p.Id),
ProtectionBandwidthSize: types.Int32(p.ProtectionBandwidthSize),
ProtectionBandwidthUnit: p.ProtectionBandwidthUnit,
ServerBandwidthSize: types.Int32(p.ServerBandwidthSize),
ServerBandwidthUnit: p.ServerBandwidthUnit,
Summary: p.Summary(network),
AdNetwork: pbNetwork,
IsOn: p.IsOn,
}
}
// 集群
var pbCluster *pb.NodeCluster
if !fromUser {
cluster, err := models.SharedNodeClusterDAO.FindClusterBasicInfo(tx, int64(instance.ClusterId), nil)
if err != nil {
return nil, err
}
if cluster != nil {
pbCluster = &pb.NodeCluster{
Id: int64(cluster.Id),
Name: cluster.Name,
IsOn: cluster.IsOn,
}
}
}
var pbInstance = &pb.ADPackageInstance{
Id: int64(instance.Id),
NodeClusterId: int64(instance.ClusterId),
NodeIds: instance.DecodeNodeIds(),
IpAddresses: instance.DecodeIPAddresses(),
NodeCluster: pbCluster,
AdPackage: pbPackage,
UserInstanceId: int64(instance.UserInstanceId),
}
// user
var pbUser *pb.User
user, err := models.SharedUserDAO.FindEnabledBasicUser(tx, int64(userInstance.UserId))
if err != nil {
return nil, err
}
if user != nil {
pbUser = &pb.User{
Id: int64(user.Id),
Username: user.Username,
Fullname: user.Fullname,
IsOn: user.IsOn,
}
}
pbUserInstances = append(pbUserInstances, &pb.UserADInstance{
Id: int64(userInstance.Id),
UserId: int64(userInstance.UserId),
AdPackageInstanceId: int64(userInstance.InstanceId),
AdPackagePeriodId: int64(userInstance.PeriodId),
AdPackagePeriodCount: int32(userInstance.PeriodCount),
AdPackagePeriodUnit: userInstance.PeriodUnit,
DayFrom: userInstance.DayFrom,
DayTo: userInstance.DayTo,
CreatedAt: int64(userInstance.CreatedAt),
MaxObjects: types.Int32(userInstance.MaxObjects),
ObjectCodes: userInstance.DecodeObjectCodes(),
AdPackageInstance: pbInstance,
User: pbUser,
CanDelete: userInstance.AdminId > 0,
IsAvailable: instanceIsAvailable,
CountObjects: int32(len(userInstance.DecodeObjectCodes())),
})
}
return &pb.ListUserADInstancesResponse{
UserADInstances: pbUserInstances,
}, nil
}
// FindUserADInstance 查找单个用户高防实例
func (this *UserADInstanceService) FindUserADInstance(ctx context.Context, req *pb.FindUserADInstanceRequest) (*pb.FindUserADInstanceResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
userInstance, err := models.SharedUserADInstanceDAO.FindEnabledUserADInstance(tx, req.UserADInstanceId)
if err != nil {
return nil, err
}
if userInstance == nil {
return &pb.FindUserADInstanceResponse{
UserADInstance: nil,
}, nil
}
// 检查用户
if userId > 0 && int64(userInstance.UserId) != userId {
return nil, this.PermissionError()
}
// 是否有效
instanceIsAvailable, err := userInstance.CheckAvailable(tx)
if err != nil {
return nil, err
}
// 防护对象
objects, err := userInstance.DecodeObjects()
if err != nil {
return nil, err
}
if objects == nil {
objects = []maps.Map{}
}
objectsJSON, err := json.Marshal(objects)
if err != nil {
return nil, err
}
// instance
instance, err := models.SharedADPackageInstanceDAO.FindEnabledADPackageInstance(tx, int64(userInstance.InstanceId))
if err != nil {
return nil, err
}
if instance == nil {
return &pb.FindUserADInstanceResponse{
UserADInstance: nil,
}, nil
}
// package
p, err := models.SharedADPackageDAO.FindEnabledADPackage(tx, int64(instance.PackageId))
if err != nil {
return nil, err
}
var pbPackage *pb.ADPackage
if p != nil {
// network
var pbNetwork *pb.ADNetwork
network, err := models.SharedADNetworkDAO.FindEnabledADNetwork(tx, int64(p.NetworkId))
if err != nil {
return nil, err
}
if network != nil {
pbNetwork = &pb.ADNetwork{
Id: int64(network.Id),
IsOn: network.IsOn,
Name: network.Name,
Description: network.Description,
}
}
pbPackage = &pb.ADPackage{
Id: int64(p.Id),
ProtectionBandwidthSize: types.Int32(p.ProtectionBandwidthSize),
ProtectionBandwidthUnit: p.ProtectionBandwidthUnit,
ServerBandwidthSize: types.Int32(p.ServerBandwidthSize),
ServerBandwidthUnit: p.ServerBandwidthUnit,
Summary: p.Summary(network),
AdNetwork: pbNetwork,
IsOn: p.IsOn,
}
}
var pbInstance = &pb.ADPackageInstance{
Id: int64(instance.Id),
AdPackageId: int64(instance.PackageId),
NodeClusterId: int64(instance.ClusterId),
NodeIds: instance.DecodeNodeIds(),
IpAddresses: instance.DecodeIPAddresses(),
NodeCluster: nil,
AdPackage: pbPackage,
UserInstanceId: int64(instance.UserInstanceId),
}
// user
var pbUser *pb.User
user, err := models.SharedUserDAO.FindEnabledBasicUser(tx, int64(userInstance.UserId))
if err != nil {
return nil, err
}
if user != nil {
pbUser = &pb.User{
Id: int64(user.Id),
Username: user.Username,
Fullname: user.Fullname,
}
}
return &pb.FindUserADInstanceResponse{
UserADInstance: &pb.UserADInstance{
Id: int64(userInstance.Id),
UserId: int64(userInstance.UserId),
AdPackageInstanceId: int64(userInstance.InstanceId),
AdPackagePeriodId: int64(userInstance.PeriodId),
AdPackagePeriodCount: int32(userInstance.PeriodCount),
IsAvailable: instanceIsAvailable,
AdPackagePeriodUnit: userInstance.PeriodUnit,
DayFrom: userInstance.DayFrom,
DayTo: userInstance.DayTo,
CreatedAt: int64(userInstance.CreatedAt),
MaxObjects: types.Int32(userInstance.MaxObjects),
ObjectCodes: userInstance.DecodeObjectCodes(),
ObjectsJSON: objectsJSON,
AdPackageInstance: pbInstance,
User: pbUser,
CanDelete: userInstance.AdminId > 0,
}}, nil
}
// DeleteUserADInstance 删除高防实例
func (this *UserADInstanceService) DeleteUserADInstance(ctx context.Context, req *pb.DeleteUserADInstanceRequest) (*pb.RPCSuccess, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
userInstance, err := models.SharedUserADInstanceDAO.FindEnabledUserADInstance(tx, req.UserADInstanceId)
if err != nil {
return nil, err
}
if userInstance == nil {
// 不存在,则直接成功
return this.Success()
}
// 检查用户
if userId > 0 {
if userId != int64(userInstance.UserId) {
return nil, this.PermissionError()
}
}
var instanceId = int64(userInstance.InstanceId)
err = this.RunTx(func(tx *dbs.Tx) error {
err = models.SharedUserADInstanceDAO.DisableUserADInstance(tx, req.UserADInstanceId)
if err != nil {
return err
}
return models.SharedADPackageInstanceDAO.ResetInstanceUser(tx, instanceId)
})
return this.Success()
}
// RenewUserADInstance 续期
func (this *UserADInstanceService) RenewUserADInstance(ctx context.Context, req *pb.RenewUserADInstanceRequest) (*pb.RPCSuccess, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
userInstance, err := models.SharedUserADInstanceDAO.FindEnabledUserADInstance(tx, req.UserADInstanceId)
if err != nil {
return nil, err
}
if userInstance == nil {
return nil, errors.New("could not find user instance to renew")
}
// 检查用户
if userId > 0 {
if userId != int64(userInstance.UserId) {
return nil, errors.New("could not find user instance to renew")
}
}
err = this.RunTx(func(tx *dbs.Tx) error {
// 查找实例信息
var instanceId = int64(userInstance.InstanceId)
instance, err := models.SharedADPackageInstanceDAO.FindEnabledADPackageInstance(tx, instanceId)
if err != nil {
return err
}
if instance == nil {
return errors.New("the instance has been invalid")
}
// 确保操作的是同一个实例
if instance.UserInstanceId > 0 && int64(instance.UserInstanceId) != req.UserADInstanceId {
return errors.New("the instance has been token by other user")
}
var packageId = int64(instance.PackageId)
adPackage, err := models.SharedADPackageDAO.FindEnabledADPackage(tx, packageId)
if err != nil {
return err
}
if adPackage == nil || !adPackage.IsOn {
return errors.New("the package has been invalid")
}
// 线路
network, err := models.SharedADNetworkDAO.FindEnabledADNetwork(tx, int64(adPackage.NetworkId))
if err != nil {
return err
}
if network == nil {
return errors.New("the network has been invalid")
}
// 检查有效期
if req.AdPackagePeriodId <= 0 {
return errors.New("invalid 'adPackagePeriodId'")
}
period, err := models.SharedADPackagePeriodDAO.FindEnabledADPackagePeriod(tx, req.AdPackagePeriodId)
if err != nil {
return err
}
if period == nil {
return errors.New("could not find period '" + types.String(req.AdPackagePeriodId) + "'")
}
price, err := models.SharedADPackagePriceDAO.FindPackagePrice(tx, packageId, req.AdPackagePeriodId)
if err != nil {
return err
}
if price <= 0 {
return errors.New("can not find price for the instance")
}
// 如果是用户需要支付费用
if userId > 0 {
var amount = price
// 先减少余额
account, err := accounts.SharedUserAccountDAO.FindUserAccountWithUserId(tx, userId)
if err != nil {
return err
}
if account == nil || account.Total < amount {
return errors.New("no enough balance to buy the package")
}
err = accounts.SharedUserAccountDAO.UpdateUserAccount(tx, int64(account.Id), -amount, userconfigs.AccountEventTypeRenewAntiDDoSPackage, "续费高DDoS防实例线路"+network.Name+" / 防护带宽:"+types.String(adPackage.ProtectionBandwidthSize)+userconfigs.ADPackageSizeFullUnit(adPackage.ProtectionBandwidthUnit)+" / 业务带宽:"+types.String(adPackage.ServerBandwidthSize)+userconfigs.ADPackageSizeFullUnit(adPackage.ServerBandwidthUnit)+" / 高防IP"+strings.Join(instance.DecodeIPAddresses(), "")+" / "+types.String(period.Count)+userconfigs.ADPackagePeriodUnitName(period.Unit), maps.Map{
"adNetworkId": network.Id,
"adPackageId": packageId,
"protectionBandwidthSize": adPackage.ProtectionBandwidthSize,
"protectionBandwidthUnit": adPackage.ProtectionBandwidthUnit,
"serverBandwidthSize": adPackage.ServerBandwidthSize,
"serverBandwidthUnit": adPackage.ServerBandwidthUnit,
"adPackagePeriodId": req.AdPackagePeriodId,
"adPackagePeriodCount": period.Count,
"adPackagePeriodUnit": period.Unit,
"count": 1,
})
if err != nil {
return err
}
}
_, err = models.SharedUserADInstanceDAO.RenewUserInstance(tx, userInstance, period)
if err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
return this.Success()
}
// UpdateUserADInstanceObjects 修改实例防护对象
func (this *UserADInstanceService) UpdateUserADInstanceObjects(ctx context.Context, req *pb.UpdateUserADInstanceObjectsRequest) (*pb.RPCSuccess, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
userInstance, err := models.SharedUserADInstanceDAO.FindEnabledUserADInstance(tx, req.UserADInstanceId)
if err != nil {
return nil, err
}
if userInstance == nil {
return nil, errors.New("could not find user instance with id '" + types.String(req.UserADInstanceId) + "'")
}
var instanceId = int64(userInstance.InstanceId)
// 检查用户
if userId > 0 {
if int64(userInstance.UserId) != userId {
return nil, this.PermissionError()
}
}
// 检查当前实例是否有效
isAvailable, err := userInstance.CheckAvailable(tx)
if err != nil {
return nil, err
}
if !isAvailable {
return nil, errors.New("the user instance is not available")
}
// TODO 检查有没有超出最大防护对象数量
err = this.RunTx(func(tx *dbs.Tx) error {
err = models.SharedUserADInstanceDAO.UpdateUserInstanceObjects(tx, req.UserADInstanceId, req.ObjectCodes)
if err != nil {
return err
}
return models.SharedADPackageInstanceDAO.UpdateInstanceObjects(tx, instanceId, req.ObjectCodes)
})
if err != nil {
return nil, err
}
return this.Success()
}

View File

@@ -0,0 +1,40 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package clients
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/clients"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
// ClientAgentService Agent服务
type ClientAgentService struct {
services.BaseService
}
// FindAllClientAgents 查找所有Agent
func (this *ClientAgentService) FindAllClientAgents(ctx context.Context, req *pb.FindAllClientAgentsRequest) (*pb.FindAllClientAgentsResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
agents, err := clients.SharedClientAgentDAO.FindAllAgents(tx)
if err != nil {
return nil, err
}
var pbAgents = []*pb.ClientAgent{}
for _, agent := range agents {
pbAgents = append(pbAgents, &pb.ClientAgent{
Id: int64(agent.Id),
Name: agent.Name,
Code: agent.Code,
Description: agent.Description,
CountIPs: int64(agent.CountIPs),
})
}
return &pb.FindAllClientAgentsResponse{ClientAgents: pbAgents}, nil
}

View File

@@ -0,0 +1,97 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package clients
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/clients"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
// ClientAgentIPService Agent IP服务
type ClientAgentIPService struct {
services.BaseService
}
// CreateClientAgentIPs 创建一组IP
func (this *ClientAgentIPService) CreateClientAgentIPs(ctx context.Context, req *pb.CreateClientAgentIPsRequest) (*pb.RPCSuccess, error) {
_, _, err := this.ValidateNodeId(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeDNS, rpcutils.UserTypeNode)
if err != nil {
return nil, err
}
if len(req.AgentIPs) == 0 {
return this.Success()
}
var tx = this.NullTx()
for _, agentIP := range req.AgentIPs {
agentId, err := clients.SharedClientAgentDAO.FindAgentIdWithCode(tx, agentIP.AgentCode)
if err != nil {
return nil, err
}
if agentId <= 0 {
continue
}
err = clients.SharedClientAgentIPDAO.CreateIP(tx, agentId, agentIP.Ip, agentIP.Ptr)
if err != nil {
return nil, err
}
}
return this.Success()
}
// ListClientAgentIPsAfterId 查询最新的IP
func (this *ClientAgentIPService) ListClientAgentIPsAfterId(ctx context.Context, req *pb.ListClientAgentIPsAfterIdRequest) (*pb.ListClientAgentIPsAfterIdResponse, error) {
_, _, err := this.ValidateNodeId(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeDNS, rpcutils.UserTypeNode)
if err != nil {
return nil, err
}
if req.Size <= 0 {
req.Size = 10000
}
var tx = this.NullTx()
var agentMap = map[int64]*clients.ClientAgent{} // agentId => agentCode
agentIPs, err := clients.SharedClientAgentIPDAO.ListIPsAfterId(tx, req.Id, req.Size)
if err != nil {
return nil, err
}
var pbIPs = []*pb.ClientAgentIP{}
for _, agentIP := range agentIPs {
var agentId = int64(agentIP.AgentId)
agent, ok := agentMap[agentId]
if !ok {
agent, err = clients.SharedClientAgentDAO.FindAgent(tx, agentId)
if err != nil {
return nil, err
}
if agent == nil {
continue
}
agentMap[agentId] = agent
}
pbIPs = append(pbIPs, &pb.ClientAgentIP{
Id: int64(agentIP.Id),
Ip: agentIP.IP,
Ptr: agentIP.Ptr, // 导出时需要
ClientAgent: &pb.ClientAgent{
Id: agentId,
Name: "",
Code: agent.Code,
Description: "",
},
})
}
return &pb.ListClientAgentIPsAfterIdResponse{
ClientAgentIPs: pbIPs,
}, nil
}

View File

@@ -0,0 +1,144 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package clients
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/types"
)
// FormalClientBrowserService 浏览器信息库服务
type FormalClientBrowserService struct {
services.BaseService
}
// CreateFormalClientBrowser 创建浏览器信息
func (this *FormalClientBrowserService) CreateFormalClientBrowser(ctx context.Context, req *pb.CreateFormalClientBrowserRequest) (*pb.CreateFormalClientBrowserResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
// 检查dataId是否存在
var tx = this.NullTx()
browser, err := models.SharedFormalClientBrowserDAO.FindBrowserWithDataId(tx, req.DataId)
if err != nil {
return nil, err
}
if browser != nil {
return nil, errors.New("dataId '" + req.DataId + "' already exists")
}
browserId, err := models.SharedFormalClientBrowserDAO.CreateBrowser(tx, req.Name, req.Codes, req.DataId)
if err != nil {
return nil, err
}
return &pb.CreateFormalClientBrowserResponse{
FormalClientBrowserId: browserId,
}, nil
}
// CountFormalClientBrowsers 计算浏览器信息数量
func (this *FormalClientBrowserService) CountFormalClientBrowsers(ctx context.Context, req *pb.CountFormalClientBrowsersRequest) (*pb.RPCCountResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
count, err := models.SharedFormalClientBrowserDAO.CountBrowsers(tx, req.Keyword)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListFormalClientBrowsers 列出单页浏览器信息
func (this *FormalClientBrowserService) ListFormalClientBrowsers(ctx context.Context, req *pb.ListFormalClientBrowsersRequest) (*pb.ListFormalClientBrowsersResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
browsers, err := models.SharedFormalClientBrowserDAO.ListBrowsers(tx, req.Keyword, req.Offset, req.Size)
if err != nil {
return nil, err
}
var pbBrowsers = []*pb.FormalClientBrowser{}
for _, browser := range browsers {
pbBrowsers = append(pbBrowsers, &pb.FormalClientBrowser{
Id: int64(browser.Id),
Name: browser.Name,
Codes: browser.DecodeCodes(),
DataId: browser.DataId,
State: types.Int32(browser.State),
})
}
return &pb.ListFormalClientBrowsersResponse{
FormalClientBrowsers: pbBrowsers,
}, nil
}
// UpdateFormalClientBrowser 修改浏览器信息
func (this *FormalClientBrowserService) UpdateFormalClientBrowser(ctx context.Context, req *pb.UpdateFormalClientBrowserRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
if len(req.DataId) == 0 {
return nil, errors.New("invalid dataId")
}
var tx = this.NullTx()
// 检查dataId是否已经被使用
oldBrowser, err := models.SharedFormalClientBrowserDAO.FindBrowserWithDataId(tx, req.DataId)
if err != nil {
return nil, err
}
if oldBrowser != nil && int64(oldBrowser.Id) != req.FormalClientBrowserId {
return nil, errors.New("the dataId '" + req.DataId + "' already has been used")
}
err = models.SharedFormalClientBrowserDAO.UpdateBrowser(tx, req.FormalClientBrowserId, req.Name, req.Codes, req.DataId)
if err != nil {
return nil, err
}
return this.Success()
}
// FindFormalClientBrowserWithDataId 通过dataId查询浏览器信息
func (this *FormalClientBrowserService) FindFormalClientBrowserWithDataId(ctx context.Context, req *pb.FindFormalClientBrowserWithDataIdRequest) (*pb.FindFormalClientBrowserWithDataIdResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
browser, err := models.SharedFormalClientBrowserDAO.FindBrowserWithDataId(tx, req.DataId)
if err != nil {
return nil, err
}
if browser == nil {
return &pb.FindFormalClientBrowserWithDataIdResponse{
FormalClientBrowser: nil,
}, nil
}
return &pb.FindFormalClientBrowserWithDataIdResponse{
FormalClientBrowser: &pb.FormalClientBrowser{
Id: int64(browser.Id),
Name: browser.Name,
Codes: browser.DecodeCodes(),
DataId: browser.DataId,
State: types.Int32(browser.State),
}}, nil
}

View File

@@ -0,0 +1,144 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package clients
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/types"
)
// FormalClientSystemService 操作系统信息库服务
type FormalClientSystemService struct {
services.BaseService
}
// CreateFormalClientSystem 创建操作系统信息
func (this *FormalClientSystemService) CreateFormalClientSystem(ctx context.Context, req *pb.CreateFormalClientSystemRequest) (*pb.CreateFormalClientSystemResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
// 检查dataId是否存在
var tx = this.NullTx()
system, err := models.SharedFormalClientSystemDAO.FindSystemWithDataId(tx, req.DataId)
if err != nil {
return nil, err
}
if system != nil {
return nil, errors.New("dataId '" + req.DataId + "' already exists")
}
systemId, err := models.SharedFormalClientSystemDAO.CreateSystem(tx, req.Name, req.Codes, req.DataId)
if err != nil {
return nil, err
}
return &pb.CreateFormalClientSystemResponse{
FormalClientSystemId: systemId,
}, nil
}
// CountFormalClientSystems 计算操作系统信息数量
func (this *FormalClientSystemService) CountFormalClientSystems(ctx context.Context, req *pb.CountFormalClientSystemsRequest) (*pb.RPCCountResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
count, err := models.SharedFormalClientSystemDAO.CountSystems(tx, req.Keyword)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListFormalClientSystems 列出单页操作系统信息
func (this *FormalClientSystemService) ListFormalClientSystems(ctx context.Context, req *pb.ListFormalClientSystemsRequest) (*pb.ListFormalClientSystemsResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
systems, err := models.SharedFormalClientSystemDAO.ListSystems(tx, req.Keyword, req.Offset, req.Size)
if err != nil {
return nil, err
}
var pbSystems = []*pb.FormalClientSystem{}
for _, system := range systems {
pbSystems = append(pbSystems, &pb.FormalClientSystem{
Id: int64(system.Id),
Name: system.Name,
Codes: system.DecodeCodes(),
DataId: system.DataId,
State: types.Int32(system.State),
})
}
return &pb.ListFormalClientSystemsResponse{
FormalClientSystems: pbSystems,
}, nil
}
// UpdateFormalClientSystem 修改操作系统信息
func (this *FormalClientSystemService) UpdateFormalClientSystem(ctx context.Context, req *pb.UpdateFormalClientSystemRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
if len(req.DataId) == 0 {
return nil, errors.New("invalid dataId")
}
var tx = this.NullTx()
// 检查dataId是否已经被使用
oldSystem, err := models.SharedFormalClientSystemDAO.FindSystemWithDataId(tx, req.DataId)
if err != nil {
return nil, err
}
if oldSystem != nil && int64(oldSystem.Id) != req.FormalClientSystemId {
return nil, errors.New("the dataId '" + req.DataId + "' already has been used")
}
err = models.SharedFormalClientSystemDAO.UpdateSystem(tx, req.FormalClientSystemId, req.Name, req.Codes, req.DataId)
if err != nil {
return nil, err
}
return this.Success()
}
// FindFormalClientSystemWithDataId 通过dataId查询操作系统信息
func (this *FormalClientSystemService) FindFormalClientSystemWithDataId(ctx context.Context, req *pb.FindFormalClientSystemWithDataIdRequest) (*pb.FindFormalClientSystemWithDataIdResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
system, err := models.SharedFormalClientSystemDAO.FindSystemWithDataId(tx, req.DataId)
if err != nil {
return nil, err
}
if system == nil {
return &pb.FindFormalClientSystemWithDataIdResponse{
FormalClientSystem: nil,
}, nil
}
return &pb.FindFormalClientSystemWithDataIdResponse{
FormalClientSystem: &pb.FormalClientSystem{
Id: int64(system.Id),
Name: system.Name,
Codes: system.DecodeCodes(),
DataId: system.DataId,
State: types.Int32(system.State),
}}, nil
}

View File

@@ -0,0 +1,263 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus
package nameservers
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/nameservers"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
timeutil "github.com/iwind/TeaGo/utils/time"
"time"
)
// NSService 域名服务
type NSService struct {
services.BaseService
}
// ComposeNSBoard 组合看板数据
func (this *NSService) ComposeNSBoard(ctx context.Context, req *pb.ComposeNSBoardRequest) (*pb.ComposeNSBoardResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
var result = &pb.ComposeNSBoardResponse{}
// 域名
countDomains, err := nameservers.SharedNSDomainDAO.CountAllEnabledDomains(tx, 0, 0, 0, dnsconfigs.NSDomainStatusVerified, "")
if err != nil {
return nil, err
}
result.CountNSDomains = countDomains
// 记录
countRecords, err := nameservers.SharedNSRecordDAO.CountAllEnabledRecords(tx)
if err != nil {
return nil, err
}
result.CountNSRecords = countRecords
// 集群数
countClusters, err := models.SharedNSClusterDAO.CountAllEnabledClusters(tx)
if err != nil {
return nil, err
}
result.CountNSClusters = countClusters
// 节点数
countNodes, err := models.SharedNSNodeDAO.CountAllEnabledNodes(tx)
if err != nil {
return nil, err
}
result.CountNSNodes = countNodes
// 离线节点数
countOfflineNodes, err := models.SharedNSNodeDAO.CountAllOfflineNodes(tx)
if err != nil {
return nil, err
}
result.CountOfflineNSNodes = countOfflineNodes
// 按小时统计
var hourFrom = timeutil.Format("YmdH", time.Now().Add(-23*time.Hour))
var hourTo = timeutil.Format("YmdH")
hourlyStats, err := nameservers.SharedNSRecordHourlyStatDAO.FindHourlyStats(tx, hourFrom, hourTo)
if err != nil {
return nil, err
}
for _, stat := range hourlyStats {
result.HourlyTrafficStats = append(result.HourlyTrafficStats, &pb.ComposeNSBoardResponse_HourlyTrafficStat{
Hour: stat.Hour,
Bytes: int64(stat.Bytes),
CountRequests: int64(stat.CountRequests),
})
}
// 按天统计
var dayFrom = timeutil.Format("Ymd", time.Now().AddDate(0, 0, -14))
var dayTo = timeutil.Format("Ymd")
dailyStats, err := nameservers.SharedNSRecordHourlyStatDAO.FindDailyStats(tx, dayFrom, dayTo)
if err != nil {
return nil, err
}
for _, stat := range dailyStats {
result.DailyTrafficStats = append(result.DailyTrafficStats, &pb.ComposeNSBoardResponse_DailyTrafficStat{
Day: stat.Day,
Bytes: int64(stat.Bytes),
CountRequests: int64(stat.CountRequests),
})
}
// 域名排行
topDomainStats, err := nameservers.SharedNSRecordHourlyStatDAO.ListTopDomains(tx, 0, hourFrom, hourTo, 10)
if err != nil {
return nil, err
}
for _, stat := range topDomainStats {
domainName, err := nameservers.SharedNSDomainDAO.FindNSDomainName(tx, int64(stat.DomainId))
if err != nil {
return nil, err
}
if len(domainName) == 0 {
continue
}
result.TopNSDomainStats = append(result.TopNSDomainStats, &pb.ComposeNSBoardResponse_DomainStat{
NsDomainId: int64(stat.DomainId),
NsDomainName: domainName,
CountRequests: int64(stat.CountRequests),
Bytes: int64(stat.Bytes),
})
}
// 节点排行
topNodeStats, err := nameservers.SharedNSRecordHourlyStatDAO.ListTopNodes(tx, hourFrom, hourTo, 10)
if err != nil {
return nil, err
}
for _, stat := range topNodeStats {
nodeName, err := models.SharedNSNodeDAO.FindEnabledNSNodeName(tx, int64(stat.NodeId))
if err != nil {
return nil, err
}
if len(nodeName) == 0 {
continue
}
result.TopNSNodeStats = append(result.TopNSNodeStats, &pb.ComposeNSBoardResponse_NodeStat{
NsClusterId: int64(stat.ClusterId),
NsNodeId: int64(stat.NodeId),
NsNodeName: nodeName,
CountRequests: int64(stat.CountRequests),
Bytes: int64(stat.Bytes),
})
}
// CPU、内存、负载
cpuValues, err := models.SharedNodeValueDAO.ListValuesForNSNodes(tx, nodeconfigs.NodeValueItemCPU, "usage", nodeconfigs.NodeValueRangeMinute)
if err != nil {
return nil, err
}
for _, v := range cpuValues {
result.CpuNodeValues = append(result.CpuNodeValues, &pb.NodeValue{
ValueJSON: v.Value,
CreatedAt: int64(v.CreatedAt),
})
}
memoryValues, err := models.SharedNodeValueDAO.ListValuesForNSNodes(tx, nodeconfigs.NodeValueItemMemory, "usage", nodeconfigs.NodeValueRangeMinute)
if err != nil {
return nil, err
}
for _, v := range memoryValues {
result.MemoryNodeValues = append(result.MemoryNodeValues, &pb.NodeValue{
ValueJSON: v.Value,
CreatedAt: int64(v.CreatedAt),
})
}
loadValues, err := models.SharedNodeValueDAO.ListValuesForNSNodes(tx, nodeconfigs.NodeValueItemLoad, "load1m", nodeconfigs.NodeValueRangeMinute)
if err != nil {
return nil, err
}
for _, v := range loadValues {
result.LoadNodeValues = append(result.LoadNodeValues, &pb.NodeValue{
ValueJSON: v.Value,
CreatedAt: int64(v.CreatedAt),
})
}
return result, nil
}
// ComposeNSUserBoard 组合用户看板数据
func (this *NSService) ComposeNSUserBoard(ctx context.Context, req *pb.ComposeNSUserBoardRequest) (*pb.ComposeNSUserBoardResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if userId > 0 {
req.UserId = userId
}
var result = &pb.ComposeNSUserBoardResponse{}
var tx = this.NullTx()
countDomains, err := nameservers.SharedNSDomainDAO.CountAllEnabledDomains(tx, 0, req.UserId, 0, "", "")
if err != nil {
return nil, err
}
result.CountNSDomains = countDomains
countRecords, err := nameservers.SharedNSRecordDAO.CountAllUserRecords(tx, req.UserId)
if err != nil {
return nil, err
}
result.CountNSRecords = countRecords
countRoutes, err := nameservers.SharedNSRouteDAO.CountAllEnabledRoutes(tx, 0, 0, req.UserId)
if err != nil {
return nil, err
}
result.CountNSRoutes = countRoutes
// 用户套餐
userPlan, err := nameservers.SharedNSUserPlanDAO.FindUserPlan(tx, req.UserId)
if err != nil {
return nil, err
}
if userPlan != nil {
if userPlan.PlanId > 0 && userPlan.DayTo >= timeutil.Format("Ymd") {
plan, err := nameservers.SharedNSPlanDAO.FindEnabledNSPlan(tx, int64(userPlan.PlanId))
if err != nil {
return nil, err
}
if plan != nil && plan.IsOn {
result.NsUserPlan = &pb.NSUserPlan{
Id: int64(userPlan.Id),
NsPlanId: int64(userPlan.PlanId),
DayFrom: userPlan.DayFrom,
DayTo: userPlan.DayTo,
PeriodUnit: userPlan.PeriodUnit,
NsPlan: &pb.NSPlan{
Id: int64(plan.Id),
Name: plan.Name,
IsOn: plan.IsOn,
},
User: nil,
}
}
}
}
// 域名排行
var hourFrom = timeutil.Format("YmdH", time.Now().Add(-23*time.Hour))
var hourTo = timeutil.Format("YmdH")
topDomainStats, err := nameservers.SharedNSRecordHourlyStatDAO.ListTopDomains(tx, userId, hourFrom, hourTo, 10)
if err != nil {
return nil, err
}
for _, stat := range topDomainStats {
domainName, err := nameservers.SharedNSDomainDAO.FindNSDomainName(tx, int64(stat.DomainId))
if err != nil {
return nil, err
}
if len(domainName) == 0 {
continue
}
result.TopNSDomainStats = append(result.TopNSDomainStats, &pb.ComposeNSUserBoardResponse_DomainStat{
NsDomainId: int64(stat.DomainId),
NsDomainName: domainName,
CountRequests: int64(stat.CountRequests),
Bytes: int64(stat.Bytes),
})
}
return result, nil
}

View File

@@ -0,0 +1,127 @@
//go:build plus
package nameservers
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/nameservers"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/types"
)
// NSAccessLogService 访问日志相关服务
type NSAccessLogService struct {
services.BaseService
}
// CreateNSAccessLogs 创建访问日志
func (this *NSAccessLogService) CreateNSAccessLogs(ctx context.Context, req *pb.CreateNSAccessLogsRequest) (*pb.CreateNSAccessLogsResponse, error) {
// 校验请求
_, _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeDNS)
if err != nil {
return nil, err
}
if len(req.NsAccessLogs) == 0 {
return &pb.CreateNSAccessLogsResponse{}, nil
}
var tx = this.NullTx()
err = models.SharedNSAccessLogDAO.CreateNSAccessLogs(tx, req.NsAccessLogs)
if err != nil {
return nil, err
}
return &pb.CreateNSAccessLogsResponse{}, nil
}
// ListNSAccessLogs 列出单页访问日志
func (this *NSAccessLogService) ListNSAccessLogs(ctx context.Context, req *pb.ListNSAccessLogsRequest) (*pb.ListNSAccessLogsResponse, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查服务ID
if userId > 0 {
// TODO 检查权限
}
accessLogs, requestId, hasMore, err := models.SharedNSAccessLogDAO.ListAccessLogs(tx, req.RequestId, req.Size, req.Day, req.NsClusterId, req.NsNodeId, req.NsDomainId, req.NsRecordId, req.RecordType, req.Keyword, req.Reverse)
if err != nil {
return nil, err
}
var result = []*pb.NSAccessLog{}
for _, accessLog := range accessLogs {
a, err := accessLog.ToPB()
if err != nil {
return nil, err
}
// 线路
if len(a.NsRouteCodes) > 0 {
for _, routeCode := range a.NsRouteCodes {
route, err := nameservers.SharedNSRouteDAO.FindEnabledRouteWithCode(nil, routeCode)
if err != nil {
return nil, err
}
if route != nil {
a.NsRoutes = append(a.NsRoutes, &pb.NSRoute{
Id: types.Int64(route.Id),
IsOn: route.IsOn,
Name: route.Name,
Code: routeCode,
NsCluster: nil,
NsDomain: nil,
})
}
}
}
result = append(result, a)
}
return &pb.ListNSAccessLogsResponse{
NsAccessLogs: result,
HasMore: hasMore,
RequestId: requestId,
}, nil
}
// FindNSAccessLog 查找单个日志
func (this *NSAccessLogService) FindNSAccessLog(ctx context.Context, req *pb.FindNSAccessLogRequest) (*pb.FindNSAccessLogResponse, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
accessLog, err := models.SharedNSAccessLogDAO.FindAccessLogWithRequestId(tx, req.RequestId)
if err != nil {
return nil, err
}
if accessLog == nil {
return &pb.FindNSAccessLogResponse{NsAccessLog: nil}, nil
}
// 检查权限
if userId > 0 {
// TODO
}
a, err := accessLog.ToPB()
if err != nil {
return nil, err
}
return &pb.FindNSAccessLogResponse{NsAccessLog: a}, nil
}

View File

@@ -0,0 +1,679 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus
package nameservers
import (
"context"
"encoding/json"
"errors"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
"github.com/iwind/TeaGo/dbs"
)
// NSClusterService 域名服务集群相关服务
type NSClusterService struct {
services.BaseService
}
// CreateNSCluster 创建集群
func (this *NSClusterService) CreateNSCluster(ctx context.Context, req *pb.CreateNSClusterRequest) (*pb.CreateNSClusterResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// SOA
var soaConfig = dnsconfigs.DefaultNSSOAConfig()
if len(req.SoaJSON) > 0 {
err = json.Unmarshal(req.SoaJSON, soaConfig)
if err != nil {
return nil, err
}
err = soaConfig.Init()
if err != nil {
return nil, errors.New("validate SOA config failed: " + err.Error())
}
}
// 校验管理员邮箱
if len(req.Email) == 0 {
return nil, errors.New("required 'email'")
}
if !utils.ValidateEmail(req.Email) {
return nil, errors.New("invalid email format '" + req.Email + "'")
}
// 校验访问日志配置
var accessLogRef = &dnsconfigs.NSAccessLogRef{}
if len(req.AccessLogJSON) > 0 {
err = json.Unmarshal(req.AccessLogJSON, accessLogRef)
if err != nil {
return nil, errors.New("invalid accessLogJSON: " + err.Error())
}
err = accessLogRef.Init()
if err != nil {
return nil, errors.New("validate accessLogJSON failed: " + err.Error())
}
}
clusterId, err := models.SharedNSClusterDAO.CreateCluster(tx, req.Name, req.Email, req.AccessLogJSON, req.Hosts, soaConfig)
if err != nil {
return nil, err
}
return &pb.CreateNSClusterResponse{NsClusterId: clusterId}, nil
}
// UpdateNSCluster 修改集群
func (this *NSClusterService) UpdateNSCluster(ctx context.Context, req *pb.UpdateNSClusterRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 校验管理员邮箱
if len(req.Email) == 0 {
return nil, errors.New("required 'email'")
}
if !utils.ValidateEmail(req.Email) {
return nil, errors.New("invalid email format '" + req.Email + "'")
}
err = models.SharedNSClusterDAO.UpdateCluster(tx, req.NsClusterId, req.Name, req.Email, req.Hosts, req.IsOn, req.TimeZone, req.AutoRemoteStart, req.DetectAgents, req.CheckingPorts)
if err != nil {
return nil, err
}
return this.Success()
}
// FindNSClusterAccessLog 查找集群访问日志配置
func (this *NSClusterService) FindNSClusterAccessLog(ctx context.Context, req *pb.FindNSClusterAccessLogRequest) (*pb.FindNSClusterAccessLogResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
accessLogJSON, err := models.SharedNSClusterDAO.FindClusterAccessLog(tx, req.NsClusterId)
if err != nil {
return nil, err
}
return &pb.FindNSClusterAccessLogResponse{AccessLogJSON: accessLogJSON}, nil
}
// UpdateNSClusterAccessLog 修改集群访问日志配置
func (this *NSClusterService) UpdateNSClusterAccessLog(ctx context.Context, req *pb.UpdateNSClusterAccessLogRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 校验访问日志配置
var accessLogRef = &dnsconfigs.NSAccessLogRef{}
if len(req.AccessLogJSON) > 0 {
err = json.Unmarshal(req.AccessLogJSON, accessLogRef)
if err != nil {
return nil, errors.New("invalid accessLogJSON: " + err.Error())
}
err = accessLogRef.Init()
if err != nil {
return nil, errors.New("validate accessLogJSON failed: " + err.Error())
}
}
err = models.SharedNSClusterDAO.UpdateClusterAccessLog(tx, req.NsClusterId, req.AccessLogJSON)
if err != nil {
return nil, err
}
return this.Success()
}
// DeleteNSCluster 删除集群
func (this *NSClusterService) DeleteNSCluster(ctx context.Context, req *pb.DeleteNSCluster) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedNSClusterDAO.DisableNSCluster(tx, req.NsClusterId)
if err != nil {
return nil, err
}
// 删除任务
err = models.SharedNodeTaskDAO.DeleteAllClusterTasks(tx, nodeconfigs.NodeRoleDNS, req.NsClusterId)
if err != nil {
return nil, err
}
return this.Success()
}
// FindNSCluster 查找单个可用集群信息
func (this *NSClusterService) FindNSCluster(ctx context.Context, req *pb.FindNSClusterRequest) (*pb.FindNSClusterResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
cluster, err := models.SharedNSClusterDAO.FindEnabledNSCluster(tx, req.NsClusterId)
if err != nil {
return nil, err
}
if cluster == nil {
return &pb.FindNSClusterResponse{NsCluster: nil}, nil
}
return &pb.FindNSClusterResponse{
NsCluster: &pb.NSCluster{
Id: int64(cluster.Id),
IsOn: cluster.IsOn,
Name: cluster.Name,
Email: cluster.Email,
Hosts: cluster.DecodeHosts(),
InstallDir: cluster.InstallDir,
TcpJSON: cluster.Tcp,
TlsJSON: cluster.Tls,
UdpJSON: cluster.Udp,
DohJSON: cluster.Doh,
TimeZone: cluster.TimeZone,
AutoRemoteStart: cluster.AutoRemoteStart,
AnswerJSON: cluster.Answer,
SoaJSON: cluster.Soa,
DetectAgents: cluster.DetectAgents,
CheckingPorts: cluster.CheckingPorts,
},
}, nil
}
// CountAllNSClusters 计算所有可用集群的数量
func (this *NSClusterService) CountAllNSClusters(ctx context.Context, req *pb.CountAllNSClustersRequest) (*pb.RPCCountResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
count, err := models.SharedNSClusterDAO.CountAllEnabledClusters(tx)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListNSClusters 列出单页可用集群
func (this *NSClusterService) ListNSClusters(ctx context.Context, req *pb.ListNSClustersRequest) (*pb.ListNSClustersResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
clusters, err := models.SharedNSClusterDAO.ListEnabledClusters(tx, req.Offset, req.Size)
if err != nil {
return nil, err
}
var pbClusters = []*pb.NSCluster{}
for _, cluster := range clusters {
pbClusters = append(pbClusters, &pb.NSCluster{
Id: int64(cluster.Id),
IsOn: cluster.IsOn,
Name: cluster.Name,
Hosts: cluster.DecodeHosts(),
InstallDir: cluster.InstallDir,
})
}
return &pb.ListNSClustersResponse{NsClusters: pbClusters}, nil
}
// FindAllNSClusters 查找所有可用集群
func (this *NSClusterService) FindAllNSClusters(ctx context.Context, req *pb.FindAllNSClustersRequest) (*pb.FindAllNSClustersResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
clusters, err := models.SharedNSClusterDAO.FindAllEnabledClusters(tx)
if err != nil {
return nil, err
}
var pbClusters = []*pb.NSCluster{}
for _, cluster := range clusters {
pbClusters = append(pbClusters, &pb.NSCluster{
Id: int64(cluster.Id),
IsOn: cluster.IsOn,
Name: cluster.Name,
InstallDir: cluster.InstallDir,
})
}
return &pb.FindAllNSClustersResponse{NsClusters: pbClusters}, nil
}
// UpdateNSClusterRecursionConfig 设置递归DNS配置
func (this *NSClusterService) UpdateNSClusterRecursionConfig(ctx context.Context, req *pb.UpdateNSClusterRecursionConfigRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
// 校验配置
var config = &dnsconfigs.NSRecursionConfig{}
err = json.Unmarshal(req.RecursionJSON, config)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedNSClusterDAO.UpdateRecursion(tx, req.NsClusterId, req.RecursionJSON)
if err != nil {
return nil, err
}
return this.Success()
}
// FindNSClusterRecursionConfig 读取递归DNS配置
func (this *NSClusterService) FindNSClusterRecursionConfig(ctx context.Context, req *pb.FindNSClusterRecursionConfigRequest) (*pb.FindNSClusterRecursionConfigResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
recursion, err := models.SharedNSClusterDAO.FindClusterRecursion(tx, req.NsClusterId)
if err != nil {
return nil, err
}
return &pb.FindNSClusterRecursionConfigResponse{
RecursionJSON: recursion,
}, nil
}
// FindNSClusterTCPConfig 查找集群的TCP设置
func (this *NSClusterService) FindNSClusterTCPConfig(ctx context.Context, req *pb.FindNSClusterTCPConfigRequest) (*pb.FindNSClusterTCPConfigResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
tcpJSON, err := models.SharedNSClusterDAO.FindClusterTCP(tx, req.NsClusterId)
if err != nil {
return nil, err
}
return &pb.FindNSClusterTCPConfigResponse{
TcpJSON: tcpJSON,
}, nil
}
// UpdateNSClusterTCP 修改集群的TCP设置
func (this *NSClusterService) UpdateNSClusterTCP(ctx context.Context, req *pb.UpdateNSClusterTCPRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var config = &serverconfigs.TCPProtocolConfig{}
err = json.Unmarshal(req.TcpJSON, config)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedNSClusterDAO.UpdateClusterTCP(tx, req.NsClusterId, config)
if err != nil {
return nil, err
}
return this.Success()
}
// FindNSClusterTLSConfig 查找集群的TLS设置
func (this *NSClusterService) FindNSClusterTLSConfig(ctx context.Context, req *pb.FindNSClusterTLSConfigRequest) (*pb.FindNSClusterTLSConfigResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
tlsJSON, err := models.SharedNSClusterDAO.FindClusterTLS(tx, req.NsClusterId)
if err != nil {
return nil, err
}
return &pb.FindNSClusterTLSConfigResponse{
TlsJSON: tlsJSON,
}, nil
}
// UpdateNSClusterTLS 修改集群的TLS设置
func (this *NSClusterService) UpdateNSClusterTLS(ctx context.Context, req *pb.UpdateNSClusterTLSRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var config = &serverconfigs.TLSProtocolConfig{}
err = json.Unmarshal(req.TlsJSON, config)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedNSClusterDAO.UpdateClusterTLS(tx, req.NsClusterId, config)
if err != nil {
return nil, err
}
return this.Success()
}
// FindNSClusterUDPConfig 查找集群的UDP设置
func (this *NSClusterService) FindNSClusterUDPConfig(ctx context.Context, req *pb.FindNSClusterUDPConfigRequest) (*pb.FindNSClusterUDPConfigResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
udpJSON, err := models.SharedNSClusterDAO.FindClusterUDP(tx, req.NsClusterId)
if err != nil {
return nil, err
}
return &pb.FindNSClusterUDPConfigResponse{
UdpJSON: udpJSON,
}, nil
}
// UpdateNSClusterUDP 修改集群的UDP设置
func (this *NSClusterService) UpdateNSClusterUDP(ctx context.Context, req *pb.UpdateNSClusterUDPRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var config = &serverconfigs.UDPProtocolConfig{}
err = json.Unmarshal(req.UdpJSON, config)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedNSClusterDAO.UpdateClusterUDP(tx, req.NsClusterId, config)
if err != nil {
return nil, err
}
return this.Success()
}
// FindNSClusterDoHConfig 查找集群的DoH设置
func (this *NSClusterService) FindNSClusterDoHConfig(ctx context.Context, req *pb.FindNSClusterDoHConfigRequest) (*pb.FindNSClusterDoHConfigResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
dohJSON, err := models.SharedNSClusterDAO.FindClusterDoH(tx, req.NsClusterId)
if err != nil {
return nil, err
}
return &pb.FindNSClusterDoHConfigResponse{
DohJSON: dohJSON,
}, nil
}
// UpdateNSClusterDoH 修改集群的DoH设置
func (this *NSClusterService) UpdateNSClusterDoH(ctx context.Context, req *pb.UpdateNSClusterDoHRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var config = dnsconfigs.NewNSDoHConfig()
err = json.Unmarshal(req.DohJSON, config)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedNSClusterDAO.UpdateClusterDoH(tx, req.NsClusterId, config)
if err != nil {
return nil, err
}
return this.Success()
}
// CountAllNSClustersWithSSLCertId 计算使用某个SSL证书的集群数量
func (this *NSClusterService) CountAllNSClustersWithSSLCertId(ctx context.Context, req *pb.CountAllNSClustersWithSSLCertIdRequest) (*pb.RPCCountResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
policyIds, err := models.SharedSSLPolicyDAO.FindAllEnabledPolicyIdsWithCertId(tx, req.SslCertId)
if err != nil {
return nil, err
}
if len(policyIds) == 0 {
return this.SuccessCount(0)
}
count, err := models.SharedNSClusterDAO.CountAllClustersWithSSLPolicyIds(tx, policyIds)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// FindNSClusterDDoSProtection 获取集群的DDoS设置
func (this *NSClusterService) FindNSClusterDDoSProtection(ctx context.Context, req *pb.FindNSClusterDDoSProtectionRequest) (*pb.FindNSClusterDDoSProtectionResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx *dbs.Tx
ddosProtection, err := models.SharedNSClusterDAO.FindClusterDDoSProtection(tx, req.NsClusterId)
if err != nil {
return nil, err
}
if ddosProtection == nil {
ddosProtection = ddosconfigs.DefaultProtectionConfig()
}
ddosProtectionJSON, err := json.Marshal(ddosProtection)
if err != nil {
return nil, err
}
var result = &pb.FindNSClusterDDoSProtectionResponse{
DdosProtectionJSON: ddosProtectionJSON,
}
return result, nil
}
// UpdateNSClusterDDoSProtection 修改集群的DDoS设置
func (this *NSClusterService) UpdateNSClusterDDoSProtection(ctx context.Context, req *pb.UpdateNSClusterDDoSProtectionRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var ddosProtection = &ddosconfigs.ProtectionConfig{}
err = json.Unmarshal(req.DdosProtectionJSON, ddosProtection)
if err != nil {
return nil, err
}
var tx *dbs.Tx
err = models.SharedNSClusterDAO.UpdateClusterDDoSProtection(tx, req.NsClusterId, ddosProtection)
if err != nil {
return nil, err
}
return this.Success()
}
// FindNSClusterHosts 查找NS集群的主机地址
func (this *NSClusterService) FindNSClusterHosts(ctx context.Context, req *pb.FindNSClusterHostsRequest) (*pb.FindNSClusterHostsResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
hosts, err := models.SharedNSClusterDAO.FindClusterHosts(tx, req.NsClusterId)
if err != nil {
return nil, err
}
return &pb.FindNSClusterHostsResponse{
Hosts: hosts,
}, nil
}
// FindAvailableNSHostsForUser 查找用户可以使用的主机地址
func (this *NSClusterService) FindAvailableNSHostsForUser(ctx context.Context, req *pb.FindAvailableNSHostsForUserRequest) (*pb.FindAvailableNSHostsForUserResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
if userId > 0 {
req.UserId = userId
}
if req.UserId <= 0 {
return &pb.FindAvailableNSHostsForUserResponse{
Hosts: nil,
}, nil
}
// 所属集群
var tx = this.NullTx()
userConfig, err := models.SharedSysSettingDAO.ReadNSUserConfig(tx)
if err != nil {
return nil, err
}
if userConfig == nil || userConfig.DefaultClusterId <= 0 {
return &pb.FindAvailableNSHostsForUserResponse{
Hosts: nil,
}, nil
}
hosts, err := models.SharedNSClusterDAO.FindClusterHosts(tx, userConfig.DefaultClusterId)
if err != nil {
return nil, err
}
return &pb.FindAvailableNSHostsForUserResponse{
Hosts: hosts,
}, nil
}
// FindNSClusterAnswerConfig 查找应答模式
func (this *NSClusterService) FindNSClusterAnswerConfig(ctx context.Context, req *pb.FindNSClusterAnswerConfigRequest) (*pb.FindNSClusterAnswerConfigResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
config, err := models.SharedNSClusterDAO.FindClusterAnswer(tx, req.NsClusterId)
if err != nil {
return nil, err
}
configJSON, err := json.Marshal(config)
if err != nil {
return nil, err
}
return &pb.FindNSClusterAnswerConfigResponse{
AnswerJSON: configJSON,
}, nil
}
// UpdateNSClusterAnswerConfig 设置应答模式
func (this *NSClusterService) UpdateNSClusterAnswerConfig(ctx context.Context, req *pb.UpdateNSClusterAnswerConfigRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var config = dnsconfigs.DefaultNSAnswerConfig()
if len(req.AnswerJSON) > 0 {
err = json.Unmarshal(req.AnswerJSON, config)
if err != nil {
return nil, err
}
}
var tx = this.NullTx()
err = models.SharedNSClusterDAO.UpdateClusterAnswer(tx, req.NsClusterId, config)
if err != nil {
return nil, err
}
return this.Success()
}
// FindNSClusterSOAConfig 查询SOA配置
func (this *NSClusterService) FindNSClusterSOAConfig(ctx context.Context, req *pb.FindNSClusterSOAConfigRequest) (*pb.FindNSClusterSOAConfigResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
config, err := models.SharedNSClusterDAO.FindClusterSOA(tx, req.NsClusterId)
if err != nil {
return nil, err
}
configJSON, err := json.Marshal(config)
if err != nil {
return nil, err
}
return &pb.FindNSClusterSOAConfigResponse{
SoaJSON: configJSON,
}, nil
}
// UpdateNSClusterSOAConfig 修改SOA配置
func (this *NSClusterService) UpdateNSClusterSOAConfig(ctx context.Context, req *pb.UpdateNSClusterSOAConfigRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var config = dnsconfigs.DefaultNSSOAConfig()
if len(req.SoaJSON) > 0 {
err = json.Unmarshal(req.SoaJSON, config)
if err != nil {
return nil, err
}
}
var tx = this.NullTx()
err = models.SharedNSClusterDAO.UpdateClusterSOA(tx, req.NsClusterId, config)
if err != nil {
return nil, err
}
return this.Success()
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,194 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package nameservers
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/nameservers"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
// NSDomainGroupService 域名分组服务
type NSDomainGroupService struct {
services.BaseService
}
// CreateNSDomainGroup 创建分组
func (this *NSDomainGroupService) CreateNSDomainGroup(ctx context.Context, req *pb.CreateNSDomainGroupRequest) (*pb.CreateNSDomainGroupResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
groupId, err := nameservers.SharedNSDomainGroupDAO.CreateGroup(tx, userId, req.Name)
if err != nil {
return nil, err
}
return &pb.CreateNSDomainGroupResponse{
NsDomainGroupId: groupId,
}, nil
}
// UpdateNSDomainGroup 修改分组
func (this *NSDomainGroupService) UpdateNSDomainGroup(ctx context.Context, req *pb.UpdateNSDomainGroupRequest) (*pb.RPCSuccess, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
err = nameservers.SharedNSDomainGroupDAO.CheckUserGroup(tx, userId, req.NsDomainGroupId)
if err != nil {
return nil, err
}
}
err = nameservers.SharedNSDomainGroupDAO.UpdateGroup(tx, req.NsDomainGroupId, req.Name, req.IsOn)
if err != nil {
return nil, err
}
return this.Success()
}
// DeleteNSDomainGroup 删除分组
func (this *NSDomainGroupService) DeleteNSDomainGroup(ctx context.Context, req *pb.DeleteNSDomainGroupRequest) (*pb.RPCSuccess, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
err = nameservers.SharedNSDomainGroupDAO.CheckUserGroup(tx, userId, req.NsDomainGroupId)
if err != nil {
return nil, err
}
}
err = nameservers.SharedNSDomainGroupDAO.DisableNSDomainGroup(tx, req.NsDomainGroupId)
if err != nil {
return nil, err
}
return this.Success()
}
// FindAllNSDomainGroups 查询所有分组
func (this *NSDomainGroupService) FindAllNSDomainGroups(ctx context.Context, req *pb.FindAllNSDomainGroupsRequest) (*pb.FindAllNSDomainGroupsResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
req.UserId = userId
}
groups, err := nameservers.SharedNSDomainGroupDAO.FindAllGroups(tx, req.UserId)
if err != nil {
return nil, err
}
var pbGroups = []*pb.NSDomainGroup{}
for _, group := range groups {
pbGroups = append(pbGroups, &pb.NSDomainGroup{
Id: int64(group.Id),
Name: group.Name,
IsOn: group.IsOn,
UserId: int64(group.UserId),
})
}
return &pb.FindAllNSDomainGroupsResponse{
NsDomainGroups: pbGroups,
}, nil
}
// CountAllAvailableNSDomainGroups 查询可用分组数量
func (this *NSDomainGroupService) CountAllAvailableNSDomainGroups(ctx context.Context, req *pb.CountAllAvailableNSDomainGroupsRequest) (*pb.RPCCountResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if userId > 0 {
req.UserId = userId
}
var tx = this.NullTx()
count, err := nameservers.SharedNSDomainGroupDAO.CountAllAvailableGroups(tx, req.UserId)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// FindAllAvailableNSDomainGroups 查询所有分组
func (this *NSDomainGroupService) FindAllAvailableNSDomainGroups(ctx context.Context, req *pb.FindAllAvailableNSDomainGroupsRequest) (*pb.FindAllAvailableNSDomainGroupsResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if userId > 0 {
req.UserId = userId
}
var tx = this.NullTx()
groups, err := nameservers.SharedNSDomainGroupDAO.FindAllAvailableGroups(tx, req.UserId)
if err != nil {
return nil, err
}
var pbGroups = []*pb.NSDomainGroup{}
for _, group := range groups {
pbGroups = append(pbGroups, &pb.NSDomainGroup{
Id: int64(group.Id),
Name: group.Name,
IsOn: group.IsOn,
UserId: int64(group.UserId),
})
}
return &pb.FindAllAvailableNSDomainGroupsResponse{
NsDomainGroups: pbGroups,
}, nil
}
// FindNSDomainGroup 查找单个分组
func (this *NSDomainGroupService) FindNSDomainGroup(ctx context.Context, req *pb.FindNSDomainGroupRequest) (*pb.FindNSDomainGroupResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
group, err := nameservers.SharedNSDomainGroupDAO.FindEnabledNSDomainGroup(tx, req.NsDomainGroupId)
if err != nil {
return nil, err
}
if group == nil {
return &pb.FindNSDomainGroupResponse{
NsDomainGroup: nil,
}, nil
}
if int64(group.UserId) != userId {
return &pb.FindNSDomainGroupResponse{
NsDomainGroup: nil,
}, nil
}
return &pb.FindNSDomainGroupResponse{
NsDomainGroup: &pb.NSDomainGroup{
Id: int64(group.Id),
Name: group.Name,
IsOn: group.IsOn,
UserId: int64(group.UserId),
},
}, nil
}

View File

@@ -0,0 +1,225 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus
package nameservers
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/nameservers"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
// NSKeyService NS密钥相关服务
type NSKeyService struct {
services.BaseService
}
// CreateNSKey 创建密钥
func (this *NSKeyService) CreateNSKey(ctx context.Context, req *pb.CreateNSKeyRequest) (*pb.CreateNSKeyResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查用户权限
if userId > 0 {
err = nameservers.SharedNSDomainDAO.CheckUserDomain(tx, userId, req.NsDomainId)
if err != nil {
return nil, err
}
}
keyId, err := nameservers.SharedNSKeyDAO.CreateKey(tx, req.NsDomainId, req.NsZoneId, req.Name, req.Algo, req.Secret, req.SecretType)
if err != nil {
return nil, err
}
return &pb.CreateNSKeyResponse{NsKeyId: keyId}, nil
}
// UpdateNSKey 修改密钥
func (this *NSKeyService) UpdateNSKey(ctx context.Context, req *pb.UpdateNSKeyRequest) (*pb.RPCSuccess, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查用户权限
if userId > 0 {
err = nameservers.SharedNSKeyDAO.CheckUserKey(tx, userId, req.NsKeyId)
if err != nil {
return nil, err
}
}
err = nameservers.SharedNSKeyDAO.UpdateKey(tx, req.NsKeyId, req.Name, req.Algo, req.Secret, req.SecretType, req.IsOn)
if err != nil {
return nil, err
}
return this.Success()
}
// DeleteNSKey 删除密钥
func (this *NSKeyService) DeleteNSKey(ctx context.Context, req *pb.DeleteNSKeyRequest) (*pb.RPCSuccess, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查用户权限
if userId > 0 {
err = nameservers.SharedNSKeyDAO.CheckUserKey(tx, userId, req.NsKeyId)
if err != nil {
return nil, err
}
}
err = nameservers.SharedNSKeyDAO.DisableNSKey(tx, req.NsKeyId)
if err != nil {
return nil, err
}
return this.Success()
}
// FindNSKey 查找单个密钥
func (this *NSKeyService) FindNSKey(ctx context.Context, req *pb.FindNSKeyRequest) (*pb.FindNSKeyResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查用户权限
if userId > 0 {
err = nameservers.SharedNSKeyDAO.CheckUserKey(tx, userId, req.NsKeyId)
if err != nil {
return nil, err
}
}
key, err := nameservers.SharedNSKeyDAO.FindEnabledNSKey(tx, req.NsKeyId)
if err != nil {
return nil, err
}
if key == nil {
return &pb.FindNSKeyResponse{NsKey: nil}, nil
}
return &pb.FindNSKeyResponse{
NsKey: &pb.NSKey{
Id: int64(key.Id),
IsOn: key.IsOn,
Name: key.Name,
Algo: key.Algo,
Secret: key.Secret,
SecretType: key.SecretType,
},
}, nil
}
// CountAllNSKeys 计算密钥数量
func (this *NSKeyService) CountAllNSKeys(ctx context.Context, req *pb.CountAllNSKeysRequest) (*pb.RPCCountResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查用户权限
if userId > 0 {
err = nameservers.SharedNSDomainDAO.CheckUserDomain(tx, userId, req.NsDomainId)
if err != nil {
return nil, err
}
}
count, err := nameservers.SharedNSKeyDAO.CountEnabledKeys(tx, req.NsDomainId, req.NsZoneId)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListNSKeys 列出单页密钥
func (this *NSKeyService) ListNSKeys(ctx context.Context, req *pb.ListNSKeysRequest) (*pb.ListNSKeysResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查用户权限
if userId > 0 {
err = nameservers.SharedNSDomainDAO.CheckUserDomain(tx, userId, req.NsDomainId)
if err != nil {
return nil, err
}
}
keys, err := nameservers.SharedNSKeyDAO.ListEnabledKeys(tx, req.NsDomainId, req.NsZoneId, req.Offset, req.Size)
if err != nil {
return nil, err
}
var pbKeys = []*pb.NSKey{}
for _, key := range keys {
pbKeys = append(pbKeys, &pb.NSKey{
Id: int64(key.Id),
IsOn: key.IsOn,
Name: key.Name,
Algo: key.Algo,
Secret: key.Secret,
SecretType: key.SecretType,
})
}
return &pb.ListNSKeysResponse{NsKeys: pbKeys}, nil
}
// ListNSKeysAfterVersion 根据版本列出一组密钥
func (this *NSKeyService) ListNSKeysAfterVersion(ctx context.Context, req *pb.ListNSKeysAfterVersionRequest) (*pb.ListNSKeysAfterVersionResponse, error) {
_, err := this.ValidateNSNode(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if req.Size <= 0 {
req.Size = 2000
}
keys, err := nameservers.SharedNSKeyDAO.ListKeysAfterVersion(tx, req.Version, req.Size)
if err != nil {
return nil, err
}
var pbKeys = []*pb.NSKey{}
for _, key := range keys {
var pbDomain *pb.NSDomain
var pbZone *pb.NSZone
if key.DomainId > 0 {
pbDomain = &pb.NSDomain{Id: int64(key.DomainId)}
}
if key.ZoneId > 0 {
pbZone = &pb.NSZone{Id: int64(key.ZoneId)}
}
pbKeys = append(pbKeys, &pb.NSKey{
Id: int64(key.Id),
IsOn: key.IsOn,
Name: "",
Algo: key.Algo,
Secret: key.Secret,
SecretType: key.SecretType,
IsDeleted: key.State == nameservers.NSKeyStateDisabled,
Version: int64(key.Version),
NsDomain: pbDomain,
NsZone: pbZone,
})
}
return &pb.ListNSKeysAfterVersionResponse{NsKeys: pbKeys}, nil
}

View File

@@ -0,0 +1,740 @@
// Copyright 2021-2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus
package nameservers
import (
"context"
"encoding/json"
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeAPI/internal/goman"
"github.com/TeaOSLab/EdgeAPI/internal/installers"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/logs"
stringutil "github.com/iwind/TeaGo/utils/string"
"io"
"path/filepath"
"time"
)
// NSNodeService 域名服务器节点服务
type NSNodeService struct {
services.BaseService
}
// FindAllNSNodesWithNSClusterId 根据集群查找所有节点
func (this *NSNodeService) FindAllNSNodesWithNSClusterId(ctx context.Context, req *pb.FindAllNSNodesWithNSClusterIdRequest) (*pb.FindAllNSNodesWithNSClusterIdResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
nodes, err := models.SharedNSNodeDAO.FindAllEnabledNodesWithClusterId(tx, req.NsClusterId)
if err != nil {
return nil, err
}
pbNodes := []*pb.NSNode{}
for _, node := range nodes {
pbNodes = append(pbNodes, &pb.NSNode{
Id: int64(node.Id),
Name: node.Name,
IsOn: node.IsOn,
UniqueId: node.UniqueId,
Secret: node.Secret,
IsInstalled: node.IsInstalled,
InstallDir: node.InstallDir,
IsUp: node.IsUp,
ConnectedAPINodeIds: node.DecodeConnectedAPINodes(),
NsCluster: nil,
})
}
return &pb.FindAllNSNodesWithNSClusterIdResponse{NsNodes: pbNodes}, nil
}
// CountAllNSNodes 所有可用的节点数量
func (this *NSNodeService) CountAllNSNodes(ctx context.Context, req *pb.CountAllNSNodesRequest) (*pb.RPCCountResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
count, err := models.SharedNSNodeDAO.CountAllEnabledNodes(tx)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// CountAllNSNodesMatch 计算匹配的节点数量
func (this *NSNodeService) CountAllNSNodesMatch(ctx context.Context, req *pb.CountAllNSNodesMatchRequest) (*pb.RPCCountResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
count, err := models.SharedNSNodeDAO.CountAllEnabledNodesMatch(tx, req.NsClusterId, configutils.ToBoolState(req.InstallState), configutils.ToBoolState(req.ActiveState), req.Keyword)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListNSNodesMatch 列出单页节点
func (this *NSNodeService) ListNSNodesMatch(ctx context.Context, req *pb.ListNSNodesMatchRequest) (*pb.ListNSNodesMatchResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
nodes, err := models.SharedNSNodeDAO.ListAllEnabledNodesMatch(tx, req.NsClusterId, configutils.ToBoolState(req.InstallState), configutils.ToBoolState(req.ActiveState), req.Keyword, req.Offset, req.Size)
if err != nil {
return nil, err
}
pbNodes := []*pb.NSNode{}
for _, node := range nodes {
// 安装信息
installStatus, err := node.DecodeInstallStatus()
if err != nil {
return nil, err
}
installStatusResult := &pb.NodeInstallStatus{}
if installStatus != nil {
installStatusResult = &pb.NodeInstallStatus{
IsRunning: installStatus.IsRunning,
IsFinished: installStatus.IsFinished,
IsOk: installStatus.IsOk,
Error: installStatus.Error,
ErrorCode: installStatus.ErrorCode,
UpdatedAt: installStatus.UpdatedAt,
}
}
pbNodes = append(pbNodes, &pb.NSNode{
Id: int64(node.Id),
Name: node.Name,
IsOn: node.IsOn,
UniqueId: node.UniqueId,
Secret: node.Secret,
IsActive: node.IsActive,
IsInstalled: node.IsInstalled,
InstallDir: node.InstallDir,
IsUp: node.IsUp,
StatusJSON: node.Status,
InstallStatus: installStatusResult,
NsCluster: nil,
})
}
return &pb.ListNSNodesMatchResponse{NsNodes: pbNodes}, nil
}
// CountAllUpgradeNSNodesWithNSClusterId 计算需要升级的节点数量
func (this *NSNodeService) CountAllUpgradeNSNodesWithNSClusterId(ctx context.Context, req *pb.CountAllUpgradeNSNodesWithNSClusterIdRequest) (*pb.RPCCountResponse, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
deployFiles := installers.SharedDeployManager.LoadNSNodeFiles()
total := int64(0)
for _, deployFile := range deployFiles {
count, err := models.SharedNSNodeDAO.CountAllLowerVersionNodesWithClusterId(tx, req.NsClusterId, deployFile.OS, deployFile.Arch, deployFile.Version)
if err != nil {
return nil, err
}
total += count
}
return this.SuccessCount(total)
}
// CreateNSNode 创建节点
func (this *NSNodeService) CreateNSNode(ctx context.Context, req *pb.CreateNSNodeRequest) (*pb.CreateNSNodeResponse, error) {
adminId, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
nodeId, err := models.SharedNSNodeDAO.CreateNode(tx, adminId, req.Name, req.NodeClusterId)
if err != nil {
return nil, err
}
// 增加认证相关
if req.NodeLogin != nil {
_, err = models.SharedNodeLoginDAO.CreateNodeLogin(tx, nodeconfigs.NodeRoleDNS, nodeId, req.NodeLogin.Name, req.NodeLogin.Type, req.NodeLogin.Params)
if err != nil {
return nil, err
}
}
return &pb.CreateNSNodeResponse{
NsNodeId: nodeId,
}, nil
}
// DeleteNSNode 删除节点
func (this *NSNodeService) DeleteNSNode(ctx context.Context, req *pb.DeleteNSNodeRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedNSNodeDAO.DisableNSNode(tx, req.NsNodeId)
if err != nil {
return nil, err
}
// 删除任务
err = models.SharedNodeTaskDAO.DeleteNodeTasks(tx, nodeconfigs.NodeRoleDNS, req.NsNodeId)
if err != nil {
return nil, err
}
return this.Success()
}
// FindNSNode 查询单个节点信息
func (this *NSNodeService) FindNSNode(ctx context.Context, req *pb.FindNSNodeRequest) (*pb.FindNSNodeResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
node, err := models.SharedNSNodeDAO.FindEnabledNSNode(tx, req.NsNodeId)
if err != nil {
return nil, err
}
if node == nil {
return &pb.FindNSNodeResponse{NsNode: nil}, nil
}
// 集群信息
clusterName, err := models.SharedNSClusterDAO.FindEnabledNSClusterName(tx, int64(node.ClusterId))
if err != nil {
return nil, err
}
// 认证信息
login, err := models.SharedNodeLoginDAO.FindEnabledNodeLoginWithNodeId(tx, nodeconfigs.NodeRoleDNS, req.NsNodeId)
if err != nil {
return nil, err
}
var respLogin *pb.NodeLogin = nil
if login != nil {
respLogin = &pb.NodeLogin{
Id: int64(login.Id),
Name: login.Name,
Type: login.Type,
Params: login.Params,
}
}
// 安装信息
installStatus, err := node.DecodeInstallStatus()
if err != nil {
return nil, err
}
var installStatusResult = &pb.NodeInstallStatus{}
if installStatus != nil {
installStatusResult = &pb.NodeInstallStatus{
IsRunning: installStatus.IsRunning,
IsFinished: installStatus.IsFinished,
IsOk: installStatus.IsOk,
Error: installStatus.Error,
ErrorCode: installStatus.ErrorCode,
UpdatedAt: installStatus.UpdatedAt,
}
}
return &pb.FindNSNodeResponse{NsNode: &pb.NSNode{
Id: int64(node.Id),
Name: node.Name,
StatusJSON: node.Status,
UniqueId: node.UniqueId,
Secret: node.Secret,
IsInstalled: node.IsInstalled,
InstallDir: node.InstallDir,
ApiNodeAddrsJSON: node.ApiNodeAddrs,
NsCluster: &pb.NSCluster{
Id: int64(node.ClusterId),
Name: clusterName,
},
InstallStatus: installStatusResult,
IsOn: node.IsOn,
IsActive: node.IsActive,
NodeLogin: respLogin,
}}, nil
}
// UpdateNSNode 修改节点
func (this *NSNodeService) UpdateNSNode(ctx context.Context, req *pb.UpdateNSNodeRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedNSNodeDAO.UpdateNode(tx, req.NsNodeId, req.Name, req.NsClusterId, req.IsOn)
if err != nil {
return nil, err
}
// 登录信息
if req.NodeLogin == nil {
err = models.SharedNodeLoginDAO.DisableNodeLogins(tx, nodeconfigs.NodeRoleDNS, req.NsNodeId)
if err != nil {
return nil, err
}
} else {
if req.NodeLogin.Id > 0 {
err = models.SharedNodeLoginDAO.UpdateNodeLogin(tx, req.NodeLogin.Id, req.NodeLogin.Name, req.NodeLogin.Type, req.NodeLogin.Params)
if err != nil {
return nil, err
}
} else {
_, err = models.SharedNodeLoginDAO.CreateNodeLogin(tx, nodeconfigs.NodeRoleDNS, req.NsNodeId, req.NodeLogin.Name, req.NodeLogin.Type, req.NodeLogin.Params)
if err != nil {
return nil, err
}
}
}
return this.Success()
}
// InstallNSNode 安装节点
func (this *NSNodeService) InstallNSNode(ctx context.Context, req *pb.InstallNSNodeRequest) (*pb.InstallNSNodeResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
goman.New(func() {
err = installers.SharedNSNodeQueue().InstallNodeProcess(req.NsNodeId, false)
if err != nil {
logs.Println("[RPC]install dns node:" + err.Error())
}
})
return &pb.InstallNSNodeResponse{}, nil
}
// FindNSNodeInstallStatus 读取节点安装状态
func (this *NSNodeService) FindNSNodeInstallStatus(ctx context.Context, req *pb.FindNSNodeInstallStatusRequest) (*pb.FindNSNodeInstallStatusResponse, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
installStatus, err := models.SharedNSNodeDAO.FindNodeInstallStatus(tx, req.NsNodeId)
if err != nil {
return nil, err
}
if installStatus == nil {
return &pb.FindNSNodeInstallStatusResponse{InstallStatus: nil}, nil
}
pbInstallStatus := &pb.NodeInstallStatus{
IsRunning: installStatus.IsRunning,
IsFinished: installStatus.IsFinished,
IsOk: installStatus.IsOk,
Error: installStatus.Error,
ErrorCode: installStatus.ErrorCode,
UpdatedAt: installStatus.UpdatedAt,
}
return &pb.FindNSNodeInstallStatusResponse{InstallStatus: pbInstallStatus}, nil
}
// UpdateNSNodeIsInstalled 修改节点安装状态
func (this *NSNodeService) UpdateNSNodeIsInstalled(ctx context.Context, req *pb.UpdateNSNodeIsInstalledRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedNSNodeDAO.UpdateNodeIsInstalled(tx, req.NsNodeId, req.IsInstalled)
if err != nil {
return nil, err
}
return this.Success()
}
// UpdateNSNodeStatus 更新节点状态
func (this *NSNodeService) UpdateNSNodeStatus(ctx context.Context, req *pb.UpdateNSNodeStatusRequest) (*pb.RPCSuccess, error) {
// 校验节点
_, nodeId, err := this.ValidateNodeId(ctx, rpcutils.UserTypeDNS)
if err != nil {
return nil, err
}
if req.NodeId > 0 {
nodeId = req.NodeId
}
if nodeId <= 0 {
return nil, errors.New("'nodeId' should be greater than 0")
}
var tx = this.NullTx()
// 修改时间戳
var nodeStatus = &nodeconfigs.NodeStatus{}
err = json.Unmarshal(req.StatusJSON, nodeStatus)
if err != nil {
return nil, errors.New("decode node status json failed: " + err.Error())
}
nodeStatus.UpdatedAt = time.Now().Unix()
// 保存
err = models.SharedNSNodeDAO.UpdateNodeStatus(tx, nodeId, nodeStatus)
if err != nil {
return nil, err
}
return this.Success()
}
// FindCurrentNSNodeConfig 获取当前节点信息
func (this *NSNodeService) FindCurrentNSNodeConfig(ctx context.Context, req *pb.FindCurrentNSNodeConfigRequest) (*pb.FindCurrentNSNodeConfigResponse, error) {
// 校验节点
_, nodeId, err := this.ValidateNodeId(ctx, rpcutils.UserTypeDNS)
if err != nil {
return nil, err
}
var tx = this.NullTx()
config, err := models.SharedNSNodeDAO.ComposeNodeConfig(tx, nodeId)
if err != nil {
return nil, err
}
if config == nil {
return &pb.FindCurrentNSNodeConfigResponse{NsNodeJSON: nil}, nil
}
configJSON, err := json.Marshal(config)
if err != nil {
return nil, err
}
return &pb.FindCurrentNSNodeConfigResponse{NsNodeJSON: configJSON}, nil
}
// CheckNSNodeLatestVersion 检查新版本
func (this *NSNodeService) CheckNSNodeLatestVersion(ctx context.Context, req *pb.CheckNSNodeLatestVersionRequest) (*pb.CheckNSNodeLatestVersionResponse, error) {
_, _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeDNS)
if err != nil {
return nil, err
}
deployFiles := installers.SharedDeployManager.LoadNSNodeFiles()
for _, file := range deployFiles {
if file.OS == req.Os && file.Arch == req.Arch && stringutil.VersionCompare(file.Version, req.CurrentVersion) > 0 {
return &pb.CheckNSNodeLatestVersionResponse{
HasNewVersion: true,
NewVersion: file.Version,
}, nil
}
}
return &pb.CheckNSNodeLatestVersionResponse{HasNewVersion: false}, nil
}
// FindLatestNSNodeVersion 获取NS节点最新版本
func (this *NSNodeService) FindLatestNSNodeVersion(ctx context.Context, req *pb.FindLatestNSNodeVersionRequest) (*pb.FindLatestNSNodeVersionResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
return &pb.FindLatestNSNodeVersionResponse{Version: teaconst.DNSNodeVersion}, nil
}
// DownloadNSNodeInstallationFile 下载最新DNS节点安装文件
func (this *NSNodeService) DownloadNSNodeInstallationFile(ctx context.Context, req *pb.DownloadNSNodeInstallationFileRequest) (*pb.DownloadNSNodeInstallationFileResponse, error) {
nodeId, err := this.ValidateNSNode(ctx)
if err != nil {
return nil, err
}
var file = installers.SharedDeployManager.FindNSNodeFile(req.Os, req.Arch)
if file == nil {
return &pb.DownloadNSNodeInstallationFileResponse{}, nil
}
sum, err := file.Sum()
if err != nil {
return nil, err
}
data, offset, err := file.Read(req.ChunkOffset)
if err != nil && err != io.EOF {
return nil, err
}
// 增加下载速度监控
installers.SharedUpgradeLimiter.UpdateNodeBytes(nodeconfigs.NodeRoleDNS, nodeId, int64(len(data)))
return &pb.DownloadNSNodeInstallationFileResponse{
Sum: sum,
Offset: offset,
ChunkData: data,
Version: file.Version,
Filename: filepath.Base(file.Path),
}, nil
}
// UpdateNSNodeConnectedAPINodes 更改节点连接的API节点信息
func (this *NSNodeService) UpdateNSNodeConnectedAPINodes(ctx context.Context, req *pb.UpdateNSNodeConnectedAPINodesRequest) (*pb.RPCSuccess, error) {
// 校验节点
_, _, nodeId, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeDNS)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedNSNodeDAO.UpdateNodeConnectedAPINodes(tx, nodeId, req.ApiNodeIds)
if err != nil {
return nil, errors.Wrap(err)
}
return this.Success()
}
// UpdateNSNodeLogin 修改节点登录信息
func (this *NSNodeService) UpdateNSNodeLogin(ctx context.Context, req *pb.UpdateNSNodeLoginRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if req.NodeLogin.Id <= 0 {
loginId, err := models.SharedNodeLoginDAO.CreateNodeLogin(tx, nodeconfigs.NodeRoleDNS, req.NsNodeId, req.NodeLogin.Name, req.NodeLogin.Type, req.NodeLogin.Params)
if err != nil {
return nil, err
}
req.NodeLogin.Id = loginId
}
err = models.SharedNodeLoginDAO.UpdateNodeLogin(tx, req.NodeLogin.Id, req.NodeLogin.Name, req.NodeLogin.Type, req.NodeLogin.Params)
if err != nil {
return nil, err
}
return this.Success()
}
// StartNSNode 启动节点
func (this *NSNodeService) StartNSNode(ctx context.Context, req *pb.StartNSNodeRequest) (*pb.StartNSNodeResponse, error) {
// 校验节点
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
err = installers.SharedNSNodeQueue().StartNode(req.NsNodeId)
if err != nil {
return &pb.StartNSNodeResponse{
IsOk: false,
Error: err.Error(),
}, nil
}
// 修改状态
var tx = this.NullTx()
err = models.SharedNSNodeDAO.UpdateNodeActive(tx, req.NsNodeId, true)
if err != nil {
return nil, err
}
return &pb.StartNSNodeResponse{IsOk: true}, nil
}
// StopNSNode 停止节点
func (this *NSNodeService) StopNSNode(ctx context.Context, req *pb.StopNSNodeRequest) (*pb.StopNSNodeResponse, error) {
// 校验节点
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
err = installers.SharedNSNodeQueue().StopNode(req.NsNodeId)
if err != nil {
return &pb.StopNSNodeResponse{
IsOk: false,
Error: err.Error(),
}, nil
}
// 修改状态
var tx = this.NullTx()
err = models.SharedNSNodeDAO.UpdateNodeActive(tx, req.NsNodeId, false)
if err != nil {
return nil, err
}
return &pb.StopNSNodeResponse{IsOk: true}, nil
}
// FindNSNodeDDoSProtection 获取集群的DDoS设置
func (this *NSNodeService) FindNSNodeDDoSProtection(ctx context.Context, req *pb.FindNSNodeDDoSProtectionRequest) (*pb.FindNSNodeDDoSProtectionResponse, error) {
var nodeId = req.NsNodeId
var isFromNode = false
_, err := this.ValidateAdmin(ctx)
if err != nil {
// 检查是否来自节点
currentNodeId, err2 := this.ValidateNSNode(ctx)
if err2 != nil {
return nil, err
}
if nodeId > 0 && currentNodeId != nodeId {
return nil, errors.New("invalid 'nsNodeId'")
}
nodeId = currentNodeId
isFromNode = true
}
var tx *dbs.Tx
ddosProtection, err := models.SharedNSNodeDAO.FindNodeDDoSProtection(tx, nodeId)
if err != nil {
return nil, err
}
if ddosProtection == nil {
ddosProtection = ddosconfigs.DefaultProtectionConfig()
}
// 组合父级节点配置
// 只有从节点读取配置时才需要组合
if isFromNode {
clusterId, err := models.SharedNSNodeDAO.FindNodeClusterId(tx, nodeId)
if err != nil {
return nil, err
}
if clusterId > 0 {
clusterDDoSProtection, err := models.SharedNSClusterDAO.FindClusterDDoSProtection(tx, clusterId)
if err != nil {
return nil, err
}
if clusterDDoSProtection == nil {
clusterDDoSProtection = ddosconfigs.DefaultProtectionConfig()
}
clusterDDoSProtection.Merge(ddosProtection)
ddosProtection = clusterDDoSProtection
}
}
ddosProtectionJSON, err := json.Marshal(ddosProtection)
if err != nil {
return nil, err
}
var result = &pb.FindNSNodeDDoSProtectionResponse{
DdosProtectionJSON: ddosProtectionJSON,
}
return result, nil
}
// UpdateNSNodeDDoSProtection 修改集群的DDoS设置
func (this *NSNodeService) UpdateNSNodeDDoSProtection(ctx context.Context, req *pb.UpdateNSNodeDDoSProtectionRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var ddosProtection = &ddosconfigs.ProtectionConfig{}
err = json.Unmarshal(req.DdosProtectionJSON, ddosProtection)
if err != nil {
return nil, err
}
var tx *dbs.Tx
err = models.SharedNSNodeDAO.UpdateNodeDDoSProtection(tx, req.NsNodeId, ddosProtection)
if err != nil {
return nil, err
}
return this.Success()
}
// FindNSNodeAPIConfig 查找单个节点的API相关配置
func (this *NSNodeService) FindNSNodeAPIConfig(ctx context.Context, req *pb.FindNSNodeAPIConfigRequest) (*pb.FindNSNodeAPIConfigResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
node, err := models.SharedNSNodeDAO.FindNodeAPIConfig(tx, req.NsNodeId)
if err != nil {
return nil, err
}
if node == nil {
return &pb.FindNSNodeAPIConfigResponse{
ApiNodeAddrsJSON: nil,
}, nil
}
return &pb.FindNSNodeAPIConfigResponse{
ApiNodeAddrsJSON: node.ApiNodeAddrs,
}, nil
}
// UpdateNSNodeAPIConfig 修改某个节点的API相关配置
func (this *NSNodeService) UpdateNSNodeAPIConfig(ctx context.Context, req *pb.UpdateNSNodeAPIConfigRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
var apiNodeAddrs = []*serverconfigs.NetworkAddressConfig{}
if len(req.ApiNodeAddrsJSON) > 0 {
err = json.Unmarshal(req.ApiNodeAddrsJSON, &apiNodeAddrs)
if err != nil {
return nil, err
}
}
err = models.SharedNSNodeDAO.UpdateNodeAPIConfig(tx, req.NsNodeId, apiNodeAddrs)
if err != nil {
return nil, err
}
return this.Success()
}

View File

@@ -0,0 +1,331 @@
//go:build plus
package nameservers
import (
"context"
"fmt"
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeAPI/internal/goman"
"github.com/TeaOSLab/EdgeAPI/internal/remotelogs"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/messageconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/types"
"strconv"
"sync"
"sync/atomic"
"time"
)
// CommandRequest 命令请求相关
type CommandRequest struct {
Id int64
Code string
CommandJSON []byte
}
type CommandRequestWaiting struct {
Timestamp int64
Chan chan *pb.NSNodeStreamMessage
}
func (this *CommandRequestWaiting) Close() {
defer func() {
_ = recover()
}()
close(this.Chan)
}
var responseChanMap = map[int64]*CommandRequestWaiting{} // request id => response
var commandRequestId = int64(0)
var nodeLocker = &sync.Mutex{}
var requestChanMap = map[int64]chan *CommandRequest{} // node id => chan
func NextCommandRequestId() int64 {
return atomic.AddInt64(&commandRequestId, 1)
}
func init() {
dbs.OnReadyDone(func() {
// 清理WaitingChannelMap
goman.New(func() {
ticker := time.NewTicker(30 * time.Second)
for range ticker.C {
nodeLocker.Lock()
for requestId, request := range responseChanMap {
if time.Now().Unix()-request.Timestamp > 3600 {
responseChanMap[requestId].Close()
delete(responseChanMap, requestId)
}
}
nodeLocker.Unlock()
}
})
// 自动同步连接到本API节点的NS节点任务
goman.New(func() {
defer func() {
_ = recover()
}()
// TODO 未来支持同步边缘节点
var ticker = time.NewTicker(3 * time.Second)
for range ticker.C {
nodeIds, err := models.SharedNodeTaskDAO.FindAllDoingNodeIds(nil, nodeconfigs.NodeRoleDNS)
if err != nil {
remotelogs.Error("NSNodeService_SYNC", err.Error())
continue
}
nodeLocker.Lock()
for _, nodeId := range nodeIds {
c, ok := requestChanMap[nodeId]
if ok {
select {
case c <- &CommandRequest{
Id: NextCommandRequestId(),
Code: messageconfigs.NSMessageCodeNewNodeTask,
CommandJSON: nil,
}:
default:
}
}
}
nodeLocker.Unlock()
}
})
})
}
// NsNodeStream 节点stream
func (this *NSNodeService) NsNodeStream(server pb.NSNodeService_NsNodeStreamServer) error {
// TODO 使用此stream快速通知NS节点更新
// 校验节点
_, _, nodeId, err := rpcutils.ValidateRequest(server.Context(), rpcutils.UserTypeDNS)
if err != nil {
return err
}
// 返回连接成功
err = models.SharedNSNodeDAO.UpdateNodeConnectedAPINodes(nil, nodeId, []int64{teaconst.NodeId})
if err != nil {
return err
}
if Tea.IsTesting() {
remotelogs.Println("NSNodeService", "accepted ns node '"+types.String(nodeId)+"' connection")
}
var tx = this.NullTx()
// 是否发送恢复通知
oldIsActive, err := models.SharedNSNodeDAO.FindNodeActive(tx, nodeId)
if err != nil {
return err
}
if !oldIsActive {
inactiveNotifiedAt, err := models.SharedNSNodeDAO.FindNodeInactiveNotifiedAt(tx, nodeId)
if err != nil {
return err
}
// 设置为活跃
err = models.SharedNSNodeDAO.UpdateNodeActive(tx, nodeId, true)
if err != nil {
return err
}
if inactiveNotifiedAt > 0 {
// 发送恢复消息
clusterId, err := models.SharedNSNodeDAO.FindNodeClusterId(tx, nodeId)
if err != nil {
return err
}
nodeName, err := models.SharedNSNodeDAO.FindEnabledNSNodeName(tx, nodeId)
if err != nil {
return err
}
subject := "NS节点\"" + nodeName + "\"已经恢复在线"
msg := "NS节点\"" + nodeName + "\"已经恢复在线"
err = models.SharedMessageDAO.CreateNodeMessage(tx, nodeconfigs.NodeRoleDNS, clusterId, nodeId, models.MessageTypeNSNodeActive, models.MessageLevelSuccess, subject, msg, nil, false)
if err != nil {
return err
}
}
}
nodeLocker.Lock()
requestChan, ok := requestChanMap[nodeId]
if !ok {
requestChan = make(chan *CommandRequest, 1024)
requestChanMap[nodeId] = requestChan
}
nodeLocker.Unlock()
defer func() {
nodeLocker.Lock()
delete(requestChanMap, nodeId)
nodeLocker.Unlock()
}()
// 发送请求
goman.New(func() {
for {
select {
case <-server.Context().Done():
return
case commandRequest := <-requestChan:
// logs.Println("[RPC]sending command '" + commandRequest.Code + "' to node '" + strconv.FormatInt(nodeId, 10) + "'")
retries := 3 // 错误重试次数
for i := 0; i < retries; i++ {
err := server.Send(&pb.NSNodeStreamMessage{
RequestId: commandRequest.Id,
Code: commandRequest.Code,
DataJSON: commandRequest.CommandJSON,
})
if err != nil {
if i == retries-1 {
logs.Println("[RPC]send command '" + commandRequest.Code + "' failed: " + err.Error())
} else {
time.Sleep(1 * time.Second)
}
} else {
break
}
}
}
}
})
// 接受请求
for {
req, err := server.Recv()
if err != nil {
// 修改节点状态
err1 := models.SharedNSNodeDAO.UpdateNodeActive(tx, nodeId, false)
if err1 != nil {
logs.Println(err1.Error())
}
return err
}
func(req *pb.NSNodeStreamMessage) {
// 因为 responseChan.Chan 有被关闭的风险所以我们使用recover防止panic
defer func() {
_ = recover()
}()
nodeLocker.Lock()
responseChan, ok := responseChanMap[req.RequestId]
if ok {
select {
case responseChan.Chan <- req:
default:
}
}
nodeLocker.Unlock()
}(req)
}
}
// SendCommandToNSNode 向节点发送命令
func (this *NSNodeService) SendCommandToNSNode(ctx context.Context, req *pb.NSNodeStreamMessage) (*pb.NSNodeStreamMessage, error) {
// 校验请求
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
nodeId := req.NsNodeId
if nodeId <= 0 {
return nil, errors.New("node id should not be less than 0")
}
nodeLocker.Lock()
requestChan, ok := requestChanMap[nodeId]
nodeLocker.Unlock()
if !ok {
return &pb.NSNodeStreamMessage{
RequestId: req.RequestId,
IsOk: false,
Message: "node '" + strconv.FormatInt(nodeId, 10) + "' not connected yet",
}, nil
}
req.RequestId = NextCommandRequestId()
select {
case requestChan <- &CommandRequest{
Id: req.RequestId,
Code: req.Code,
CommandJSON: req.DataJSON,
}:
// 加入到等待队列中
respChan := make(chan *pb.NSNodeStreamMessage, 1)
waiting := &CommandRequestWaiting{
Timestamp: time.Now().Unix(),
Chan: respChan,
}
nodeLocker.Lock()
responseChanMap[req.RequestId] = waiting
nodeLocker.Unlock()
// 等待响应
timeoutSeconds := req.TimeoutSeconds
if timeoutSeconds <= 0 {
timeoutSeconds = 10
}
timeout := time.NewTimer(time.Duration(timeoutSeconds) * time.Second)
select {
case resp := <-respChan:
// 从队列中删除
nodeLocker.Lock()
delete(responseChanMap, req.RequestId)
waiting.Close()
nodeLocker.Unlock()
if resp == nil {
return &pb.NSNodeStreamMessage{
RequestId: req.RequestId,
Code: req.Code,
Message: "response timeout",
IsOk: false,
}, nil
}
return resp, nil
case <-timeout.C:
// 从队列中删除
nodeLocker.Lock()
delete(responseChanMap, req.RequestId)
waiting.Close()
nodeLocker.Unlock()
return &pb.NSNodeStreamMessage{
RequestId: req.RequestId,
Code: req.Code,
Message: "response timeout over " + fmt.Sprintf("%d", timeoutSeconds) + " seconds",
IsOk: false,
}, nil
}
default:
return &pb.NSNodeStreamMessage{
RequestId: req.RequestId,
Code: req.Code,
Message: "command queue is full over " + strconv.Itoa(len(requestChan)),
IsOk: false,
}, nil
}
}

View File

@@ -0,0 +1,191 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package nameservers
import (
"context"
"encoding/json"
"errors"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/nameservers"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
// NSPlanService DNS套餐服务
type NSPlanService struct {
services.BaseService
}
// CreateNSPlan 创建DNS套餐
func (this *NSPlanService) CreateNSPlan(ctx context.Context, req *pb.CreateNSPlanRequest) (*pb.CreateNSPlanResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if len(req.ConfigJSON) == 0 {
return nil, errors.New("invalid 'configJSON'")
}
var config = dnsconfigs.DefaultNSPlanConfig()
err = json.Unmarshal(req.ConfigJSON, config)
if err != nil {
return nil, errors.New("decode 'configJSON' failed: " + err.Error())
}
planId, err := nameservers.SharedNSPlanDAO.CreatePlan(tx, req.Name, req.MonthlyPrice, req.YearlyPrice, config)
if err != nil {
return nil, err
}
return &pb.CreateNSPlanResponse{NsPlanId: planId}, nil
}
// UpdateNSPlan 修改DNS套餐
func (this *NSPlanService) UpdateNSPlan(ctx context.Context, req *pb.UpdateNSPlanRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if len(req.ConfigJSON) == 0 {
return nil, errors.New("invalid 'configJSON'")
}
var config = dnsconfigs.DefaultNSPlanConfig()
err = json.Unmarshal(req.ConfigJSON, config)
if err != nil {
return nil, errors.New("decode 'configJSON' failed: " + err.Error())
}
err = nameservers.SharedNSPlanDAO.UpdatePlan(tx, req.NsPlanId, req.Name, req.IsOn, req.MonthlyPrice, req.YearlyPrice, config)
if err != nil {
return nil, err
}
return this.Success()
}
// SortNSPlanOrders 修改DNS套餐顺序
func (this *NSPlanService) SortNSPlanOrders(ctx context.Context, req *pb.SortNSPlansRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = nameservers.SharedNSPlanDAO.UpdatePlanOrders(tx, req.NsPlanIds)
if err != nil {
return nil, err
}
return this.Success()
}
// FindAllNSPlans 查找所有DNS套餐
func (this *NSPlanService) FindAllNSPlans(ctx context.Context, req *pb.FindAllNSPlansRequest) (*pb.FindAllNSPlansResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
var pbPlans = []*pb.NSPlan{}
plans, err := nameservers.SharedNSPlanDAO.FindAllPlans(tx)
if err != nil {
return nil, err
}
for _, plan := range plans {
pbPlans = append(pbPlans, &pb.NSPlan{
Id: int64(plan.Id),
Name: plan.Name,
IsOn: plan.IsOn,
MonthlyPrice: float32(plan.MonthlyPrice),
YearlyPrice: float32(plan.YearlyPrice),
ConfigJSON: plan.Config,
})
}
return &pb.FindAllNSPlansResponse{NsPlans: pbPlans}, nil
}
// FindAllEnabledNSPlans 查找所有可用DNS套餐
func (this *NSPlanService) FindAllEnabledNSPlans(ctx context.Context, req *pb.FindAllEnabledNSPlansRequest) (*pb.FindAllEnabledNSPlansResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
var pbPlans = []*pb.NSPlan{}
plans, err := nameservers.SharedNSPlanDAO.FindAllEnabledPlans(tx)
if err != nil {
return nil, err
}
for _, plan := range plans {
pbPlans = append(pbPlans, &pb.NSPlan{
Id: int64(plan.Id),
Name: plan.Name,
IsOn: plan.IsOn,
MonthlyPrice: float32(plan.MonthlyPrice),
YearlyPrice: float32(plan.YearlyPrice),
ConfigJSON: plan.Config,
})
}
return &pb.FindAllEnabledNSPlansResponse{NsPlans: pbPlans}, nil
}
// FindNSPlan 查找DNS套餐
func (this *NSPlanService) FindNSPlan(ctx context.Context, req *pb.FindNSPlanRequest) (*pb.FindNSPlanResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
plan, err := nameservers.SharedNSPlanDAO.FindEnabledNSPlan(tx, req.NsPlanId)
if err != nil {
return nil, err
}
if plan == nil {
return &pb.FindNSPlanResponse{}, nil
}
return &pb.FindNSPlanResponse{
NsPlan: &pb.NSPlan{
Id: int64(plan.Id),
Name: plan.Name,
IsOn: plan.IsOn,
MonthlyPrice: float32(plan.MonthlyPrice),
YearlyPrice: float32(plan.YearlyPrice),
ConfigJSON: plan.Config,
},
}, nil
}
// DeleteNSPlan 删除DNS套餐
func (this *NSPlanService) DeleteNSPlan(ctx context.Context, req *pb.DeleteNSPlanRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = nameservers.SharedNSPlanDAO.DisableNSPlan(tx, req.NsPlanId)
if err != nil {
return nil, err
}
return this.Success()
}

View File

@@ -0,0 +1,78 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus
package nameservers
import (
"context"
"encoding/json"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/nameservers"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/maps"
)
// NSQuestionOptionService DNS查询选项
type NSQuestionOptionService struct {
services.BaseService
}
// CreateNSQuestionOption 创建选项
func (this *NSQuestionOptionService) CreateNSQuestionOption(ctx context.Context, req *pb.CreateNSQuestionOptionRequest) (*pb.CreateNSQuestionOptionResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
var values = maps.Map{}
if len(req.ValuesJSON) > 0 {
err = json.Unmarshal(req.ValuesJSON, &values)
if err != nil {
return nil, err
}
}
optionId, err := nameservers.SharedNSQuestionOptionDAO.CreateOption(tx, req.Name, values)
if err != nil {
return nil, err
}
return &pb.CreateNSQuestionOptionResponse{NsQuestionOptionId: optionId}, nil
}
// FindNSQuestionOption 读取选项
func (this *NSQuestionOptionService) FindNSQuestionOption(ctx context.Context, req *pb.FindNSQuestionOptionRequest) (*pb.FindNSQuestionOptionResponse, error) {
_, err := this.ValidateNSNode(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
option, err := nameservers.SharedNSQuestionOptionDAO.FindOption(tx, req.NsQuestionOptionId)
if err != nil {
return nil, err
}
if option == nil {
return &pb.FindNSQuestionOptionResponse{NsQuestionOption: nil}, nil
}
return &pb.FindNSQuestionOptionResponse{NsQuestionOption: &pb.NSQuestionOption{
Id: int64(option.Id),
Name: option.Name,
ValuesJSON: option.Values,
}}, nil
}
// DeleteNSQuestionOption 删除选项
func (this *NSQuestionOptionService) DeleteNSQuestionOption(ctx context.Context, req *pb.DeleteNSQuestionOptionRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = nameservers.SharedNSQuestionOptionDAO.DeleteOption(tx, req.NsQuestionOptionId)
if err != nil {
return nil, err
}
return this.Success()
}

View File

@@ -0,0 +1,993 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus
package nameservers
import (
"context"
"encoding/json"
"errors"
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/nameservers"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/types"
"regexp"
"strings"
)
// NSRecordService 域名记录相关服务
type NSRecordService struct {
services.BaseService
}
// CreateNSRecord 创建记录
func (this *NSRecordService) CreateNSRecord(ctx context.Context, req *pb.CreateNSRecordRequest) (*pb.CreateNSRecordResponse, error) {
// 检查是否为商业用户
if !teaconst.IsPlus {
return nil, errors.New("non commercial user")
}
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查权限
if userId > 0 {
err = nameservers.SharedNSDomainDAO.CheckUserDomain(tx, userId, req.NsDomainId)
if err != nil {
return nil, err
}
}
// 检查线路代号
if len(req.NsRouteCodes) > 0 {
err = nameservers.SharedNSRouteDAO.CheckRouteCodes(tx, req.NsRouteCodes, userId)
if err != nil {
return nil, err
}
}
recordId, err := nameservers.SharedNSRecordDAO.CreateRecord(tx, req.NsDomainId, req.Description, req.Name, req.Type, req.Value, req.MxPriority, req.SrvPriority, req.SrvWeight, req.SrvPort, req.CaaFlag, req.CaaTag, req.Ttl, req.NsRouteCodes, req.Weight)
if err != nil {
return nil, err
}
return &pb.CreateNSRecordResponse{NsRecordId: recordId}, nil
}
// CreateNSRecords 创建记录
func (this *NSRecordService) CreateNSRecords(ctx context.Context, req *pb.CreateNSRecordsRequest) (*pb.CreateNSRecordsResponse, error) {
// 检查是否为商业用户
if !teaconst.IsPlus {
return nil, errors.New("non commercial user")
}
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
// 检查权限
var tx = this.NullTx()
if userId > 0 {
err = nameservers.SharedNSDomainDAO.CheckUserDomain(tx, userId, req.NsDomainId)
if err != nil {
return nil, err
}
}
var recordIds = []int64{}
err = this.RunTx(func(tx *dbs.Tx) error {
for _, name := range req.Names {
// 检查线路代号
if len(req.NsRouteCodes) > 0 {
err = nameservers.SharedNSRouteDAO.CheckRouteCodes(tx, req.NsRouteCodes, userId)
if err != nil {
return err
}
}
recordId, err := nameservers.SharedNSRecordDAO.CreateRecord(tx, req.NsDomainId, req.Description, name, req.Type, req.Value, req.MxPriority, req.SrvPriority, req.SrvWeight, req.SrvPort, req.CaaFlag, req.CaaTag, req.Ttl, req.NsRouteCodes, req.Weight)
if err != nil {
return err
}
recordIds = append(recordIds, recordId)
}
return nil
})
if err != nil {
return nil, err
}
return &pb.CreateNSRecordsResponse{NsRecordIds: recordIds}, nil
}
// CreateNSRecordsWithDomainNames 为一组域名批量创建记录
func (this *NSRecordService) CreateNSRecordsWithDomainNames(ctx context.Context, req *pb.CreateNSRecordsWithDomainNamesRequest) (*pb.RPCSuccess, error) {
// 检查是否为商业用户
if !teaconst.IsPlus {
return nil, errors.New("non commercial user")
}
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if userId > 0 {
req.UserId = userId
}
if len(req.RecordsJSON) == 0 {
return this.Success()
}
type recordItem struct {
Name string `json:"name"`
Type string `json:"type"`
Value string `json:"value"`
MxPriority int32 `json:"mxPriority"`
SRVPriority int32 `json:"srvPriority"`
SRVWeight int32 `json:"srvWeight"`
SRVPort int32 `json:"srvPort"`
CAAFlag int32 `json:"caaFlag"`
CAATag string `json:"caaTag"`
RouteCodes []string `json:"routeCodes"`
TTL int32 `json:"ttl"`
Weight int32 `json:"weight"`
}
var records = []*recordItem{}
err = json.Unmarshal(req.RecordsJSON, &records)
if err != nil {
return nil, err
}
if len(records) == 0 {
return this.Success()
}
err = this.RunTx(func(tx *dbs.Tx) error {
for _, domainName := range req.NsDomainNames {
domainName = strings.ToLower(strings.TrimSpace(domainName))
if len(domainName) == 0 {
continue
}
domainId, err := nameservers.SharedNSDomainDAO.FindDomainIdWithName(tx, 0, req.UserId, domainName, false)
if err != nil {
return err
}
if domainId <= 0 {
continue
}
// 是否删除所有以往记录
if req.RemoveAll {
err = nameservers.SharedNSRecordDAO.DisableRecordsInDomain(tx, domainId)
if err != nil {
return err
}
}
for _, record := range records {
record.Type = strings.ToLower(record.Type)
if !req.RemoveAll && req.RemoveOld {
err = nameservers.SharedNSRecordDAO.DisableRecordsInDomainWithNameAndType(tx, domainId, record.Name, record.Type)
if err != nil {
return err
}
}
// 检查线路代号
if len(record.RouteCodes) > 0 {
err = nameservers.SharedNSRouteDAO.CheckRouteCodes(tx, record.RouteCodes, userId)
if err != nil {
return err
}
}
_, err = nameservers.SharedNSRecordDAO.CreateRecord(tx, domainId, "批量创建", record.Name, strings.ToUpper(record.Type), record.Value, record.MxPriority, record.SRVPriority, record.SRVWeight, record.SRVPort, record.CAAFlag, record.CAATag, record.TTL, record.RouteCodes, record.Weight)
if err != nil {
return err
}
}
}
return nil
})
if err != nil {
return nil, err
}
return this.Success()
}
// UpdateNSRecordsWithDomainNames 批量修改一组域名的一组记录
func (this *NSRecordService) UpdateNSRecordsWithDomainNames(ctx context.Context, req *pb.UpdateNSRecordsWithDomainNamesRequest) (*pb.RPCSuccess, error) {
// 检查是否为商业用户
if !teaconst.IsPlus {
return nil, errors.New("non commercial user")
}
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if userId > 0 {
req.UserId = userId
}
err = this.RunTx(func(tx *dbs.Tx) error {
for _, domainName := range req.NsDomainNames {
domainName = strings.ToLower(strings.TrimSpace(domainName))
if len(domainName) == 0 {
continue
}
domainId, err := nameservers.SharedNSDomainDAO.FindDomainIdWithName(tx, 0, req.UserId, domainName, false)
if err != nil {
return err
}
if domainId <= 0 {
continue
}
err = nameservers.SharedNSRecordDAO.UpdateRecordsWithDomainId(tx, domainId, req.SearchName, req.SearchType, req.SearchValue, req.SearchNSRouteCodes, req.NewName, req.NewType, req.NewValue, req.NewNSRouteCodes)
if err != nil {
return err
}
}
return nil
})
if err != nil {
return nil, err
}
return this.Success()
}
// DeleteNSRecordsWithDomainNames 批量删除一组域名的一组记录
func (this *NSRecordService) DeleteNSRecordsWithDomainNames(ctx context.Context, req *pb.DeleteNSRecordsWithDomainNamesRequest) (*pb.RPCSuccess, error) {
// 检查是否为商业用户
if !teaconst.IsPlus {
return nil, errors.New("non commercial user")
}
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if userId > 0 {
req.UserId = userId
}
err = this.RunTx(func(tx *dbs.Tx) error {
for _, domainName := range req.NsDomainNames {
domainName = strings.ToLower(strings.TrimSpace(domainName))
if len(domainName) == 0 {
continue
}
domainId, err := nameservers.SharedNSDomainDAO.FindDomainIdWithName(tx, 0, req.UserId, domainName, false)
if err != nil {
return err
}
if domainId <= 0 {
continue
}
err = nameservers.SharedNSRecordDAO.DisableRecordsWithDomainId(tx, domainId, req.SearchName, req.SearchType, req.SearchValue, req.SearchNSRouteCodes)
if err != nil {
return err
}
}
return nil
})
if err != nil {
return nil, err
}
return this.Success()
}
// UpdateNSRecordsIsOnWithDomainNames 批量一组域名的一组记录启用状态
func (this *NSRecordService) UpdateNSRecordsIsOnWithDomainNames(ctx context.Context, req *pb.UpdateNSRecordsIsOnWithDomainNamesRequest) (*pb.RPCSuccess, error) {
// 检查是否为商业用户
if !teaconst.IsPlus {
return nil, errors.New("non commercial user")
}
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if userId > 0 {
req.UserId = userId
}
err = this.RunTx(func(tx *dbs.Tx) error {
for _, domainName := range req.NsDomainNames {
domainName = strings.ToLower(strings.TrimSpace(domainName))
if len(domainName) == 0 {
continue
}
domainId, err := nameservers.SharedNSDomainDAO.FindDomainIdWithName(tx, 0, req.UserId, domainName, false)
if err != nil {
return err
}
if domainId <= 0 {
continue
}
err = nameservers.SharedNSRecordDAO.UpdateRecordsIsOnWithDomainId(tx, domainId, req.SearchName, req.SearchType, req.SearchValue, req.SearchNSRouteCodes, req.IsOn)
if err != nil {
return err
}
}
return nil
})
if err != nil {
return nil, err
}
return this.Success()
}
// ImportNSRecords 导入域名解析
func (this *NSRecordService) ImportNSRecords(ctx context.Context, req *pb.ImportNSRecordsRequest) (*pb.RPCSuccess, error) {
// 检查是否为商业用户
if !teaconst.IsPlus {
return nil, errors.New("non commercial user")
}
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if userId > 0 {
req.UserId = userId
}
err = this.RunTx(func(tx *dbs.Tx) error {
for _, record := range req.NsRecords {
var domainName = strings.ToLower(strings.TrimSpace(record.NsDomainName))
if len(domainName) == 0 {
continue
}
domainId, err := nameservers.SharedNSDomainDAO.FindDomainIdWithName(tx, 0, req.UserId, domainName, false)
if err != nil {
return err
}
if domainId <= 0 {
continue
}
if record.Ttl <= 0 {
record.Ttl = 600
}
_, err = nameservers.SharedNSRecordDAO.CreateRecord(tx, domainId, "批量导入", record.Name, record.Type, record.Value, record.MxPriority, record.SrvPriority, record.SrvWeight, record.SrvPort, record.CaaFlag, record.CaaTag, record.Ttl, nil, record.Weight)
if err != nil {
return err
}
}
return nil
})
if err != nil {
return nil, err
}
return this.Success()
}
// UpdateNSRecord 修改记录
func (this *NSRecordService) UpdateNSRecord(ctx context.Context, req *pb.UpdateNSRecordRequest) (*pb.RPCSuccess, error) {
// 检查是否为商业用户
if !teaconst.IsPlus {
return nil, errors.New("non commercial user")
}
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查权限
if userId > 0 {
err = nameservers.SharedNSRecordDAO.CheckUserRecord(tx, userId, req.NsRecordId)
if err != nil {
return nil, err
}
}
err = nameservers.SharedNSRecordDAO.UpdateRecord(tx, req.NsRecordId, req.Description, req.Name, req.Type, req.Value, req.MxPriority, req.SrvPriority, req.SrvWeight, req.SrvPort, req.CaaFlag, req.CaaTag, req.Ttl, req.NsRouteCodes, req.Weight, req.IsOn)
if err != nil {
return nil, err
}
return this.Success()
}
// DeleteNSRecord 删除记录
func (this *NSRecordService) DeleteNSRecord(ctx context.Context, req *pb.DeleteNSRecordRequest) (*pb.RPCSuccess, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查权限
if userId > 0 {
err = nameservers.SharedNSRecordDAO.CheckUserRecord(tx, userId, req.NsRecordId)
if err != nil {
return nil, err
}
}
err = nameservers.SharedNSRecordDAO.DisableNSRecord(tx, req.NsRecordId)
if err != nil {
return nil, err
}
return this.Success()
}
// CountAllNSRecords 计算记录数量
func (this *NSRecordService) CountAllNSRecords(ctx context.Context, req *pb.CountAllNSRecordsRequest) (*pb.RPCCountResponse, error) {
// 检查是否为商业用户
if !teaconst.IsPlus {
return nil, errors.New("non commercial user")
}
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查权限
if userId > 0 {
err = nameservers.SharedNSDomainDAO.CheckUserDomain(tx, userId, req.NsDomainId)
if err != nil {
return nil, err
}
}
count, err := nameservers.SharedNSRecordDAO.CountAllEnabledDomainRecords(tx, req.NsDomainId, req.Type, req.Keyword, req.NsRouteCode)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// CountAllNSRecordsWithName 查询相同记录名的记录数
func (this *NSRecordService) CountAllNSRecordsWithName(ctx context.Context, req *pb.CountAllNSRecordsWithNameRequest) (*pb.RPCCountResponse, error) {
// 检查是否为商业用户
if !teaconst.IsPlus {
return nil, errors.New("non commercial user")
}
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查权限
if userId > 0 {
err = nameservers.SharedNSDomainDAO.CheckUserDomain(tx, userId, req.NsDomainId)
if err != nil {
return nil, err
}
}
count, err := nameservers.SharedNSRecordDAO.CountAllRecordsWithName(tx, req.NsDomainId, req.Type, req.Name)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListNSRecords 读取单页记录
func (this *NSRecordService) ListNSRecords(ctx context.Context, req *pb.ListNSRecordsRequest) (*pb.ListNSRecordsResponse, error) {
// 检查是否为商业用户
if !teaconst.IsPlus {
return nil, errors.New("non commercial user")
}
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查权限
if userId > 0 {
err = nameservers.SharedNSDomainDAO.CheckUserDomain(tx, userId, req.NsDomainId)
if err != nil {
return nil, err
}
}
records, err := nameservers.SharedNSRecordDAO.ListEnabledRecords(tx, req.NsDomainId, req.Type, req.Keyword, req.NsRouteCode, req.NameAsc, req.NameDesc, req.TypeAsc, req.TypeDesc, req.TtlAsc, req.TtlDesc, req.UpAsc, req.UpDesc, req.Offset, req.Size)
if err != nil {
return nil, err
}
var pbRecords = []*pb.NSRecord{}
for _, record := range records {
// 线路
var pbRoutes = []*pb.NSRoute{}
for _, routeCode := range record.DecodeRouteIds() {
route, err := nameservers.SharedNSRouteDAO.FindEnabledRouteWithCode(tx, routeCode)
if err != nil {
return nil, err
}
if route == nil {
continue
}
pbRoutes = append(pbRoutes, &pb.NSRoute{
Id: int64(route.Id),
Name: route.Name,
Code: route.Code,
})
// TODO 读取其他线路
}
pbRecords = append(pbRecords, &pb.NSRecord{
Id: int64(record.Id),
Description: record.Description,
Name: record.Name,
Type: record.Type,
Value: record.Value,
MxPriority: int32(record.MxPriority),
Ttl: types.Int32(record.Ttl),
Weight: types.Int32(record.Weight),
CreatedAt: int64(record.CreatedAt),
IsOn: record.IsOn,
NsDomain: nil,
NsRoutes: pbRoutes,
HealthCheckJSON: record.HealthCheck,
IsUp: record.IsUp,
})
}
return &pb.ListNSRecordsResponse{NsRecords: pbRecords}, nil
}
// FindNSRecord 查询单个记录信息
func (this *NSRecordService) FindNSRecord(ctx context.Context, req *pb.FindNSRecordRequest) (*pb.FindNSRecordResponse, error) {
// 检查是否为商业用户
if !teaconst.IsPlus {
return nil, errors.New("non commercial user")
}
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查权限
if userId > 0 {
err = nameservers.SharedNSRecordDAO.CheckUserRecord(tx, userId, req.NsRecordId)
if err != nil {
return nil, err
}
}
record, err := nameservers.SharedNSRecordDAO.FindEnabledNSRecord(tx, req.NsRecordId)
if err != nil {
return nil, err
}
if record == nil {
return &pb.FindNSRecordResponse{NsRecord: nil}, nil
}
// 域名
domain, err := nameservers.SharedNSDomainDAO.FindEnabledNSDomain(tx, int64(record.DomainId))
if err != nil {
return nil, err
}
if domain == nil {
return &pb.FindNSRecordResponse{NsRecord: nil}, nil
}
var pbDomain = &pb.NSDomain{
Id: int64(domain.Id),
Name: domain.Name,
IsOn: domain.IsOn,
}
// 线路
var pbRoutes = []*pb.NSRoute{}
for _, routeCode := range record.DecodeRouteIds() {
route, err := nameservers.SharedNSRouteDAO.FindEnabledRouteWithCode(tx, routeCode)
if err != nil {
return nil, err
}
if route == nil {
continue
}
pbRoutes = append(pbRoutes, &pb.NSRoute{
Id: int64(route.Id),
Name: route.Name,
Code: route.Code,
})
}
// TODO 读取其他线路
return &pb.FindNSRecordResponse{NsRecord: &pb.NSRecord{
Id: int64(record.Id),
Description: record.Description,
Name: record.Name,
Type: record.Type,
Value: record.Value,
MxPriority: types.Int32(record.MxPriority),
SrvPort: types.Int32(record.SrvPort),
SrvPriority: types.Int32(record.SrvPriority),
SrvWeight: types.Int32(record.SrvWeight),
CaaFlag: types.Int32(record.CaaFlag),
CaaTag: record.CaaTag,
Ttl: types.Int32(record.Ttl),
Weight: types.Int32(record.Weight),
CreatedAt: int64(record.CreatedAt),
IsOn: record.IsOn,
NsDomain: pbDomain,
NsRoutes: pbRoutes,
HealthCheckJSON: record.HealthCheck,
IsUp: record.IsUp,
}}, nil
}
// FindNSRecordWithNameAndType 使用名称和类型查询单个记录信息
func (this *NSRecordService) FindNSRecordWithNameAndType(ctx context.Context, req *pb.FindNSRecordWithNameAndTypeRequest) (*pb.FindNSRecordWithNameAndTypeResponse, error) {
// 检查是否为商业用户
if !teaconst.IsPlus {
return nil, errors.New("non commercial user")
}
if req.NsDomainId <= 0 {
return &pb.FindNSRecordWithNameAndTypeResponse{
NsRecord: nil,
}, nil
}
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
// 检查权限
var tx = this.NullTx()
if userId > 0 {
err = nameservers.SharedNSDomainDAO.CheckUserDomain(tx, userId, req.NsDomainId)
if err != nil {
return nil, err
}
}
record, err := nameservers.SharedNSRecordDAO.FindEnabledRecordWithName(tx, req.NsDomainId, req.Name, req.Type)
if err != nil {
return nil, err
}
if record == nil {
return &pb.FindNSRecordWithNameAndTypeResponse{
NsRecord: nil,
}, nil
}
// 线路
var pbRoutes = []*pb.NSRoute{}
for _, routeCode := range record.DecodeRouteIds() {
route, err := nameservers.SharedNSRouteDAO.FindEnabledRouteWithCode(tx, routeCode)
if err != nil {
return nil, err
}
if route == nil {
continue
}
pbRoutes = append(pbRoutes, &pb.NSRoute{
Id: int64(route.Id),
Name: route.Name,
Code: route.Code,
})
}
return &pb.FindNSRecordWithNameAndTypeResponse{
NsRecord: &pb.NSRecord{
Id: int64(record.Id),
Description: record.Description,
Name: record.Name,
Type: record.Type,
Value: record.Value,
MxPriority: types.Int32(record.MxPriority),
SrvPriority: types.Int32(record.SrvPriority),
SrvWeight: types.Int32(record.SrvWeight),
SrvPort: types.Int32(record.SrvPort),
CaaFlag: types.Int32(record.CaaFlag),
CaaTag: record.CaaTag,
Ttl: types.Int32(record.Ttl),
Weight: types.Int32(record.Weight),
CreatedAt: int64(record.CreatedAt),
IsOn: record.IsOn,
NsRoutes: pbRoutes,
HealthCheckJSON: record.HealthCheck,
IsUp: record.IsUp,
},
}, nil
}
// FindNSRecordsWithNameAndType 使用名称和类型查询多个记录信息
func (this *NSRecordService) FindNSRecordsWithNameAndType(ctx context.Context, req *pb.FindNSRecordsWithNameAndTypeRequest) (*pb.FindNSRecordsWithNameAndTypeResponse, error) {
// 检查是否为商业用户
if !teaconst.IsPlus {
return nil, errors.New("non commercial user")
}
if req.NsDomainId <= 0 {
return &pb.FindNSRecordsWithNameAndTypeResponse{
NsRecords: nil,
}, nil
}
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
// 检查权限
var tx = this.NullTx()
if userId > 0 {
err = nameservers.SharedNSDomainDAO.CheckUserDomain(tx, userId, req.NsDomainId)
if err != nil {
return nil, err
}
}
records, err := nameservers.SharedNSRecordDAO.FindEnabledRecordsWithName(tx, req.NsDomainId, req.Name, req.Type)
if err != nil {
return nil, err
}
var pbRecords = []*pb.NSRecord{}
for _, record := range records {
// 线路
var pbRoutes = []*pb.NSRoute{}
for _, routeCode := range record.DecodeRouteIds() {
route, err := nameservers.SharedNSRouteDAO.FindEnabledRouteWithCode(tx, routeCode)
if err != nil {
return nil, err
}
if route == nil {
continue
}
pbRoutes = append(pbRoutes, &pb.NSRoute{
Id: int64(route.Id),
Name: route.Name,
Code: route.Code,
})
}
pbRecords = append(pbRecords, &pb.NSRecord{
Id: int64(record.Id),
Description: record.Description,
Name: record.Name,
Type: record.Type,
Value: record.Value,
MxPriority: int32(record.MxPriority),
SrvPriority: types.Int32(record.SrvPriority),
SrvWeight: types.Int32(record.SrvWeight),
SrvPort: types.Int32(record.SrvPort),
CaaFlag: types.Int32(record.CaaFlag),
CaaTag: record.CaaTag,
Ttl: types.Int32(record.Ttl),
Weight: types.Int32(record.Weight),
CreatedAt: int64(record.CreatedAt),
IsOn: record.IsOn,
NsRoutes: pbRoutes,
HealthCheckJSON: record.HealthCheck,
IsUp: record.IsUp,
})
}
return &pb.FindNSRecordsWithNameAndTypeResponse{
NsRecords: pbRecords,
}, nil
}
// ListNSRecordsAfterVersion 根据版本列出一组记录
func (this *NSRecordService) ListNSRecordsAfterVersion(ctx context.Context, req *pb.ListNSRecordsAfterVersionRequest) (*pb.ListNSRecordsAfterVersionResponse, error) {
// 检查是否为商业用户
if !teaconst.IsPlus {
return nil, errors.New("non commercial user")
}
_, _, err := this.ValidateNodeId(ctx, rpcutils.UserTypeDNS)
if err != nil {
return nil, err
}
// 检查是否为商业用户
if !teaconst.IsPlus {
return &pb.ListNSRecordsAfterVersionResponse{
NsRecords: nil,
}, nil
}
// 集群ID
var tx = this.NullTx()
if req.Size <= 0 {
req.Size = 2000
}
records, err := nameservers.SharedNSRecordDAO.ListRecordsAfterVersion(tx, req.Version, req.Size)
if err != nil {
return nil, err
}
var pbRecords []*pb.NSRecord
for _, record := range records {
// 线路
pbRoutes := []*pb.NSRoute{}
routeIds := record.DecodeRouteIds()
for _, routeId := range routeIds {
var routeIdInt int64 = 0
if regexp.MustCompile(`^id:\d+$`).MatchString(routeId) {
routeIdInt = types.Int64(routeId[strings.Index(routeId, ":")+1:])
}
pbRoutes = append(pbRoutes, &pb.NSRoute{
Id: routeIdInt,
Code: routeId,
})
}
// TODO 读取其他线路
pbRecords = append(pbRecords, &pb.NSRecord{
Id: int64(record.Id),
Description: "",
Name: record.Name,
Type: record.Type,
Value: record.Value,
MxPriority: int32(record.MxPriority),
SrvPriority: types.Int32(record.SrvPriority),
SrvWeight: types.Int32(record.SrvWeight),
SrvPort: types.Int32(record.SrvPort),
CaaFlag: types.Int32(record.CaaFlag),
CaaTag: record.CaaTag,
Ttl: types.Int32(record.Ttl),
Weight: types.Int32(record.Weight),
IsDeleted: record.State == nameservers.NSRecordStateDisabled,
IsOn: record.IsOn && record.IsUp,
Version: int64(record.Version),
NsDomain: &pb.NSDomain{Id: int64(record.DomainId)},
NsRoutes: pbRoutes,
HealthCheckJSON: record.HealthCheck,
IsUp: record.IsUp,
})
}
return &pb.ListNSRecordsAfterVersionResponse{NsRecords: pbRecords}, nil
}
// FindNSRecordHealthCheck 查询记录健康检查设置
func (this *NSRecordService) FindNSRecordHealthCheck(ctx context.Context, req *pb.FindNSRecordHealthCheckRequest) (*pb.FindNSRecordHealthCheckResponse, error) {
// 检查是否为商业用户
if !teaconst.IsPlus {
return nil, errors.New("non commercial user")
}
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
// 检查权限
var tx = this.NullTx()
if userId > 0 {
err = nameservers.SharedNSRecordDAO.CheckUserRecord(tx, userId, req.NsRecordId)
if err != nil {
return nil, err
}
// TODO 检查套餐
}
config, err := nameservers.SharedNSRecordDAO.FindRecordHealthCheckConfig(tx, req.NsRecordId)
if err != nil {
return nil, err
}
configJSON, err := json.Marshal(config)
if err != nil {
return nil, err
}
return &pb.FindNSRecordHealthCheckResponse{NsRecordHealthCheckJSON: configJSON}, nil
}
// UpdateNSRecordHealthCheck 修改记录健康检查设置
func (this *NSRecordService) UpdateNSRecordHealthCheck(ctx context.Context, req *pb.UpdateNSRecordHealthCheckRequest) (*pb.RPCSuccess, error) {
// 检查是否为商业用户
if !teaconst.IsPlus {
return nil, errors.New("non commercial user")
}
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
// 检查权限
var tx = this.NullTx()
if userId > 0 {
err = nameservers.SharedNSRecordDAO.CheckUserRecord(tx, userId, req.NsRecordId)
if err != nil {
return nil, err
}
// TODO 检查套餐
}
if len(req.NsRecordHealthCheckJSON) == 0 {
return nil, errors.New("invalid 'nsRecordHealthCheckJSON'")
}
var healthCheckConfig = dnsconfigs.NewNSRecordHealthCheckConfig()
err = json.Unmarshal(req.NsRecordHealthCheckJSON, healthCheckConfig)
if err != nil {
return nil, errors.New("decode 'nsRecordHealthCheckJSON' failed: " + err.Error())
}
err = nameservers.SharedNSRecordDAO.UpdateRecordHealthCheckConfig(tx, req.NsRecordId, healthCheckConfig)
if err != nil {
return nil, err
}
return this.Success()
}
// UpdateNSRecordIsUp 手动修改记录在线状态
func (this *NSRecordService) UpdateNSRecordIsUp(ctx context.Context, req *pb.UpdateNSRecordIsUpRequest) (*pb.RPCSuccess, error) {
// 检查是否为商业用户
if !teaconst.IsPlus {
return nil, errors.New("non commercial user")
}
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
// 检查权限
var tx = this.NullTx()
if userId > 0 {
err = nameservers.SharedNSRecordDAO.CheckUserRecord(tx, userId, req.NsRecordId)
if err != nil {
return nil, err
}
// TODO 检查套餐
}
err = nameservers.SharedNSRecordDAO.UpdateRecordIsUp(tx, req.NsRecordId, req.IsUp, 0, 0)
if err != nil {
return nil, err
}
return this.Success()
}

View File

@@ -0,0 +1,170 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus
package nameservers
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/nameservers"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
timeutil "github.com/iwind/TeaGo/utils/time"
"regexp"
"time"
)
// NSRecordHourlyStatService NS记录小时统计
type NSRecordHourlyStatService struct {
services.BaseService
}
// UploadNSRecordHourlyStats 上传统计
func (this *NSRecordHourlyStatService) UploadNSRecordHourlyStats(ctx context.Context, req *pb.UploadNSRecordHourlyStatsRequest) (*pb.RPCSuccess, error) {
_, nodeId, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeDNS)
if err != nil {
return nil, err
}
if nodeId <= 0 {
return nil, errors.New("invalid nodeId")
}
if len(req.Stats) == 0 {
return this.Success()
}
var tx = this.NullTx()
clusterId, err := models.SharedNSNodeDAO.FindNodeClusterId(tx, nodeId)
if err != nil {
return nil, err
}
// 增加小时统计
for _, stat := range req.Stats {
err := nameservers.SharedNSRecordHourlyStatDAO.IncreaseHourlyStat(tx, clusterId, nodeId, timeutil.FormatTime("YmdH", stat.CreatedAt), stat.NsDomainId, stat.NsRecordId, stat.CountRequests, stat.Bytes)
if err != nil {
return nil, err
}
}
return this.Success()
}
// FindNSRecordHourlyStat 获取单个记录单个小时的统计
func (this *NSRecordHourlyStatService) FindNSRecordHourlyStat(ctx context.Context, req *pb.FindNSRecordHourlyStatRequest) (*pb.FindNSRecordHourlyStatResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if len(req.Hour) == 0 {
req.Hour = timeutil.Format("YmdH")
} else if !regexp.MustCompile(`^\d{10}$`).MatchString(req.Hour) {
return nil, errors.New("invalid hour '" + req.Hour + "'")
}
var tx = this.NullTx()
if userId > 0 {
err = nameservers.SharedNSRecordDAO.CheckUserRecord(tx, userId, req.NsRecordId)
if err != nil {
return nil, err
}
}
stat, err := nameservers.SharedNSRecordHourlyStatDAO.FindHourlyStatWithRecordId(tx, req.NsRecordId, req.Hour)
if err != nil {
return nil, err
}
if stat == nil {
return &pb.FindNSRecordHourlyStatResponse{NsRecordHourlyStat: nil}, nil
}
return &pb.FindNSRecordHourlyStatResponse{
NsRecordHourlyStat: &pb.NSRecordHourlyStat{
NsRecordId: req.NsRecordId,
Bytes: int64(stat.Bytes),
CountRequests: int64(stat.CountRequests),
Hour: req.Hour,
},
}, nil
}
// FindLatestNSRecordsHourlyStats 获取单个记录24小时内的统计
func (this *NSRecordHourlyStatService) FindLatestNSRecordsHourlyStats(ctx context.Context, req *pb.FindLatestNSRecordsHourlyStatsRequest) (*pb.FindLatestNSRecordsHourlyStatsResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
err = nameservers.SharedNSRecordDAO.CheckUserRecord(tx, userId, req.NsRecordId)
if err != nil {
return nil, err
}
}
stats, err := nameservers.SharedNSRecordHourlyStatDAO.FindHourlyStatsWithRecordId(tx, req.NsRecordId, timeutil.Format("YmdH", time.Now().Add(-23*time.Hour)), timeutil.Format("YmdH"))
if err != nil {
return nil, err
}
var pbStats = []*pb.NSRecordHourlyStat{}
for _, stat := range stats {
pbStats = append(pbStats, &pb.NSRecordHourlyStat{
NsRecordId: req.NsRecordId,
Bytes: int64(stat.Bytes),
CountRequests: int64(stat.CountRequests),
Hour: stat.Hour,
})
}
return &pb.FindLatestNSRecordsHourlyStatsResponse{
NsRecordHourlyStats: pbStats,
}, nil
}
// FindNSRecordHourlyStatWithRecordIds 批量获取一组记录的统计
func (this *NSRecordHourlyStatService) FindNSRecordHourlyStatWithRecordIds(ctx context.Context, req *pb.FindNSRecordHourlyStatWithRecordIdsRequest) (*pb.FindNSRecordHourlyStatWithRecordIdsResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if len(req.Hour) == 0 {
req.Hour = timeutil.Format("YmdH")
} else if !regexp.MustCompile(`^\d{10}$`).MatchString(req.Hour) {
return nil, errors.New("invalid hour '" + req.Hour + "'")
}
var tx = this.NullTx()
if userId > 0 {
for _, recordId := range req.NsRecordIds {
err = nameservers.SharedNSRecordDAO.CheckUserRecord(tx, userId, recordId)
if err != nil {
return nil, err
}
}
}
var pbStats = []*pb.NSRecordHourlyStat{}
for _, recordId := range req.NsRecordIds {
stat, err := nameservers.SharedNSRecordHourlyStatDAO.FindHourlyStatWithRecordId(tx, recordId, req.Hour)
if err != nil {
return nil, err
}
if stat == nil {
continue
}
pbStats = append(pbStats, &pb.NSRecordHourlyStat{
NsRecordId: recordId,
Bytes: int64(stat.Bytes),
CountRequests: int64(stat.CountRequests),
Hour: stat.Hour,
})
}
return &pb.FindNSRecordHourlyStatWithRecordIdsResponse{
NsRecordHourlyStats: pbStats,
}, nil
}

View File

@@ -0,0 +1,614 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus
package nameservers
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/clients"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/nameservers"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/types"
"sort"
)
// NSRouteService 线路相关服务
type NSRouteService struct {
services.BaseService
}
// CreateNSRoute 创建自定义线路
func (this *NSRouteService) CreateNSRoute(ctx context.Context, req *pb.CreateNSRouteRequest) (*pb.CreateNSRouteResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
req.UserId = userId
// 暂时不允许在集群和域名下创建线路
req.NsClusterId = 0
req.NsDomainId = 0
}
// TODO 检查线路数限制
// 检查分类是否存在
if req.NsRouteCategoryId > 0 {
if userId > 0 {
err = nameservers.SharedNSRouteCategoryDAO.CheckUserCategory(tx, userId, req.NsRouteCategoryId)
if err != nil {
return nil, err
}
} else {
exists, err := nameservers.SharedNSRouteCategoryDAO.Exist(tx, req.NsRouteCategoryId)
if err != nil {
return nil, err
}
if !exists {
return nil, errors.New("route category id '" + types.String(req.NsRouteCategoryId) + "' not found")
}
}
} else {
req.NsRouteCategoryId = 0
}
routeId, err := nameservers.SharedNSRouteDAO.CreateRoute(tx, req.NsClusterId, req.NsDomainId, req.UserId, req.Name, req.RangesJSON, req.NsRouteCategoryId, req.Priority, req.IsPublic)
if err != nil {
return nil, err
}
return &pb.CreateNSRouteResponse{NsRouteId: routeId}, nil
}
// UpdateNSRoute 修改自定义线路
func (this *NSRouteService) UpdateNSRoute(ctx context.Context, req *pb.UpdateNSRouteRequest) (*pb.RPCSuccess, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查权限
if userId > 0 {
err = nameservers.SharedNSRouteDAO.CheckUserRoute(tx, userId, req.NsRouteId)
if err != nil {
return nil, err
}
}
// 检查分类是否存在
if req.NsRouteCategoryId > 0 {
if userId > 0 {
err = nameservers.SharedNSRouteCategoryDAO.CheckUserCategory(tx, userId, req.NsRouteCategoryId)
if err != nil {
return nil, err
}
} else {
exists, err := nameservers.SharedNSRouteCategoryDAO.Exist(tx, req.NsRouteCategoryId)
if err != nil {
return nil, err
}
if !exists {
return nil, errors.New("route category id '" + types.String(req.NsRouteCategoryId) + "' not found")
}
}
} else {
req.NsRouteCategoryId = 0
}
err = nameservers.SharedNSRouteDAO.UpdateRoute(tx, req.NsRouteId, req.Name, req.RangesJSON, req.NsRouteCategoryId, req.Priority, req.IsPublic, req.IsOn)
if err != nil {
return nil, err
}
return this.Success()
}
// DeleteNSRoute 删除自定义线路
func (this *NSRouteService) DeleteNSRoute(ctx context.Context, req *pb.DeleteNSRouteRequest) (*pb.RPCSuccess, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查权限
if userId > 0 {
err = nameservers.SharedNSRouteDAO.CheckUserRoute(tx, userId, req.NsRouteId)
if err != nil {
return nil, err
}
}
err = nameservers.SharedNSRouteDAO.DisableNSRoute(tx, req.NsRouteId)
if err != nil {
return nil, err
}
return this.Success()
}
// FindNSRoute 获取单个自定义路线信息
func (this *NSRouteService) FindNSRoute(ctx context.Context, req *pb.FindNSRouteRequest) (*pb.FindNSRouteResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查权限
if userId > 0 {
err = nameservers.SharedNSRouteDAO.CheckUserRoute(tx, userId, req.NsRouteId)
if err != nil {
return nil, err
}
}
route, err := nameservers.SharedNSRouteDAO.FindEnabledNSRoute(tx, req.NsRouteId)
if err != nil {
return nil, err
}
if route == nil {
return &pb.FindNSRouteResponse{NsRoute: nil}, nil
}
// 集群
var pbCluster *pb.NSCluster
if route.ClusterId > 0 {
cluster, err := models.SharedNSClusterDAO.FindEnabledNSCluster(tx, int64(route.ClusterId))
if err != nil {
return nil, err
}
if cluster != nil {
pbCluster = &pb.NSCluster{
Id: int64(cluster.Id),
IsOn: cluster.IsOn,
Name: cluster.Name,
}
}
}
// 域名
var pbDomain *pb.NSDomain
if route.DomainId > 0 {
domain, err := nameservers.SharedNSDomainDAO.FindEnabledNSDomain(tx, int64(route.DomainId))
if err != nil {
return nil, err
}
if domain != nil {
pbDomain = &pb.NSDomain{
Id: int64(domain.Id),
Name: domain.Name,
IsOn: domain.IsOn,
}
}
}
// 分类
var pbCategory *pb.NSRouteCategory
if route.CategoryId > 0 {
category, err := nameservers.SharedNSRouteCategoryDAO.FindCategory(tx, int64(route.CategoryId))
if err != nil {
return nil, err
}
if category != nil {
pbCategory = &pb.NSRouteCategory{
Id: int64(category.Id),
Name: category.Name,
IsOn: category.IsOn,
}
}
}
return &pb.FindNSRouteResponse{NsRoute: &pb.NSRoute{
Id: int64(route.Id),
IsOn: route.IsOn,
Name: route.Name,
RangesJSON: route.Ranges,
IsPublic: route.IsPublic,
Priority: types.Int32(route.Priority),
NsCluster: pbCluster,
NsDomain: pbDomain,
NsRouteCategory: pbCategory,
}}, nil
}
// CountAllNSRoutes 查询自定义线路数量
func (this *NSRouteService) CountAllNSRoutes(ctx context.Context, req *pb.CountAllNSRoutesRequest) (*pb.RPCCountResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if userId > 0 {
req.UserId = userId
}
var tx = this.NullTx()
countRoutes, err := nameservers.SharedNSRouteDAO.CountAllEnabledRoutes(tx, req.NsClusterId, req.NsClusterId, req.UserId)
if err != nil {
return nil, err
}
return this.SuccessCount(countRoutes)
}
// FindAllNSRoutes 读取所有自定义线路
func (this *NSRouteService) FindAllNSRoutes(ctx context.Context, req *pb.FindAllNSRoutesRequest) (*pb.FindAllNSRoutesResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
req.UserId = userId
}
routes, err := nameservers.SharedNSRouteDAO.FindAllEnabledRoutes(tx, req.NsClusterId, req.NsDomainId, req.UserId, req.NsRouteCategoryId)
if err != nil {
return nil, err
}
var pbRoutes = []*pb.NSRoute{}
for _, route := range routes {
// 集群
var pbCluster *pb.NSCluster
if route.ClusterId > 0 {
cluster, err := models.SharedNSClusterDAO.FindEnabledNSCluster(tx, int64(route.ClusterId))
if err != nil {
return nil, err
}
if cluster != nil {
pbCluster = &pb.NSCluster{
Id: int64(cluster.Id),
IsOn: cluster.IsOn,
Name: cluster.Name,
}
}
}
// 域名
var pbDomain *pb.NSDomain
if route.DomainId > 0 {
domain, err := nameservers.SharedNSDomainDAO.FindEnabledNSDomain(tx, int64(route.DomainId))
if err != nil {
return nil, err
}
if domain != nil {
pbDomain = &pb.NSDomain{
Id: int64(domain.Id),
Name: domain.Name,
IsOn: domain.IsOn,
}
}
}
// 分类
var pbCategory *pb.NSRouteCategory
if route.CategoryId > 0 {
category, err := nameservers.SharedNSRouteCategoryDAO.FindCategory(tx, int64(route.CategoryId))
if err != nil {
return nil, err
}
if category != nil {
pbCategory = &pb.NSRouteCategory{
Id: int64(category.Id),
Name: category.Name,
IsOn: category.IsOn,
Order: types.Int32(category.Order),
}
}
}
pbRoutes = append(pbRoutes, &pb.NSRoute{
Id: int64(route.Id),
IsOn: route.IsOn,
Code: "id:" + types.String(route.Id),
Name: route.Name,
IsPublic: route.IsPublic,
RangesJSON: route.Ranges,
Order: types.Int32(route.Order),
Priority: types.Int32(route.Priority),
NsCluster: pbCluster,
NsDomain: pbDomain,
NsRouteCategory: pbCategory,
})
}
// 按照分类排序
if len(pbRoutes) > 0 {
sort.Slice(pbRoutes, func(i, j int) bool {
var route1 = pbRoutes[i]
var route2 = pbRoutes[j]
// route1.category = nil
if route1.NsRouteCategory == nil {
if route2.NsRouteCategory == nil {
if route1.Order == route2.Order {
return route1.Id < route2.Id
}
return route1.Order > route2.Order
}
return true
}
// route1.category != nil && route2.category = nil
if route2.NsRouteCategory == nil {
return false
}
// 同一个分类
if route1.NsRouteCategory.Id == route2.NsRouteCategory.Id {
if route1.Order == route2.Order {
return route1.Id < route2.Id
}
return route1.Order > route2.Order
}
if route1.NsRouteCategory.Order == route2.NsRouteCategory.Order {
return route1.NsRouteCategory.Id < route2.NsRouteCategory.Id
}
return route1.NsRouteCategory.Order > route2.NsRouteCategory.Order
})
}
return &pb.FindAllNSRoutesResponse{NsRoutes: pbRoutes}, nil
}
// FindAllPublicNSRoutes 读取所有公用的自定义线路
// 目前只允许读取系统管理员设置的公用自定义线路
func (this *NSRouteService) FindAllPublicNSRoutes(ctx context.Context, req *pb.FindAllPublicRoutesRequest) (*pb.FindAllPublicRoutesResponse, error) {
_, err := this.ValidateUserNode(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
routes, err := nameservers.SharedNSRouteDAO.FindAllPublicRoutes(tx)
if err != nil {
return nil, err
}
var pbRoutes = []*pb.NSRoute{}
for _, route := range routes {
// 分类
var pbCategory *pb.NSRouteCategory
if route.CategoryId > 0 {
category, err := nameservers.SharedNSRouteCategoryDAO.FindCategory(tx, int64(route.CategoryId))
if err != nil {
return nil, err
}
if category != nil {
// 如果分类未启用,则当前分类下面的线路也不显示
if !category.IsOn {
continue
}
pbCategory = &pb.NSRouteCategory{
Id: int64(category.Id),
Name: category.Name,
IsOn: category.IsOn,
Order: types.Int32(category.Order),
}
}
}
pbRoutes = append(pbRoutes, &pb.NSRoute{
Id: int64(route.Id),
IsOn: route.IsOn,
Code: "id:" + types.String(route.Id),
Name: route.Name,
IsPublic: route.IsPublic,
RangesJSON: route.Ranges,
Order: types.Int32(route.Order),
Priority: types.Int32(route.Priority),
NsCluster: nil,
NsDomain: nil,
NsRouteCategory: pbCategory,
})
}
// 按照分类排序
if len(pbRoutes) > 0 {
sort.Slice(pbRoutes, func(i, j int) bool {
var route1 = pbRoutes[i]
var route2 = pbRoutes[j]
// route1.category = nil
if route1.NsRouteCategory == nil {
if route2.NsRouteCategory == nil {
if route1.Order == route2.Order {
return route1.Id < route2.Id
}
return route1.Order > route2.Order
}
return true
}
// route1.category != nil && route2.category = nil
if route2.NsRouteCategory == nil {
return false
}
// 同一个分类
if route1.NsRouteCategory.Id == route2.NsRouteCategory.Id {
if route1.Order == route2.Order {
return route1.Id < route2.Id
}
return route1.Order > route2.Order
}
if route1.NsRouteCategory.Order == route2.NsRouteCategory.Order {
return route1.NsRouteCategory.Id < route2.NsRouteCategory.Id
}
return route1.NsRouteCategory.Order > route2.NsRouteCategory.Order
})
}
return &pb.FindAllPublicRoutesResponse{NsRoutes: pbRoutes}, nil
}
// UpdateNSRouteOrders 设置自定义线路排序
func (this *NSRouteService) UpdateNSRouteOrders(ctx context.Context, req *pb.UpdateNSRouteOrdersRequest) (*pb.RPCSuccess, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查权限
if userId > 0 {
for _, routeId := range req.NsRouteIds {
err = nameservers.SharedNSRouteDAO.CheckUserRoute(tx, userId, routeId)
if err != nil {
return nil, err
}
}
}
err = nameservers.SharedNSRouteDAO.UpdateRouteOrders(tx, req.NsRouteIds)
if err != nil {
return nil, err
}
return this.Success()
}
// ListNSRoutesAfterVersion 根据版本列出一组自定义线路
func (this *NSRouteService) ListNSRoutesAfterVersion(ctx context.Context, req *pb.ListNSRoutesAfterVersionRequest) (*pb.ListNSRoutesAfterVersionResponse, error) {
_, _, err := this.ValidateNodeId(ctx, rpcutils.UserTypeDNS)
if err != nil {
return nil, err
}
// 集群ID
var tx = this.NullTx()
routes, err := nameservers.SharedNSRouteDAO.ListRoutesAfterVersion(tx, req.Version, 2000)
if err != nil {
return nil, err
}
var pbRoutes []*pb.NSRoute
for _, route := range routes {
// 集群
var pbCluster *pb.NSCluster
if route.ClusterId > 0 {
pbCluster = &pb.NSCluster{Id: int64(route.ClusterId)}
}
// 域名
var pbDomain *pb.NSDomain
if route.DomainId > 0 {
pbDomain = &pb.NSDomain{Id: int64(route.DomainId)}
}
pbRoutes = append(pbRoutes, &pb.NSRoute{
Id: int64(route.Id),
IsOn: route.IsOn,
Name: "",
RangesJSON: route.Ranges,
IsDeleted: route.State == nameservers.NSRouteStateDisabled,
Order: types.Int32(route.Order),
Priority: types.Int32(route.Priority),
Version: int64(route.Version),
UserId: types.Int64(route.UserId),
NsCluster: pbCluster,
NsDomain: pbDomain,
})
}
return &pb.ListNSRoutesAfterVersionResponse{NsRoutes: pbRoutes}, nil
}
// FindAllDefaultWorldRegionRoutes 查找默认的世界区域线路
func (this *NSRouteService) FindAllDefaultWorldRegionRoutes(ctx context.Context, req *pb.FindAllDefaultWorldRegionRoutesRequest) (*pb.FindAllDefaultWorldRegionRoutesResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var pbRoutes = []*pb.NSRoute{}
for _, route := range dnsconfigs.AllDefaultWorldRegionRoutes {
pbRoutes = append(pbRoutes, &pb.NSRoute{
Code: route.Code,
Name: route.Name,
})
}
return &pb.FindAllDefaultWorldRegionRoutesResponse{
NsRoutes: pbRoutes,
}, nil
}
// FindAllDefaultChinaProvinceRoutes 查找默认的中国省份线路
func (this *NSRouteService) FindAllDefaultChinaProvinceRoutes(ctx context.Context, req *pb.FindAllDefaultChinaProvinceRoutesRequest) (*pb.FindAllDefaultChinaProvinceRoutesResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var pbRoutes = []*pb.NSRoute{}
for _, route := range dnsconfigs.AllDefaultChinaProvinceRoutes {
pbRoutes = append(pbRoutes, &pb.NSRoute{
Code: route.Code,
Name: route.Name,
})
}
return &pb.FindAllDefaultChinaProvinceRoutesResponse{
NsRoutes: pbRoutes,
}, nil
}
// FindAllDefaultISPRoutes 查找默认的ISP线路
func (this *NSRouteService) FindAllDefaultISPRoutes(ctx context.Context, req *pb.FindAllDefaultISPRoutesRequest) (*pb.FindAllDefaultISPRoutesResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var pbRoutes = []*pb.NSRoute{}
for _, route := range dnsconfigs.AllDefaultISPRoutes {
pbRoutes = append(pbRoutes, &pb.NSRoute{
Code: route.Code,
Name: route.Name,
})
}
return &pb.FindAllDefaultISPRoutesResponse{
NsRoutes: pbRoutes,
}, nil
}
// FindAllAgentNSRoutes 查找默认的搜索引擎线路
func (this *NSRouteService) FindAllAgentNSRoutes(ctx context.Context, req *pb.FindAllAgentNSRoutesRequest) (*pb.FindAllAgentNSRoutesResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
agents, err := clients.SharedClientAgentDAO.FindAllNSAgents(tx)
if err != nil {
return nil, err
}
var pbRoutes = []*pb.NSRoute{}
for _, agent := range agents {
pbRoutes = append(pbRoutes, &pb.NSRoute{
Code: agent.NSRouteCode(),
Name: agent.Name,
})
}
pbRoutes = append(pbRoutes, &pb.NSRoute{
Code: "agent",
Name: "搜索引擎",
})
return &pb.FindAllAgentNSRoutesResponse{NsRoutes: pbRoutes}, nil
}

View File

@@ -0,0 +1,171 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nameservers
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/nameservers"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
// NSRouteCategoryService 线路分类服务
type NSRouteCategoryService struct {
services.BaseService
}
// CreateNSRouteCategory 创建线路分类
func (this *NSRouteCategoryService) CreateNSRouteCategory(ctx context.Context, req *pb.CreateNSRouteCategoryRequest) (*pb.CreateNSRouteCategoryResponse, error) {
// TODO 需要防止用户恶意创建非常多的分类
adminId, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
categoryId, err := nameservers.SharedNSRouteCategoryDAO.CreateCategory(tx, adminId, userId, req.Name)
if err != nil {
return nil, err
}
return &pb.CreateNSRouteCategoryResponse{NsRouteCategoryId: categoryId}, nil
}
// UpdateNSRouteCategory 修改线路分类
func (this *NSRouteCategoryService) UpdateNSRouteCategory(ctx context.Context, req *pb.UpdateNSRouteCategoryRequest) (*pb.RPCSuccess, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
err = nameservers.SharedNSRouteCategoryDAO.CheckUserCategory(tx, userId, req.NsRouteCategoryId)
if err != nil {
return nil, err
}
}
err = nameservers.SharedNSRouteCategoryDAO.UpdateCategory(tx, req.NsRouteCategoryId, req.Name, req.IsOn)
if err != nil {
return nil, err
}
return this.Success()
}
// DeleteNSRouteCategory 删除线路分类
func (this *NSRouteCategoryService) DeleteNSRouteCategory(ctx context.Context, req *pb.DeleteNSRouteCategoryRequest) (*pb.RPCSuccess, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
err = nameservers.SharedNSRouteCategoryDAO.CheckUserCategory(tx, userId, req.NsRouteCategoryId)
if err != nil {
return nil, err
}
}
err = nameservers.SharedNSRouteCategoryDAO.DisableNSRouteCategory(tx, req.NsRouteCategoryId)
if err != nil {
return nil, err
}
// 重置线路
err = nameservers.SharedNSRouteDAO.ResetRoutesCategory(tx, req.NsRouteCategoryId)
if err != nil {
return nil, err
}
return this.Success()
}
// FindAllNSRouteCategories 列出所有线路分类
func (this *NSRouteCategoryService) FindAllNSRouteCategories(ctx context.Context, req *pb.FindAllNSRouteCategoriesRequest) (*pb.FindAllNSRouteCategoriesResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
categories, err := nameservers.SharedNSRouteCategoryDAO.FindAllCategories(tx, userId)
if err != nil {
return nil, err
}
var pbCategories = []*pb.NSRouteCategory{}
for _, category := range categories {
pbCategories = append(pbCategories, &pb.NSRouteCategory{
Id: int64(category.Id),
Name: category.Name,
IsOn: category.IsOn,
})
}
return &pb.FindAllNSRouteCategoriesResponse{
NsRouteCategories: pbCategories,
}, nil
}
// UpdateNSRouteCategoryOrders 对线路分类进行排序
func (this *NSRouteCategoryService) UpdateNSRouteCategoryOrders(ctx context.Context, req *pb.UpdateNSRouteCategoryOrders) (*pb.RPCSuccess, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
for _, categoryId := range req.NsRouteCategoryIds {
err = nameservers.SharedNSRouteCategoryDAO.CheckUserCategory(tx, userId, categoryId)
if err != nil {
return nil, err
}
}
}
err = nameservers.SharedNSRouteCategoryDAO.UpdateCategoryOrders(tx, userId, req.NsRouteCategoryIds)
if err != nil {
return nil, err
}
return this.Success()
}
// FindNSRouteCategory 查找单个线路分类
func (this *NSRouteCategoryService) FindNSRouteCategory(ctx context.Context, req *pb.FindNSRouteCategoryRequest) (*pb.FindNSRouteCategoryResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
err = nameservers.SharedNSRouteCategoryDAO.CheckUserCategory(tx, userId, req.NsRouteCategoryId)
if err != nil {
return nil, err
}
}
category, err := nameservers.SharedNSRouteCategoryDAO.FindCategory(tx, req.NsRouteCategoryId)
if err != nil {
return nil, err
}
if category == nil {
return &pb.FindNSRouteCategoryResponse{
NsRouteCategory: nil,
}, nil
}
return &pb.FindNSRouteCategoryResponse{
NsRouteCategory: &pb.NSRouteCategory{
Id: int64(category.Id),
Name: category.Name,
IsOn: category.IsOn,
},
}, nil
}

View File

@@ -0,0 +1,370 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package nameservers
import (
"context"
"errors"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/accounts"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/nameservers"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
"github.com/TeaOSLab/EdgeAPI/internal/utils/regexputils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/userconfigs"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
timeutil "github.com/iwind/TeaGo/utils/time"
"time"
)
// NSUserPlanService 用户DNS套餐服务
type NSUserPlanService struct {
services.BaseService
}
// CreateNSUserPlan 创建用户套餐
func (this *NSUserPlanService) CreateNSUserPlan(ctx context.Context, req *pb.CreateNSUserPlanRequest) (*pb.CreateNSUserPlanResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var dayFrom = req.DayFrom
var dayTo = req.DayTo
if !regexputils.YYYYMMDD.MatchString(dayFrom) {
return nil, errors.New("invalid dayFrom: " + dayFrom)
}
if !regexputils.YYYYMMDD.MatchString(dayTo) {
return nil, errors.New("invalid dayTo: " + dayTo)
}
if !lists.ContainsString([]string{nameservers.NSUserPlanPeriodUnitMonthly, nameservers.NSUserPlanPeriodUnitYearly}, req.PeriodUnit) {
return nil, errors.New("invalid periodUnit: " + req.PeriodUnit)
}
// 检查plan是否存在
var tx = this.NullTx()
existPlan, err := nameservers.SharedNSPlanDAO.ExistPlan(tx, req.NsPlanId)
if err != nil {
return nil, err
}
if !existPlan {
return nil, errors.New("plan '" + types.String(req.NsPlanId) + "' not found")
}
// 用户Plan是否存在
var resultUserPlanId int64
err = this.RunTx(func(tx *dbs.Tx) error {
userPlan, err := nameservers.SharedNSUserPlanDAO.FindUserPlan(tx, req.UserId)
if err != nil {
return err
}
if userPlan == nil {
userPlanId, err := nameservers.SharedNSUserPlanDAO.CreateUserPlan(tx, req.UserId, req.NsPlanId, dayFrom, dayTo, req.PeriodUnit)
if err != nil {
return err
}
resultUserPlanId = userPlanId
} else {
err = nameservers.SharedNSUserPlanDAO.UpdateUserPlan(tx, int64(userPlan.Id), req.NsPlanId, dayFrom, dayTo, req.PeriodUnit)
if err != nil {
return err
}
}
return nil
})
if err != nil {
return nil, err
}
return &pb.CreateNSUserPlanResponse{NsUserPlanId: resultUserPlanId}, nil
}
// UpdateNSUserPlan 修改用户套餐
func (this *NSUserPlanService) UpdateNSUserPlan(ctx context.Context, req *pb.UpdateNSUserPlanRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var dayFrom = req.DayFrom
var dayTo = req.DayTo
if !regexputils.YYYYMMDD.MatchString(dayFrom) {
return nil, errors.New("invalid dayFrom: " + dayFrom)
}
if !regexputils.YYYYMMDD.MatchString(dayTo) {
return nil, errors.New("invalid dayTo: " + dayTo)
}
if !lists.ContainsString([]string{nameservers.NSUserPlanPeriodUnitMonthly, nameservers.NSUserPlanPeriodUnitYearly}, req.PeriodUnit) {
return nil, errors.New("invalid periodUnit: " + req.PeriodUnit)
}
var tx = this.NullTx()
err = nameservers.SharedNSUserPlanDAO.UpdateUserPlan(tx, req.NsUserPlanId, req.NsPlanId, dayFrom, dayTo, req.PeriodUnit)
if err != nil {
return nil, err
}
return this.Success()
}
// DeleteNSUserPlan 删除用户套餐
func (this *NSUserPlanService) DeleteNSUserPlan(ctx context.Context, req *pb.DeleteNSUserPlanRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = nameservers.SharedNSUserPlanDAO.DisableNSUserPlan(tx, req.NsUserPlanId)
if err != nil {
return nil, err
}
return this.Success()
}
// FindNSUserPlan 读取用户套餐
func (this *NSUserPlanService) FindNSUserPlan(ctx context.Context, req *pb.FindNSUserPlanRequest) (*pb.FindNSUserPlanResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
req.UserId = userId
}
var userPlan *nameservers.NSUserPlan
if req.NsUserPlanId > 0 {
userPlan, err = nameservers.SharedNSUserPlanDAO.FindEnabledNSUserPlan(tx, req.NsUserPlanId)
} else if req.UserId > 0 {
userPlan, err = nameservers.SharedNSUserPlanDAO.FindUserPlan(tx, req.UserId)
} else {
return &pb.FindNSUserPlanResponse{NsUserPlan: nil}, nil
}
if err != nil {
return nil, err
}
if userPlan == nil {
return &pb.FindNSUserPlanResponse{NsUserPlan: nil}, nil
}
plan, err := nameservers.SharedNSPlanDAO.FindEnabledNSPlan(tx, int64(userPlan.PlanId))
if err != nil {
return nil, err
}
if plan == nil {
return &pb.FindNSUserPlanResponse{NsUserPlan: nil}, nil
}
// user
user, err := models.SharedUserDAO.FindEnabledBasicUser(tx, int64(userPlan.UserId))
if err != nil {
return nil, err
}
var pbUser *pb.User
if user != nil {
pbUser = &pb.User{
Id: int64(user.Id),
Username: user.Username,
Fullname: user.Fullname,
}
}
return &pb.FindNSUserPlanResponse{
NsUserPlan: &pb.NSUserPlan{
Id: int64(userPlan.Id),
NsPlanId: int64(userPlan.PlanId),
DayFrom: userPlan.DayFrom,
DayTo: userPlan.DayTo,
PeriodUnit: userPlan.PeriodUnit,
NsPlan: &pb.NSPlan{
Id: int64(plan.Id),
Name: plan.Name,
IsOn: plan.IsOn,
MonthlyPrice: float32(plan.MonthlyPrice),
YearlyPrice: float32(plan.YearlyPrice),
ConfigJSON: plan.Config,
},
User: pbUser,
},
}, nil
}
// CountNSUserPlans 计算用户套餐数量
func (this *NSUserPlanService) CountNSUserPlans(ctx context.Context, req *pb.CountNSUserPlansRequest) (*pb.RPCCountResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
count, err := nameservers.SharedNSUserPlanDAO.CountUserPlans(tx, req.UserId, req.NsPlanId, req.PeriodUnit, req.IsExpired, req.ExpireDays)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListNSUserPlans 列出单页套餐
func (this *NSUserPlanService) ListNSUserPlans(ctx context.Context, req *pb.ListNSUserPlansRequest) (*pb.ListNSUserPlansResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
userPlans, err := nameservers.SharedNSUserPlanDAO.ListUserPlans(tx, req.UserId, req.NsPlanId, req.PeriodUnit, req.IsExpired, req.ExpireDays, req.Offset, req.Size)
if err != nil {
return nil, err
}
var pbUserPlans = []*pb.NSUserPlan{}
for _, userPlan := range userPlans {
// plan
plan, err := nameservers.SharedNSPlanDAO.FindEnabledNSPlan(tx, int64(userPlan.PlanId))
if err != nil {
return nil, err
}
if plan == nil {
continue
}
// user
user, err := models.SharedUserDAO.FindEnabledBasicUser(tx, int64(userPlan.UserId))
if err != nil {
return nil, err
}
var pbUser *pb.User
if user != nil {
pbUser = &pb.User{
Id: int64(user.Id),
Username: user.Username,
Fullname: user.Fullname,
}
}
pbUserPlans = append(pbUserPlans, &pb.NSUserPlan{
Id: int64(userPlan.Id),
NsPlanId: int64(userPlan.PlanId),
UserId: int64(userPlan.UserId),
DayFrom: userPlan.DayFrom,
DayTo: userPlan.DayTo,
PeriodUnit: userPlan.PeriodUnit,
NsPlan: &pb.NSPlan{
Id: int64(plan.Id),
Name: plan.Name,
IsOn: plan.IsOn,
MonthlyPrice: float32(plan.MonthlyPrice),
YearlyPrice: float32(plan.YearlyPrice),
ConfigJSON: plan.Config,
},
User: pbUser,
})
}
return &pb.ListNSUserPlansResponse{
NsUserPlans: pbUserPlans,
}, nil
}
// BuyNSUserPlan 使用余额购买用户套餐
func (this *NSUserPlanService) BuyNSUserPlan(ctx context.Context, req *pb.BuyNSUserPlanRequest) (*pb.BuyNSUserPlanResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
if userId > 0 {
req.UserId = userId
}
userId = req.UserId
// 查询套餐
var tx *dbs.Tx
plan, err := nameservers.SharedNSPlanDAO.FindEnabledNSPlan(tx, req.PlanId)
if err != nil {
return nil, err
}
if plan == nil || !plan.IsOn {
return nil, errors.New("could not find plan with id '" + types.String(req.PlanId) + "'")
}
var dayFrom = timeutil.Format("Ymd")
var dayTo = ""
var price float64
switch req.Period {
case "yearly":
price = plan.YearlyPrice
dayTo = timeutil.Format("Ymd", time.Now().AddDate(1, 0, 0))
case "monthly":
price = plan.MonthlyPrice
dayTo = timeutil.Format("Ymd", time.Now().AddDate(0, 1, 0))
default:
return nil, errors.New("invalid period '" + req.Period + "'")
}
var userPlanId int64
err = this.RunTx(func(tx *dbs.Tx) error {
// 当前是否有套餐在有效期
userPlan, err := nameservers.SharedNSUserPlanDAO.FindUserPlan(tx, userId)
if err != nil {
return err
}
if userPlan != nil && userPlan.IsAvailable() {
return errors.New("there is an available user plan yet, you can not buy again")
}
// 如果是0价格则不允许购买
if price <= 0 {
return errors.New("invalid plan price")
}
// 先减少余额
account, err := accounts.SharedUserAccountDAO.FindUserAccountWithUserId(tx, userId)
if err != nil {
return err
}
if account == nil || account.Total < price {
return errors.New("no enough balance to buy the plan")
}
err = accounts.SharedUserAccountDAO.UpdateUserAccount(tx, int64(account.Id), -price, userconfigs.AccountEventTypeBuyNSPlan, "购买DNS套餐\""+plan.Name+"\"", maps.Map{
"nsPlanId": plan.Id,
})
if err != nil {
return err
}
// 创建套餐
userPlanId, err = nameservers.SharedNSUserPlanDAO.CreateUserPlan(tx, userId, req.PlanId, dayFrom, dayTo, req.Period)
if err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
return &pb.BuyNSUserPlanResponse{UserPlanId: userPlanId}, nil
}

View File

@@ -0,0 +1,229 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package posts
import (
"context"
"errors"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/posts"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
// PostService 文章服务
type PostService struct {
services.BaseService
}
// CreatePost 创建文章
func (this *PostService) CreatePost(ctx context.Context, req *pb.CreatePostRequest) (*pb.CreatePostResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if !posts.SharedPostDAO.IsValidType(req.Type) {
return nil, errors.New("invalid 'type' value: " + req.Type)
}
postId, err := posts.SharedPostDAO.CreatePost(tx, req.ProductCode, req.PostCategoryId, req.Subject, req.Type, req.Body, req.Url)
if err != nil {
return nil, err
}
return &pb.CreatePostResponse{PostId: postId}, nil
}
// UpdatePost 修改文章
func (this *PostService) UpdatePost(ctx context.Context, req *pb.UpdatePostRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if !posts.SharedPostDAO.IsValidType(req.Type) {
return nil, errors.New("invalid 'type' value: " + req.Type)
}
err = posts.SharedPostDAO.UpdatePost(tx, req.PostId, req.ProductCode, req.PostCategoryId, req.Subject, req.Type, req.Body, req.Url)
if err != nil {
return nil, err
}
return this.Success()
}
// DeletePost 删除文章
func (this *PostService) DeletePost(ctx context.Context, req *pb.DeletePostRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = posts.SharedPostDAO.DisablePost(tx, req.PostId)
if err != nil {
return nil, err
}
return this.Success()
}
// PublishPost 发布文章
func (this *PostService) PublishPost(ctx context.Context, req *pb.PublishPostRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = posts.SharedPostDAO.PublishPost(tx, req.PostId)
if err != nil {
return nil, err
}
return this.Success()
}
// CountPosts 计算文章数量
func (this *PostService) CountPosts(ctx context.Context, req *pb.CountPostsRequest) (*pb.RPCCountResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
count, err := posts.SharedPostDAO.CountPosts(tx, req.ProductCode, req.PostCategoryId, req.PublishedOnly)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListPosts 列出单页文章
func (this *PostService) ListPosts(ctx context.Context, req *pb.ListPostsRequest) (*pb.ListPostsResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
// check permission
if userId > 0 {
req.PublishedOnly = true
}
var tx = this.NullTx()
var categoryId int64
if len(req.PostCategoryCode) > 0 {
categoryId, err = posts.SharedPostCategoryDAO.FindCategoryIdWithCode(tx, req.PostCategoryCode)
if err != nil {
return nil, err
}
}
if req.PostCategoryId > 0 {
categoryId = req.PostCategoryId
}
var excludingCategoryId int64
if len(req.ExcludingPostCategoryCode) > 0 {
excludingCategoryId, err = posts.SharedPostCategoryDAO.FindCategoryIdWithCode(tx, req.ExcludingPostCategoryCode)
if err != nil {
return nil, err
}
}
postList, err := posts.SharedPostDAO.ListPosts(tx, req.ProductCode, categoryId, excludingCategoryId, req.PublishedOnly, req.ContainsBody, req.Offset, req.Size)
if err != nil {
return nil, err
}
var pbPosts []*pb.Post
for _, post := range postList {
category, err := posts.SharedPostCategoryDAO.FindEnabledPostCategory(tx, int64(post.CategoryId))
if err != nil {
return nil, err
}
var pbCategory *pb.PostCategory
if category != nil {
pbCategory = &pb.PostCategory{
Id: int64(category.Id),
Name: category.Name,
Code: category.Code,
IsOn: category.IsOn,
}
}
pbPosts = append(pbPosts, &pb.Post{
Id: int64(post.Id),
ProductCode: post.ProductCode,
PostCategoryId: int64(post.CategoryId),
Type: post.Type,
Subject: post.Subject,
Url: post.Url,
Body: post.Body,
CreatedAt: int64(post.CreatedAt),
IsPublished: post.IsPublished,
PublishedAt: int64(post.PublishedAt),
PostCategory: pbCategory,
})
}
return &pb.ListPostsResponse{
Posts: pbPosts,
}, nil
}
// FindPost 查询单篇文章
func (this *PostService) FindPost(ctx context.Context, req *pb.FindPostRequest) (*pb.FindPostResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
post, err := posts.SharedPostDAO.FindEnabledPost(tx, req.PostId)
if err != nil {
return nil, err
}
if post == nil {
return &pb.FindPostResponse{Post: nil}, nil
}
// check permission
if userId > 0 && !post.IsPublished {
return &pb.FindPostResponse{Post: nil}, nil
}
category, err := posts.SharedPostCategoryDAO.FindEnabledPostCategory(tx, int64(post.CategoryId))
if err != nil {
return nil, err
}
var pbCategory *pb.PostCategory
if category != nil {
pbCategory = &pb.PostCategory{
Id: int64(category.Id),
Name: category.Name,
Code: category.Code,
IsOn: category.IsOn,
}
}
return &pb.FindPostResponse{
Post: &pb.Post{
Id: int64(post.Id),
ProductCode: post.ProductCode,
PostCategoryId: int64(post.CategoryId),
Type: post.Type,
Subject: post.Subject,
Url: post.Url,
Body: post.Body,
CreatedAt: int64(post.CreatedAt),
IsPublished: post.IsPublished,
PublishedAt: int64(post.PublishedAt),
PostCategory: pbCategory,
},
}, nil
}

View File

@@ -0,0 +1,166 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package posts
import (
"context"
"errors"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/posts"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
// PostCategoryService 文章分类服务
type PostCategoryService struct {
services.BaseService
}
// CreatePostCategory 创建分类
func (this *PostCategoryService) CreatePostCategory(ctx context.Context, req *pb.CreatePostCategoryRequest) (*pb.CreatePostCategoryResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
if len(req.Name) == 0 {
return nil, errors.New("require 'name'")
}
var tx = this.NullTx()
categoryId, err := posts.SharedPostCategoryDAO.CreateCategory(tx, req.Name, req.Code)
if err != nil {
return nil, err
}
return &pb.CreatePostCategoryResponse{PostCategoryId: categoryId}, nil
}
// UpdatePostCategory 修改分类
func (this *PostCategoryService) UpdatePostCategory(ctx context.Context, req *pb.UpdatePostCategoryRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if req.PostCategoryId <= 0 {
return nil, errors.New("invalid 'postCategoryId'")
}
if len(req.Name) == 0 {
return nil, errors.New("require 'name'")
}
err = posts.SharedPostCategoryDAO.UpdateCategory(tx, req.PostCategoryId, req.Name, req.Code, req.IsOn)
if err != nil {
return nil, err
}
return this.Success()
}
// DeletePostCategory 删除分类
func (this *PostCategoryService) DeletePostCategory(ctx context.Context, req *pb.DeletePostCategoryRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = posts.SharedPostCategoryDAO.DisablePostCategory(tx, req.PostCategoryId)
if err != nil {
return nil, err
}
return this.Success()
}
// FindAllPostCategories 列出所有分类
func (this *PostCategoryService) FindAllPostCategories(ctx context.Context, req *pb.FindAllPostCategoriesRequest) (*pb.FindAllPostCategoriesResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
categories, err := posts.SharedPostCategoryDAO.FindAllCategories(tx)
if err != nil {
return nil, err
}
var pbCategories []*pb.PostCategory
for _, category := range categories {
pbCategories = append(pbCategories, &pb.PostCategory{
Id: int64(category.Id),
Name: category.Name,
Code: category.Code,
IsOn: category.IsOn,
})
}
return &pb.FindAllPostCategoriesResponse{PostCategories: pbCategories}, nil
}
// FindAllAvailablePostCategories 列出所有可用分类
func (this *PostCategoryService) FindAllAvailablePostCategories(ctx context.Context, req *pb.FindAllAvailablePostCategoriesRequest) (*pb.FindAllAvailablePostCategoriesResponse, error) {
_, err := this.ValidateUserNode(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
categories, err := posts.SharedPostCategoryDAO.FindAllAvailableCategories(tx)
if err != nil {
return nil, err
}
var pbCategories []*pb.PostCategory
for _, category := range categories {
pbCategories = append(pbCategories, &pb.PostCategory{
Id: int64(category.Id),
Name: category.Name,
Code: category.Code,
IsOn: category.IsOn,
})
}
return &pb.FindAllAvailablePostCategoriesResponse{PostCategories: pbCategories}, nil
}
// FindPostCategory 查询单个分类
func (this *PostCategoryService) FindPostCategory(ctx context.Context, req *pb.FindPostCategoryRequest) (*pb.FindPostCategoryResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
category, err := posts.SharedPostCategoryDAO.FindEnabledPostCategory(tx, req.PostCategoryId)
if err != nil {
return nil, err
}
if category == nil {
return &pb.FindPostCategoryResponse{
PostCategory: nil,
}, nil
}
return &pb.FindPostCategoryResponse{
PostCategory: &pb.PostCategory{
Id: int64(category.Id),
Name: category.Name,
Code: category.Code,
IsOn: category.IsOn,
},
}, nil
}
// SortPostCategories 对分类进行排序
func (this *PostCategoryService) SortPostCategories(ctx context.Context, req *pb.SortPostCategoriesRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = posts.SharedPostCategoryDAO.UpdateCategoryOrders(tx, req.PostCategoryIds)
if err != nil {
return nil, err
}
return this.Success()
}

View File

@@ -0,0 +1,123 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package reporters
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
// ReportNodeGroupService 监控节点分组
type ReportNodeGroupService struct {
services.BaseService
}
// CreateReportNodeGroup 创建分组
func (this *ReportNodeGroupService) CreateReportNodeGroup(ctx context.Context, req *pb.CreateReportNodeGroupRequest) (*pb.CreateReportNodeGroupResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
groupId, err := models.SharedReportNodeGroupDAO.CreateGroup(tx, req.Name)
if err != nil {
return nil, err
}
return &pb.CreateReportNodeGroupResponse{ReportNodeGroupId: groupId}, nil
}
// UpdateReportNodeGroup 修改分组
func (this *ReportNodeGroupService) UpdateReportNodeGroup(ctx context.Context, req *pb.UpdateReportNodeGroupRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedReportNodeGroupDAO.UpdateGroup(tx, req.ReportNodeGroupId, req.Name)
if err != nil {
return nil, err
}
return this.Success()
}
// DeleteReportNodeGroup 删除分组
func (this *ReportNodeGroupService) DeleteReportNodeGroup(ctx context.Context, req *pb.DeleteReportNodeGroupRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedReportNodeGroupDAO.DisableReportNodeGroup(tx, req.ReportNodeGroupId)
if err != nil {
return nil, err
}
return this.Success()
}
// FindAllEnabledReportNodeGroups 查找所有分组
func (this *ReportNodeGroupService) FindAllEnabledReportNodeGroups(ctx context.Context, req *pb.FindAllEnabledReportNodeGroupsRequest) (*pb.FindAllEnabledReportNodeGroupsResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
groups, err := models.SharedReportNodeGroupDAO.FindAllEnabledGroups(tx)
if err != nil {
return nil, err
}
var pbGroups = []*pb.ReportNodeGroup{}
for _, group := range groups {
pbGroups = append(pbGroups, &pb.ReportNodeGroup{
Id: int64(group.Id),
Name: group.Name,
IsOn: group.IsOn,
})
}
return &pb.FindAllEnabledReportNodeGroupsResponse{ReportNodeGroups: pbGroups}, nil
}
// FindEnabledReportNodeGroup 查找单个分组
func (this *ReportNodeGroupService) FindEnabledReportNodeGroup(ctx context.Context, req *pb.FindEnabledReportNodeGroupRequest) (*pb.FindEnabledReportNodeGroupResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
group, err := models.SharedReportNodeGroupDAO.FindEnabledReportNodeGroup(tx, req.ReportNodeGroupId)
if err != nil {
return nil, err
}
if group == nil {
return &pb.FindEnabledReportNodeGroupResponse{ReportNodeGroup: nil}, nil
}
return &pb.FindEnabledReportNodeGroupResponse{
ReportNodeGroup: &pb.ReportNodeGroup{
Id: int64(group.Id),
Name: group.Name,
IsOn: group.IsOn,
},
}, nil
}
// CountAllEnabledReportNodeGroups 计算所有分组数量
func (this *ReportNodeGroupService) CountAllEnabledReportNodeGroups(ctx context.Context, req *pb.CountAllEnabledReportNodeGroupsRequest) (*pb.RPCCountResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
count, err := models.SharedReportNodeGroupDAO.CountAllEnabledGroups(tx)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}

View File

@@ -0,0 +1,467 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus
// +build plus
package reporters
import (
"context"
"encoding/json"
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/iplibrary"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/reporterconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/systemconfigs"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/types"
"google.golang.org/grpc/peer"
"net"
"time"
)
// ReportNodeService 监控终端服务
type ReportNodeService struct {
services.BaseService
}
// CreateReportNode 添加终端
func (this *ReportNodeService) CreateReportNode(ctx context.Context, req *pb.CreateReportNodeRequest) (*pb.CreateReportNodeResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
reporterId, err := models.SharedReportNodeDAO.CreateReportNode(tx, req.Name, req.Location, req.Isp, req.AllowIPs, req.ReportNodeGroupIds)
if err != nil {
return nil, err
}
return &pb.CreateReportNodeResponse{ReportNodeId: reporterId}, nil
}
// DeleteReportNode 删除终端
func (this *ReportNodeService) DeleteReportNode(ctx context.Context, req *pb.DeleteReportNodeRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedReportNodeDAO.DisableReportNode(tx, req.ReportNodeId)
if err != nil {
return nil, err
}
return this.Success()
}
// UpdateReportNode 修改终端
func (this *ReportNodeService) UpdateReportNode(ctx context.Context, req *pb.UpdateReportNodeRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedReportNodeDAO.UpdateReportNode(tx, req.ReportNodeId, req.Name, req.Location, req.Isp, req.AllowIPs, req.ReportNodeGroupIds, req.IsOn)
if err != nil {
return nil, err
}
return this.Success()
}
// CountAllEnabledReportNodes 计算终端数量
func (this *ReportNodeService) CountAllEnabledReportNodes(ctx context.Context, req *pb.CountAllEnabledReportNodesRequest) (*pb.RPCCountResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
count, err := models.SharedReportNodeDAO.CountAllEnabledReportNodes(tx, req.ReportNodeGroupId, req.Keyword)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListEnabledReportNodes 列出单页终端
func (this *ReportNodeService) ListEnabledReportNodes(ctx context.Context, req *pb.ListEnabledReportNodesRequest) (*pb.ListEnabledReportNodesResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
ones, err := models.SharedReportNodeDAO.ListEnabledReportNodes(tx, req.ReportNodeGroupId, req.Keyword, req.Offset, req.Size)
if err != nil {
return nil, err
}
var pbNodes = []*pb.ReportNode{}
for _, one := range ones {
var pbGroups = []*pb.ReportNodeGroup{}
var groupIds = one.DecodeGroupIds()
for _, groupId := range groupIds {
group, err := models.SharedReportNodeGroupDAO.FindEnabledReportNodeGroup(tx, groupId)
if err != nil {
return nil, err
}
if group == nil {
continue
}
pbGroups = append(pbGroups, &pb.ReportNodeGroup{
Id: int64(group.Id),
Name: group.Name,
IsOn: group.IsOn,
})
}
pbNodes = append(pbNodes, &pb.ReportNode{
Id: int64(one.Id),
UniqueId: one.UniqueId,
Secret: one.Secret,
IsOn: one.IsOn,
Name: one.Name,
Location: one.Location,
Isp: one.Isp,
IsActive: one.IsActive,
StatusJSON: one.Status,
AllowIPs: one.DecodeAllowIPs(),
ReportNodeGroups: pbGroups,
})
}
return &pb.ListEnabledReportNodesResponse{
ReportNodes: pbNodes,
}, nil
}
// FindEnabledReportNode 查找单个终端
func (this *ReportNodeService) FindEnabledReportNode(ctx context.Context, req *pb.FindEnabledReportNodeRequest) (*pb.FindEnabledReportNodeResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
node, err := models.SharedReportNodeDAO.FindEnabledReportNode(tx, req.ReportNodeId)
if err != nil {
return nil, err
}
if node == nil {
return &pb.FindEnabledReportNodeResponse{ReportNode: nil}, nil
}
var pbGroups = []*pb.ReportNodeGroup{}
var groupIds = node.DecodeGroupIds()
for _, groupId := range groupIds {
group, err := models.SharedReportNodeGroupDAO.FindEnabledReportNodeGroup(tx, groupId)
if err != nil {
return nil, err
}
if group == nil {
continue
}
pbGroups = append(pbGroups, &pb.ReportNodeGroup{
Id: int64(group.Id),
Name: group.Name,
IsOn: group.IsOn,
})
}
return &pb.FindEnabledReportNodeResponse{ReportNode: &pb.ReportNode{
Id: int64(node.Id),
UniqueId: node.UniqueId,
Secret: node.Secret,
IsOn: node.IsOn,
Name: node.Name,
Location: node.Location,
Isp: node.Isp,
IsActive: node.IsActive,
StatusJSON: node.Status,
AllowIPs: node.DecodeAllowIPs(),
ReportNodeGroups: pbGroups,
}}, nil
}
// UpdateReportNodeStatus 更新节点状态
func (this *ReportNodeService) UpdateReportNodeStatus(ctx context.Context, req *pb.UpdateReportNodeStatusRequest) (*pb.RPCSuccess, error) {
_, nodeId, err := this.ValidateNodeId(ctx, rpcutils.UserTypeReport)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = validateClient(tx, nodeId, ctx)
if err != nil {
return nil, err
}
var status = &reporterconfigs.Status{}
err = json.Unmarshal(req.StatusJSON, status)
if err != nil {
return nil, err
}
status.UpdatedAt = time.Now().Unix()
p, ok := peer.FromContext(ctx)
if ok {
host, _, _ := net.SplitHostPort(p.Addr.String())
if len(host) > 0 {
status.IP = host
var result = iplibrary.LookupIP(host)
if result != nil && result.IsOk() {
status.Location = result.Summary()
status.ISP = result.ProviderName()
}
}
}
err = models.SharedReportNodeDAO.UpdateNodeStatus(tx, nodeId, status)
if err != nil {
return nil, err
}
return this.Success()
}
// FindCurrentReportNodeConfig 获取当前节点信息
func (this *ReportNodeService) FindCurrentReportNodeConfig(ctx context.Context, req *pb.FindCurrentReportNodeConfigRequest) (*pb.FindCurrentReportNodeConfigResponse, error) {
_, nodeId, err := this.ValidateNodeId(ctx, rpcutils.UserTypeReport)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = validateClient(tx, nodeId, ctx)
if err != nil {
return nil, err
}
config, err := models.SharedReportNodeDAO.ComposeConfig(tx, nodeId)
if err != nil {
return nil, err
}
configJSON, err := json.Marshal(config)
if err != nil {
return nil, err
}
return &pb.FindCurrentReportNodeConfigResponse{ReportNodeJSON: configJSON}, nil
}
// FindReportNodeTasks 读取任务
func (this *ReportNodeService) FindReportNodeTasks(ctx context.Context, req *pb.FindReportNodeTasksRequest) (*pb.FindReportNodeTasksResponse, error) {
_, nodeId, err := this.ValidateNodeId(ctx, rpcutils.UserTypeReport)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = validateClient(tx, nodeId, ctx)
if err != nil {
return nil, err
}
var result = &pb.FindReportNodeTasksResponse{}
var ipTasks = []*reporterconfigs.IPTask{}
// 所有的集群
// TODO 将来支持NS节点
clusters, err := models.SharedNodeClusterDAO.FindAllEnableClusters(tx)
if err != nil {
return nil, err
}
for _, cluster := range clusters {
if !cluster.IsOn {
continue
}
var clusterId = int64(cluster.Id)
port, err := models.SharedServerDAO.FindFirstHTTPOrHTTPSPortWithClusterId(tx, clusterId)
if err != nil {
return nil, err
}
if port <= 0 {
continue
}
// 读取所有IP地址
addrList, err := models.SharedNodeIPAddressDAO.FindAllAccessibleIPAddressesWithClusterId(tx, nodeconfigs.NodeRoleNode, clusterId, nil)
if err != nil {
return nil, err
}
for _, addr := range addrList {
if !addr.IsOn {
continue
}
var addrIP = addr.Ip
var backupIP = addr.DecodeBackupIP()
if len(backupIP) > 0 {
addrIP = backupIP
}
ipTasks = append(ipTasks, &reporterconfigs.IPTask{
AddrId: int64(addr.Id),
IP: addrIP,
Port: port,
})
}
}
ipTasksJSON, err := json.Marshal(ipTasks)
if err != nil {
return nil, err
}
result.IpAddrTasksJSON = ipTasksJSON
return result, nil
}
// FindLatestReportNodeVersion 取得最新的版本号
func (this *ReportNodeService) FindLatestReportNodeVersion(ctx context.Context, req *pb.FindLatestReportNodeVersionRequest) (*pb.FindLatestReportNodeVersionResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
return &pb.FindLatestReportNodeVersionResponse{
Version: teaconst.ReportNodeVersion,
}, nil
}
// CountAllReportNodeTasks 计算任务数量
func (this *ReportNodeService) CountAllReportNodeTasks(ctx context.Context, req *pb.CountAllReportNodeTasksRequest) (*pb.RPCCountResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var count int64
var tx *dbs.Tx
switch req.Type {
case reporterconfigs.TaskTypeIPAddr:
count, err = models.SharedNodeIPAddressDAO.CountAllAccessibleIPAddressesWithClusterId(tx, req.Role, req.NodeClusterId)
}
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListReportNodeTasks 列出单页任务
func (this *ReportNodeService) ListReportNodeTasks(ctx context.Context, req *pb.ListReportNodeTasksRequest) (*pb.ListReportNodeTasksResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx *dbs.Tx
switch req.Type {
case reporterconfigs.TaskTypeIPAddr:
port, err := models.SharedServerDAO.FindFirstHTTPOrHTTPSPortWithClusterId(tx, req.NodeClusterId)
if err != nil {
return nil, err
}
addrs, err := models.SharedNodeIPAddressDAO.ListAccessibleIPAddressesWithClusterId(tx, req.Role, req.NodeClusterId, req.Offset, req.Size)
if err != nil {
return nil, err
}
var pbTasks = []*pb.IPAddrReportTask{}
for _, addr := range addrs {
var addrIP = addr.Ip
var backupIP = addr.DecodeBackupIP()
if len(backupIP) > 0 {
addrIP = backupIP
}
// 地址
var pbAddr = &pb.NodeIPAddress{
Id: int64(addr.Id),
NodeId: int64(addr.NodeId),
Name: addr.Name,
Ip: addrIP,
Description: addr.Description,
CanAccess: addr.CanAccess,
IsOn: addr.IsOn,
IsUp: addr.IsUp,
Role: addr.Role,
}
var connectivity = addr.DecodeConnectivity()
pbTasks = append(pbTasks, &pb.IPAddrReportTask{
Ip: addr.Ip,
Port: types.Int32(port),
NodeIPAddress: pbAddr,
CostMs: float32(connectivity.CostMs),
Level: connectivity.Level,
Connectivity: float32(connectivity.Percent),
})
}
return &pb.ListReportNodeTasksResponse{
IpAddrReportTasks: pbTasks,
}, nil
}
return &pb.ListReportNodeTasksResponse{}, nil
}
// UpdateReportNodeGlobalSetting 修改全局设置
func (this *ReportNodeService) UpdateReportNodeGlobalSetting(ctx context.Context, req *pb.UpdateReportNodeGlobalSetting) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedSysSettingDAO.UpdateSetting(tx, systemconfigs.SettingCodeReportNodeGlobalSetting, req.SettingJSON)
if err != nil {
return nil, err
}
return this.Success()
}
// ReadReportNodeGlobalSetting 读取全局设置
func (this *ReportNodeService) ReadReportNodeGlobalSetting(ctx context.Context, req *pb.ReadReportNodeGlobalSettingRequest) (*pb.ReadReportNodeGlobalSettingResponse, error) {
_, _, err := this.ValidateNodeId(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeReport)
if err != nil {
return nil, err
}
var tx = this.NullTx()
valueJSON, err := models.SharedSysSettingDAO.ReadSetting(tx, systemconfigs.SettingCodeReportNodeGlobalSetting)
if err != nil {
return nil, err
}
var setting = reporterconfigs.DefaultGlobalSetting()
if len(valueJSON) > 0 {
err = json.Unmarshal(valueJSON, setting)
if err != nil {
return nil, err
}
}
// 重新编码
valueJSON, err = json.Marshal(setting)
if err != nil {
return nil, err
}
return &pb.ReadReportNodeGlobalSettingResponse{
SettingJSON: valueJSON,
}, nil
}

View File

@@ -0,0 +1,335 @@
//go:build plus
// +build plus
package reporters
import (
"context"
"encoding/json"
"fmt"
"github.com/TeaOSLab/EdgeAPI/internal/configs"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeAPI/internal/goman"
"github.com/TeaOSLab/EdgeAPI/internal/remotelogs"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/reporterconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/logs"
"strconv"
"sync"
"sync/atomic"
"time"
)
// CommandRequest 命令请求相关
type CommandRequest struct {
Id int64
Code string
CommandJSON []byte
}
type CommandRequestWaiting struct {
Timestamp int64
Chan chan *pb.ReportNodeStreamMessage
}
func (this *CommandRequestWaiting) Close() {
defer func() {
_ = recover()
}()
close(this.Chan)
}
var responseChanMap = map[int64]*CommandRequestWaiting{} // request id => response
var commandRequestId = int64(0)
var nodeLocker = &sync.Mutex{}
var requestChanMap = map[int64]chan *CommandRequest{} // node id => chan
func NextCommandRequestId() int64 {
return atomic.AddInt64(&commandRequestId, 1)
}
func init() {
dbs.OnReadyDone(func() {
// 清理WaitingChannelMap
goman.New(func() {
ticker := time.NewTicker(30 * time.Second)
for range ticker.C {
nodeLocker.Lock()
for requestId, request := range responseChanMap {
if time.Now().Unix()-request.Timestamp > 3600 {
responseChanMap[requestId].Close()
delete(responseChanMap, requestId)
}
}
nodeLocker.Unlock()
}
})
// 自动同步连接到本API节点的Report节点任务
goman.New(func() {
defer func() {
_ = recover()
}()
// TODO 未来支持同步边缘节点
var ticker = time.NewTicker(3 * time.Second)
for range ticker.C {
nodeIds, err := models.SharedNodeTaskDAO.FindAllDoingNodeIds(nil, nodeconfigs.NodeRoleReport)
if err != nil {
remotelogs.Error("ReportNodeService_SYNC", err.Error())
continue
}
nodeLocker.Lock()
for _, nodeId := range nodeIds {
c, ok := requestChanMap[nodeId]
if ok {
select {
case c <- &CommandRequest{
Id: NextCommandRequestId(),
Code: reporterconfigs.MessageCodeNewNodeTask,
CommandJSON: nil,
}:
default:
}
}
}
nodeLocker.Unlock()
}
})
})
}
// ReportNodeStream 节点stream
func (this *ReportNodeService) ReportNodeStream(server pb.ReportNodeService_ReportNodeStreamServer) error {
// TODO 使用此stream快速通知Reporter节点更新
// 校验节点
_, nodeId, err := this.ValidateNodeId(server.Context(), rpcutils.UserTypeReport)
if err != nil {
return err
}
var tx = this.NullTx()
err = validateClient(tx, nodeId, server.Context())
if err != nil {
return err
}
// 返回连接成功
{
apiConfig, err := configs.SharedAPIConfig()
if err != nil {
return err
}
connectedMessage := &reporterconfigs.ConnectedAPINodeMessage{APINodeId: apiConfig.NumberId()}
connectedMessageJSON, err := json.Marshal(connectedMessage)
if err != nil {
return errors.Wrap(err)
}
err = server.Send(&pb.ReportNodeStreamMessage{
Code: reporterconfigs.MessageCodeConnectedAPINode,
DataJSON: connectedMessageJSON,
})
if err != nil {
return err
}
}
//logs.Println("[RPC]accepted ns node '" + types.String(nodeId) + "' connection")
// 标记为活跃状态
oldIsActive, err := models.SharedReportNodeDAO.FindNodeActive(tx, nodeId)
if err != nil {
return err
}
if !oldIsActive {
err = models.SharedReportNodeDAO.UpdateNodeActive(tx, nodeId, true)
if err != nil {
return err
}
// 发送恢复消息
nodeName, err := models.SharedReportNodeDAO.FindReportNodeName(tx, nodeId)
if err != nil {
return err
}
subject := "区域监控节点\"" + nodeName + "\"已经恢复在线"
msg := "区域监控节点\"" + nodeName + "\"已经恢复在线"
err = models.SharedMessageDAO.CreateNodeMessage(tx, nodeconfigs.NodeRoleReport, 0, nodeId, models.MessageTypeReportNodeActive, models.MessageLevelSuccess, subject, msg, nil, false)
if err != nil {
return err
}
}
nodeLocker.Lock()
requestChan, ok := requestChanMap[nodeId]
if !ok {
requestChan = make(chan *CommandRequest, 1024)
requestChanMap[nodeId] = requestChan
}
nodeLocker.Unlock()
defer func() {
nodeLocker.Lock()
delete(requestChanMap, nodeId)
nodeLocker.Unlock()
}()
// 发送请求
goman.New(func() {
for {
select {
case <-server.Context().Done():
return
case commandRequest := <-requestChan:
// logs.Println("[RPC]sending command '" + commandRequest.Code + "' to node '" + strconv.FormatInt(nodeId, 10) + "'")
retries := 3 // 错误重试次数
for i := 0; i < retries; i++ {
err := server.Send(&pb.ReportNodeStreamMessage{
RequestId: commandRequest.Id,
Code: commandRequest.Code,
DataJSON: commandRequest.CommandJSON,
})
if err != nil {
if i == retries-1 {
logs.Println("[RPC]send command '" + commandRequest.Code + "' failed: " + err.Error())
} else {
time.Sleep(1 * time.Second)
}
} else {
break
}
}
}
}
})
// 接受请求
for {
req, err := server.Recv()
if err != nil {
// 修改节点状态
err1 := models.SharedReportNodeDAO.UpdateNodeActive(tx, nodeId, false)
if err1 != nil {
logs.Println(err1.Error())
}
return err
}
func(req *pb.ReportNodeStreamMessage) {
// 因为 responseChan.Chan 有被关闭的风险所以我们使用recover防止panic
defer func() {
_ = recover()
}()
nodeLocker.Lock()
responseChan, ok := responseChanMap[req.RequestId]
if ok {
select {
case responseChan.Chan <- req:
default:
}
}
nodeLocker.Unlock()
}(req)
}
}
// SendCommandToReportNode 向节点发送命令
func (this *ReportNodeService) SendCommandToReportNode(ctx context.Context, req *pb.ReportNodeStreamMessage) (*pb.ReportNodeStreamMessage, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
nodeId := req.ReportNodeId
if nodeId <= 0 {
return nil, errors.New("node id should not be less than 0")
}
nodeLocker.Lock()
requestChan, ok := requestChanMap[nodeId]
nodeLocker.Unlock()
if !ok {
return &pb.ReportNodeStreamMessage{
RequestId: req.RequestId,
IsOk: false,
Message: "node '" + strconv.FormatInt(nodeId, 10) + "' not connected yet",
}, nil
}
req.RequestId = NextCommandRequestId()
select {
case requestChan <- &CommandRequest{
Id: req.RequestId,
Code: req.Code,
CommandJSON: req.DataJSON,
}:
// 加入到等待队列中
respChan := make(chan *pb.ReportNodeStreamMessage, 1)
waiting := &CommandRequestWaiting{
Timestamp: time.Now().Unix(),
Chan: respChan,
}
nodeLocker.Lock()
responseChanMap[req.RequestId] = waiting
nodeLocker.Unlock()
// 等待响应
timeoutSeconds := req.TimeoutSeconds
if timeoutSeconds <= 0 {
timeoutSeconds = 10
}
timeout := time.NewTimer(time.Duration(timeoutSeconds) * time.Second)
select {
case resp := <-respChan:
// 从队列中删除
nodeLocker.Lock()
delete(responseChanMap, req.RequestId)
waiting.Close()
nodeLocker.Unlock()
if resp == nil {
return &pb.ReportNodeStreamMessage{
RequestId: req.RequestId,
Code: req.Code,
Message: "response timeout",
IsOk: false,
}, nil
}
return resp, nil
case <-timeout.C:
// 从队列中删除
nodeLocker.Lock()
delete(responseChanMap, req.RequestId)
waiting.Close()
nodeLocker.Unlock()
return &pb.ReportNodeStreamMessage{
RequestId: req.RequestId,
Code: req.Code,
Message: "response timeout over " + fmt.Sprintf("%d", timeoutSeconds) + " seconds",
IsOk: false,
}, nil
}
default:
return &pb.ReportNodeStreamMessage{
RequestId: req.RequestId,
Code: req.Code,
Message: "command queue is full over " + strconv.Itoa(len(requestChan)),
IsOk: false,
}, nil
}
}

View File

@@ -0,0 +1,237 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package reporters
import (
"context"
"encoding/json"
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeAPI/internal/remotelogs"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/reporterconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/systemconfigs"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
"net/http"
"strings"
"time"
)
// ReportResultService 区域监控报告结果
type ReportResultService struct {
services.BaseService
}
// CountAllReportResults 计算监控结果数量
func (this *ReportResultService) CountAllReportResults(ctx context.Context, req *pb.CountAllReportResultsRequest) (*pb.RPCCountResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
count, err := models.SharedReportResultDAO.CountAllResults(tx, req.ReportNodeId, req.Level, types.Int8(req.OkState))
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListReportResults 列出单页监控结果
func (this *ReportResultService) ListReportResults(ctx context.Context, req *pb.ListReportResultsRequest) (*pb.ListReportResultsResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
results, err := models.SharedReportResultDAO.ListResults(tx, req.ReportNodeId, types.Int8(req.OkState), req.Level, req.Offset, req.Size)
if err != nil {
return nil, err
}
var pbResults = []*pb.ReportResult{}
for _, result := range results {
pbResults = append(pbResults, &pb.ReportResult{
Id: int64(result.Id),
Type: result.Type,
TargetId: int64(result.TargetId),
TargetDesc: result.TargetDesc,
ReportNodeId: int64(result.ReportNodeId),
IsOk: result.IsOk,
CostMs: float32(result.CostMs),
Error: result.Error,
UpdatedAt: int64(result.UpdatedAt),
Level: result.Level,
})
}
return &pb.ListReportResultsResponse{
ReportResults: pbResults,
}, nil
}
// UpdateReportResults 上传报告结果
func (this *ReportResultService) UpdateReportResults(ctx context.Context, req *pb.UpdateReportResultsRequest) (*pb.RPCSuccess, error) {
_, nodeId, err := this.ValidateNodeId(ctx, rpcutils.UserTypeReport)
if err != nil {
return nil, err
}
if !teaconst.IsPlus {
return nil, errors.New("the commercial version is expired.")
}
var tx = this.NullTx()
err = validateClient(tx, nodeId, ctx)
if err != nil {
return nil, err
}
// 设置
var setting = reporterconfigs.DefaultGlobalSetting()
settingJSON, err := models.SharedSysSettingDAO.ReadSetting(tx, systemconfigs.SettingCodeReportNodeGlobalSetting)
if err != nil {
return nil, err
}
if len(settingJSON) > 0 {
err = json.Unmarshal(settingJSON, setting)
if err != nil {
return nil, err
}
}
for _, result := range req.ReportResults {
// 更新数据
err := models.SharedReportResultDAO.UpdateResult(tx, result.Type, result.TargetId, result.TargetDesc, nodeId, result.Level, result.IsOk, float64(result.CostMs), result.Error)
if err != nil {
return nil, err
}
// 更新对象状态
costMs, err := models.SharedReportResultDAO.FindAvgCostMsWithTarget(tx, result.Type, result.TargetId)
if err != nil {
return nil, err
}
level, err := models.SharedReportResultDAO.FindAvgLevelWithTarget(tx, result.Type, result.TargetId)
if err != nil {
return nil, err
}
percent, err := models.SharedReportResultDAO.FindConnectivityWithTargetPercent(tx, result.Type, result.TargetId, 0)
if err != nil {
return nil, err
}
// 是否应该通知
if setting != nil && percent < setting.MinNotifyConnectivity {
switch result.Type {
case reporterconfigs.TaskTypeIPAddr:
addr, err := models.SharedNodeIPAddressDAO.FindEnabledAddress(tx, result.TargetId)
if err != nil {
return nil, err
}
if addr != nil {
var nodeId = int64(addr.NodeId)
clusterId, err := models.SharedNodeDAO.FindNodeClusterId(tx, nodeId)
if err != nil {
return nil, err
}
var messageSubject = "IP地址" + addr.Ip + "连通性低于" + types.String(setting.MinNotifyConnectivity) + "%"
err = models.SharedMessageDAO.CreateNodeMessage(tx, addr.Role, clusterId, nodeId, models.MessageTypeConnectivity, models.LevelError, messageSubject, messageSubject, maps.Map{"addrId": addr.Id}.AsJSON(), false)
if err != nil {
return nil, err
}
err = models.SharedMessageTaskDAO.CreateMessageTasks(tx, addr.Role, clusterId, nodeId, 0, models.MessageTypeConnectivity, messageSubject, messageSubject)
if err != nil {
return nil, err
}
// 发送外部通知
if len(setting.NotifyWebHookURL) > 0 {
var client = utils.SharedHttpClient(10 * time.Second)
var url = setting.NotifyWebHookURL
var args = "role=" + addr.Role + "&clusterId=" + types.String(clusterId) + "&nodeId=" + types.String(nodeId) + "&addressId=" + types.String(addr.Id) + "&ip=" + addr.Ip
var hasQuestionMark = strings.Contains(url, "?")
if hasQuestionMark {
url += "&" + args
} else {
url += "?" + args
}
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
// 不阻断执行
remotelogs.Error("ReportResultService.UpdateReportResults", "notify url '"+url+"' failed: "+err.Error())
} else {
req.Header.Set("User-Agent", teaconst.ProductName+"/"+teaconst.Version)
resp, err := client.Do(req)
if err != nil {
// 不阻断执行
remotelogs.Error("ReportResultService.UpdateReportResults", "notify url '"+url+"' failed: "+err.Error())
} else {
_ = resp.Body.Close()
}
}
}
}
}
}
// 保存
switch result.Type {
case reporterconfigs.TaskTypeIPAddr:
err = models.SharedNodeIPAddressDAO.UpdateAddressConnectivity(tx, result.TargetId, &nodeconfigs.Connectivity{
CostMs: costMs,
Level: level,
Percent: percent * 100,
UpdatedAt: time.Now().Unix(),
})
if err != nil {
return nil, err
}
}
}
return this.Success()
}
// FindAllReportResults 查询某个对象的监控结果
func (this *ReportResultService) FindAllReportResults(ctx context.Context, req *pb.FindAllReportResultsRequest) (*pb.FindAllReportResultsResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
results, err := models.SharedReportResultDAO.FindAllResults(tx, req.Type, req.TargetId)
if err != nil {
return nil, err
}
var pbResults = []*pb.ReportResult{}
for _, result := range results {
pbResults = append(pbResults, &pb.ReportResult{
Id: int64(result.Id),
Type: result.Type,
TargetId: int64(result.TargetId),
TargetDesc: result.TargetDesc,
ReportNodeId: int64(result.ReportNodeId),
IsOk: result.IsOk,
CostMs: float32(result.CostMs),
Error: result.Error,
UpdatedAt: int64(result.UpdatedAt),
Level: result.Level,
})
}
return &pb.FindAllReportResultsResponse{
ReportResults: pbResults,
}, nil
}

View File

@@ -0,0 +1,40 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package reporters
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
"github.com/iwind/TeaGo/dbs"
"google.golang.org/grpc/peer"
"net"
)
// 校验客户端IP
func validateClient(tx *dbs.Tx, nodeId int64, ctx context.Context) error {
allowIPs, err := models.SharedReportNodeDAO.FindNodeAllowIPs(tx, nodeId)
if err != nil {
return err
}
if len(allowIPs) == 0 {
return nil
}
p, ok := peer.FromContext(ctx)
if ok {
host, _, _ := net.SplitHostPort(p.Addr.String())
if len(host) > 0 {
for _, ip := range allowIPs {
r, err := shared.ParseIPRange(ip)
if err == nil && r != nil {
if r.Contains(host) {
return nil
}
}
}
}
}
return errors.New("client was not allowed")
}

View File

@@ -0,0 +1,35 @@
package services
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/acme"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
// ACME认证相关
type ACMEAuthenticationService struct {
BaseService
}
// 获取Key
func (this *ACMEAuthenticationService) FindACMEAuthenticationKeyWithToken(ctx context.Context, req *pb.FindACMEAuthenticationKeyWithTokenRequest) (*pb.FindACMEAuthenticationKeyWithTokenResponse, error) {
_, err := this.ValidateNode(ctx)
if err != nil {
return nil, err
}
if len(req.Token) == 0 {
return nil, errors.New("'token' should not be empty")
}
var tx = this.NullTx()
auth, err := acme.SharedACMEAuthenticationDAO.FindAuthWithToken(tx, req.Token)
if err != nil {
return nil, err
}
if auth == nil {
return &pb.FindACMEAuthenticationKeyWithTokenResponse{Key: ""}, nil
}
return &pb.FindACMEAuthenticationKeyWithTokenResponse{Key: auth.Key}, nil
}

View File

@@ -0,0 +1,62 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package services
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/acme"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
// ACMEProviderService ACME服务商
type ACMEProviderService struct {
BaseService
}
// FindAllACMEProviders 查找所有的服务商
func (this *ACMEProviderService) FindAllACMEProviders(ctx context.Context, req *pb.FindAllACMEProvidersRequest) (*pb.FindAllACMEProvidersResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var pbProviders = []*pb.ACMEProvider{}
for _, provider := range acme.FindAllProviders() {
pbProviders = append(pbProviders, &pb.ACMEProvider{
Name: provider.Name,
Code: provider.Code,
Description: provider.Description,
ApiURL: provider.APIURL,
RequireEAB: provider.RequireEAB,
EabDescription: provider.EABDescription,
})
}
return &pb.FindAllACMEProvidersResponse{AcmeProviders: pbProviders}, nil
}
// FindACMEProviderWithCode 根据代号查找服务商
func (this *ACMEProviderService) FindACMEProviderWithCode(ctx context.Context, req *pb.FindACMEProviderWithCodeRequest) (*pb.FindACMEProviderWithCodeResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var provider = acme.FindProviderWithCode(req.AcmeProviderCode)
if provider == nil {
return &pb.FindACMEProviderWithCodeResponse{
AcmeProvider: nil,
}, nil
}
return &pb.FindACMEProviderWithCodeResponse{
AcmeProvider: &pb.ACMEProvider{
Name: provider.Name,
Code: provider.Code,
Description: provider.Description,
ApiURL: provider.APIURL,
RequireEAB: provider.RequireEAB,
EabDescription: provider.EABDescription,
},
}, nil
}

View File

@@ -0,0 +1,230 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package services
import (
"context"
acmeutils "github.com/TeaOSLab/EdgeAPI/internal/acme"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/acme"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
// ACMEProviderAccountService ACME服务商账号服务
type ACMEProviderAccountService struct {
BaseService
}
// CreateACMEProviderAccount 创建服务商账号
func (this *ACMEProviderAccountService) CreateACMEProviderAccount(ctx context.Context, req *pb.CreateACMEProviderAccountRequest) (*pb.CreateACMEProviderAccountResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
accountId, err := acme.SharedACMEProviderAccountDAO.CreateAccount(tx, userId, req.Name, req.ProviderCode, req.EabKid, req.EabKey)
if err != nil {
return nil, err
}
return &pb.CreateACMEProviderAccountResponse{
AcmeProviderAccountId: accountId,
}, nil
}
// FindAllACMEProviderAccountsWithProviderCode 使用代号查找服务商账号
func (this *ACMEProviderAccountService) FindAllACMEProviderAccountsWithProviderCode(ctx context.Context, req *pb.FindAllACMEProviderAccountsWithProviderCodeRequest) (*pb.FindAllACMEProviderAccountsWithProviderCodeResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
accounts, err := acme.SharedACMEProviderAccountDAO.FindAllEnabledAccountsWithProviderCode(tx, userId, req.AcmeProviderCode)
if err != nil {
return nil, err
}
var pbAccounts = []*pb.ACMEProviderAccount{}
for _, account := range accounts {
var pbProvider *pb.ACMEProvider
provider := acmeutils.FindProviderWithCode(account.ProviderCode)
if provider != nil {
pbProvider = &pb.ACMEProvider{
Name: provider.Name,
Code: provider.Code,
Description: provider.Description,
ApiURL: provider.APIURL,
RequireEAB: provider.RequireEAB,
}
}
pbAccounts = append(pbAccounts, &pb.ACMEProviderAccount{
Id: int64(account.Id),
Name: account.Name,
ProviderCode: account.ProviderCode,
IsOn: account.IsOn,
EabKid: account.EabKid,
EabKey: account.EabKey,
Error: account.Error,
AcmeProvider: pbProvider,
})
}
return &pb.FindAllACMEProviderAccountsWithProviderCodeResponse{
AcmeProviderAccounts: pbAccounts,
}, nil
}
// UpdateACMEProviderAccount 修改服务商账号
func (this *ACMEProviderAccountService) UpdateACMEProviderAccount(ctx context.Context, req *pb.UpdateACMEProviderAccountRequest) (*pb.RPCSuccess, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查权限
if userId > 0 {
err = acme.SharedACMEProviderAccountDAO.CheckUserAccount(tx, userId, req.AcmeProviderAccountId)
if err != nil {
return nil, err
}
}
err = acme.SharedACMEProviderAccountDAO.UpdateAccount(tx, req.AcmeProviderAccountId, req.Name, req.EabKid, req.EabKey)
if err != nil {
return nil, err
}
return this.Success()
}
// DeleteACMEProviderAccount 删除服务商账号
func (this *ACMEProviderAccountService) DeleteACMEProviderAccount(ctx context.Context, req *pb.DeleteACMEProviderAccountRequest) (*pb.RPCSuccess, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查权限
if userId > 0 {
err = acme.SharedACMEProviderAccountDAO.CheckUserAccount(tx, userId, req.AcmeProviderAccountId)
if err != nil {
return nil, err
}
}
err = acme.SharedACMEProviderAccountDAO.DisableACMEProviderAccount(tx, req.AcmeProviderAccountId)
if err != nil {
return nil, err
}
return this.Success()
}
// FindEnabledACMEProviderAccount 查找单个服务商账号
func (this *ACMEProviderAccountService) FindEnabledACMEProviderAccount(ctx context.Context, req *pb.FindEnabledACMEProviderAccountRequest) (*pb.FindEnabledACMEProviderAccountResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查权限
if userId > 0 {
err = acme.SharedACMEProviderAccountDAO.CheckUserAccount(tx, userId, req.AcmeProviderAccountId)
if err != nil {
return nil, err
}
}
account, err := acme.SharedACMEProviderAccountDAO.FindEnabledACMEProviderAccount(tx, req.AcmeProviderAccountId)
if err != nil {
return nil, err
}
if account == nil {
return &pb.FindEnabledACMEProviderAccountResponse{AcmeProviderAccount: nil}, nil
}
var pbProvider *pb.ACMEProvider
provider := acmeutils.FindProviderWithCode(account.ProviderCode)
if provider != nil {
pbProvider = &pb.ACMEProvider{
Name: provider.Name,
Code: provider.Code,
Description: provider.Description,
ApiURL: provider.APIURL,
RequireEAB: provider.RequireEAB,
EabDescription: provider.EABDescription,
}
}
return &pb.FindEnabledACMEProviderAccountResponse{AcmeProviderAccount: &pb.ACMEProviderAccount{
Id: int64(account.Id),
Name: account.Name,
ProviderCode: account.ProviderCode,
IsOn: account.IsOn,
EabKid: account.EabKid,
EabKey: account.EabKey,
Error: account.Error,
AcmeProvider: pbProvider,
}}, nil
}
// CountAllEnabledACMEProviderAccounts 计算所有服务商账号数量
func (this *ACMEProviderAccountService) CountAllEnabledACMEProviderAccounts(ctx context.Context, req *pb.CountAllEnabledACMEProviderAccountsRequest) (*pb.RPCCountResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
count, err := acme.SharedACMEProviderAccountDAO.CountAllEnabledAccounts(tx, userId)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListEnabledACMEProviderAccounts 列出单页服务商账号
func (this *ACMEProviderAccountService) ListEnabledACMEProviderAccounts(ctx context.Context, req *pb.ListEnabledACMEProviderAccountsRequest) (*pb.ListEnabledACMEProviderAccountsResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
accounts, err := acme.SharedACMEProviderAccountDAO.ListEnabledAccounts(tx, userId, req.Offset, req.Size)
if err != nil {
return nil, err
}
var pbAccounts = []*pb.ACMEProviderAccount{}
for _, account := range accounts {
var pbProvider *pb.ACMEProvider
provider := acmeutils.FindProviderWithCode(account.ProviderCode)
if provider != nil {
pbProvider = &pb.ACMEProvider{
Name: provider.Name,
Code: provider.Code,
Description: provider.Description,
ApiURL: provider.APIURL,
RequireEAB: provider.RequireEAB,
EabDescription: provider.EABDescription,
}
}
pbAccounts = append(pbAccounts, &pb.ACMEProviderAccount{
Id: int64(account.Id),
Name: account.Name,
ProviderCode: account.ProviderCode,
IsOn: account.IsOn,
EabKid: account.EabKid,
EabKey: account.EabKey,
Error: account.Error,
AcmeProvider: pbProvider,
})
}
return &pb.ListEnabledACMEProviderAccountsResponse{AcmeProviderAccounts: pbAccounts}, nil
}

View File

@@ -0,0 +1,479 @@
package services
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/acme"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
acmemodels "github.com/TeaOSLab/EdgeAPI/internal/db/models/acme"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/dns"
"github.com/TeaOSLab/EdgeAPI/internal/dnsclients"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
// ACMETaskService ACME任务相关服务
type ACMETaskService struct {
BaseService
}
// CountAllEnabledACMETasksWithACMEUserId 计算某个ACME用户相关的任务数量
func (this *ACMETaskService) CountAllEnabledACMETasksWithACMEUserId(ctx context.Context, req *pb.CountAllEnabledACMETasksWithACMEUserIdRequest) (*pb.RPCCountResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
// TODO 校验权限
}
count, err := acmemodels.SharedACMETaskDAO.CountACMETasksWithACMEUserId(tx, req.AcmeUserId)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// CountEnabledACMETasksWithDNSProviderId 计算跟某个DNS服务商相关的任务数量
func (this *ACMETaskService) CountEnabledACMETasksWithDNSProviderId(ctx context.Context, req *pb.CountEnabledACMETasksWithDNSProviderIdRequest) (*pb.RPCCountResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
// TODO 校验权限
var tx = this.NullTx()
count, err := acmemodels.SharedACMETaskDAO.CountACMETasksWithDNSProviderId(tx, req.DnsProviderId)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// CountAllEnabledACMETasks 计算所有任务数量
func (this *ACMETaskService) CountAllEnabledACMETasks(ctx context.Context, req *pb.CountAllEnabledACMETasksRequest) (*pb.RPCCountResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
req.UserId = userId
}
count, err := acmemodels.SharedACMETaskDAO.CountAllEnabledACMETasks(tx, req.UserId, req.IsAvailable, req.IsExpired, int64(req.ExpiringDays), req.Keyword, req.UserOnly)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListEnabledACMETasks 列出单页任务
func (this *ACMETaskService) ListEnabledACMETasks(ctx context.Context, req *pb.ListEnabledACMETasksRequest) (*pb.ListEnabledACMETasksResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
req.UserId = userId
}
tasks, err := acmemodels.SharedACMETaskDAO.ListEnabledACMETasks(tx, req.UserId, req.IsAvailable, req.IsExpired, int64(req.ExpiringDays), req.Keyword, req.UserOnly, req.Offset, req.Size)
if err != nil {
return nil, err
}
var result = []*pb.ACMETask{}
for _, task := range tasks {
// ACME用户
acmeUser, err := acmemodels.SharedACMEUserDAO.FindEnabledACMEUser(tx, int64(task.AcmeUserId))
if err != nil {
return nil, err
}
if acmeUser == nil {
continue
}
pbACMEUser := &pb.ACMEUser{
Id: int64(acmeUser.Id),
Email: acmeUser.Email,
Description: acmeUser.Description,
CreatedAt: int64(acmeUser.CreatedAt),
}
// 服务商
if len(acmeUser.ProviderCode) == 0 {
acmeUser.ProviderCode = acme.DefaultProviderCode
}
var provider = acme.FindProviderWithCode(acmeUser.ProviderCode)
if provider != nil {
pbACMEUser.AcmeProvider = &pb.ACMEProvider{
Name: provider.Name,
Code: provider.Code,
Description: provider.Description,
RequireEAB: provider.RequireEAB,
EabDescription: provider.EABDescription,
}
}
// 账号
if acmeUser.AccountId > 0 {
account, err := acmemodels.SharedACMEProviderAccountDAO.FindEnabledACMEProviderAccount(tx, int64(acmeUser.AccountId))
if err != nil {
return nil, err
}
if account != nil {
pbACMEUser.AcmeProviderAccount = &pb.ACMEProviderAccount{
Id: int64(account.Id),
Name: account.Name,
IsOn: account.IsOn,
ProviderCode: account.ProviderCode,
AcmeProvider: nil,
}
var provider = acme.FindProviderWithCode(account.ProviderCode)
if provider != nil {
pbACMEUser.AcmeProviderAccount.AcmeProvider = &pb.ACMEProvider{
Name: provider.Name,
Code: provider.Code,
Description: provider.Description,
RequireEAB: provider.RequireEAB,
EabDescription: provider.EABDescription,
}
}
}
}
var pbDNSProvider *pb.DNSProvider
if task.AuthType == acme.AuthTypeDNS {
// DNS
provider, err := dns.SharedDNSProviderDAO.FindEnabledDNSProvider(tx, int64(task.DnsProviderId))
if err != nil {
return nil, err
}
if provider == nil {
continue
}
pbDNSProvider = &pb.DNSProvider{
Id: int64(provider.Id),
Name: provider.Name,
Type: provider.Type,
TypeName: dnsclients.FindProviderTypeName(provider.Type),
}
}
// 证书
var pbCert *pb.SSLCert = nil
if task.CertId > 0 {
cert, err := models.SharedSSLCertDAO.FindEnabledSSLCert(tx, int64(task.CertId))
if err != nil {
return nil, err
}
if cert == nil {
continue
}
pbCert = &pb.SSLCert{
Id: int64(cert.Id),
IsOn: cert.IsOn,
Name: cert.Name,
TimeBeginAt: int64(cert.TimeBeginAt),
TimeEndAt: int64(cert.TimeEndAt),
}
}
// 最近一条日志
var pbTaskLog *pb.ACMETaskLog = nil
taskLog, err := acmemodels.SharedACMETaskLogDAO.FindLatestACMETasKLog(tx, int64(task.Id))
if err != nil {
return nil, err
}
if taskLog != nil {
pbTaskLog = &pb.ACMETaskLog{
Id: int64(taskLog.Id),
IsOk: taskLog.IsOk,
Error: taskLog.Error,
CreatedAt: int64(taskLog.CreatedAt),
}
}
result = append(result, &pb.ACMETask{
Id: int64(task.Id),
IsOn: task.IsOn,
DnsDomain: task.DnsDomain,
Domains: task.DecodeDomains(),
CreatedAt: int64(task.CreatedAt),
AutoRenew: task.AutoRenew == 1,
AcmeUser: pbACMEUser,
DnsProvider: pbDNSProvider,
SslCert: pbCert,
LatestACMETaskLog: pbTaskLog,
AuthType: task.AuthType,
AuthURL: task.AuthURL,
})
}
return &pb.ListEnabledACMETasksResponse{AcmeTasks: result}, nil
}
// CreateACMETask 创建任务
func (this *ACMETaskService) CreateACMETask(ctx context.Context, req *pb.CreateACMETaskRequest) (*pb.CreateACMETaskResponse, error) {
adminId, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
if len(req.AuthType) == 0 {
req.AuthType = acme.AuthTypeDNS
}
if adminId > 0 {
if req.UserId > 0 {
userId = req.UserId
}
}
var tx = this.NullTx()
taskId, err := acmemodels.SharedACMETaskDAO.CreateACMETask(tx, adminId, userId, req.AuthType, req.AcmeUserId, req.DnsProviderId, req.DnsDomain, req.Domains, req.AutoRenew, req.AuthURL)
if err != nil {
return nil, err
}
return &pb.CreateACMETaskResponse{AcmeTaskId: taskId}, nil
}
// UpdateACMETask 修改任务
func (this *ACMETaskService) UpdateACMETask(ctx context.Context, req *pb.UpdateACMETaskRequest) (*pb.RPCSuccess, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
canAccess, err := acmemodels.SharedACMETaskDAO.CheckUserACMETask(tx, userId, req.AcmeTaskId)
if err != nil {
return nil, err
}
if !canAccess {
return nil, this.PermissionError()
}
err = acmemodels.SharedACMETaskDAO.UpdateACMETask(tx, req.AcmeTaskId, req.AcmeUserId, req.DnsProviderId, req.DnsDomain, req.Domains, req.AutoRenew, req.AuthURL)
if err != nil {
return nil, err
}
return this.Success()
}
// DeleteACMETask 删除任务
func (this *ACMETaskService) DeleteACMETask(ctx context.Context, req *pb.DeleteACMETaskRequest) (*pb.RPCSuccess, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
canAccess, err := acmemodels.SharedACMETaskDAO.CheckUserACMETask(tx, userId, req.AcmeTaskId)
if err != nil {
return nil, err
}
if !canAccess {
return nil, this.PermissionError()
}
err = acmemodels.SharedACMETaskDAO.DisableACMETask(tx, req.AcmeTaskId)
if err != nil {
return nil, err
}
return this.Success()
}
// RunACMETask 运行某个任务
func (this *ACMETaskService) RunACMETask(ctx context.Context, req *pb.RunACMETaskRequest) (*pb.RunACMETaskResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
canAccess, err := acmemodels.SharedACMETaskDAO.CheckUserACMETask(tx, userId, req.AcmeTaskId)
if err != nil {
return nil, err
}
if !canAccess {
return nil, this.PermissionError()
}
isOk, msg, certId := acmemodels.SharedACMETaskDAO.RunTask(tx, req.AcmeTaskId)
return &pb.RunACMETaskResponse{
IsOk: isOk,
Error: msg,
SslCertId: certId,
}, nil
}
// FindEnabledACMETask 查找单个任务信息
func (this *ACMETaskService) FindEnabledACMETask(ctx context.Context, req *pb.FindEnabledACMETaskRequest) (*pb.FindEnabledACMETaskResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
canAccess, err := acmemodels.SharedACMETaskDAO.CheckUserACMETask(tx, userId, req.AcmeTaskId)
if err != nil {
return nil, err
}
if !canAccess {
return nil, this.PermissionError()
}
task, err := acmemodels.SharedACMETaskDAO.FindEnabledACMETask(tx, req.AcmeTaskId)
if err != nil {
return nil, err
}
if task == nil {
return &pb.FindEnabledACMETaskResponse{AcmeTask: nil}, nil
}
// 用户
var pbACMEUser *pb.ACMEUser = nil
if task.AcmeUserId > 0 {
acmeUser, err := acmemodels.SharedACMEUserDAO.FindEnabledACMEUser(tx, int64(task.AcmeUserId))
if err != nil {
return nil, err
}
if acmeUser != nil {
pbACMEUser = &pb.ACMEUser{
Id: int64(acmeUser.Id),
Email: acmeUser.Email,
Description: acmeUser.Description,
CreatedAt: int64(acmeUser.CreatedAt),
}
// 服务商
if len(acmeUser.ProviderCode) == 0 {
acmeUser.ProviderCode = acme.DefaultProviderCode
}
var provider = acme.FindProviderWithCode(acmeUser.ProviderCode)
if provider != nil {
pbACMEUser.AcmeProvider = &pb.ACMEProvider{
Name: provider.Name,
Code: provider.Code,
Description: provider.Description,
RequireEAB: provider.RequireEAB,
EabDescription: provider.EABDescription,
}
}
// 账号
if acmeUser.AccountId > 0 {
account, err := acmemodels.SharedACMEProviderAccountDAO.FindEnabledACMEProviderAccount(tx, int64(acmeUser.AccountId))
if err != nil {
return nil, err
}
if account != nil {
pbACMEUser.AcmeProviderAccount = &pb.ACMEProviderAccount{
Id: int64(account.Id),
Name: account.Name,
IsOn: account.IsOn,
ProviderCode: account.ProviderCode,
AcmeProvider: nil,
}
var provider = acme.FindProviderWithCode(account.ProviderCode)
if provider != nil {
pbACMEUser.AcmeProviderAccount.AcmeProvider = &pb.ACMEProvider{
Name: provider.Name,
Code: provider.Code,
Description: provider.Description,
RequireEAB: provider.RequireEAB,
EabDescription: provider.EABDescription,
}
}
}
}
}
}
// DNS
var pbProvider *pb.DNSProvider
provider, err := dns.SharedDNSProviderDAO.FindEnabledDNSProvider(tx, int64(task.DnsProviderId))
if err != nil {
return nil, err
}
if provider != nil {
pbProvider = &pb.DNSProvider{
Id: int64(provider.Id),
Name: provider.Name,
Type: provider.Type,
TypeName: dnsclients.FindProviderTypeName(provider.Type),
}
}
// 证书
var pbCert *pb.SSLCert
if task.CertId > 0 {
pbCert = &pb.SSLCert{Id: int64(task.CertId)}
}
return &pb.FindEnabledACMETaskResponse{AcmeTask: &pb.ACMETask{
Id: int64(task.Id),
IsOn: task.IsOn,
DnsDomain: task.DnsDomain,
Domains: task.DecodeDomains(),
CreatedAt: int64(task.CreatedAt),
AutoRenew: task.AutoRenew == 1,
DnsProvider: pbProvider,
AcmeUser: pbACMEUser,
AuthType: task.AuthType,
AuthURL: task.AuthURL,
SslCert: pbCert,
}}, nil
}
// FindACMETaskUser 查找任务所属用户
func (this *ACMETaskService) FindACMETaskUser(ctx context.Context, req *pb.FindACMETaskUserRequest) (*pb.FindACMETaskUserResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
userId, err := acmemodels.SharedACMETaskDAO.FindACMETaskUserId(tx, req.AcmeTaskId)
if err != nil {
return nil, err
}
if userId <= 0 {
return &pb.FindACMETaskUserResponse{User: nil}, nil
}
user, err := models.SharedUserDAO.FindEnabledBasicUser(tx, userId)
if err != nil {
return nil, err
}
if user == nil {
return &pb.FindACMETaskUserResponse{
User: &pb.User{
Id: userId,
},
}, nil
}
return &pb.FindACMETaskUserResponse{
User: &pb.User{
Id: userId,
Username: user.Username,
Fullname: user.Fullname,
},
}, nil
}

View File

@@ -0,0 +1,295 @@
package services
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/acme"
acmemodels "github.com/TeaOSLab/EdgeAPI/internal/db/models/acme"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
// ACMEUserService 用户服务
type ACMEUserService struct {
BaseService
}
// CreateACMEUser 创建用户
func (this *ACMEUserService) CreateACMEUser(ctx context.Context, req *pb.CreateACMEUserRequest) (*pb.CreateACMEUserResponse, error) {
// 校验请求
adminId, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if adminId > 0 {
if req.UserId > 0 {
userId = req.UserId
}
}
acmeUserId, err := acmemodels.SharedACMEUserDAO.CreateACMEUser(tx, adminId, userId, req.AcmeProviderCode, req.AcmeProviderAccountId, req.Email, req.Description)
if err != nil {
return nil, err
}
return &pb.CreateACMEUserResponse{AcmeUserId: acmeUserId}, nil
}
// UpdateACMEUser 修改用户
func (this *ACMEUserService) UpdateACMEUser(ctx context.Context, req *pb.UpdateACMEUserRequest) (*pb.RPCSuccess, error) {
// 校验请求
adminId, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查是否有权限
b, err := acmemodels.SharedACMEUserDAO.CheckACMEUser(tx, req.AcmeUserId, adminId, userId)
if err != nil {
return nil, err
}
if !b {
return nil, this.PermissionError()
}
err = acmemodels.SharedACMEUserDAO.UpdateACMEUser(tx, req.AcmeUserId, req.Description)
if err != nil {
return nil, err
}
return this.Success()
}
// DeleteACMEUser 删除用户
func (this *ACMEUserService) DeleteACMEUser(ctx context.Context, req *pb.DeleteACMEUserRequest) (*pb.RPCSuccess, error) {
// 校验请求
adminId, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查是否有权限
b, err := acmemodels.SharedACMEUserDAO.CheckACMEUser(tx, req.AcmeUserId, adminId, userId)
if err != nil {
return nil, err
}
if !b {
return nil, this.PermissionError()
}
err = acmemodels.SharedACMEUserDAO.DisableACMEUser(tx, req.AcmeUserId)
if err != nil {
return nil, err
}
return this.Success()
}
// CountACMEUsers 计算用户数量
func (this *ACMEUserService) CountACMEUsers(ctx context.Context, req *pb.CountAcmeUsersRequest) (*pb.RPCCountResponse, error) {
// 校验请求
adminId, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
req.UserId = userId
}
count, err := acmemodels.SharedACMEUserDAO.CountACMEUsersWithAdminId(tx, adminId, req.UserId, req.AcmeProviderAccountId)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListACMEUsers 列出单页用户
func (this *ACMEUserService) ListACMEUsers(ctx context.Context, req *pb.ListACMEUsersRequest) (*pb.ListACMEUsersResponse, error) {
// 校验请求
adminId, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
req.UserId = userId
}
acmeUsers, err := acmemodels.SharedACMEUserDAO.ListACMEUsers(tx, adminId, req.UserId, req.Offset, req.Size)
if err != nil {
return nil, err
}
var result = []*pb.ACMEUser{}
for _, user := range acmeUsers {
var pbUser = &pb.ACMEUser{
Id: int64(user.Id),
Email: user.Email,
Description: user.Description,
CreatedAt: int64(user.CreatedAt),
AcmeProviderCode: user.ProviderCode,
}
// 服务商
if len(user.ProviderCode) == 0 {
user.ProviderCode = acme.DefaultProviderCode
}
var provider = acme.FindProviderWithCode(user.ProviderCode)
if provider != nil {
pbUser.AcmeProvider = &pb.ACMEProvider{
Name: provider.Name,
Code: provider.Code,
Description: provider.Description,
RequireEAB: provider.RequireEAB,
EabDescription: provider.EABDescription,
}
}
// 账号
if user.AccountId > 0 {
account, err := acmemodels.SharedACMEProviderAccountDAO.FindEnabledACMEProviderAccount(tx, int64(user.AccountId))
if err != nil {
return nil, err
}
if account != nil {
pbUser.AcmeProviderAccount = &pb.ACMEProviderAccount{
Id: int64(account.Id),
Name: account.Name,
IsOn: account.IsOn,
ProviderCode: account.ProviderCode,
AcmeProvider: nil,
}
var provider = acme.FindProviderWithCode(account.ProviderCode)
if provider != nil {
pbUser.AcmeProviderAccount.AcmeProvider = &pb.ACMEProvider{
Name: provider.Name,
Code: provider.Code,
Description: provider.Description,
RequireEAB: provider.RequireEAB,
EabDescription: provider.EABDescription,
}
}
}
}
result = append(result, pbUser)
}
return &pb.ListACMEUsersResponse{AcmeUsers: result}, nil
}
// FindEnabledACMEUser 查找单个用户
func (this *ACMEUserService) FindEnabledACMEUser(ctx context.Context, req *pb.FindEnabledACMEUserRequest) (*pb.FindEnabledACMEUserResponse, error) {
// 校验请求
adminId, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查是否有权限
b, err := acmemodels.SharedACMEUserDAO.CheckACMEUser(tx, req.AcmeUserId, adminId, userId)
if err != nil {
return nil, err
}
if !b {
return nil, this.PermissionError()
}
acmeUser, err := acmemodels.SharedACMEUserDAO.FindEnabledACMEUser(tx, req.AcmeUserId)
if err != nil {
return nil, err
}
if acmeUser == nil {
return &pb.FindEnabledACMEUserResponse{AcmeUser: nil}, nil
}
// 服务商
var pbACMEUser = &pb.ACMEUser{
Id: int64(acmeUser.Id),
Email: acmeUser.Email,
Description: acmeUser.Description,
CreatedAt: int64(acmeUser.CreatedAt),
AcmeProviderCode: acmeUser.ProviderCode,
}
if len(acmeUser.ProviderCode) == 0 {
acmeUser.ProviderCode = acme.DefaultProviderCode
}
var provider = acme.FindProviderWithCode(acmeUser.ProviderCode)
if provider != nil {
pbACMEUser.AcmeProvider = &pb.ACMEProvider{
Name: provider.Name,
Code: provider.Code,
Description: provider.Description,
RequireEAB: provider.RequireEAB,
EabDescription: provider.EABDescription,
}
}
// 账号
if acmeUser.AccountId > 0 {
account, err := acmemodels.SharedACMEProviderAccountDAO.FindEnabledACMEProviderAccount(tx, int64(acmeUser.AccountId))
if err != nil {
return nil, err
}
if account != nil {
pbACMEUser.AcmeProviderAccount = &pb.ACMEProviderAccount{
Id: int64(account.Id),
Name: account.Name,
IsOn: account.IsOn,
ProviderCode: account.ProviderCode,
AcmeProvider: nil,
}
var provider = acme.FindProviderWithCode(account.ProviderCode)
if provider != nil {
pbACMEUser.AcmeProviderAccount.AcmeProvider = &pb.ACMEProvider{
Name: provider.Name,
Code: provider.Code,
Description: provider.Description,
RequireEAB: provider.RequireEAB,
EabDescription: provider.EABDescription,
}
}
}
}
return &pb.FindEnabledACMEUserResponse{AcmeUser: pbACMEUser}, nil
}
// FindAllACMEUsers 查找所有用户
func (this *ACMEUserService) FindAllACMEUsers(ctx context.Context, req *pb.FindAllACMEUsersRequest) (*pb.FindAllACMEUsersResponse, error) {
// 校验请求
adminId, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
req.UserId = userId
}
acmeUsers, err := acmemodels.SharedACMEUserDAO.FindAllACMEUsers(tx, adminId, req.UserId, req.AcmeProviderCode)
if err != nil {
return nil, err
}
var result = []*pb.ACMEUser{}
for _, user := range acmeUsers {
result = append(result, &pb.ACMEUser{
Id: int64(user.Id),
Email: user.Email,
Description: user.Description,
CreatedAt: int64(user.CreatedAt),
AcmeProviderCode: user.ProviderCode,
})
}
return &pb.FindAllACMEUsersResponse{AcmeUsers: result}, nil
}

View File

@@ -0,0 +1,787 @@
package services
import (
"context"
"encoding/json"
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/stats"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/tasks"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/systemconfigs"
timeutil "github.com/iwind/TeaGo/utils/time"
"time"
)
// AdminService 管理员相关服务
type AdminService struct {
BaseService
debug bool
}
// LoginAdmin 登录
func (this *AdminService) LoginAdmin(ctx context.Context, req *pb.LoginAdminRequest) (*pb.LoginAdminResponse, error) {
_, _, _, err := rpcutils.ValidateRequest(ctx)
if err != nil {
return nil, err
}
if len(req.Username) == 0 || len(req.Password) == 0 {
return &pb.LoginAdminResponse{
AdminId: 0,
IsOk: false,
Message: "请输入正确的用户名密码",
}, nil
}
var tx = this.NullTx()
adminId, err := models.SharedAdminDAO.CheckAdminPassword(tx, req.Username, req.Password)
if err != nil {
utils.PrintError(err)
return nil, err
}
if adminId <= 0 {
return &pb.LoginAdminResponse{
AdminId: 0,
IsOk: false,
Message: "请输入正确的用户名密码",
}, nil
}
return &pb.LoginAdminResponse{
AdminId: adminId,
IsOk: true,
}, nil
}
// CheckAdminExists 检查管理员是否存在
func (this *AdminService) CheckAdminExists(ctx context.Context, req *pb.CheckAdminExistsRequest) (*pb.CheckAdminExistsResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
if req.AdminId <= 0 {
return &pb.CheckAdminExistsResponse{
IsOk: false,
}, nil
}
var tx = this.NullTx()
ok, err := models.SharedAdminDAO.ExistEnabledAdmin(tx, req.AdminId)
if err != nil {
return nil, err
}
return &pb.CheckAdminExistsResponse{
IsOk: ok,
}, nil
}
// CheckAdminUsername 检查用户名是否存在
func (this *AdminService) CheckAdminUsername(ctx context.Context, req *pb.CheckAdminUsernameRequest) (*pb.CheckAdminUsernameResponse, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
exists, err := models.SharedAdminDAO.CheckAdminUsername(tx, req.AdminId, req.Username)
if err != nil {
return nil, err
}
return &pb.CheckAdminUsernameResponse{Exists: exists}, nil
}
// FindAdminWithUsername 使用用管理员户名查找管理员信息
func (this *AdminService) FindAdminWithUsername(ctx context.Context, req *pb.FindAdminWithUsernameRequest) (*pb.FindAdminWithUsernameResponse, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if len(req.Username) == 0 {
return nil, errors.New("require 'username'")
}
admin, err := models.SharedAdminDAO.FindAdminWithUsername(tx, req.Username)
if err != nil {
return nil, err
}
if admin == nil {
return &pb.FindAdminWithUsernameResponse{Admin: nil}, nil
}
return &pb.FindAdminWithUsernameResponse{
Admin: &pb.Admin{
Id: int64(admin.Id),
Fullname: admin.Fullname,
Username: admin.Username,
IsOn: admin.IsOn,
IsSuper: admin.IsSuper,
CanLogin: admin.CanLogin,
},
}, nil
}
// FindAdminFullname 获取管理员名称
func (this *AdminService) FindAdminFullname(ctx context.Context, req *pb.FindAdminFullnameRequest) (*pb.FindAdminFullnameResponse, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
fullname, err := models.SharedAdminDAO.FindAdminFullname(tx, req.AdminId)
if err != nil {
utils.PrintError(err)
return nil, err
}
return &pb.FindAdminFullnameResponse{
Fullname: fullname,
}, nil
}
// FindEnabledAdmin 获取管理员信息
func (this *AdminService) FindEnabledAdmin(ctx context.Context, req *pb.FindEnabledAdminRequest) (*pb.FindEnabledAdminResponse, error) {
adminId, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
// TODO 检查权限
var tx = this.NullTx()
// 超级管理员才能查看是否为弱密码
isSuperAdmin, err := models.SharedAdminDAO.CheckSuperAdmin(tx, adminId)
if err != nil {
return nil, err
}
admin, err := models.SharedAdminDAO.FindEnabledAdmin(tx, req.AdminId)
if err != nil {
return nil, err
}
if admin == nil {
return &pb.FindEnabledAdminResponse{Admin: nil}, nil
}
var pbModules = []*pb.AdminModule{}
modules := []*systemconfigs.AdminModule{}
if len(admin.Modules) > 0 {
err = json.Unmarshal(admin.Modules, &modules)
if err != nil {
return nil, err
}
for _, module := range modules {
pbModules = append(pbModules, &pb.AdminModule{
AllowAll: module.AllowAll,
Code: module.Code,
Actions: module.Actions,
})
}
}
// OTP认证
var pbOtpAuth *pb.Login = nil
{
adminAuth, err := models.SharedLoginDAO.FindEnabledLoginWithType(tx, int64(admin.Id), 0, models.LoginTypeOTP)
if err != nil {
return nil, err
}
if adminAuth != nil {
pbOtpAuth = &pb.Login{
Id: int64(adminAuth.Id),
Type: adminAuth.Type,
ParamsJSON: adminAuth.Params,
IsOn: adminAuth.IsOn,
}
}
}
result := &pb.Admin{
Id: int64(admin.Id),
Fullname: admin.Fullname,
Username: admin.Username,
IsOn: admin.IsOn,
IsSuper: admin.IsSuper,
Modules: pbModules,
OtpLogin: pbOtpAuth,
CanLogin: admin.CanLogin,
HasWeakPassword: isSuperAdmin && admin.HasWeakPassword(),
}
return &pb.FindEnabledAdminResponse{Admin: result}, nil
}
// CreateOrUpdateAdmin 创建或修改管理员
func (this *AdminService) CreateOrUpdateAdmin(ctx context.Context, req *pb.CreateOrUpdateAdminRequest) (*pb.CreateOrUpdateAdminResponse, error) {
// 校验请求
_, _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeAPI)
if err != nil {
return nil, err
}
var tx = this.NullTx()
adminId, err := models.SharedAdminDAO.FindAdminIdWithUsername(tx, req.Username)
if err != nil {
return nil, err
}
if adminId > 0 {
err = models.SharedAdminDAO.UpdateAdminPassword(tx, adminId, req.Password)
if err != nil {
return nil, err
}
return &pb.CreateOrUpdateAdminResponse{AdminId: adminId}, nil
}
adminId, err = models.SharedAdminDAO.CreateAdmin(tx, req.Username, true, req.Password, "管理员", true, nil)
if err != nil {
return nil, err
}
return &pb.CreateOrUpdateAdminResponse{AdminId: adminId}, nil
}
// UpdateAdminInfo 修改管理员信息
func (this *AdminService) UpdateAdminInfo(ctx context.Context, req *pb.UpdateAdminInfoRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeAPI)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedAdminDAO.UpdateAdminInfo(tx, req.AdminId, req.Fullname)
if err != nil {
return nil, err
}
return this.Success()
}
// UpdateAdminLogin 修改管理员登录信息
func (this *AdminService) UpdateAdminLogin(ctx context.Context, req *pb.UpdateAdminLoginRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeAPI)
if err != nil {
return nil, err
}
var tx = this.NullTx()
exists, err := models.SharedAdminDAO.CheckAdminUsername(tx, req.AdminId, req.Username)
if err != nil {
return nil, err
}
if exists {
return nil, errors.New("username already been token")
}
err = models.SharedAdminDAO.UpdateAdminLogin(tx, req.AdminId, req.Username, req.Password)
if err != nil {
return nil, err
}
return this.Success()
}
// FindAllAdminModules 获取所有管理员的权限列表
func (this *AdminService) FindAllAdminModules(ctx context.Context, req *pb.FindAllAdminModulesRequest) (*pb.FindAllAdminModulesResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
// TODO 检查权限
var tx = this.NullTx()
admins, err := models.SharedAdminDAO.FindAllAdminModules(tx)
if err != nil {
return nil, err
}
var result = []*pb.AdminModuleList{}
for _, admin := range admins {
modules := []*systemconfigs.AdminModule{}
if len(admin.Modules) > 0 {
err = json.Unmarshal(admin.Modules, &modules)
if err != nil {
return nil, err
}
}
var pbModules = []*pb.AdminModule{}
for _, module := range modules {
pbModules = append(pbModules, &pb.AdminModule{
AllowAll: module.AllowAll,
Code: module.Code,
Actions: module.Actions,
})
}
var list = &pb.AdminModuleList{
AdminId: int64(admin.Id),
IsSuper: admin.IsSuper,
Fullname: admin.Fullname,
Theme: admin.Theme,
Lang: admin.Lang,
Modules: pbModules,
}
result = append(result, list)
}
return &pb.FindAllAdminModulesResponse{AdminModules: result}, nil
}
// CreateAdmin 创建管理员
func (this *AdminService) CreateAdmin(ctx context.Context, req *pb.CreateAdminRequest) (*pb.CreateAdminResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
// TODO 检查权限
var tx = this.NullTx()
adminId, err := models.SharedAdminDAO.CreateAdmin(tx, req.Username, req.CanLogin, req.Password, req.Fullname, req.IsSuper, req.ModulesJSON)
if err != nil {
return nil, err
}
return &pb.CreateAdminResponse{AdminId: adminId}, nil
}
// UpdateAdmin 修改管理员
func (this *AdminService) UpdateAdmin(ctx context.Context, req *pb.UpdateAdminRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
// TODO 检查权限
var tx = this.NullTx()
err = models.SharedAdminDAO.UpdateAdmin(tx, req.AdminId, req.Username, req.CanLogin, req.Password, req.Fullname, req.IsSuper, req.ModulesJSON, req.IsOn)
if err != nil {
return nil, err
}
return this.Success()
}
// CountAllEnabledAdmins 计算管理员数量
func (this *AdminService) CountAllEnabledAdmins(ctx context.Context, req *pb.CountAllEnabledAdminsRequest) (*pb.RPCCountResponse, error) {
adminId, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
// TODO 检查权限
var tx = this.NullTx()
// 超级管理员才能查看是否为弱密码
isSuperAdmin, err := models.SharedAdminDAO.CheckSuperAdmin(tx, adminId)
if err != nil {
return nil, err
}
if !isSuperAdmin && req.HasWeakPassword {
return this.SuccessCount(0)
}
count, err := models.SharedAdminDAO.CountAllEnabledAdmins(tx, req.Keyword, req.HasWeakPassword)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListEnabledAdmins 列出单页的管理员
func (this *AdminService) ListEnabledAdmins(ctx context.Context, req *pb.ListEnabledAdminsRequest) (*pb.ListEnabledAdminsResponse, error) {
adminId, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
// TODO 检查权限
var tx = this.NullTx()
// 超级管理员才能查看是否为弱密码
isSuperAdmin, err := models.SharedAdminDAO.CheckSuperAdmin(tx, adminId)
if err != nil {
return nil, err
}
if !isSuperAdmin && req.HasWeakPassword {
return &pb.ListEnabledAdminsResponse{Admins: nil}, nil
}
admins, err := models.SharedAdminDAO.ListEnabledAdmins(tx, req.Keyword, req.HasWeakPassword, req.Offset, req.Size)
if err != nil {
return nil, err
}
var result = []*pb.Admin{}
for _, admin := range admins {
var pbOtpAuth *pb.Login = nil
{
adminAuth, err := models.SharedLoginDAO.FindEnabledLoginWithType(tx, int64(admin.Id), 0, models.LoginTypeOTP)
if err != nil {
return nil, err
}
if adminAuth != nil {
pbOtpAuth = &pb.Login{
Id: int64(adminAuth.Id),
Type: adminAuth.Type,
ParamsJSON: adminAuth.Params,
IsOn: adminAuth.IsOn,
}
}
}
result = append(result, &pb.Admin{
Id: int64(admin.Id),
Fullname: admin.Fullname,
Username: admin.Username,
IsOn: admin.IsOn,
IsSuper: admin.IsSuper,
CreatedAt: int64(admin.CreatedAt),
OtpLogin: pbOtpAuth,
CanLogin: admin.CanLogin,
HasWeakPassword: isSuperAdmin && admin.HasWeakPassword(),
})
}
return &pb.ListEnabledAdminsResponse{Admins: result}, nil
}
// DeleteAdmin 删除管理员
func (this *AdminService) DeleteAdmin(ctx context.Context, req *pb.DeleteAdminRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
// TODO 检查权限
var tx = this.NullTx()
// TODO 超级管理员用户是不能删除的,或者要至少留一个超级管理员用户
err = models.SharedAdminDAO.DisableAdmin(tx, req.AdminId)
if err != nil {
return nil, err
}
return this.Success()
}
// CheckAdminOTPWithUsername 检查是否需要输入OTP
func (this *AdminService) CheckAdminOTPWithUsername(ctx context.Context, req *pb.CheckAdminOTPWithUsernameRequest) (*pb.CheckAdminOTPWithUsernameResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
if len(req.Username) == 0 {
return &pb.CheckAdminOTPWithUsernameResponse{RequireOTP: false}, nil
}
var tx = this.NullTx()
adminId, err := models.SharedAdminDAO.FindAdminIdWithUsername(tx, req.Username)
if err != nil {
return nil, err
}
if adminId <= 0 {
return &pb.CheckAdminOTPWithUsernameResponse{RequireOTP: false}, nil
}
otpIsOn, err := models.SharedLoginDAO.CheckLoginIsOn(tx, adminId, 0, "otp")
if err != nil {
return nil, err
}
return &pb.CheckAdminOTPWithUsernameResponse{RequireOTP: otpIsOn}, nil
}
// ComposeAdminDashboard 取得管理员Dashboard数据
func (this *AdminService) ComposeAdminDashboard(ctx context.Context, req *pb.ComposeAdminDashboardRequest) (*pb.ComposeAdminDashboardResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
result := &pb.ComposeAdminDashboardResponse{}
var tx = this.NullTx()
// 默认集群
this.BeginTag(ctx, "SharedNodeClusterDAO.ListEnabledClusters")
nodeClusters, err := models.SharedNodeClusterDAO.ListEnabledClusters(tx, "", true, false, 0, 1)
this.EndTag(ctx, "SharedNodeClusterDAO.ListEnabledClusters")
if err != nil {
return nil, err
}
if len(nodeClusters) > 0 {
result.DefaultNodeClusterId = int64(nodeClusters[0].Id)
}
// 集群数
this.BeginTag(ctx, "SharedNodeClusterDAO.CountAllEnabledClusters")
countClusters, err := models.SharedNodeClusterDAO.CountAllEnabledClusters(tx, "")
this.EndTag(ctx, "SharedNodeClusterDAO.CountAllEnabledClusters")
if err != nil {
return nil, err
}
result.CountNodeClusters = countClusters
// 节点数
{
this.BeginTag(ctx, "SharedNodeDAO.CountAllEnabledNodes")
countNodes, err := models.SharedNodeDAO.CountAllEnabledNodes(tx)
this.EndTag(ctx, "SharedNodeDAO.CountAllEnabledNodes")
if err != nil {
return nil, err
}
result.CountNodes = countNodes
}
// 离线节点
this.BeginTag(ctx, "SharedNodeDAO.CountAllEnabledOfflineNodes")
countOfflineNodes, err := models.SharedNodeDAO.CountAllEnabledOfflineNodes(tx)
this.EndTag(ctx, "SharedNodeDAO.CountAllEnabledOfflineNodes")
if err != nil {
return nil, err
}
result.CountOfflineNodes = countOfflineNodes
// 网站数
this.BeginTag(ctx, "SharedServerDAO.CountAllEnabledServers")
countServers, err := models.SharedServerDAO.CountAllEnabledServers(tx)
this.EndTag(ctx, "SharedServerDAO.CountAllEnabledServers")
if err != nil {
return nil, err
}
result.CountServers = countServers
this.BeginTag(ctx, "SharedServerDAO.CountAllEnabledServersMatch")
countAuditingServers, err := models.SharedServerDAO.CountAllEnabledServersMatch(tx, 0, "", 0, 0, configutils.BoolStateYes, nil, 0)
this.EndTag(ctx, "SharedServerDAO.CountAllEnabledServersMatch")
if err != nil {
return nil, err
}
result.CountAuditingServers = countAuditingServers
// 用户数
this.BeginTag(ctx, "SharedUserDAO.CountAllEnabledUsers")
countUsers, err := models.SharedUserDAO.CountAllEnabledUsers(tx, 0, "", false, -1)
this.EndTag(ctx, "SharedUserDAO.CountAllEnabledUsers")
if err != nil {
return nil, err
}
result.CountUsers = countUsers
// API节点数
this.BeginTag(ctx, "SharedAPINodeDAO.CountAllEnabledAndOnAPINodes")
countAPINodes, err := models.SharedAPINodeDAO.CountAllEnabledAndOnAPINodes(tx)
this.EndTag(ctx, "SharedAPINodeDAO.CountAllEnabledAndOnAPINodes")
if err != nil {
return nil, err
}
result.CountAPINodes = countAPINodes
// 离线API节点
this.BeginTag(ctx, "SharedAPINodeDAO.CountAllEnabledAndOnOfflineAPINodes")
countOfflineAPINodes, err := models.SharedAPINodeDAO.CountAllEnabledAndOnOfflineAPINodes(tx)
this.EndTag(ctx, "SharedAPINodeDAO.CountAllEnabledAndOnOfflineAPINodes")
if err != nil {
return nil, err
}
result.CountOfflineAPINodes = countOfflineAPINodes
// 数据库节点数
this.BeginTag(ctx, "SharedDBNodeDAO.CountAllEnabledNodes")
countDBNodes, err := models.SharedDBNodeDAO.CountAllEnabledNodes(tx)
this.EndTag(ctx, "SharedDBNodeDAO.CountAllEnabledNodes")
if err != nil {
return nil, err
}
result.CountDBNodes = countDBNodes
// 用户节点数
this.BeginTag(ctx, "SharedUserNodeDAO.CountAllEnabledAndOnUserNodes")
countUserNodes, err := models.SharedUserNodeDAO.CountAllEnabledAndOnUserNodes(tx)
this.EndTag(ctx, "SharedUserNodeDAO.CountAllEnabledAndOnUserNodes")
if err != nil {
return nil, err
}
result.CountUserNodes = countUserNodes
// 离线用户节点数
this.BeginTag(ctx, "SharedUserNodeDAO.CountAllEnabledAndOnOfflineNodes")
countOfflineUserNodes, err := models.SharedUserNodeDAO.CountAllEnabledAndOnOfflineNodes(tx)
this.EndTag(ctx, "SharedUserNodeDAO.CountAllEnabledAndOnOfflineNodes")
if err != nil {
return nil, err
}
result.CountOfflineUserNodes = countOfflineUserNodes
// 按日流量统计
this.BeginTag(ctx, "SharedTrafficDailyStatDAO.FindDailyStats")
dayFrom := timeutil.Format("Ymd", time.Now().AddDate(0, 0, -14))
dailyTrafficStats, err := stats.SharedTrafficDailyStatDAO.FindDailyStats(tx, dayFrom, timeutil.Format("Ymd"))
this.EndTag(ctx, "SharedTrafficDailyStatDAO.FindDailyStats")
if err != nil {
return nil, err
}
for _, stat := range dailyTrafficStats {
result.DailyTrafficStats = append(result.DailyTrafficStats, &pb.ComposeAdminDashboardResponse_DailyTrafficStat{
Day: stat.Day,
Bytes: int64(stat.Bytes),
CachedBytes: int64(stat.CachedBytes),
CountRequests: int64(stat.CountRequests),
CountCachedRequests: int64(stat.CountCachedRequests),
CountAttackRequests: int64(stat.CountAttackRequests),
AttackBytes: int64(stat.AttackBytes),
CountIPs: int64(stat.CountIPs),
})
}
// 小时流量统计
var hourFrom = timeutil.Format("YmdH", time.Now().Add(-23*time.Hour))
var hourTo = timeutil.Format("YmdH")
this.BeginTag(ctx, "SharedTrafficHourlyStatDAO.FindHourlyStats")
hourlyTrafficStats, err := stats.SharedTrafficHourlyStatDAO.FindHourlyStats(tx, hourFrom, hourTo)
this.EndTag(ctx, "SharedTrafficHourlyStatDAO.FindHourlyStats")
if err != nil {
return nil, err
}
for _, stat := range hourlyTrafficStats {
result.HourlyTrafficStats = append(result.HourlyTrafficStats, &pb.ComposeAdminDashboardResponse_HourlyTrafficStat{
Hour: stat.Hour,
Bytes: int64(stat.Bytes),
CachedBytes: int64(stat.CachedBytes),
CountRequests: int64(stat.CountRequests),
CountCachedRequests: int64(stat.CountCachedRequests),
CountAttackRequests: int64(stat.CountAttackRequests),
AttackBytes: int64(stat.AttackBytes),
})
}
// 边缘节点升级信息
{
upgradeInfo := &pb.ComposeAdminDashboardResponse_UpgradeInfo{
NewVersion: teaconst.NodeVersion,
}
this.BeginTag(ctx, "SharedNodeDAO.CountAllLowerVersionNodes")
countNodes, err := models.SharedNodeDAO.CountAllLowerVersionNodes(tx, upgradeInfo.NewVersion)
this.EndTag(ctx, "SharedNodeDAO.CountAllLowerVersionNodes")
if err != nil {
return nil, err
}
upgradeInfo.CountNodes = countNodes
result.NodeUpgradeInfo = upgradeInfo
}
// API节点升级信息
{
var apiVersion = req.ApiVersion
if len(apiVersion) == 0 {
apiVersion = teaconst.Version
}
upgradeInfo := &pb.ComposeAdminDashboardResponse_UpgradeInfo{
NewVersion: apiVersion,
}
this.BeginTag(ctx, "SharedAPINodeDAO.CountAllLowerVersionNodes")
countNodes, err := models.SharedAPINodeDAO.CountAllLowerVersionNodes(tx, upgradeInfo.NewVersion)
this.EndTag(ctx, "SharedAPINodeDAO.CountAllLowerVersionNodes")
if err != nil {
return nil, err
}
upgradeInfo.CountNodes = countNodes
result.ApiNodeUpgradeInfo = upgradeInfo
}
// 额外的检查节点版本
err = this.composeAdminDashboardExt(tx, ctx, result)
if err != nil {
return nil, err
}
// 域名排行
this.BeginTag(ctx, "SharedServerDomainHourlyStatDAO.FindTopDomainStats")
var topDomainStats []*stats.ServerDomainHourlyStat
topDomainStatsCache, ok := tasks.SharedCacheTaskManager.GetGlobalTopDomains()
if ok {
topDomainStats = topDomainStatsCache.([]*stats.ServerDomainHourlyStat)
}
this.EndTag(ctx, "SharedServerDomainHourlyStatDAO.FindTopDomainStats")
for _, stat := range topDomainStats {
result.TopDomainStats = append(result.TopDomainStats, &pb.ComposeAdminDashboardResponse_DomainStat{
ServerId: int64(stat.ServerId),
Domain: stat.Domain,
CountRequests: int64(stat.CountRequests),
Bytes: int64(stat.Bytes),
})
}
// 指标数据
this.BeginTag(ctx, "findMetricDataCharts")
var pbCharts []*pb.MetricDataChart
pbChartsCache, ok := tasks.SharedCacheTaskManager.Get(tasks.CacheKeyFindAllMetricDataCharts)
if ok {
pbCharts = pbChartsCache.([]*pb.MetricDataChart)
}
this.EndTag(ctx, "findMetricDataCharts")
result.MetricDataCharts = pbCharts
return result, nil
}
// UpdateAdminTheme 修改管理员使用的界面风格
func (this *AdminService) UpdateAdminTheme(ctx context.Context, req *pb.UpdateAdminThemeRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedAdminDAO.UpdateAdminTheme(tx, req.AdminId, req.Theme)
if err != nil {
return nil, err
}
return this.Success()
}
// UpdateAdminLang 修改管理员使用的语言
func (this *AdminService) UpdateAdminLang(ctx context.Context, req *pb.UpdateAdminLangRequest) (*pb.RPCSuccess, error) {
adminId, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedAdminDAO.UpdateAdminLang(tx, adminId, req.LangCode)
if err != nil {
return nil, err
}
return this.Success()
}

View File

@@ -0,0 +1,15 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build !plus
package services
import (
"context"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/dbs"
)
// ComposeAdminDashboard方法扩展
func (this *AdminService) composeAdminDashboardExt(tx *dbs.Tx, ctx context.Context, result *pb.ComposeAdminDashboardResponse) error {
return nil
}

View File

@@ -0,0 +1,127 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package services
import (
"context"
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/regions"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/stats"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/dbs"
timeutil "github.com/iwind/TeaGo/utils/time"
"time"
)
// 检查节点版本
func (this *AdminService) composeAdminDashboardExt(tx *dbs.Tx, ctx context.Context, result *pb.ComposeAdminDashboardResponse) error {
var isPlus = teaconst.IsPlus
// 用户节点升级信息
if isPlus {
upgradeInfo := &pb.ComposeAdminDashboardResponse_UpgradeInfo{
NewVersion: teaconst.UserNodeVersion,
}
this.BeginTag(ctx, "SharedUserNodeDAO.CountAllLowerVersionNodes")
countNodes, err := models.SharedUserNodeDAO.CountAllLowerVersionNodes(tx, upgradeInfo.NewVersion)
this.EndTag(ctx, "SharedUserNodeDAO.CountAllLowerVersionNodes")
if err != nil {
return err
}
upgradeInfo.CountNodes = countNodes
result.UserNodeUpgradeInfo = upgradeInfo
}
// DNS节点升级信息
if isPlus {
upgradeInfo := &pb.ComposeAdminDashboardResponse_UpgradeInfo{
NewVersion: teaconst.DNSNodeVersion,
}
this.BeginTag(ctx, "SharedNSNodeDAO.CountAllLowerVersionNodes")
countNodes, err := models.SharedNSNodeDAO.CountAllLowerVersionNodes(tx, upgradeInfo.NewVersion)
this.EndTag(ctx, "SharedNSNodeDAO.CountAllLowerVersionNodes")
if err != nil {
return err
}
upgradeInfo.CountNodes = countNodes
result.NsNodeUpgradeInfo = upgradeInfo
}
// Report节点升级信息
if isPlus {
upgradeInfo := &pb.ComposeAdminDashboardResponse_UpgradeInfo{
NewVersion: teaconst.ReportNodeVersion,
}
this.BeginTag(ctx, "SharedReportNodeDAO.CountAllLowerVersionNodes")
countNodes, err := models.SharedReportNodeDAO.CountAllLowerVersionNodes(tx, upgradeInfo.NewVersion)
this.EndTag(ctx, "SharedReportNodeDAO.CountAllLowerVersionNodes")
if err != nil {
return err
}
upgradeInfo.CountNodes = countNodes
result.ReportNodeUpgradeInfo = upgradeInfo
}
// 节点排行
var hourFrom = timeutil.Format("YmdH", time.Now().Add(-23*time.Hour))
var hourTo = timeutil.Format("YmdH")
if isPlus {
this.BeginTag(ctx, "SharedNodeTrafficHourlyStatDAO.FindTopNodeStats")
topNodeStats, err := stats.SharedNodeTrafficHourlyStatDAO.FindTopNodeStats(tx, "node", hourFrom, hourTo, 10)
this.EndTag(ctx, "SharedNodeTrafficHourlyStatDAO.FindTopNodeStats")
if err != nil {
return err
}
for _, stat := range topNodeStats {
nodeName, err := models.SharedNodeDAO.FindNodeName(tx, int64(stat.NodeId))
if err != nil {
return err
}
if len(nodeName) == 0 {
continue
}
result.TopNodeStats = append(result.TopNodeStats, &pb.ComposeAdminDashboardResponse_NodeStat{
NodeId: int64(stat.NodeId),
NodeName: nodeName,
CountRequests: int64(stat.CountRequests),
Bytes: int64(stat.Bytes),
})
}
}
// 地区流量排行
if isPlus {
this.BeginTag(ctx, "SharedServerRegionCountryDailyStatDAO.SumDailyTotalBytes")
totalCountryBytes, err := stats.SharedServerRegionCountryDailyStatDAO.SumDailyTotalBytes(tx, timeutil.Format("Ymd"))
this.EndTag(ctx, "SharedServerRegionCountryDailyStatDAO.SumDailyTotalBytes")
if err != nil {
return err
}
if totalCountryBytes > 0 {
topCountryStats, err := stats.SharedServerRegionCountryDailyStatDAO.ListSumStats(tx, timeutil.Format("Ymd"), "bytes", 0, 100)
if err != nil {
return err
}
for _, stat := range topCountryStats {
countryName, err := regions.SharedRegionCountryDAO.FindRegionCountryName(tx, int64(stat.CountryId))
if err != nil {
return err
}
result.TopCountryStats = append(result.TopCountryStats, &pb.ComposeAdminDashboardResponse_CountryStat{
CountryName: countryName,
Bytes: int64(stat.Bytes),
CountRequests: int64(stat.CountRequests),
AttackBytes: int64(stat.AttackBytes),
CountAttackRequests: int64(stat.CountAttackRequests),
Percent: float32(stat.Bytes*100) / float32(totalCountryBytes),
})
}
}
}
return nil
}

View File

@@ -0,0 +1,65 @@
package services
import (
"context"
"encoding/base64"
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
"github.com/TeaOSLab/EdgeAPI/internal/encrypt"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/assert"
"github.com/iwind/TeaGo/maps"
stringutil "github.com/iwind/TeaGo/utils/string"
"google.golang.org/grpc/metadata"
"testing"
"time"
)
func TestAdminService_Login(t *testing.T) {
a := assert.NewAssertion(t)
service := &AdminService{
debug: true,
}
resp, err := service.LoginAdmin(testCtx(t), &pb.LoginAdminRequest{
Username: "admin",
Password: stringutil.Md5("123456"),
})
if err != nil {
t.Fatal(err)
}
a.LogJSON(resp)
}
func TestAdminService_FindAdminFullname(t *testing.T) {
service := &AdminService{
debug: true,
}
resp, err := service.FindAdminFullname(testCtx(t), &pb.FindAdminFullnameRequest{AdminId: 1})
if err != nil {
t.Fatal(err)
}
t.Log(resp)
}
func testCtx(t *testing.T) context.Context {
ctx := context.Background()
nodeId := "H6sjDf779jimnVPnBFSgZxvr6Ca0wQ0z"
token := maps.Map{
"timestamp": time.Now().Unix(),
"adminId": 1,
}
data := token.AsJSON()
method, err := encrypt.NewMethodInstance(teaconst.EncryptMethod, "hMHjmEng0SIcT3yiA3HIoUjogwAC9cur", nodeId)
if err != nil {
t.Fatal(err)
}
data, err = method.Encrypt(data)
if err != nil {
t.Fatal(err)
}
tokenString := base64.StdEncoding.EncodeToString(data)
ctx = metadata.AppendToOutgoingContext(ctx, "nodeId", nodeId, "token", tokenString)
return ctx
}

View File

@@ -0,0 +1,83 @@
package services
import (
"context"
"errors"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
// APIAccessTokenService AccessToken相关服务
type APIAccessTokenService struct {
BaseService
}
// GetAPIAccessToken 获取AccessToken
func (this *APIAccessTokenService) GetAPIAccessToken(ctx context.Context, req *pb.GetAPIAccessTokenRequest) (*pb.GetAPIAccessTokenResponse, error) {
if req.Type != "user" && req.Type != "admin" {
return nil, errors.New("unsupported type '" + req.Type + "'")
}
var tx = this.NullTx()
accessKey, err := models.SharedUserAccessKeyDAO.FindAccessKeyWithUniqueId(tx, req.AccessKeyId)
if err != nil {
return nil, err
}
if accessKey == nil {
return nil, errors.New("access key not found")
}
if accessKey.Secret != req.AccessKey {
return nil, errors.New("access key not found")
}
// 检查数据
switch req.Type {
case "user":
// TODO 将来支持子用户
if accessKey.UserId == 0 {
return nil, errors.New("access key not found")
}
// 检查用户状态
user, err := models.SharedUserDAO.FindEnabledUser(tx, int64(accessKey.UserId), nil)
if err != nil {
return nil, err
}
if user == nil || !user.IsOn {
return nil, errors.New("the user is not available")
}
case "admin":
if accessKey.AdminId == 0 {
return nil, errors.New("access key not found")
}
// 检查管理员状态
admin, err := models.SharedAdminDAO.FindEnabledAdmin(tx, int64(accessKey.AdminId))
if err != nil {
return nil, err
}
if admin == nil || !admin.IsOn {
return nil, errors.New("the admin is not available")
}
default:
return nil, errors.New("invalid type '" + req.Type + "'")
}
// 更新AccessKey访问时间
err = models.SharedUserAccessKeyDAO.UpdateAccessKeyAccessedAt(tx, int64(accessKey.Id))
if err != nil {
return nil, err
}
// 创建AccessToken
token, expiresAt, err := models.SharedAPIAccessTokenDAO.GenerateAccessToken(tx, int64(accessKey.AdminId), int64(accessKey.UserId))
if err != nil {
return nil, err
}
return &pb.GetAPIAccessTokenResponse{
Token: token,
ExpiresAt: expiresAt,
}, nil
}

View File

@@ -0,0 +1,83 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package services
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
timeutil "github.com/iwind/TeaGo/utils/time"
)
// APIMethodStatService API方法统计服务
type APIMethodStatService struct {
BaseService
}
// FindAPIMethodStatsWithDay 查找某天的统计
func (this *APIMethodStatService) FindAPIMethodStatsWithDay(ctx context.Context, req *pb.FindAPIMethodStatsWithDayRequest) (*pb.FindAPIMethodStatsWithDayResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var day = req.Day
if len(day) == 0 {
day = timeutil.Format("Ymd")
}
var tx = this.NullTx()
stats, err := models.SharedAPIMethodStatDAO.FindAllStatsWithDay(tx, day)
if err != nil {
return nil, err
}
var pbStats = []*pb.APIMethodStat{}
var cacheMap = utils.NewCacheMap()
for _, stat := range stats {
apiNode, err := models.SharedAPINodeDAO.FindEnabledAPINode(tx, int64(stat.ApiNodeId), cacheMap)
if err != nil {
return nil, err
}
if apiNode == nil {
continue
}
pbStats = append(pbStats, &pb.APIMethodStat{
Id: int64(stat.Id),
ApiNodeId: int64(stat.ApiNodeId),
Method: stat.Method,
Tag: stat.Tag,
CostMs: float32(stat.CostMs),
PeekMs: float32(stat.PeekMs),
CountCalls: int64(stat.CountCalls),
ApiNode: &pb.APINode{
Id: int64(apiNode.Id),
Name: apiNode.Name,
},
})
}
return &pb.FindAPIMethodStatsWithDayResponse{
ApiMethodStats: pbStats,
}, nil
}
// CountAPIMethodStatsWithDay 检查是否有统计数据
func (this *APIMethodStatService) CountAPIMethodStatsWithDay(ctx context.Context, req *pb.CountAPIMethodStatsWithDayRequest) (*pb.RPCCountResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var day = req.Day
if len(day) == 0 {
day = timeutil.Format("Ymd")
}
var tx = this.NullTx()
count, err := models.SharedAPIMethodStatDAO.CountAllStatsWithDay(tx, day)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}

View File

@@ -0,0 +1,588 @@
package services
import (
"compress/gzip"
"context"
"crypto/md5"
"errors"
"fmt"
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/installers"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
executils "github.com/TeaOSLab/EdgeAPI/internal/utils/exec"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs"
stringutil "github.com/iwind/TeaGo/utils/string"
"io"
"os"
"path/filepath"
"runtime"
)
type APINodeService struct {
BaseService
}
// CreateAPINode 创建API节点
func (this *APINodeService) CreateAPINode(ctx context.Context, req *pb.CreateAPINodeRequest) (*pb.CreateAPINodeResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
nodeId, err := models.SharedAPINodeDAO.CreateAPINode(tx, req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.RestIsOn, req.RestHTTPJSON, req.RestHTTPSJSON, req.AccessAddrsJSON, req.IsOn)
if err != nil {
return nil, err
}
return &pb.CreateAPINodeResponse{ApiNodeId: nodeId}, nil
}
// UpdateAPINode 修改API节点
func (this *APINodeService) UpdateAPINode(ctx context.Context, req *pb.UpdateAPINodeRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedAPINodeDAO.UpdateAPINode(tx, req.ApiNodeId, req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.RestIsOn, req.RestHTTPJSON, req.RestHTTPSJSON, req.AccessAddrsJSON, req.IsOn, req.IsPrimary)
if err != nil {
return nil, err
}
return this.Success()
}
// DeleteAPINode 删除API节点
func (this *APINodeService) DeleteAPINode(ctx context.Context, req *pb.DeleteAPINodeRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedAPINodeDAO.DisableAPINode(tx, req.ApiNodeId)
if err != nil {
return nil, err
}
return this.Success()
}
// FindAllEnabledAPINodes 列出所有可用API节点
func (this *APINodeService) FindAllEnabledAPINodes(ctx context.Context, req *pb.FindAllEnabledAPINodesRequest) (*pb.FindAllEnabledAPINodesResponse, error) {
_, _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeUser, rpcutils.UserTypeNode, rpcutils.UserTypeDNS, rpcutils.UserTypeAuthority)
if err != nil {
return nil, err
}
var tx = this.NullTx()
nodes, err := models.SharedAPINodeDAO.FindAllEnabledAPINodes(tx)
if err != nil {
return nil, err
}
result := []*pb.APINode{}
for _, node := range nodes {
accessAddrs, err := node.DecodeAccessAddrStrings()
if err != nil {
return nil, err
}
result = append(result, &pb.APINode{
Id: int64(node.Id),
IsOn: node.IsOn,
NodeClusterId: int64(node.ClusterId),
UniqueId: node.UniqueId,
Secret: node.Secret,
Name: node.Name,
Description: node.Description,
HttpJSON: node.Http,
HttpsJSON: node.Https,
AccessAddrsJSON: node.AccessAddrs,
AccessAddrs: accessAddrs,
IsPrimary: node.IsPrimary,
})
}
return &pb.FindAllEnabledAPINodesResponse{ApiNodes: result}, nil
}
// CountAllEnabledAPINodes 计算API节点数量
func (this *APINodeService) CountAllEnabledAPINodes(ctx context.Context, req *pb.CountAllEnabledAPINodesRequest) (*pb.RPCCountResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
count, err := models.SharedAPINodeDAO.CountAllEnabledAPINodes(tx)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// CountAllEnabledAndOnAPINodes 计算API节点数量
func (this *APINodeService) CountAllEnabledAndOnAPINodes(ctx context.Context, req *pb.CountAllEnabledAndOnAPINodesRequest) (*pb.RPCCountResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
count, err := models.SharedAPINodeDAO.CountAllEnabledAndOnAPINodes(tx)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListEnabledAPINodes 列出单页的API节点
func (this *APINodeService) ListEnabledAPINodes(ctx context.Context, req *pb.ListEnabledAPINodesRequest) (*pb.ListEnabledAPINodesResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
nodes, err := models.SharedAPINodeDAO.ListEnabledAPINodes(tx, req.Offset, req.Size)
if err != nil {
return nil, err
}
result := []*pb.APINode{}
for _, node := range nodes {
accessAddrs, err := node.DecodeAccessAddrStrings()
if err != nil {
return nil, err
}
result = append(result, &pb.APINode{
Id: int64(node.Id),
IsOn: node.IsOn,
NodeClusterId: int64(node.ClusterId),
UniqueId: node.UniqueId,
Secret: node.Secret,
Name: node.Name,
Description: node.Description,
HttpJSON: node.Http,
HttpsJSON: node.Https,
RestIsOn: node.RestIsOn == 1,
RestHTTPJSON: node.RestHTTP,
RestHTTPSJSON: node.RestHTTPS,
AccessAddrsJSON: node.AccessAddrs,
AccessAddrs: accessAddrs,
StatusJSON: node.Status,
IsPrimary: node.IsPrimary,
})
}
return &pb.ListEnabledAPINodesResponse{ApiNodes: result}, nil
}
// FindEnabledAPINode 根据ID查找节点
func (this *APINodeService) FindEnabledAPINode(ctx context.Context, req *pb.FindEnabledAPINodeRequest) (*pb.FindEnabledAPINodeResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
node, err := models.SharedAPINodeDAO.FindEnabledAPINode(tx, req.ApiNodeId, nil)
if err != nil {
return nil, err
}
if node == nil {
return &pb.FindEnabledAPINodeResponse{ApiNode: nil}, nil
}
accessAddrs, err := node.DecodeAccessAddrStrings()
if err != nil {
return nil, err
}
result := &pb.APINode{
Id: int64(node.Id),
IsOn: node.IsOn,
NodeClusterId: int64(node.ClusterId),
UniqueId: node.UniqueId,
Secret: node.Secret,
Name: node.Name,
Description: node.Description,
HttpJSON: node.Http,
HttpsJSON: node.Https,
RestIsOn: node.RestIsOn == 1,
RestHTTPJSON: node.RestHTTP,
RestHTTPSJSON: node.RestHTTPS,
AccessAddrsJSON: node.AccessAddrs,
AccessAddrs: accessAddrs,
IsPrimary: node.IsPrimary,
StatusJSON: node.Status,
}
return &pb.FindEnabledAPINodeResponse{ApiNode: result}, nil
}
// FindCurrentAPINodeVersion 获取当前API节点的版本
func (this *APINodeService) FindCurrentAPINodeVersion(ctx context.Context, req *pb.FindCurrentAPINodeVersionRequest) (*pb.FindCurrentAPINodeVersionResponse, error) {
role, _, _, err := rpcutils.ValidateRequest(ctx)
if err != nil {
return nil, err
}
return &pb.FindCurrentAPINodeVersionResponse{
Version: teaconst.Version,
Os: runtime.GOOS,
Arch: runtime.GOARCH,
Role: role,
}, nil
}
// FindCurrentAPINode 获取当前API节点的信息
func (this *APINodeService) FindCurrentAPINode(ctx context.Context, req *pb.FindCurrentAPINodeRequest) (*pb.FindCurrentAPINodeResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var nodeId = teaconst.NodeId
var tx *dbs.Tx
node, err := models.SharedAPINodeDAO.FindEnabledAPINode(tx, nodeId, nil)
if err != nil {
return nil, err
}
if node == nil {
return &pb.FindCurrentAPINodeResponse{ApiNode: nil}, nil
}
accessAddrs, err := node.DecodeAccessAddrStrings()
if err != nil {
return nil, err
}
return &pb.FindCurrentAPINodeResponse{ApiNode: &pb.APINode{
Id: int64(node.Id),
IsOn: node.IsOn,
NodeClusterId: 0,
UniqueId: "",
Secret: "",
Name: "",
Description: "",
HttpJSON: nil,
HttpsJSON: nil,
RestIsOn: false,
RestHTTPJSON: nil,
RestHTTPSJSON: nil,
AccessAddrsJSON: node.AccessAddrs,
AccessAddrs: accessAddrs,
StatusJSON: node.Status,
IsPrimary: node.IsPrimary,
InstanceCode: teaconst.InstanceCode,
}}, nil
}
// CountAllEnabledAPINodesWithSSLCertId 计算使用某个SSL证书的API节点数量
func (this *APINodeService) CountAllEnabledAPINodesWithSSLCertId(ctx context.Context, req *pb.CountAllEnabledAPINodesWithSSLCertIdRequest) (*pb.RPCCountResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
policyIds, err := models.SharedSSLPolicyDAO.FindAllEnabledPolicyIdsWithCertId(tx, req.SslCertId)
if err != nil {
return nil, err
}
if len(policyIds) == 0 {
return this.SuccessCount(0)
}
count, err := models.SharedAPINodeDAO.CountAllEnabledAPINodesWithSSLPolicyIds(tx, policyIds)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// DebugAPINode 修改调试模式状态
func (this *APINodeService) DebugAPINode(ctx context.Context, req *pb.DebugAPINodeRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
teaconst.Debug = req.Debug
return this.Success()
}
// UploadAPINodeFile 上传新版API节点文件
func (this *APINodeService) UploadAPINodeFile(ctx context.Context, req *pb.UploadAPINodeFileRequest) (*pb.UploadAPINodeFileResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
exe, err := os.Executable()
if err != nil {
return nil, errors.New("can not find executable file: " + err.Error())
}
var targetDir = filepath.Dir(exe)
var targetFilename = teaconst.ProcessName // 这里不使用 filepath.Base() 是因为文件名可能变成修改后的临时文件名
var targetCompressedFile = targetDir + "/." + targetFilename + ".gz"
var targetFile = targetDir + "/." + targetFilename
if req.IsFirstChunk {
_ = os.Remove(targetCompressedFile)
_ = os.Remove(targetFile)
}
if len(req.ChunkData) > 0 {
err = func() error {
var flags = os.O_CREATE | os.O_WRONLY
if req.IsFirstChunk {
flags |= os.O_TRUNC
} else {
flags |= os.O_APPEND
}
fp, err := os.OpenFile(targetCompressedFile, flags, 0666)
if err != nil {
return err
}
defer func() {
_ = fp.Close()
}()
_, err = fp.Write(req.ChunkData)
return err
}()
if err != nil {
return nil, errors.New("write file failed: " + err.Error())
}
}
if req.IsLastChunk {
err = func() error {
// 删除压缩文件
defer func() {
_ = os.Remove(targetCompressedFile)
}()
// 检查SUM
fp, err := os.Open(targetCompressedFile)
if err != nil {
return err
}
defer func() {
_ = fp.Close()
}()
var hash = md5.New()
_, err = io.Copy(hash, fp)
if err != nil {
return err
}
var sum = fmt.Sprintf("%x", hash.Sum(nil))
if sum != req.Sum {
return errors.New("check sum failed: '" + sum + "' expected: '" + req.Sum + "'")
}
// 解压
fp2, err := os.Open(targetCompressedFile)
if err != nil {
return err
}
defer func() {
_ = fp2.Close()
}()
gzipReader, err := gzip.NewReader(fp2)
if err != nil {
return err
}
defer func() {
_ = gzipReader.Close()
}()
targetWriter, err := os.OpenFile(targetFile, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0777)
if err != nil {
return err
}
defer func() {
_ = targetWriter.Close()
}()
_, err = io.Copy(targetWriter, gzipReader)
if err != nil {
return err
}
return nil
}()
if err != nil {
return nil, errors.New("extract file failed: " + err.Error())
}
// 检查文件是否可执行
var versionCmd = executils.NewCmd(targetFile, "-V").WithStdout().WithStderr()
err = versionCmd.Run()
if err != nil {
return nil, errors.New("test file failed: " + versionCmd.Stderr())
}
var newVersion = versionCmd.Stdout()
if len(newVersion) == 0 {
return nil, errors.New("test file failed, new version should not be empty")
}
// 检查版本
if stringutil.VersionCompare(newVersion, teaconst.Version) <= 0 {
return &pb.UploadAPINodeFileResponse{}, nil
}
// 替换文件
err = os.Remove(exe)
if err != nil {
return nil, errors.New("remove old file failed: " + err.Error())
}
err = os.Rename(targetFile, exe)
if err != nil {
return nil, errors.New("rename file failed: " + err.Error())
}
// 执行升级
if !Tea.IsTesting() { // 开发环境下防止破坏本地数据库
var upgradeCmd = executils.NewCmd(exe, "upgrade").WithStderr()
err = upgradeCmd.Run()
if err != nil {
return nil, errors.New("execute 'upgrade' command failed: " + upgradeCmd.Stderr())
}
}
// 重启
var restartCmd = executils.NewCmd(exe, "restart").WithStderr()
err = restartCmd.Start()
if err != nil {
return nil, errors.New("start new process failed: " + restartCmd.Stderr())
}
}
return &pb.UploadAPINodeFileResponse{}, nil
}
// UploadDeployFileToAPINode 上传节点安装文件
func (this *APINodeService) UploadDeployFileToAPINode(ctx context.Context, req *pb.UploadDeployFileToAPINodeRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var targetDir = Tea.Root + "/deploy/"
var targetTmpFile = targetDir + "/" + req.Filename + ".tmp"
var targetFile = targetDir + "/" + req.Filename
if req.IsFirstChunk {
_ = os.Remove(targetTmpFile)
}
if len(req.ChunkData) > 0 {
err = func() error {
var flags = os.O_CREATE | os.O_WRONLY
if req.IsFirstChunk {
flags |= os.O_TRUNC
} else {
flags |= os.O_APPEND
}
fp, err := os.OpenFile(targetTmpFile, flags, 0666)
if err != nil {
return err
}
defer func() {
_ = fp.Close()
}()
_, err = fp.Write(req.ChunkData)
return err
}()
if err != nil {
return nil, errors.New("write file failed: " + err.Error())
}
}
if req.IsLastChunk {
// 检查SUM
fp, err := os.Open(targetTmpFile)
if err != nil {
return nil, err
}
var hash = md5.New()
_, err = io.Copy(hash, fp)
_ = fp.Close()
if err != nil {
return nil, err
}
var tmpSum = fmt.Sprintf("%x", hash.Sum(nil))
if tmpSum != req.Sum {
_ = os.Remove(targetTmpFile)
return nil, errors.New("check sum failed")
}
// 正式改名
err = os.Rename(targetTmpFile, targetFile)
if err != nil {
return nil, errors.New("rename failed: " + err.Error())
}
// 重载数据
installers.SharedDeployManager.Reload()
}
return this.Success()
}
// FindLatestDeployFiles 查找已有节点安装文件信息
func (this *APINodeService) FindLatestDeployFiles(ctx context.Context, req *pb.FindLatestDeployFilesRequest) (*pb.FindLatestDeployFilesResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var pbNodeFiles = []*pb.FindLatestDeployFilesResponse_DeployFile{}
var nodeFiles = installers.SharedDeployManager.LoadNodeFiles()
for _, nodeFile := range nodeFiles {
pbNodeFiles = append(pbNodeFiles, &pb.FindLatestDeployFilesResponse_DeployFile{
Os: nodeFile.OS,
Arch: nodeFile.Arch,
Version: nodeFile.Version,
})
}
var pbNSNodeFiles = []*pb.FindLatestDeployFilesResponse_DeployFile{}
var nsNodeFiles = installers.SharedDeployManager.LoadNSNodeFiles()
for _, nodeFile := range nsNodeFiles {
pbNSNodeFiles = append(pbNSNodeFiles, &pb.FindLatestDeployFilesResponse_DeployFile{
Os: nodeFile.OS,
Arch: nodeFile.Arch,
Version: nodeFile.Version,
})
}
return &pb.FindLatestDeployFilesResponse{
NodeDeployFiles: pbNodeFiles,
NsNodeDeployFiles: pbNSNodeFiles,
}, nil
}

View File

@@ -0,0 +1,40 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package services
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
// APITokenService API令牌服务
type APITokenService struct {
BaseService
}
// FindAllEnabledAPITokens 获取API令牌
func (this *APITokenService) FindAllEnabledAPITokens(ctx context.Context, req *pb.FindAllEnabledAPITokensRequest) (*pb.FindAllEnabledAPITokensResponse, error) {
// 这里为了安全只允许通过API节点信息获取
_, _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAPI)
if err != nil {
return nil, err
}
var tx = this.NullTx()
apiTokens, err := models.SharedApiTokenDAO.FindAllEnabledAPITokens(tx, req.Role)
if err != nil {
return nil, err
}
var pbTokens = []*pb.APIToken{}
for _, token := range apiTokens {
pbTokens = append(pbTokens, &pb.APIToken{
Id: int64(token.Id),
NodeId: token.NodeId,
Secret: token.Secret,
Role: token.Role,
})
}
return &pb.FindAllEnabledAPITokensResponse{ApiTokens: pbTokens}, nil
}

View File

@@ -0,0 +1,212 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus
package services
import (
"context"
"encoding/json"
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/authority"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/systemconfigs"
plusutils "github.com/TeaOSLab/EdgePlus/pkg/utils"
"github.com/iwind/TeaGo/types"
timeutil "github.com/iwind/TeaGo/utils/time"
)
// AuthorityKeyService 版本认证
type AuthorityKeyService struct {
BaseService
}
// UpdateAuthorityKey 设置Key
func (this *AuthorityKeyService) UpdateAuthorityKey(ctx context.Context, req *pb.UpdateAuthorityKeyRequest) (*pb.RPCSuccess, error) {
_, _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeAuthority)
if err != nil {
return nil, err
}
var tx = this.NullTx()
key, err := plusutils.DecodeKey([]byte(req.Value))
if err != nil {
return nil, err
}
// 检查节点数是否超出限制
if key.Nodes > 0 {
countAuthorityNodes, err := models.SharedNodeDAO.CountAllAuthorityNodes(tx)
if err != nil {
return nil, err
}
if countAuthorityNodes > int64(key.Nodes) {
return nil, errors.New("nodes count in system (" + types.String(countAuthorityNodes) + ") is over limit (" + types.String(key.Nodes) + ")(系统内节点数(" + types.String(countAuthorityNodes) + ")已超出授权节点数(" + types.String(key.Nodes) + ")限制,请删除一些节点后再激活;如有疑问,请联系授权商)")
}
}
var addresses = []string{}
var macAddresses = key.MacAddresses
for _, addr := range macAddresses {
addresses = append(addresses, types.String(addr))
}
err = authority.SharedAuthorityKeyDAO.UpdateKey(tx, req.Value, req.RequestCode, key.DayFrom, key.DayTo, key.Hostname, addresses, key.Company)
if err != nil {
return nil, err
}
// 设置显示财务管理
if key.IsValid() {
adminConfig, err := models.SharedSysSettingDAO.ReadAdminUIConfig(tx, nil)
if err != nil {
return nil, err
}
if adminConfig != nil {
adminConfig.ShowFinance = true
adminConfigJSON, err := json.Marshal(adminConfig)
if err != nil {
return nil, err
}
err = models.SharedSysSettingDAO.UpdateSetting(tx, systemconfigs.SettingCodeAdminUIConfig, adminConfigJSON)
if err != nil {
return nil, err
}
}
}
return this.Success()
}
// ReadAuthorityKey 读取Key
func (this *AuthorityKeyService) ReadAuthorityKey(ctx context.Context, req *pb.ReadAuthorityKeyRequest) (*pb.ReadAuthorityKeyResponse, error) {
_, _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeProvider, rpcutils.UserTypeDNS)
if err != nil {
return nil, err
}
var tx = this.NullTx()
key, err := authority.SharedAuthorityKeyDAO.ReadKey(tx)
if err != nil {
return nil, err
}
if key == nil {
return &pb.ReadAuthorityKeyResponse{AuthorityKey: nil}, nil
}
if len(key.Value) == 0 {
return &pb.ReadAuthorityKeyResponse{AuthorityKey: nil}, nil
}
m, err := plusutils.DecodeKey([]byte(key.Value))
if err != nil {
return nil, err
}
if m.IsValid() {
teaconst.MaxNodes = int32(m.Nodes)
} else {
teaconst.MaxNodes = teaconst.DefaultMaxNodes
}
if len(m.Components) == 0 {
m.Components = []string{"*"}
}
return &pb.ReadAuthorityKeyResponse{AuthorityKey: &pb.AuthorityKey{
Value: key.Value,
DayFrom: m.DayFrom,
DayTo: m.DayTo,
Nodes: int32(m.Nodes),
Hostname: m.Hostname,
MacAddresses: m.MacAddresses,
Company: m.Company,
UpdatedAt: m.UpdatedAt,
Components: m.Components,
Edition: m.Edition,
RequestCode: m.RequestCode,
}}, nil
}
// ResetAuthorityKey 重置Key
func (this *AuthorityKeyService) ResetAuthorityKey(ctx context.Context, req *pb.ResetAuthorityKeyRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
err = authority.SharedAuthorityKeyDAO.ResetKey(nil, true)
if err != nil {
return nil, err
}
return this.Success()
}
// ValidateAuthorityKey 校验Key
func (this *AuthorityKeyService) ValidateAuthorityKey(ctx context.Context, req *pb.ValidateAuthorityKeyRequest) (*pb.ValidateAuthorityKeyResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
m, err := plusutils.DecodeKey([]byte(req.Key))
if err != nil {
return &pb.ValidateAuthorityKeyResponse{IsOk: false, Error: "错误的注册码"}, nil
}
var dayTo = m.DayTo
if dayTo < timeutil.Format("Y-m-d") {
return &pb.ValidateAuthorityKeyResponse{IsOk: false, Error: "注册码已于" + dayTo + "过期"}, nil
}
// remote activation
if m.Method == plusutils.MethodRemote {
ok, errorCode := authority.SharedAuthorityKeyDAO.ActivateRemotely(req.Key, req.RequestCode)
if ok {
return &pb.ValidateAuthorityKeyResponse{IsOk: true}, nil
} else {
if len(errorCode) > 0 {
return &pb.ValidateAuthorityKeyResponse{IsOk: false, Error: "远程校验失败(" + errorCode + "),请联系软件开发者"}, nil
}
return &pb.ValidateAuthorityKeyResponse{IsOk: false, Error: "远程校验失败,请联系软件开发者"}, nil
}
}
return &pb.ValidateAuthorityKeyResponse{IsOk: true}, nil
}
// CheckAuthority 检查版本信息
func (this *AuthorityKeyService) CheckAuthority(ctx context.Context, req *pb.CheckAuthorityRequest) (*pb.CheckAuthorityResponse, error) {
_, err := this.ValidateNode(ctx) // 目前仅支持边缘节点查询
if err != nil {
return nil, err
}
return &pb.CheckAuthorityResponse{
IsPlus: teaconst.IsPlus,
Edition: teaconst.Edition,
}, nil
}
// FindAuthorityQuota 查询授权容量
func (this *AuthorityKeyService) FindAuthorityQuota(ctx context.Context, req *pb.FindAuthorityQuotaRequest) (*pb.FindAuthorityQuotaResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var maxNodes = teaconst.DefaultMaxNodes
if teaconst.IsPlus && teaconst.MaxNodes > 0 {
maxNodes = teaconst.MaxNodes
}
var tx = this.NullTx()
countNodes, err := models.SharedNodeDAO.CountAllAuthorityNodes(tx)
if err != nil {
return nil, err
}
return &pb.FindAuthorityQuotaResponse{
MaxNodes: maxNodes,
CountNodes: types.Int32(countNodes),
}, nil
}

View File

@@ -0,0 +1,245 @@
package services
import (
"context"
"encoding/json"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/authority"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"google.golang.org/grpc/metadata"
"time"
)
type AuthorityNodeService struct {
BaseService
}
// CreateAuthorityNode 创建认证节点
func (this *AuthorityNodeService) CreateAuthorityNode(ctx context.Context, req *pb.CreateAuthorityNodeRequest) (*pb.CreateAuthorityNodeResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
nodeId, err := authority.SharedAuthorityNodeDAO.CreateAuthorityNode(tx, req.Name, req.Description, req.IsOn)
if err != nil {
return nil, err
}
return &pb.CreateAuthorityNodeResponse{AuthorityNodeId: nodeId}, nil
}
// UpdateAuthorityNode 修改认证节点
func (this *AuthorityNodeService) UpdateAuthorityNode(ctx context.Context, req *pb.UpdateAuthorityNodeRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = authority.SharedAuthorityNodeDAO.UpdateAuthorityNode(tx, req.AuthorityNodeId, req.Name, req.Description, req.IsOn)
if err != nil {
return nil, err
}
return this.Success()
}
// DeleteAuthorityNode 删除认证节点
func (this *AuthorityNodeService) DeleteAuthorityNode(ctx context.Context, req *pb.DeleteAuthorityNodeRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = authority.SharedAuthorityNodeDAO.DisableAuthorityNode(tx, req.AuthorityNodeId)
if err != nil {
return nil, err
}
return this.Success()
}
// FindAllEnabledAuthorityNodes 列出所有可用认证节点
func (this *AuthorityNodeService) FindAllEnabledAuthorityNodes(ctx context.Context, req *pb.FindAllEnabledAuthorityNodesRequest) (*pb.FindAllEnabledAuthorityNodesResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
nodes, err := authority.SharedAuthorityNodeDAO.FindAllEnabledAuthorityNodes(tx)
if err != nil {
return nil, err
}
result := []*pb.AuthorityNode{}
for _, node := range nodes {
result = append(result, &pb.AuthorityNode{
Id: int64(node.Id),
IsOn: node.IsOn,
UniqueId: node.UniqueId,
Secret: node.Secret,
Name: node.Name,
Description: node.Description,
})
}
return &pb.FindAllEnabledAuthorityNodesResponse{AuthorityNodes: result}, nil
}
// CountAllEnabledAuthorityNodes 计算认证节点数量
func (this *AuthorityNodeService) CountAllEnabledAuthorityNodes(ctx context.Context, req *pb.CountAllEnabledAuthorityNodesRequest) (*pb.RPCCountResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
count, err := authority.SharedAuthorityNodeDAO.CountAllEnabledAuthorityNodes(tx)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListEnabledAuthorityNodes 列出单页的认证节点
func (this *AuthorityNodeService) ListEnabledAuthorityNodes(ctx context.Context, req *pb.ListEnabledAuthorityNodesRequest) (*pb.ListEnabledAuthorityNodesResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
nodes, err := authority.SharedAuthorityNodeDAO.ListEnabledAuthorityNodes(tx, req.Offset, req.Size)
if err != nil {
return nil, err
}
result := []*pb.AuthorityNode{}
for _, node := range nodes {
result = append(result, &pb.AuthorityNode{
Id: int64(node.Id),
IsOn: node.IsOn,
UniqueId: node.UniqueId,
Secret: node.Secret,
Name: node.Name,
Description: node.Description,
StatusJSON: node.Status,
})
}
return &pb.ListEnabledAuthorityNodesResponse{AuthorityNodes: result}, nil
}
// FindEnabledAuthorityNode 根据ID查找节点
func (this *AuthorityNodeService) FindEnabledAuthorityNode(ctx context.Context, req *pb.FindEnabledAuthorityNodeRequest) (*pb.FindEnabledAuthorityNodeResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
node, err := authority.SharedAuthorityNodeDAO.FindEnabledAuthorityNode(tx, req.AuthorityNodeId)
if err != nil {
return nil, err
}
if node == nil {
return &pb.FindEnabledAuthorityNodeResponse{AuthorityNode: nil}, nil
}
result := &pb.AuthorityNode{
Id: int64(node.Id),
IsOn: node.IsOn,
UniqueId: node.UniqueId,
Secret: node.Secret,
Name: node.Name,
Description: node.Description,
}
return &pb.FindEnabledAuthorityNodeResponse{AuthorityNode: result}, nil
}
// FindCurrentAuthorityNode 获取当前认证节点的版本
func (this *AuthorityNodeService) FindCurrentAuthorityNode(ctx context.Context, req *pb.FindCurrentAuthorityNodeRequest) (*pb.FindCurrentAuthorityNodeResponse, error) {
_, err := this.ValidateAuthorityNode(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, errors.New("context: need 'nodeId'")
}
nodeIds := md.Get("nodeid")
if len(nodeIds) == 0 {
return nil, errors.New("invalid 'nodeId'")
}
nodeId := nodeIds[0]
node, err := authority.SharedAuthorityNodeDAO.FindEnabledAuthorityNodeWithUniqueId(tx, nodeId)
if err != nil {
return nil, err
}
if node == nil {
return &pb.FindCurrentAuthorityNodeResponse{AuthorityNode: nil}, nil
}
result := &pb.AuthorityNode{
Id: int64(node.Id),
IsOn: node.IsOn,
UniqueId: node.UniqueId,
Secret: node.Secret,
Name: node.Name,
Description: node.Description,
}
return &pb.FindCurrentAuthorityNodeResponse{AuthorityNode: result}, nil
}
// UpdateAuthorityNodeStatus 更新节点状态
func (this *AuthorityNodeService) UpdateAuthorityNodeStatus(ctx context.Context, req *pb.UpdateAuthorityNodeStatusRequest) (*pb.RPCSuccess, error) {
// 校验节点
_, nodeId, err := this.ValidateNodeId(ctx, rpcutils.UserTypeAuthority)
if err != nil {
return nil, err
}
if req.AuthorityNodeId > 0 {
nodeId = req.AuthorityNodeId
}
if nodeId <= 0 {
return nil, errors.New("'nodeId' should be greater than 0")
}
var tx = this.NullTx()
// 修改时间戳
var nodeStatus = &nodeconfigs.NodeStatus{}
err = json.Unmarshal(req.StatusJSON, nodeStatus)
if err != nil {
return nil, errors.New("decode node status json failed: " + err.Error())
}
nodeStatus.UpdatedAt = time.Now().Unix()
// 保存
err = authority.SharedAuthorityNodeDAO.UpdateNodeStatus(tx, nodeId, nodeStatus)
if err != nil {
return nil, err
}
return this.Success()
}

View File

@@ -0,0 +1,263 @@
package services
import (
"context"
"encoding/base64"
"encoding/json"
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/authority"
"github.com/TeaOSLab/EdgeAPI/internal/encrypt"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeAPI/internal/rpc"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/maps"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
type BaseService struct {
}
// ValidateAdmin 校验管理员
func (this *BaseService) ValidateAdmin(ctx context.Context) (adminId int64, err error) {
_, _, reqUserId, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
if err != nil {
return
}
return reqUserId, nil
}
// ValidateAdminAndUser 校验管理员和用户
func (this *BaseService) ValidateAdminAndUser(ctx context.Context, canRest bool) (adminId int64, userId int64, err error) {
reqUserType, _, reqUserId, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeUser)
if err != nil {
return
}
adminId = int64(0)
userId = int64(0)
switch reqUserType {
case rpcutils.UserTypeAdmin:
adminId = reqUserId
if adminId < 0 { // 允许AdminId = 0
err = errors.New("invalid 'adminId'")
return
}
case rpcutils.UserTypeUser:
userId = reqUserId
if userId < 0 { // 允许等于0
err = errors.New("invalid 'userId'")
return
}
default:
err = errors.New("invalid user type")
}
if err != nil {
return
}
if userId > 0 && !canRest && rpcutils.IsRest(ctx) {
err = errors.New("can not be called by rest")
return
}
return
}
// ValidateNode 校验边缘节点
func (this *BaseService) ValidateNode(ctx context.Context) (nodeId int64, err error) {
_, _, nodeId, err = rpcutils.ValidateRequest(ctx, rpcutils.UserTypeNode)
return
}
// ValidateNSNode 校验DNS节点
func (this *BaseService) ValidateNSNode(ctx context.Context) (nodeId int64, err error) {
_, _, nodeId, err = rpcutils.ValidateRequest(ctx, rpcutils.UserTypeDNS)
return
}
// ValidateUserNode 校验用户节点
func (this *BaseService) ValidateUserNode(ctx context.Context, canRest bool) (userId int64, err error) {
// 不允许REST调用
if !canRest && rpcutils.IsRest(ctx) {
err = errors.New("can not be called by rest")
return
}
_, _, userId, err = rpcutils.ValidateRequest(ctx, rpcutils.UserTypeUser)
return
}
// ValidateAuthorityNode 校验认证节点
func (this *BaseService) ValidateAuthorityNode(ctx context.Context) (nodeId int64, err error) {
_, _, nodeId, err = rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAuthority)
return
}
// ValidateNodeId 获取节点ID
func (this *BaseService) ValidateNodeId(ctx context.Context, roles ...rpcutils.UserType) (role rpcutils.UserType, nodeIntId int64, err error) {
// 默认包含大部分节点
if len(roles) == 0 {
roles = []rpcutils.UserType{rpcutils.UserTypeNode, rpcutils.UserTypeCluster, rpcutils.UserTypeAdmin, rpcutils.UserTypeUser, rpcutils.UserTypeDNS, rpcutils.UserTypeReport, rpcutils.UserTypeLog, rpcutils.UserTypeAPI}
}
if ctx == nil {
err = errors.New("context should not be nil")
role = rpcutils.UserTypeNone
return
}
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return rpcutils.UserTypeNone, 0, errors.New("context: need 'nodeId'")
}
nodeIds := md.Get("nodeid")
if len(nodeIds) == 0 || len(nodeIds[0]) == 0 {
return rpcutils.UserTypeNone, 0, errors.New("context: need 'nodeId'")
}
nodeId := nodeIds[0]
// 获取角色Node信息
// TODO 缓存节点ID相关信息
apiToken, err := models.SharedApiTokenDAO.FindEnabledTokenWithNode(nil, nodeId)
if err != nil {
return rpcutils.UserTypeNone, 0, err
}
if apiToken == nil {
return rpcutils.UserTypeNone, 0, errors.New("context: can not find api token for node '" + nodeId + "'")
}
if !lists.ContainsString(roles, apiToken.Role) {
return rpcutils.UserTypeNone, 0, errors.New("context: unsupported role '" + apiToken.Role + "'")
}
tokens := md.Get("token")
if len(tokens) == 0 || len(tokens[0]) == 0 {
return rpcutils.UserTypeNone, 0, errors.New("context: need 'token'")
}
token := tokens[0]
data, err := base64.StdEncoding.DecodeString(token)
if err != nil {
return rpcutils.UserTypeNone, 0, err
}
method, err := encrypt.NewMethodInstance(teaconst.EncryptMethod, apiToken.Secret, nodeId)
if err != nil {
utils.PrintError(err)
return rpcutils.UserTypeNone, 0, err
}
data, err = method.Decrypt(data)
if err != nil {
return rpcutils.UserTypeNone, 0, err
}
if len(data) == 0 {
return rpcutils.UserTypeNone, 0, errors.New("invalid token")
}
m := maps.Map{}
err = json.Unmarshal(data, &m)
if err != nil {
return rpcutils.UserTypeNone, 0, errors.New("decode token error: " + err.Error())
}
role = apiToken.Role
switch apiToken.Role {
case rpcutils.UserTypeNode:
nodeIntId, err = models.SharedNodeDAO.FindEnabledNodeIdWithUniqueId(nil, nodeId)
if err != nil {
return rpcutils.UserTypeNode, 0, errors.New("context: " + err.Error())
}
if nodeIntId <= 0 {
return rpcutils.UserTypeNode, 0, errors.New("context: not found node with id '" + nodeId + "'")
}
case rpcutils.UserTypeCluster:
nodeIntId, err = models.SharedNodeClusterDAO.FindEnabledClusterIdWithUniqueId(nil, nodeId)
if err != nil {
return rpcutils.UserTypeCluster, 0, errors.New("context: " + err.Error())
}
if nodeIntId <= 0 {
return rpcutils.UserTypeCluster, 0, errors.New("context: not found cluster with id '" + nodeId + "'")
}
case rpcutils.UserTypeUser:
nodeIntId, err = models.SharedUserNodeDAO.FindEnabledUserNodeIdWithUniqueId(nil, nodeId)
case rpcutils.UserTypeAdmin:
nodeIntId = 0
case rpcutils.UserTypeDNS:
nodeIntId, err = models.SharedNSNodeDAO.FindEnabledNodeIdWithUniqueId(nil, nodeId)
case rpcutils.UserTypeReport:
nodeIntId, err = models.SharedReportNodeDAO.FindEnabledNodeIdWithUniqueId(nil, nodeId)
case rpcutils.UserTypeAuthority:
nodeIntId, err = authority.SharedAuthorityNodeDAO.FindEnabledAuthorityNodeIdWithUniqueId(nil, nodeId)
default:
err = errors.New("unsupported user role '" + apiToken.Role + "'")
}
return
}
// Success 返回成功
func (this *BaseService) Success() (*pb.RPCSuccess, error) {
return &pb.RPCSuccess{}, nil
}
// SuccessCount 返回数字
func (this *BaseService) SuccessCount(count int64) (*pb.RPCCountResponse, error) {
return &pb.RPCCountResponse{Count: count}, nil
}
// Exists 返回是否存在
func (this *BaseService) Exists(b bool) (*pb.RPCExists, error) {
return &pb.RPCExists{Exists: b}, nil
}
// PermissionError 返回权限错误
func (this *BaseService) PermissionError() error {
return errors.New("Permission Denied")
}
func (this *BaseService) NotImplementedYet() error {
return status.Error(codes.Unimplemented, "not implemented yet")
}
// NullTx 空的数据库事务
func (this *BaseService) NullTx() *dbs.Tx {
return nil
}
// RunTx 在当前数据中执行一个事务
func (this *BaseService) RunTx(callback func(tx *dbs.Tx) error) error {
db, err := dbs.Default()
if err != nil {
return err
}
return db.RunTx(callback)
}
// BeginTag 开始标签统计
func (this *BaseService) BeginTag(ctx context.Context, name string) {
if !teaconst.Debug {
return
}
traceCtx, ok := ctx.(*rpc.Context)
if ok {
traceCtx.Begin(name)
}
}
// EndTag 结束标签统计
func (this *BaseService) EndTag(ctx context.Context, name string) {
if !teaconst.Debug {
return
}
traceCtx, ok := ctx.(*rpc.Context)
if ok {
traceCtx.End(name)
}
}

View File

@@ -0,0 +1,104 @@
package services
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/lists"
"strings"
)
// DBService 数据库相关服务
type DBService struct {
BaseService
}
// FindAllDBTables 获取所有表信息
func (this *DBService) FindAllDBTables(ctx context.Context, req *pb.FindAllDBTablesRequest) (*pb.FindAllDBTablesResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
db, err := dbs.Default()
if err != nil {
return nil, err
}
ones, _, err := db.FindPreparedOnes("SELECT * FROM information_schema.`TABLES` WHERE TABLE_SCHEMA=?", db.Name())
if err != nil {
return nil, err
}
pbTables := []*pb.DBTable{}
for _, one := range ones {
lowerTableName := strings.ToLower(one.GetString("TABLE_NAME"))
canDelete := false
canClean := false
if strings.HasPrefix(lowerTableName, "edgehttpaccesslogs_") {
canDelete = true
canClean = true
} else if lists.ContainsString([]string{"edgemessages", "edgelogs", "edgenodelogs", "edgemetricstats", "edgemetricsumstats", "edgeserverdomainhourlystats", "edgeserverregionprovincemonthlystats", "edgeserverregionprovidermonthlystats", "edgeserverregioncountrymonthlystats", "edgeserverregioncountrydailystats", "edgeserverregioncitymonthlystats", "edgeserverhttpfirewallhourlystats", "edgeserverhttpfirewalldailystats", "edgenodeclustertrafficdailystats", "edgenodetrafficdailystats", "edgenodetraffichourlystats", "edgensrecordhourlystats", "edgeserverclientbrowsermonthlystats", "edgeserverclientsystemmonthlystats"}, lowerTableName) || strings.HasPrefix(lowerTableName, "edgeserverdomainhourlystats_") || strings.HasPrefix(lowerTableName, "edgemetricstats_") || strings.HasPrefix(lowerTableName, "edgemetricsumstats_") {
canClean = true
}
pbTables = append(pbTables, &pb.DBTable{
Name: one.GetString("TABLE_NAME"),
Schema: one.GetString("TABLE_SCHEMA"),
Type: one.GetString("TABLE_TYPE"),
Engine: one.GetString("ENGINE"),
Rows: one.GetInt64("TABLE_ROWS"),
DataLength: one.GetInt64("DATA_LENGTH"),
IndexLength: one.GetInt64("INDEX_LENGTH"),
Comment: one.GetString("TABLE_COMMENT"),
Collation: one.GetString("TABLE_COLLATION"),
IsBaseTable: one.GetString("TABLE_TYPE") == "BASE TABLE",
CanClean: canClean,
CanDelete: canDelete,
})
}
return &pb.FindAllDBTablesResponse{DbTables: pbTables}, nil
}
// DeleteDBTable 删除表
func (this *DBService) DeleteDBTable(ctx context.Context, req *pb.DeleteDBTableRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
db, err := dbs.Default()
if err != nil {
return nil, err
}
// 检查是否能够删除
if !strings.HasPrefix(strings.ToLower(req.DbTable), "edgehttpaccesslogs_") {
return nil, errors.New("forbidden to delete the table")
}
_, err = db.Exec("DROP TABLE `" + req.DbTable + "`")
if err != nil {
return nil, err
}
return this.Success()
}
// TruncateDBTable 清空表
func (this *DBService) TruncateDBTable(ctx context.Context, req *pb.TruncateDBTableRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
db, err := dbs.Default()
if err != nil {
return nil, err
}
_, err = db.Exec("TRUNCATE TABLE `" + req.DbTable + "`")
if err != nil {
return nil, err
}
return this.Success()
}

View File

@@ -0,0 +1,346 @@
package services
import (
"context"
"errors"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/types"
"strings"
)
// DBNodeService 数据库节点相关服务
type DBNodeService struct {
BaseService
}
// CreateDBNode 创建数据库节点
func (this *DBNodeService) CreateDBNode(ctx context.Context, req *pb.CreateDBNodeRequest) (*pb.CreateDBNodeResponse, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
nodeId, err := models.SharedDBNodeDAO.CreateDBNode(tx, req.IsOn, req.Name, req.Description, req.Host, req.Port, req.Database, req.Username, req.Password, req.Charset)
if err != nil {
return nil, err
}
return &pb.CreateDBNodeResponse{DbNodeId: nodeId}, nil
}
// UpdateDBNode 修改数据库节点
func (this *DBNodeService) UpdateDBNode(ctx context.Context, req *pb.UpdateDBNodeRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedDBNodeDAO.UpdateNode(tx, req.DbNodeId, req.IsOn, req.Name, req.Description, req.Host, req.Port, req.Database, req.Username, req.Password, req.Charset)
if err != nil {
return nil, err
}
return this.Success()
}
// DeleteDBNode 删除节点
func (this *DBNodeService) DeleteDBNode(ctx context.Context, req *pb.DeleteDBNodeRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedDBNodeDAO.DisableDBNode(tx, req.DbNodeId)
if err != nil {
return nil, err
}
return this.Success()
}
// CountAllEnabledDBNodes 计算可用的数据库节点数量
func (this *DBNodeService) CountAllEnabledDBNodes(ctx context.Context, req *pb.CountAllEnabledDBNodesRequest) (*pb.RPCCountResponse, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
count, err := models.SharedDBNodeDAO.CountAllEnabledNodes(tx)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListEnabledDBNodes 列出单页的数据库节点
func (this *DBNodeService) ListEnabledDBNodes(ctx context.Context, req *pb.ListEnabledDBNodesRequest) (*pb.ListEnabledDBNodesResponse, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
nodes, err := models.SharedDBNodeDAO.ListEnabledNodes(tx, req.Offset, req.Size)
if err != nil {
return nil, err
}
result := []*pb.DBNode{}
for _, node := range nodes {
status := &pb.DBNodeStatus{}
// 是否能够连接
if node.IsOn {
db, err := dbs.NewInstanceFromConfig(node.DBConfig())
if err != nil {
status.Error = err.Error()
} else {
// 版本
version, _ := db.FindCol(0, "SELECT VERSION()")
status.Version = types.String(version)
one, err := db.FindOne("SELECT SUM(DATA_LENGTH+INDEX_LENGTH) AS size FROM information_schema.`TABLES` WHERE TABLE_SCHEMA=?", db.Name())
if err != nil {
status.Error = err.Error()
_ = db.Close()
} else if one == nil {
status.Error = "unable to read size from database server"
_ = db.Close()
} else {
status.IsOk = true
status.Size = one.GetInt64("size")
_ = db.Close()
}
}
}
result = append(result, &pb.DBNode{
Id: int64(node.Id),
Name: node.Name,
Description: node.Description,
IsOn: node.IsOn,
Host: node.Host,
Port: types.Int32(node.Port),
Database: node.Database,
Username: node.Username,
Password: node.Password,
Charset: node.Charset,
Status: status,
})
}
return &pb.ListEnabledDBNodesResponse{DbNodes: result}, nil
}
// FindEnabledDBNode 根据ID查找可用的数据库节点
func (this *DBNodeService) FindEnabledDBNode(ctx context.Context, req *pb.FindEnabledDBNodeRequest) (*pb.FindEnabledDBNodeResponse, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
node, err := models.SharedDBNodeDAO.FindEnabledDBNode(tx, req.DbNodeId)
if err != nil {
return nil, err
}
if node == nil {
return &pb.FindEnabledDBNodeResponse{DbNode: nil}, nil
}
return &pb.FindEnabledDBNodeResponse{DbNode: &pb.DBNode{
Id: int64(node.Id),
Name: node.Name,
Description: node.Description,
IsOn: node.IsOn,
Host: node.Host,
Port: types.Int32(node.Port),
Database: node.Database,
Username: node.Username,
Password: node.Password,
Charset: node.Charset,
}}, nil
}
// FindAllDBNodeTables 获取所有表信息
func (this *DBNodeService) FindAllDBNodeTables(ctx context.Context, req *pb.FindAllDBNodeTablesRequest) (*pb.FindAllDBNodeTablesResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
node, err := models.SharedDBNodeDAO.FindEnabledDBNode(tx, req.DbNodeId)
if err != nil {
return nil, err
}
if node == nil {
return nil, dbs.ErrNotFound
}
db, err := dbs.NewInstanceFromConfig(node.DBConfig())
if err != nil {
return nil, err
}
defer func() {
_ = db.Close()
}()
ones, _, err := db.FindPreparedOnes("SELECT * FROM information_schema.`TABLES` WHERE TABLE_SCHEMA=?", db.Name())
if err != nil {
return nil, err
}
pbTables := []*pb.DBTable{}
for _, one := range ones {
lowerTableName := strings.ToLower(one.GetString("TABLE_NAME"))
canDelete := false
canClean := false
if strings.HasPrefix(lowerTableName, "edgehttpaccesslogs_") || strings.HasPrefix(lowerTableName, "edgensaccesslogs_") {
canDelete = true
canClean = true
} else if lists.ContainsString([]string{"edgemessages", "edgelogs", "edgenodelogs", "edgemetricstats", "edgemetricsumstats", "edgeserverdomainhourlystats", "edgeserverregionprovincemonthlystats", "edgeserverregionprovidermonthlystats", "edgeserverregioncountrymonthlystats", "edgeserverregioncountrydailystats", "edgeserverregioncitymonthlystats", "edgeserverhttpfirewallhourlystats", "edgeserverhttpfirewalldailystats", "edgenodeclustertrafficdailystats", "edgenodetrafficdailystats", "edgenodetraffichourlystats", "edgensrecordhourlystats", "edgeserverclientbrowsermonthlystats", "edgeserverclientsystemmonthlystats"}, lowerTableName) || strings.HasPrefix(lowerTableName, "edgeserverdomainhourlystats_") || strings.HasPrefix(lowerTableName, "edgemetricstats_") || strings.HasPrefix(lowerTableName, "edgemetricsumstats_") {
canClean = true
}
pbTables = append(pbTables, &pb.DBTable{
Name: one.GetString("TABLE_NAME"),
Schema: one.GetString("TABLE_SCHEMA"),
Type: one.GetString("TABLE_TYPE"),
Engine: one.GetString("ENGINE"),
Rows: one.GetInt64("TABLE_ROWS"),
DataLength: one.GetInt64("DATA_LENGTH"),
IndexLength: one.GetInt64("INDEX_LENGTH"),
Comment: one.GetString("TABLE_COMMENT"),
Collation: one.GetString("TABLE_COLLATION"),
IsBaseTable: one.GetString("TABLE_TYPE") == "BASE TABLE",
CanClean: canClean,
CanDelete: canDelete,
})
}
return &pb.FindAllDBNodeTablesResponse{DbNodeTables: pbTables}, nil
}
// DeleteDBNodeTable 删除表
func (this *DBNodeService) DeleteDBNodeTable(ctx context.Context, req *pb.DeleteDBNodeTableRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
node, err := models.SharedDBNodeDAO.FindEnabledDBNode(tx, req.DbNodeId)
if err != nil {
return nil, err
}
if node == nil {
return nil, dbs.ErrNotFound
}
db, err := dbs.NewInstanceFromConfig(node.DBConfig())
if err != nil {
return nil, err
}
defer func() {
_ = db.Close()
}()
// 检查是否能够删除
if !strings.HasPrefix(strings.ToLower(req.DbNodeTable), "edgehttpaccesslogs_") && !strings.HasPrefix(strings.ToLower(req.DbNodeTable), "edgensaccesslogs_") {
return nil, errors.New("unable to delete the table")
}
_, err = db.Exec("DROP TABLE `" + req.DbNodeTable + "`")
if err != nil {
return nil, err
}
return this.Success()
}
// TruncateDBNodeTable 清空表
func (this *DBNodeService) TruncateDBNodeTable(ctx context.Context, req *pb.TruncateDBNodeTableRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
node, err := models.SharedDBNodeDAO.FindEnabledDBNode(tx, req.DbNodeId)
if err != nil {
return nil, err
}
if node == nil {
return nil, dbs.ErrNotFound
}
db, err := dbs.NewInstanceFromConfig(node.DBConfig())
if err != nil {
return nil, err
}
defer func() {
_ = db.Close()
}()
_, err = db.Exec("TRUNCATE TABLE `" + req.DbNodeTable + "`")
if err != nil {
return nil, err
}
return this.Success()
}
// CheckDBNodeStatus 检查数据库节点状态
func (this *DBNodeService) CheckDBNodeStatus(ctx context.Context, req *pb.CheckDBNodeStatusRequest) (*pb.CheckDBNodeStatusResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
node, err := models.SharedDBNodeDAO.FindEnabledDBNode(tx, req.DbNodeId)
if err != nil {
return nil, err
}
if node == nil {
return &pb.CheckDBNodeStatusResponse{DbNodeStatus: nil}, nil
}
status := &pb.DBNodeStatus{}
// 是否能够连接
if node.IsOn {
db, err := dbs.NewInstanceFromConfig(node.DBConfig())
if err != nil {
status.Error = err.Error()
} else {
// 版本
version, _ := db.FindCol(0, "SELECT VERSION()")
status.Version = types.String(version)
one, err := db.FindOne("SELECT SUM(DATA_LENGTH+INDEX_LENGTH) AS size FROM information_schema.`TABLES` WHERE TABLE_SCHEMA=?", db.Name())
if err != nil {
status.Error = err.Error()
_ = db.Close()
} else if one == nil {
status.Error = "unable to read size from database server"
_ = db.Close()
} else {
status.IsOk = true
status.Size = one.GetInt64("size")
_ = db.Close()
}
}
}
return &pb.CheckDBNodeStatusResponse{DbNodeStatus: status}, nil
}

View File

@@ -0,0 +1,19 @@
package services
import (
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/logs"
"testing"
)
func TestDBService_FindAllDBTables(t *testing.T) {
db, err := dbs.Default()
if err != nil {
t.Fatal(err)
}
ones, _, err := db.FindPreparedOnes("SELECT * FROM information_schema.`TABLES` WHERE TABLE_SCHEMA=?", db.Name())
if err != nil {
t.Fatal(err)
}
logs.PrintAsJSON(ones, t)
}

View File

@@ -0,0 +1,53 @@
package services
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/dns/dnsutils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
// DNSService DNS相关服务
type DNSService struct {
BaseService
}
// FindAllDNSIssues 查找问题
func (this *DNSService) FindAllDNSIssues(ctx context.Context, req *pb.FindAllDNSIssuesRequest) (*pb.FindAllDNSIssuesResponse, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var result = []*pb.DNSIssue{}
var tx = this.NullTx()
var clusters []*models.NodeCluster
if req.NodeClusterId <= 0 {
clusters, err = models.SharedNodeClusterDAO.FindAllEnabledClustersHaveDNSDomain(tx)
if err != nil {
return nil, err
}
} else {
cluster, err := models.SharedNodeClusterDAO.FindEnabledNodeCluster(tx, req.NodeClusterId)
if err != nil {
return nil, err
}
if cluster == nil {
return &pb.FindAllDNSIssuesResponse{Issues: nil}, nil
}
clusters = []*models.NodeCluster{cluster}
}
for _, cluster := range clusters {
issues, err := dnsutils.CheckClusterDNS(tx, cluster, true)
if err != nil {
return nil, err
}
if len(issues) > 0 {
result = append(result, issues...)
}
}
return &pb.FindAllDNSIssuesResponse{Issues: result}, nil
}

View File

@@ -0,0 +1,932 @@
package services
import (
"context"
"encoding/json"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/dns"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/dns/dnsutils"
"github.com/TeaOSLab/EdgeAPI/internal/dnsclients"
"github.com/TeaOSLab/EdgeAPI/internal/dnsclients/dnstypes"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeAPI/internal/goman"
"github.com/TeaOSLab/EdgeAPI/internal/utils/numberutils"
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/iputils"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/maps"
"net"
)
// DNSDomainService DNS域名相关服务
type DNSDomainService struct {
BaseService
}
// CreateDNSDomain 创建域名
func (this *DNSDomainService) CreateDNSDomain(ctx context.Context, req *pb.CreateDNSDomainRequest) (*pb.CreateDNSDomainResponse, error) {
// 校验请求
adminId, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 查询Provider
provider, err := dns.SharedDNSProviderDAO.FindEnabledDNSProvider(tx, req.DnsProviderId)
if err != nil {
return nil, err
}
if provider == nil {
return nil, errors.New("can not find provider")
}
apiParams, err := provider.DecodeAPIParams()
if err != nil {
return nil, err
}
domainId, err := dns.SharedDNSDomainDAO.CreateDomain(tx, adminId, userId, req.DnsProviderId, req.Name)
if err != nil {
return nil, err
}
// 更新数据,且不提示错误
goman.New(func() {
domainName := req.Name
providerInterface := dnsclients.FindProvider(provider.Type, int64(provider.Id))
if providerInterface == nil {
return
}
err = providerInterface.Auth(apiParams)
if err != nil {
// 这里我们刻意不提示错误
return
}
routes, err := providerInterface.GetRoutes(domainName)
if err != nil {
return
}
routesJSON, err := json.Marshal(routes)
if err != nil {
return
}
err = dns.SharedDNSDomainDAO.UpdateDomainRoutes(tx, domainId, routesJSON)
if err != nil {
return
}
records, err := providerInterface.GetRecords(domainName)
if err != nil {
return
}
recordsJSON, err := json.Marshal(records)
if err != nil {
return
}
err = dns.SharedDNSDomainDAO.UpdateDomainRecords(tx, domainId, recordsJSON)
if err != nil {
return
}
})
return &pb.CreateDNSDomainResponse{DnsDomainId: domainId}, nil
}
// UpdateDNSDomain 修改域名
func (this *DNSDomainService) UpdateDNSDomain(ctx context.Context, req *pb.UpdateDNSDomainRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = dns.SharedDNSDomainDAO.UpdateDomain(tx, req.DnsDomainId, req.Name, req.IsOn)
if err != nil {
return nil, err
}
return this.Success()
}
// DeleteDNSDomain 删除域名
func (this *DNSDomainService) DeleteDNSDomain(ctx context.Context, req *pb.DeleteDNSDomainRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = dns.SharedDNSDomainDAO.UpdateDomainIsDeleted(tx, req.DnsDomainId, true)
if err != nil {
return nil, err
}
return this.Success()
}
// RecoverDNSDomain 恢复删除的域名
func (this *DNSDomainService) RecoverDNSDomain(ctx context.Context, req *pb.RecoverDNSDomainRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = dns.SharedDNSDomainDAO.UpdateDomainIsDeleted(tx, req.DnsDomainId, false)
if err != nil {
return nil, err
}
return this.Success()
}
// FindDNSDomain 查询单个域名完整信息
func (this *DNSDomainService) FindDNSDomain(ctx context.Context, req *pb.FindDNSDomainRequest) (*pb.FindDNSDomainResponse, error) {
// 校验请求
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
domain, err := dns.SharedDNSDomainDAO.FindEnabledDNSDomain(tx, req.DnsDomainId, nil)
if err != nil {
return nil, err
}
if domain == nil {
return &pb.FindDNSDomainResponse{DnsDomain: nil}, nil
}
pbDomain, err := this.convertDomainToPB(tx, domain)
return &pb.FindDNSDomainResponse{DnsDomain: pbDomain}, nil
}
// FindBasicDNSDomain 查询单个域名基础信息
func (this *DNSDomainService) FindBasicDNSDomain(ctx context.Context, req *pb.FindBasicDNSDomainRequest) (*pb.FindBasicDNSDomainResponse, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
domain, err := dns.SharedDNSDomainDAO.FindEnabledDNSDomain(tx, req.DnsDomainId, nil)
if err != nil {
return nil, err
}
if domain == nil {
return &pb.FindBasicDNSDomainResponse{DnsDomain: nil}, nil
}
return &pb.FindBasicDNSDomainResponse{DnsDomain: &pb.DNSDomain{
Id: int64(domain.Id),
Name: domain.Name,
IsOn: domain.IsOn,
ProviderId: int64(domain.ProviderId),
}}, nil
}
// CountAllDNSDomainsWithDNSProviderId 计算服务商下的域名数量
func (this *DNSDomainService) CountAllDNSDomainsWithDNSProviderId(ctx context.Context, req *pb.CountAllDNSDomainsWithDNSProviderIdRequest) (*pb.RPCCountResponse, error) {
// 校验请求
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
count, err := dns.SharedDNSDomainDAO.CountAllEnabledDomainsWithProviderId(tx, req.DnsProviderId, req.IsDeleted, !req.IsDown)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// FindAllDNSDomainsWithDNSProviderId 列出服务商下的所有域名
func (this *DNSDomainService) FindAllDNSDomainsWithDNSProviderId(ctx context.Context, req *pb.FindAllDNSDomainsWithDNSProviderIdRequest) (*pb.FindAllDNSDomainsWithDNSProviderIdResponse, error) {
// 校验请求
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
domains, err := dns.SharedDNSDomainDAO.FindAllEnabledDomainsWithProviderId(tx, req.DnsProviderId)
if err != nil {
return nil, err
}
result := []*pb.DNSDomain{}
for _, domain := range domains {
pbDomain, err := this.convertDomainToPB(tx, domain)
if err != nil {
return nil, err
}
result = append(result, pbDomain)
}
return &pb.FindAllDNSDomainsWithDNSProviderIdResponse{DnsDomains: result}, nil
}
// FindAllBasicDNSDomainsWithDNSProviderId 列出服务商下的所有域名基本信息
func (this *DNSDomainService) FindAllBasicDNSDomainsWithDNSProviderId(ctx context.Context, req *pb.FindAllBasicDNSDomainsWithDNSProviderIdRequest) (*pb.FindAllBasicDNSDomainsWithDNSProviderIdResponse, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
domains, err := dns.SharedDNSDomainDAO.FindAllEnabledDomainsWithProviderId(tx, req.DnsProviderId)
if err != nil {
return nil, err
}
var result = []*pb.DNSDomain{}
for _, domain := range domains {
result = append(result, &pb.DNSDomain{
Id: int64(domain.Id),
Name: domain.Name,
IsOn: domain.IsOn,
IsUp: domain.IsUp,
IsDeleted: domain.IsDeleted,
})
}
return &pb.FindAllBasicDNSDomainsWithDNSProviderIdResponse{DnsDomains: result}, nil
}
// ListBasicDNSDomainsWithDNSProviderId 列出服务商下的单页域名信息
func (this *DNSDomainService) ListBasicDNSDomainsWithDNSProviderId(ctx context.Context, req *pb.ListBasicDNSDomainsWithDNSProviderIdRequest) (*pb.ListDNSDomainsWithDNSProviderIdResponse, error) {
// 校验请求
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
domains, err := dns.SharedDNSDomainDAO.ListDomains(tx, req.DnsProviderId, req.IsDeleted, !req.IsDown, req.Offset, req.Size)
if err != nil {
return nil, err
}
var result = []*pb.DNSDomain{}
for _, domain := range domains {
pbDomain, err := this.convertDomainToPB(tx, domain)
if err != nil {
return nil, err
}
result = append(result, pbDomain)
}
return &pb.ListDNSDomainsWithDNSProviderIdResponse{DnsDomains: result}, nil
}
// SyncDNSDomainData 同步域名数据
func (this *DNSDomainService) SyncDNSDomainData(ctx context.Context, req *pb.SyncDNSDomainDataRequest) (*pb.SyncDNSDomainDataResponse, error) {
// 校验请求
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var latestVersion = dns.SharedDNSTaskDAO.GenerateVersion()
resp, err := this.syncClusterDNS(req)
if err != nil {
return resp, err
}
// 标记集群所有任务已完成
if req.NodeClusterId > 0 && resp != nil && resp.IsOk {
var tx = this.NullTx()
err = dns.SharedDNSTaskDAO.UpdateClusterDNSTasksDone(tx, req.NodeClusterId, latestVersion)
if err != nil {
return resp, err
}
}
return resp, err
}
// FindAllDNSDomainRoutes 查看支持的线路
func (this *DNSDomainService) FindAllDNSDomainRoutes(ctx context.Context, req *pb.FindAllDNSDomainRoutesRequest) (*pb.FindAllDNSDomainRoutesResponse, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
routes, err := dns.SharedDNSDomainDAO.FindDomainRoutes(tx, req.DnsDomainId)
if err != nil {
return nil, err
}
pbRoutes := []*pb.DNSRoute{}
for _, route := range routes {
pbRoutes = append(pbRoutes, &pb.DNSRoute{
Name: route.Name,
Code: route.Code,
})
}
return &pb.FindAllDNSDomainRoutesResponse{Routes: pbRoutes}, nil
}
// ExistAvailableDomains 判断是否有域名可选
func (this *DNSDomainService) ExistAvailableDomains(ctx context.Context, req *pb.ExistAvailableDomainsRequest) (*pb.ExistAvailableDomainsResponse, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
exist, err := dns.SharedDNSDomainDAO.ExistAvailableDomains(tx)
if err != nil {
return nil, err
}
return &pb.ExistAvailableDomainsResponse{Exist: exist}, nil
}
// 转换域名信息
func (this *DNSDomainService) convertDomainToPB(tx *dbs.Tx, domain *dns.DNSDomain) (*pb.DNSDomain, error) {
var domainId = int64(domain.Id)
defaultRoute, err := dnsutils.FindDefaultDomainRoute(tx, domain)
if err != nil {
return nil, err
}
records := []*dnstypes.Record{}
if models.IsNotNull(domain.Records) {
err := json.Unmarshal(domain.Records, &records)
if err != nil {
return nil, err
}
}
// 集群域名
countNodeRecords := 0
nodesChanged := false
// 服务域名
countServerRecords := 0
serversChanged := false
// 检查是否所有的集群都已经被解析
clusters, err := models.SharedNodeClusterDAO.FindAllEnabledClustersWithDNSDomainId(tx, domainId)
if err != nil {
return nil, err
}
countClusters := len(clusters)
countAllNodes1 := int64(0)
countAllServers1 := int64(0)
for _, cluster := range clusters {
_, nodeRecords, serverRecords, countAllNodes, countAllServers, nodesChanged2, serversChanged2, err := this.findClusterDNSChanges(cluster, records, domain.Name, defaultRoute)
if err != nil {
return nil, err
}
countNodeRecords += len(nodeRecords)
countServerRecords += len(serverRecords)
countAllNodes1 += countAllNodes
countAllServers1 += countAllServers
if nodesChanged2 {
nodesChanged = true
}
if serversChanged2 {
serversChanged = true
}
}
// 线路
routes, err := domain.DecodeRoutes()
if err != nil {
return nil, err
}
pbRoutes := []*pb.DNSRoute{}
for _, route := range routes {
pbRoutes = append(pbRoutes, &pb.DNSRoute{
Name: route.Name,
Code: route.Code,
})
}
return &pb.DNSDomain{
Id: int64(domain.Id),
ProviderId: int64(domain.ProviderId),
Name: domain.Name,
IsOn: domain.IsOn,
IsUp: domain.IsUp,
IsDeleted: domain.IsDeleted,
DataUpdatedAt: int64(domain.DataUpdatedAt),
CountNodeRecords: int64(countNodeRecords),
NodesChanged: nodesChanged,
CountServerRecords: int64(countServerRecords),
ServersChanged: serversChanged,
Routes: pbRoutes,
CountNodeClusters: int64(countClusters),
CountAllNodes: countAllNodes1,
CountAllServers: countAllServers1,
}, nil
}
// 检查集群节点变化
func (this *DNSDomainService) findClusterDNSChanges(cluster *models.NodeCluster, records []*dnstypes.Record, domainName string, defaultRoute string) (result []maps.Map, doneNodeRecords []*dnstypes.Record, doneServerRecords []*dnstypes.Record, countAllNodes int64, countAllServers int64, nodesChanged bool, serversChanged bool, err error) {
var clusterId = int64(cluster.Id)
var clusterDnsName = cluster.DnsName
var clusterDomain = clusterDnsName + "." + domainName
dnsConfig, err := cluster.DecodeDNSConfig()
if err != nil {
return nil, nil, nil, 0, 0, false, false, err
}
if dnsConfig == nil {
dnsConfig = dnsconfigs.DefaultClusterDNSConfig()
}
var tx = this.NullTx()
// 自动设置的cname记录
var ttl int32
var cnameRecords = dnsConfig.CNAMERecords
if dnsConfig.TTL > 0 {
ttl = dnsConfig.TTL
}
// 节点域名
nodes, err := models.SharedNodeDAO.FindAllEnabledNodesDNSWithClusterId(tx, clusterId, true, dnsConfig != nil && dnsConfig.IncludingLnNodes, true)
if err != nil {
return nil, nil, nil, 0, 0, false, false, err
}
countAllNodes = int64(len(nodes))
var nodeRecords = []*dnstypes.Record{} // 之所以用数组再存一遍是因为dnsName可能会重复
var nodeRecordMapping = map[string]*dnstypes.Record{} // value_route => *Record
for _, record := range records {
if (record.Type == dnstypes.RecordTypeA || record.Type == dnstypes.RecordTypeAAAA) && record.Name == clusterDnsName {
nodeRecords = append(nodeRecords, record)
nodeRecordMapping[record.Value+"_"+record.Route] = record
}
}
// 新增的节点域名
var nodeKeys = []string{}
var addingNodeRecordKeysMap = map[string]bool{} // clusterDnsName_type_ip_route
for _, node := range nodes {
shouldSkip, shouldOverwrite, ipAddressesStrings, err := models.SharedNodeDAO.CheckNodeIPAddresses(tx, node)
if err != nil {
return nil, nil, nil, 0, 0, false, false, err
}
if shouldSkip {
continue
}
routeCodes, err := node.DNSRouteCodesForDomainId(int64(cluster.DnsDomainId))
if err != nil {
return nil, nil, nil, 0, 0, false, false, err
}
if len(routeCodes) == 0 {
// 默认线路
if len(defaultRoute) > 0 {
routeCodes = []string{defaultRoute}
} else {
continue
}
}
if !shouldOverwrite {
ipAddresses, err := models.SharedNodeIPAddressDAO.FindNodeAccessAndUpIPAddresses(tx, int64(node.Id), nodeconfigs.NodeRoleNode)
if err != nil {
return nil, nil, nil, 0, 0, false, false, err
}
if len(ipAddresses) == 0 {
continue
}
for _, ipAddress := range ipAddresses {
// 检查专属节点
if !ipAddress.IsValidInCluster(clusterId) {
continue
}
var ip = ipAddress.DNSIP()
if len(ip) == 0 {
continue
}
if net.ParseIP(ip) == nil {
continue
}
ipAddressesStrings = append(ipAddressesStrings, ip)
}
}
if len(ipAddressesStrings) == 0 {
continue
}
for _, route := range routeCodes {
for _, ip := range ipAddressesStrings {
var key = ip + "_" + route
nodeKeys = append(nodeKeys, key)
record, ok := nodeRecordMapping[key]
if !ok {
var recordType = dnstypes.RecordTypeA
if iputils.IsIPv6(ip) {
recordType = dnstypes.RecordTypeAAAA
}
// 避免添加重复的记录
var fullKey = clusterDnsName + "_" + recordType + "_" + ip + "_" + route
if addingNodeRecordKeysMap[fullKey] {
continue
}
addingNodeRecordKeysMap[fullKey] = true
result = append(result, maps.Map{
"action": "create",
"record": &dnstypes.Record{
Id: "",
Name: clusterDnsName,
Type: recordType,
Value: ip,
Route: route,
TTL: ttl,
},
})
nodesChanged = true
} else {
doneNodeRecords = append(doneNodeRecords, record)
}
}
}
}
// 多余的节点域名
for _, record := range nodeRecords {
key := record.Value + "_" + record.Route
if !lists.ContainsString(nodeKeys, key) {
nodesChanged = true
result = append(result, maps.Map{
"action": "delete",
"record": record,
})
}
}
// 服务域名
servers, err := models.SharedServerDAO.FindAllServersDNSWithClusterId(tx, clusterId)
if err != nil {
return nil, nil, nil, 0, 0, false, false, err
}
countAllServers = int64(len(servers))
var serverRecords = []*dnstypes.Record{} // 之所以用数组再存一遍是因为dnsName可能会重复
var serverRecordsMap = map[string]*dnstypes.Record{} // dnsName => *Record
for _, record := range records {
if record.Type == dnstypes.RecordTypeCNAME && record.Value == clusterDomain+"." {
serverRecords = append(serverRecords, record)
serverRecordsMap[record.Name] = record
}
}
// 新增的域名
var serverDNSNames = []string{}
for _, server := range servers {
var dnsName = server.DnsName
if len(dnsName) == 0 {
return nil, nil, nil, 0, 0, false, false, errors.New("server '" + numberutils.FormatInt64(int64(server.Id)) + "' 'dnsName' should not empty")
}
serverDNSNames = append(serverDNSNames, dnsName)
record, ok := serverRecordsMap[dnsName]
if !ok {
serversChanged = true
result = append(result, maps.Map{
"action": "create",
"record": &dnstypes.Record{
Id: "",
Name: dnsName,
Type: dnstypes.RecordTypeCNAME,
Value: clusterDomain + ".",
Route: "", // 注意这里为空,需要在执行过程中获取默认值
TTL: ttl,
},
})
} else {
doneServerRecords = append(doneServerRecords, record)
}
}
// 自动设置的CNAME
for _, cnameRecord := range cnameRecords {
// 如果记录已存在,则跳过
if lists.ContainsString(serverDNSNames, cnameRecord) {
continue
}
serverDNSNames = append(serverDNSNames, cnameRecord)
record, ok := serverRecordsMap[cnameRecord]
if !ok {
serversChanged = true
result = append(result, maps.Map{
"action": "create",
"record": &dnstypes.Record{
Id: "",
Name: cnameRecord,
Type: dnstypes.RecordTypeCNAME,
Value: clusterDomain + ".",
Route: "", // 注意这里为空,需要在执行过程中获取默认值
TTL: ttl,
},
})
} else {
doneServerRecords = append(doneServerRecords, record)
}
}
// 多余的域名
for _, record := range serverRecords {
if !lists.ContainsString(serverDNSNames, record.Name) {
serversChanged = true
result = append(result, maps.Map{
"action": "delete",
"record": record,
})
}
}
return
}
// 执行同步
func (this *DNSDomainService) syncClusterDNS(req *pb.SyncDNSDomainDataRequest) (*pb.SyncDNSDomainDataResponse, error) {
var tx = this.NullTx()
// 查询集群信息
var err error
var clusters = []*models.NodeCluster{}
if req.NodeClusterId > 0 {
cluster, err := models.SharedNodeClusterDAO.FindEnabledNodeCluster(tx, req.NodeClusterId)
if err != nil {
return nil, err
}
if cluster == nil {
return &pb.SyncDNSDomainDataResponse{
IsOk: false,
Error: "找不到要同步的集群",
ShouldFix: false,
}, nil
}
if int64(cluster.DnsDomainId) != req.DnsDomainId {
return &pb.SyncDNSDomainDataResponse{
IsOk: false,
Error: "集群设置的域名和参数不符",
ShouldFix: false,
}, nil
}
clusters = append(clusters, cluster)
} else {
clusters, err = models.SharedNodeClusterDAO.FindAllEnabledClustersWithDNSDomainId(tx, req.DnsDomainId)
if err != nil {
return nil, err
}
}
// 域名信息
domain, err := dns.SharedDNSDomainDAO.FindEnabledDNSDomain(tx, req.DnsDomainId, nil)
if err != nil {
return nil, err
}
if domain == nil {
return &pb.SyncDNSDomainDataResponse{IsOk: false, Error: "找不到要操作的域名"}, nil
}
var domainId = int64(domain.Id)
var domainName = domain.Name
// 服务商信息
provider, err := dns.SharedDNSProviderDAO.FindEnabledDNSProvider(tx, int64(domain.ProviderId))
if err != nil {
return nil, err
}
if provider == nil {
return &pb.SyncDNSDomainDataResponse{IsOk: false, Error: "域名没有设置服务商"}, nil
}
var apiParams = maps.Map{}
if models.IsNotNull(provider.ApiParams) {
err = json.Unmarshal(provider.ApiParams, &apiParams)
if err != nil {
return nil, err
}
}
// 开始同步
var manager = dnsclients.FindProvider(provider.Type, int64(provider.Id))
if manager == nil {
return &pb.SyncDNSDomainDataResponse{IsOk: false, Error: "目前不支持'" + provider.Type + "'"}, nil
}
err = manager.Auth(apiParams)
if err != nil {
return &pb.SyncDNSDomainDataResponse{IsOk: false, Error: "调用API认证失败" + err.Error()}, nil
}
// 更新线路
routes, err := manager.GetRoutes(domainName)
if err != nil {
return &pb.SyncDNSDomainDataResponse{IsOk: false, Error: "获取线路失败:" + err.Error()}, nil
}
routesJSON, err := json.Marshal(routes)
if err != nil {
return nil, err
}
err = dns.SharedDNSDomainDAO.UpdateDomainRoutes(tx, domainId, routesJSON)
if err != nil {
return nil, err
}
// 检查集群设置
for _, cluster := range clusters {
issues, err := dnsutils.CheckClusterDNS(tx, cluster, req.CheckNodeIssues)
if err != nil {
return nil, err
}
if len(issues) > 0 {
return &pb.SyncDNSDomainDataResponse{IsOk: false, Error: "发现问题需要修复", ShouldFix: true}, nil
}
}
// 所有记录
records, err := manager.GetRecords(domainName)
if err != nil {
return &pb.SyncDNSDomainDataResponse{IsOk: false, Error: "获取域名解析记录失败:" + err.Error()}, nil
}
recordsJSON, err := json.Marshal(records)
if err != nil {
return nil, err
}
err = dns.SharedDNSDomainDAO.UpdateDomainRecords(tx, domainId, recordsJSON)
if err != nil {
return nil, err
}
// 对比变化
var allChanges = []maps.Map{}
for _, cluster := range clusters {
changes, _, _, _, _, _, _, err := this.findClusterDNSChanges(cluster, records, domainName, manager.DefaultRoute())
if err != nil {
return nil, err
}
allChanges = append(allChanges, changes...)
}
for _, change := range allChanges {
action := change.GetString("action")
record := change.Get("record").(*dnstypes.Record)
if len(record.Route) == 0 {
record.Route = manager.DefaultRoute()
}
switch action {
case "create":
err = manager.AddRecord(domainName, record)
if err != nil {
return &pb.SyncDNSDomainDataResponse{IsOk: false, Error: "创建域名记录失败:" + err.Error()}, nil
}
case "delete":
err = manager.DeleteRecord(domainName, record)
if err != nil {
return &pb.SyncDNSDomainDataResponse{IsOk: false, Error: "删除域名记录失败:" + err.Error()}, nil
}
}
}
// 重新更新记录
if len(allChanges) > 0 {
records, err := manager.GetRecords(domainName)
if err != nil {
return &pb.SyncDNSDomainDataResponse{IsOk: false, Error: "重新获取域名解析记录失败:" + err.Error()}, nil
}
recordsJSON, err := json.Marshal(records)
if err != nil {
return nil, err
}
err = dns.SharedDNSDomainDAO.UpdateDomainRecords(tx, domainId, recordsJSON)
if err != nil {
return nil, err
}
}
return &pb.SyncDNSDomainDataResponse{
IsOk: true,
}, nil
}
// ExistDNSDomainRecord 检查域名是否在记录中
func (this *DNSDomainService) ExistDNSDomainRecord(ctx context.Context, req *pb.ExistDNSDomainRecordRequest) (*pb.ExistDNSDomainRecordResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
isOk, err := dns.SharedDNSDomainDAO.ExistDomainRecord(tx, req.DnsDomainId, req.Name, req.Type, req.Route, req.Value)
if err != nil {
return nil, err
}
return &pb.ExistDNSDomainRecordResponse{IsOk: isOk}, nil
}
// SyncDNSDomainsFromProvider 从服务商同步域名
func (this *DNSDomainService) SyncDNSDomainsFromProvider(ctx context.Context, req *pb.SyncDNSDomainsFromProviderRequest) (*pb.SyncDNSDomainsFromProviderResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
provider, err := dns.SharedDNSProviderDAO.FindEnabledDNSProvider(tx, req.DnsProviderId)
if err != nil {
return nil, err
}
if provider == nil {
return nil, errors.New("can not find provider")
}
// 下线不存在的域名
oldDomains, err := dns.SharedDNSDomainDAO.FindAllEnabledDomainsWithProviderId(tx, req.DnsProviderId)
if err != nil {
return nil, err
}
dnsProvider := dnsclients.FindProvider(provider.Type, int64(provider.Id))
if dnsProvider == nil {
return nil, errors.New("provider type '" + provider.Type + "' is not supported yet")
}
params, err := provider.DecodeAPIParams()
if err != nil {
return nil, errors.New("decode params failed: " + err.Error())
}
err = dnsProvider.Auth(params)
if err != nil {
return nil, errors.New("auth failed: " + err.Error())
}
domainNames, err := dnsProvider.GetDomains()
if err != nil {
return nil, err
}
var hasChanges = false
// 创建或上线域名
for _, domainName := range domainNames {
domain, err := dns.SharedDNSDomainDAO.FindEnabledDomainWithName(tx, req.DnsProviderId, domainName)
if err != nil {
return nil, err
}
if domain == nil {
_, err = dns.SharedDNSDomainDAO.CreateDomain(tx, 0, 0, req.DnsProviderId, domainName)
if err != nil {
return nil, err
}
hasChanges = true
} else if !domain.IsUp {
err = dns.SharedDNSDomainDAO.UpdateDomainIsUp(tx, int64(domain.Id), true)
if err != nil {
return nil, err
}
hasChanges = true
}
}
// 将老的域名置为下线
for _, oldDomain := range oldDomains {
var domainName = oldDomain.Name
if oldDomain.IsUp && !lists.ContainsString(domainNames, domainName) {
err = dns.SharedDNSDomainDAO.UpdateDomainIsUp(tx, int64(oldDomain.Id), false)
if err != nil {
return nil, err
}
hasChanges = true
}
}
return &pb.SyncDNSDomainsFromProviderResponse{
HasChanges: hasChanges,
}, nil
}

View File

@@ -0,0 +1,272 @@
package services
import (
"context"
"encoding/json"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/dns"
"github.com/TeaOSLab/EdgeAPI/internal/dnsclients"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/maps"
)
// DNSProviderService DNS服务商相关服务
type DNSProviderService struct {
BaseService
}
// CreateDNSProvider 创建服务商
func (this *DNSProviderService) CreateDNSProvider(ctx context.Context, req *pb.CreateDNSProviderRequest) (*pb.CreateDNSProviderResponse, error) {
// 校验请求
adminId, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
providerId, err := dns.SharedDNSProviderDAO.CreateDNSProvider(tx, adminId, userId, req.Type, req.Name, req.ApiParamsJSON, req.MinTTL)
if err != nil {
return nil, err
}
return &pb.CreateDNSProviderResponse{DnsProviderId: providerId}, nil
}
// UpdateDNSProvider 修改服务商
func (this *DNSProviderService) UpdateDNSProvider(ctx context.Context, req *pb.UpdateDNSProviderRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
// TODO 校验权限
var tx = this.NullTx()
provider, err := dns.SharedDNSProviderDAO.FindEnabledDNSProvider(tx, req.DnsProviderId)
if err != nil {
return nil, err
}
if provider == nil {
// do nothing here
return this.Success()
}
// 恢复被掩码的数据
req.ApiParamsJSON, err = dnsclients.UnmaskAPIParams(provider.ApiParams, req.ApiParamsJSON)
if err != nil {
return nil, err
}
err = dns.SharedDNSProviderDAO.UpdateDNSProvider(tx, req.DnsProviderId, req.Name, req.ApiParamsJSON, req.MinTTL)
if err != nil {
return nil, err
}
return this.Success()
}
// CountAllEnabledDNSProviders 计算服务商数量
func (this *DNSProviderService) CountAllEnabledDNSProviders(ctx context.Context, req *pb.CountAllEnabledDNSProvidersRequest) (*pb.RPCCountResponse, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
req.UserId = userId
}
count, err := dns.SharedDNSProviderDAO.CountAllEnabledDNSProviders(tx, req.AdminId, req.UserId, req.Keyword, req.Domain, req.Type)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListEnabledDNSProviders 列出单页服务商信息
func (this *DNSProviderService) ListEnabledDNSProviders(ctx context.Context, req *pb.ListEnabledDNSProvidersRequest) (*pb.ListEnabledDNSProvidersResponse, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
// TODO 校验权限
var tx = this.NullTx()
if userId > 0 {
req.UserId = userId
}
providers, err := dns.SharedDNSProviderDAO.ListEnabledDNSProviders(tx, req.AdminId, req.UserId, req.Keyword, req.Domain, req.Type, req.Offset, req.Size)
if err != nil {
return nil, err
}
result := []*pb.DNSProvider{}
for _, provider := range providers {
result = append(result, &pb.DNSProvider{
Id: int64(provider.Id),
Name: provider.Name,
Type: provider.Type,
TypeName: dnsclients.FindProviderTypeName(provider.Type),
ApiParamsJSON: provider.ApiParams,
DataUpdatedAt: int64(provider.DataUpdatedAt),
MinTTL: int32(provider.MinTTL),
})
}
return &pb.ListEnabledDNSProvidersResponse{DnsProviders: result}, nil
}
// FindAllEnabledDNSProviders 查找所有的DNS服务商
func (this *DNSProviderService) FindAllEnabledDNSProviders(ctx context.Context, req *pb.FindAllEnabledDNSProvidersRequest) (*pb.FindAllEnabledDNSProvidersResponse, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
// TODO 校验权限
var tx = this.NullTx()
if userId > 0 {
req.UserId = userId
}
providers, err := dns.SharedDNSProviderDAO.FindAllEnabledDNSProviders(tx, req.AdminId, req.UserId)
if err != nil {
return nil, err
}
result := []*pb.DNSProvider{}
for _, provider := range providers {
result = append(result, &pb.DNSProvider{
Id: int64(provider.Id),
Name: provider.Name,
Type: provider.Type,
TypeName: dnsclients.FindProviderTypeName(provider.Type),
ApiParamsJSON: provider.ApiParams,
DataUpdatedAt: int64(provider.DataUpdatedAt),
MinTTL: int32(provider.MinTTL),
})
}
return &pb.FindAllEnabledDNSProvidersResponse{DnsProviders: result}, nil
}
// DeleteDNSProvider 删除服务商
func (this *DNSProviderService) DeleteDNSProvider(ctx context.Context, req *pb.DeleteDNSProviderRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
// TODO 校验权限
}
err = dns.SharedDNSProviderDAO.DisableDNSProvider(tx, req.DnsProviderId)
if err != nil {
return nil, err
}
return this.Success()
}
// FindEnabledDNSProvider 查找单个服务商
func (this *DNSProviderService) FindEnabledDNSProvider(ctx context.Context, req *pb.FindEnabledDNSProviderRequest) (*pb.FindEnabledDNSProviderResponse, error) {
// 校验请求
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
provider, err := dns.SharedDNSProviderDAO.FindEnabledDNSProvider(tx, req.DnsProviderId)
if err != nil {
return nil, err
}
if provider == nil {
return &pb.FindEnabledDNSProviderResponse{DnsProvider: nil}, nil
}
if req.MaskParams {
var providerObj = dnsclients.FindProvider(provider.Type, int64(provider.Id))
if providerObj != nil {
var paramsMap = maps.Map{}
if len(provider.ApiParams) > 0 {
err = json.Unmarshal(provider.ApiParams, &paramsMap)
if err != nil {
return nil, err
}
providerObj.MaskParams(paramsMap)
provider.ApiParams, err = json.Marshal(paramsMap)
if err != nil {
return nil, err
}
}
}
}
return &pb.FindEnabledDNSProviderResponse{
DnsProvider: &pb.DNSProvider{
Id: int64(provider.Id),
Name: provider.Name,
Type: provider.Type,
TypeName: dnsclients.FindProviderTypeName(provider.Type),
ApiParamsJSON: provider.ApiParams,
DataUpdatedAt: int64(provider.DataUpdatedAt),
MinTTL: int32(provider.MinTTL),
},
}, nil
}
// FindAllDNSProviderTypes 取得所有服务商类型
func (this *DNSProviderService) FindAllDNSProviderTypes(ctx context.Context, req *pb.FindAllDNSProviderTypesRequest) (*pb.FindAllDNSProviderTypesResponse, error) {
// 校验请求
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
result := []*pb.DNSProviderType{}
for _, t := range dnsclients.FindAllProviderTypes() {
result = append(result, &pb.DNSProviderType{
Name: t.GetString("name"),
Code: t.GetString("code"),
Description: t.GetString("description"),
})
}
return &pb.FindAllDNSProviderTypesResponse{ProviderTypes: result}, nil
}
// FindAllEnabledDNSProvidersWithType 取得某个类型的所有服务商
func (this *DNSProviderService) FindAllEnabledDNSProvidersWithType(ctx context.Context, req *pb.FindAllEnabledDNSProvidersWithTypeRequest) (*pb.FindAllEnabledDNSProvidersWithTypeResponse, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
providers, err := dns.SharedDNSProviderDAO.FindAllEnabledDNSProvidersWithType(tx, req.ProviderTypeCode)
if err != nil {
return nil, err
}
result := []*pb.DNSProvider{}
for _, provider := range providers {
result = append(result, &pb.DNSProvider{
Id: int64(provider.Id),
Name: provider.Name,
Type: provider.Type,
TypeName: dnsclients.FindProviderTypeName(provider.Type),
MinTTL: int32(provider.MinTTL),
})
}
return &pb.FindAllEnabledDNSProvidersWithTypeResponse{DnsProviders: result}, nil
}

View File

@@ -0,0 +1,135 @@
package services
import (
"context"
"fmt"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/dns"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
// DNSTaskService DNS同步相关任务
type DNSTaskService struct {
BaseService
}
// ExistsDNSTasks 检查是否有正在执行的任务
func (this *DNSTaskService) ExistsDNSTasks(ctx context.Context, req *pb.ExistsDNSTasksRequest) (*pb.ExistsDNSTasksResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
existDoingTasks, err := dns.SharedDNSTaskDAO.ExistDoingTasks(tx)
if err != nil {
return nil, err
}
existErrorTasks, err := dns.SharedDNSTaskDAO.ExistErrorTasks(tx)
if err != nil {
return nil, err
}
return &pb.ExistsDNSTasksResponse{
ExistTasks: existDoingTasks,
ExistError: existErrorTasks,
}, nil
}
// FindAllDoingDNSTasks 查找正在执行的所有任务
func (this *DNSTaskService) FindAllDoingDNSTasks(ctx context.Context, req *pb.FindAllDoingDNSTasksRequest) (*pb.FindAllDoingDNSTasksResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
tasks, err := dns.SharedDNSTaskDAO.FindAllDoingOrErrorTasks(tx, req.NodeClusterId)
if err != nil {
return nil, err
}
pbTasks := []*pb.DNSTask{}
for _, task := range tasks {
pbTask := &pb.DNSTask{
Id: int64(task.Id),
Type: task.Type,
IsDone: task.IsDone,
IsOk: task.IsOk,
Error: task.Error,
UpdatedAt: int64(task.UpdatedAt),
}
switch task.Type {
case dns.DNSTaskTypeClusterChange, dns.DNSTaskTypeClusterNodesChange, dns.DNSTaskTypeClusterRemoveDomain:
clusterName, err := models.SharedNodeClusterDAO.FindNodeClusterName(tx, int64(task.ClusterId))
if err != nil {
return nil, err
}
if len(clusterName) == 0 {
clusterName = "集群[" + fmt.Sprintf("%d", task.ClusterId) + "]"
}
pbTask.NodeCluster = &pb.NodeCluster{Id: int64(task.ClusterId), Name: clusterName}
case dns.DNSTaskTypeNodeChange:
nodeName, err := models.SharedNodeDAO.FindNodeName(tx, int64(task.NodeId))
if err != nil {
return nil, err
}
if len(nodeName) == 0 {
nodeName = "节点[" + fmt.Sprintf("%d", task.NodeId) + "]"
}
pbTask.Node = &pb.Node{Id: int64(task.NodeId), Name: nodeName}
case dns.DNSTaskTypeServerChange:
serverName, err := models.SharedServerDAO.FindEnabledServerName(tx, int64(task.ServerId))
if err != nil {
return nil, err
}
if len(serverName) == 0 {
serverName = "服务[" + fmt.Sprintf("%d", task.ServerId) + "]"
}
pbTask.Server = &pb.Server{Id: int64(task.ServerId), Name: serverName}
case dns.DNSTaskTypeDomainChange:
domainName, err := dns.SharedDNSDomainDAO.FindDNSDomainName(tx, int64(task.DomainId))
if err != nil {
return nil, err
}
if len(domainName) == 0 {
domainName = "域名[" + fmt.Sprintf("%d", task.DomainId) + "]"
}
pbTask.DnsDomain = &pb.DNSDomain{Id: int64(task.DomainId), Name: domainName}
}
pbTasks = append(pbTasks, pbTask)
}
return &pb.FindAllDoingDNSTasksResponse{DnsTasks: pbTasks}, nil
}
// DeleteDNSTask 删除任务
func (this *DNSTaskService) DeleteDNSTask(ctx context.Context, req *pb.DeleteDNSTaskRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
err = dns.SharedDNSTaskDAO.DeleteDNSTask(this.NullTx(), req.DnsTaskId)
if err != nil {
return nil, err
}
return this.Success()
}
// DeleteAllDNSTasks 删除所有同步任务
func (this *DNSTaskService) DeleteAllDNSTasks(ctx context.Context, req *pb.DeleteAllDNSTasksRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = dns.SharedDNSTaskDAO.DeleteAllDNSTasks(tx)
if err != nil {
return nil, err
}
return this.Success()
}

View File

@@ -0,0 +1,87 @@
package services
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
// FileService 文件相关服务
type FileService struct {
BaseService
}
// FindEnabledFile 查找文件
func (this *FileService) FindEnabledFile(ctx context.Context, req *pb.FindEnabledFileRequest) (*pb.FindEnabledFileResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
file, err := models.SharedFileDAO.FindEnabledFile(tx, req.FileId)
if err != nil {
return nil, err
}
if file == nil {
return &pb.FindEnabledFileResponse{File: nil}, nil
}
if !file.IsPublic {
// 校验权限
if userId > 0 && int64(file.UserId) != userId {
return nil, this.PermissionError()
}
}
return &pb.FindEnabledFileResponse{
File: &pb.File{
Id: int64(file.Id),
Filename: file.Filename,
Size: int64(file.Size),
CreatedAt: int64(file.CreatedAt),
IsPublic: file.IsPublic,
MimeType: file.MimeType,
Type: file.Type,
},
}, nil
}
// CreateFile 创建文件
func (this *FileService) CreateFile(ctx context.Context, req *pb.CreateFileRequest) (*pb.CreateFileResponse, error) {
adminId, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
fileId, err := models.SharedFileDAO.CreateFile(tx, adminId, userId, req.Type, "", req.Filename, req.Size, req.MimeType, req.IsPublic)
if err != nil {
return nil, err
}
return &pb.CreateFileResponse{FileId: fileId}, nil
}
// UpdateFileFinished 将文件置为已完成
func (this *FileService) UpdateFileFinished(ctx context.Context, req *pb.UpdateFileFinishedRequest) (*pb.RPCSuccess, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
err = models.SharedFileDAO.CheckUserFile(tx, userId, req.FileId)
if err != nil {
return nil, err
}
}
err = models.SharedFileDAO.UpdateFileIsFinished(tx, req.FileId)
if err != nil {
return nil, err
}
return this.Success()
}

View File

@@ -0,0 +1,80 @@
package services
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
// FileChunkService 文件片段相关服务
type FileChunkService struct {
BaseService
}
// CreateFileChunk 创建文件片段
func (this *FileChunkService) CreateFileChunk(ctx context.Context, req *pb.CreateFileChunkRequest) (*pb.CreateFileChunkResponse, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查权限
if userId > 0 {
err = models.SharedFileDAO.CheckUserFile(tx, userId, req.FileId)
if err != nil {
return nil, err
}
}
chunkId, err := models.SharedFileChunkDAO.CreateFileChunk(tx, req.FileId, req.Data)
if err != nil {
return nil, err
}
return &pb.CreateFileChunkResponse{FileChunkId: chunkId}, nil
}
// FindAllFileChunkIds 获取的一个文件的所有片段IDs
func (this *FileChunkService) FindAllFileChunkIds(ctx context.Context, req *pb.FindAllFileChunkIdsRequest) (*pb.FindAllFileChunkIdsResponse, error) {
// 校验请求
_, _, err := this.ValidateNodeId(ctx, rpcutils.UserTypeNode, rpcutils.UserTypeDNS, rpcutils.UserTypeAdmin, rpcutils.UserTypeUser)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 校验用户
// TODO
chunkIds, err := models.SharedFileChunkDAO.FindAllFileChunkIds(tx, req.FileId)
if err != nil {
return nil, err
}
return &pb.FindAllFileChunkIdsResponse{FileChunkIds: chunkIds}, nil
}
// DownloadFileChunk 下载文件片段
func (this *FileChunkService) DownloadFileChunk(ctx context.Context, req *pb.DownloadFileChunkRequest) (*pb.DownloadFileChunkResponse, error) {
// 校验请求
_, _, err := this.ValidateNodeId(ctx, rpcutils.UserTypeNode, rpcutils.UserTypeDNS, rpcutils.UserTypeAdmin, rpcutils.UserTypeUser)
if err != nil {
return nil, err
}
// TODO 校验用户
var tx = this.NullTx()
chunk, err := models.SharedFileChunkDAO.FindFileChunk(tx, req.FileChunkId)
if err != nil {
return nil, err
}
if chunk == nil {
return &pb.DownloadFileChunkResponse{FileChunk: nil}, nil
}
return &pb.DownloadFileChunkResponse{FileChunk: &pb.FileChunk{Data: chunk.Data}}, nil
}

View File

@@ -0,0 +1,339 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package services
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/regions"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/stats"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/types"
timeutil "github.com/iwind/TeaGo/utils/time"
"sort"
"strconv"
"time"
)
// FirewallService 防火墙全局服务
type FirewallService struct {
BaseService
}
// ComposeFirewallGlobalBoard 组合看板数据
func (this *FirewallService) ComposeFirewallGlobalBoard(ctx context.Context, req *pb.ComposeFirewallGlobalBoardRequest) (*pb.ComposeFirewallGlobalBoardResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var now = time.Now()
var day = timeutil.Format("Ymd")
var w = types.Int(timeutil.Format("w"))
if w == 0 {
w = 7
}
weekFrom := timeutil.Format("Ymd", now.AddDate(0, 0, -w+1))
weekTo := timeutil.Format("Ymd", now.AddDate(0, 0, -w+7))
var result = &pb.ComposeFirewallGlobalBoardResponse{}
var tx = this.NullTx()
countDailyLog, err := stats.SharedServerHTTPFirewallDailyStatDAO.SumDailyCount(tx, 0, 0, "log", day, day)
if err != nil {
return nil, err
}
result.CountDailyLogs = countDailyLog
countDailyBlock, err := stats.SharedServerHTTPFirewallDailyStatDAO.SumDailyCount(tx, 0, 0, "block", day, day)
if err != nil {
return nil, err
}
result.CountDailyBlocks = countDailyBlock
countDailyCaptcha, err := stats.SharedServerHTTPFirewallDailyStatDAO.SumDailyCount(tx, 0, 0, "captcha", day, day)
if err != nil {
return nil, err
}
result.CountDailyCaptcha = countDailyCaptcha
countWeeklyBlock, err := stats.SharedServerHTTPFirewallDailyStatDAO.SumDailyCount(tx, 0, 0, "block", weekFrom, weekTo)
if err != nil {
return nil, err
}
result.CountWeeklyBlocks = countWeeklyBlock
// 24小时趋势
var hourFrom = timeutil.Format("YmdH", time.Now().Add(-23*time.Hour))
var hourTo = timeutil.Format("YmdH")
hours, err := utils.RangeHours(hourFrom, hourTo)
if err != nil {
return nil, err
}
{
statList, err := stats.SharedServerHTTPFirewallHourlyStatDAO.FindHourlyStats(tx, 0, 0, "log", hourFrom, hourTo)
if err != nil {
return nil, err
}
m := map[string]int64{} // day => count
for _, stat := range statList {
m[stat.Hour] = int64(stat.Count)
}
for _, hour := range hours {
result.HourlyStats = append(result.HourlyStats, &pb.ComposeFirewallGlobalBoardResponse_HourlyStat{Hour: hour, CountLogs: m[hour]})
}
}
{
statList, err := stats.SharedServerHTTPFirewallHourlyStatDAO.FindHourlyStats(tx, 0, 0, "captcha", hourFrom, hourTo)
if err != nil {
return nil, err
}
m := map[string]int64{} // day => count
for _, stat := range statList {
m[stat.Hour] = int64(stat.Count)
}
for index, hour := range hours {
result.HourlyStats[index].CountCaptcha = m[hour]
}
}
{
statList, err := stats.SharedServerHTTPFirewallHourlyStatDAO.FindHourlyStats(tx, 0, 0, "block", hourFrom, hourTo)
if err != nil {
return nil, err
}
m := map[string]int64{} // day => count
for _, stat := range statList {
m[stat.Hour] = int64(stat.Count)
}
for index, hour := range hours {
result.HourlyStats[index].CountBlocks = m[hour]
}
}
// 14天趋势
dayFrom := timeutil.Format("Ymd", now.AddDate(0, 0, -14))
days, err := utils.RangeDays(dayFrom, day)
if err != nil {
return nil, err
}
{
statList, err := stats.SharedServerHTTPFirewallDailyStatDAO.FindDailyStats(tx, 0, 0, []string{"log", "tag"}, dayFrom, day)
if err != nil {
return nil, err
}
m := map[string]int64{} // day => count
for _, stat := range statList {
m[stat.Day] = int64(stat.Count)
}
for _, day := range days {
result.DailyStats = append(result.DailyStats, &pb.ComposeFirewallGlobalBoardResponse_DailyStat{Day: day, CountLogs: m[day]})
}
}
{
statList, err := stats.SharedServerHTTPFirewallDailyStatDAO.FindDailyStats(tx, 0, 0, []string{"captcha"}, dayFrom, day)
if err != nil {
return nil, err
}
m := map[string]int64{} // day => count
for _, stat := range statList {
m[stat.Day] = int64(stat.Count)
}
for index, day := range days {
result.DailyStats[index].CountCaptcha = m[day]
}
}
{
statList, err := stats.SharedServerHTTPFirewallDailyStatDAO.FindDailyStats(tx, 0, 0, []string{"block", "page"}, dayFrom, day)
if err != nil {
return nil, err
}
m := map[string]int64{} // day => count
for _, stat := range statList {
m[stat.Day] = int64(stat.Count)
}
for index, day := range days {
result.DailyStats[index].CountBlocks = m[day]
}
}
// 规则分组
var today = timeutil.Format("Ymd")
groupStats, err := stats.SharedServerHTTPFirewallDailyStatDAO.GroupDailyCount(tx, 0, 0, today, today, 0, 20)
if err != nil {
return nil, err
}
// 合并同名
var groupNamedStatsMap = map[string]*stats.ServerHTTPFirewallDailyStat{} // name => *ServerHTTPFirewallDailyStat
for _, stat := range groupStats {
ruleGroupName, err := models.SharedHTTPFirewallRuleGroupDAO.FindHTTPFirewallRuleGroupName(tx, int64(stat.HttpFirewallRuleGroupId))
if err != nil {
return nil, err
}
if len(ruleGroupName) == 0 {
continue
}
namedStat, ok := groupNamedStatsMap[ruleGroupName]
if ok {
namedStat.Count += stat.Count
} else {
groupNamedStatsMap[ruleGroupName] = stat
}
}
for ruleGroupName, stat := range groupNamedStatsMap {
result.HttpFirewallRuleGroups = append(result.HttpFirewallRuleGroups, &pb.ComposeFirewallGlobalBoardResponse_HTTPFirewallRuleGroupStat{
HttpFirewallRuleGroup: &pb.HTTPFirewallRuleGroup{Id: int64(stat.HttpFirewallRuleGroupId), Name: ruleGroupName},
Count: int64(stat.Count),
})
}
sort.Slice(result.HttpFirewallRuleGroups, func(i, j int) bool {
return result.HttpFirewallRuleGroups[i].Count > result.HttpFirewallRuleGroups[j].Count
})
if len(result.HttpFirewallRuleGroups) > 10 {
result.HttpFirewallRuleGroups = result.HttpFirewallRuleGroups[:10]
}
// 节点排行
topNodeStats, err := stats.SharedNodeTrafficHourlyStatDAO.FindTopNodeStatsWithAttack(tx, "node", hourFrom, hourTo, 10)
if err != nil {
return nil, err
}
for _, stat := range topNodeStats {
nodeName, err := models.SharedNodeDAO.FindNodeName(tx, int64(stat.NodeId))
if err != nil {
return nil, err
}
if len(nodeName) == 0 {
continue
}
result.TopNodeStats = append(result.TopNodeStats, &pb.ComposeFirewallGlobalBoardResponse_NodeStat{
NodeId: int64(stat.NodeId),
NodeName: nodeName,
CountRequests: int64(stat.CountRequests),
Bytes: int64(stat.Bytes),
CountAttackRequests: int64(stat.CountAttackRequests),
AttackBytes: int64(stat.AttackBytes),
})
}
// 域名排行
topDomainStats, err := stats.SharedServerDomainHourlyStatDAO.FindTopDomainStatsWithAttack(tx, hourFrom, hourTo, 10)
if err != nil {
return nil, err
}
for _, stat := range topDomainStats {
result.TopDomainStats = append(result.TopDomainStats, &pb.ComposeFirewallGlobalBoardResponse_DomainStat{
ServerId: int64(stat.ServerId),
Domain: stat.Domain,
CountRequests: int64(stat.CountRequests),
Bytes: int64(stat.Bytes),
CountAttackRequests: int64(stat.CountAttackRequests),
AttackBytes: int64(stat.AttackBytes),
})
}
// 地区流量排行
totalCountryRequests, err := stats.SharedServerRegionCountryDailyStatDAO.SumDailyTotalAttackRequests(tx, timeutil.Format("Ymd"))
if err != nil {
return nil, err
}
if totalCountryRequests > 0 {
topCountryStats, err := stats.SharedServerRegionCountryDailyStatDAO.ListSumStats(tx, timeutil.Format("Ymd"), "countAttackRequests", 0, 100)
if err != nil {
return nil, err
}
for _, stat := range topCountryStats {
countryName, err := regions.SharedRegionCountryDAO.FindRegionCountryName(tx, int64(stat.CountryId))
if err != nil {
return nil, err
}
result.TopCountryStats = append(result.TopCountryStats, &pb.ComposeFirewallGlobalBoardResponse_CountryStat{
CountryName: countryName,
Bytes: int64(stat.Bytes),
CountRequests: int64(stat.CountRequests),
AttackBytes: int64(stat.AttackBytes),
CountAttackRequests: int64(stat.CountAttackRequests),
Percent: float32(stat.CountAttackRequests*100) / float32(totalCountryRequests),
})
}
}
return result, nil
}
// NotifyHTTPFirewallEvent 发送告警(notify)消息
func (this *FirewallService) NotifyHTTPFirewallEvent(ctx context.Context, req *pb.NotifyHTTPFirewallEventRequest) (*pb.RPCSuccess, error) {
nodeId, err := this.ValidateNode(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
clusterId, err := models.SharedNodeDAO.FindNodeClusterId(tx, nodeId)
if err != nil {
return nil, err
}
if clusterId <= 0 {
return this.Success()
}
clusterName, err := models.SharedNodeClusterDAO.FindNodeClusterName(tx, clusterId)
if err != nil {
return nil, err
}
nodeName, err := models.SharedNodeDAO.FindNodeName(tx, nodeId)
if err != nil {
return nil, err
}
serverName, err := models.SharedServerDAO.FindEnabledServerName(tx, req.ServerId)
if err != nil {
return nil, err
}
ruleGroupName, err := models.SharedHTTPFirewallRuleGroupDAO.FindHTTPFirewallRuleGroupName(tx, req.HttpFirewallRuleGroupId)
if err != nil {
return nil, err
}
ruleSetName, err := models.SharedHTTPFirewallRuleSetDAO.FindHTTPFirewallRuleSetName(tx, req.HttpFirewallRuleSetId)
if err != nil {
return nil, err
}
msg := "集群:" + clusterName + "ID" + strconv.FormatInt(clusterId, 10) + "" +
"\n节点" + nodeName + "ID" + strconv.FormatInt(nodeId, 10) + "" +
"\n服务" + serverName + "ID" + strconv.FormatInt(req.ServerId, 10) + "" +
"\n规则分组" + ruleGroupName +
"\n规则集" + ruleSetName +
"\n时间" + timeutil.FormatTime("Y-m-d H:i:s", req.CreatedAt)
err = models.SharedMessageTaskDAO.CreateMessageTasks(tx, nodeconfigs.NodeRoleNode, clusterId, nodeId, req.ServerId, models.MessageTypeFirewallEvent, "触发防火墙事件", msg)
if err != nil {
return nil, err
}
return this.Success()
}
// CountFirewallDailyBlocks 读取当前Block动作次数
func (this *FirewallService) CountFirewallDailyBlocks(ctx context.Context, req *pb.CountFirewallDailyBlocksRequest) (*pb.CountFirewallDailyBlocksResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
var day = timeutil.Format("Ymd")
countDailyBlock, err := stats.SharedServerHTTPFirewallDailyStatDAO.SumDailyCount(tx, 0, 0, "block", day, day)
if err != nil {
return nil, err
}
return &pb.CountFirewallDailyBlocksResponse{
CountBlocks: countDailyBlock,
}, nil
}

View File

@@ -0,0 +1,224 @@
package services
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeAPI/internal/utils/regexputils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/lists"
"sync"
)
// HTTPAccessLogService 访问日志相关服务
type HTTPAccessLogService struct {
BaseService
}
// CreateHTTPAccessLogs 创建访问日志
func (this *HTTPAccessLogService) CreateHTTPAccessLogs(ctx context.Context, req *pb.CreateHTTPAccessLogsRequest) (*pb.CreateHTTPAccessLogsResponse, error) {
// 校验请求
_, _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeNode)
if err != nil {
return nil, err
}
if len(req.HttpAccessLogs) == 0 {
return &pb.CreateHTTPAccessLogsResponse{}, nil
}
var tx = this.NullTx()
if this.canWriteAccessLogsToDB() {
err = models.SharedHTTPAccessLogDAO.CreateHTTPAccessLogs(tx, req.HttpAccessLogs)
if err != nil {
return nil, err
}
}
err = this.writeAccessLogsToPolicy(req.HttpAccessLogs)
if err != nil {
return nil, err
}
return &pb.CreateHTTPAccessLogsResponse{}, nil
}
// ListHTTPAccessLogs 列出单页访问日志
func (this *HTTPAccessLogService) ListHTTPAccessLogs(ctx context.Context, req *pb.ListHTTPAccessLogsRequest) (*pb.ListHTTPAccessLogsResponse, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查服务ID
if userId > 0 {
req.UserId = userId
// 这里不用担心serverId <= 0 的情况因为如果userId>0则只会查询当前用户下的服务不会产生安全问题
if req.ServerId > 0 {
err = models.SharedServerDAO.CheckUserServer(tx, userId, req.ServerId)
if err != nil {
return nil, err
}
}
}
accessLogs, requestId, hasMore, err := models.SharedHTTPAccessLogDAO.ListAccessLogs(tx, req.Partition, req.RequestId, req.Size, req.Day, req.HourFrom, req.HourTo, req.NodeClusterId, req.NodeId, req.ServerId, req.Reverse, req.HasError, req.FirewallPolicyId, req.FirewallRuleGroupId, req.FirewallRuleSetId, req.HasFirewallPolicy, req.UserId, req.Keyword, req.Ip, req.Domain)
if err != nil {
return nil, err
}
var result = []*pb.HTTPAccessLog{}
var pbNodeMap = map[int64]*pb.Node{}
var pbClusterMap = map[int64]*pb.NodeCluster{}
for _, accessLog := range accessLogs {
a, err := accessLog.ToPB()
if err != nil {
return nil, err
}
// 节点 & 集群
pbNode, ok := pbNodeMap[a.NodeId]
if ok {
a.Node = pbNode
} else {
node, err := models.SharedNodeDAO.FindEnabledNode(tx, a.NodeId)
if err != nil {
return nil, err
}
if node != nil {
pbNode = &pb.Node{Id: int64(node.Id), Name: node.Name}
var clusterId = int64(node.ClusterId)
pbCluster, ok := pbClusterMap[clusterId]
if ok {
pbNode.NodeCluster = pbCluster
} else {
cluster, err := models.SharedNodeClusterDAO.FindEnabledNodeCluster(tx, clusterId)
if err != nil {
return nil, err
}
if cluster != nil {
pbCluster = &pb.NodeCluster{
Id: int64(cluster.Id),
Name: cluster.Name,
}
pbNode.NodeCluster = pbCluster
pbClusterMap[clusterId] = pbCluster
}
}
pbNodeMap[a.NodeId] = pbNode
a.Node = pbNode
}
}
result = append(result, a)
}
return &pb.ListHTTPAccessLogsResponse{
HttpAccessLogs: result,
AccessLogs: result, // TODO 仅仅为了兼容当用户节点版本大于0.0.8时可以删除
HasMore: hasMore,
RequestId: requestId,
}, nil
}
// FindHTTPAccessLog 查找单个日志
func (this *HTTPAccessLogService) FindHTTPAccessLog(ctx context.Context, req *pb.FindHTTPAccessLogRequest) (*pb.FindHTTPAccessLogResponse, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
accessLog, err := models.SharedHTTPAccessLogDAO.FindAccessLogWithRequestId(tx, req.RequestId)
if err != nil {
return nil, err
}
if accessLog == nil {
return &pb.FindHTTPAccessLogResponse{HttpAccessLog: nil}, nil
}
// 检查权限
if userId > 0 {
err = models.SharedServerDAO.CheckUserServer(tx, userId, int64(accessLog.ServerId))
if err != nil {
return nil, err
}
}
a, err := accessLog.ToPB()
if err != nil {
return nil, err
}
return &pb.FindHTTPAccessLogResponse{HttpAccessLog: a}, nil
}
// FindHTTPAccessLogPartitions 查找日志分区
func (this *HTTPAccessLogService) FindHTTPAccessLogPartitions(ctx context.Context, req *pb.FindHTTPAccessLogPartitionsRequest) (*pb.FindHTTPAccessLogPartitionsResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
if !regexputils.YYYYMMDD.MatchString(req.Day) {
return nil, errors.New("invalid 'day': " + req.Day)
}
var dbList = models.AllAccessLogDBs()
if len(dbList) == 0 {
return &pb.FindHTTPAccessLogPartitionsResponse{
Partitions: nil,
}, nil
}
var partitions = []int32{}
var locker sync.Mutex
var wg = sync.WaitGroup{}
wg.Add(len(dbList))
var lastErr error
for _, db := range dbList {
go func(db *dbs.DB) {
defer wg.Done()
names, err := models.SharedHTTPAccessLogManager.FindTableNames(db, req.Day)
if err != nil {
lastErr = err
}
for _, name := range names {
var partition = models.SharedHTTPAccessLogManager.TablePartition(name)
locker.Lock()
if !lists.Contains(partitions, partition) {
partitions = append(partitions, partition)
}
locker.Unlock()
}
}(db)
}
wg.Wait()
if lastErr != nil {
return nil, lastErr
}
var reversePartitions = []int32{}
for i := len(partitions) - 1; i >= 0; i-- {
reversePartitions = append(reversePartitions, partitions[i])
}
return &pb.FindHTTPAccessLogPartitionsResponse{
Partitions: partitions,
ReversePartitions: reversePartitions,
}, nil
}

View File

@@ -0,0 +1,14 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build !plus
package services
import "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
func (this *HTTPAccessLogService) canWriteAccessLogsToDB() bool {
return true
}
func (this *HTTPAccessLogService) writeAccessLogsToPolicy(pbAccessLogs []*pb.HTTPAccessLog) error {
return nil
}

View File

@@ -0,0 +1,39 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package services
import (
"github.com/TeaOSLab/EdgeAPI/internal/accesslogs"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/rands"
)
func (this *HTTPAccessLogService) canWriteAccessLogsToDB() bool {
return !accesslogs.SharedStorageManager.DisableDefaultDB()
}
func (this *HTTPAccessLogService) writeAccessLogsToPolicy(pbAccessLogs []*pb.HTTPAccessLog) error {
if len(pbAccessLogs) == 0 {
return nil
}
// 应用采样率
var percent = models.AccessLogQueuePercent()
if percent > 0 && percent < 99 {
var newAccessLogs = []*pb.HTTPAccessLog{}
for _, accessLog := range pbAccessLogs {
if rands.Int(1, 100) < percent {
newAccessLogs = append(newAccessLogs, accessLog)
}
}
if len(newAccessLogs) == 0 {
return nil
}
return accesslogs.SharedStorageManager.Write(newAccessLogs)
}
// 如果没有设置采样率,则写入全部
return accesslogs.SharedStorageManager.Write(pbAccessLogs)
}

View File

@@ -0,0 +1,169 @@
//go:build plus
package services
import (
"context"
"errors"
"github.com/TeaOSLab/EdgeAPI/internal/accesslogs"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
type HTTPAccessLogPolicyService struct {
BaseService
}
// CountAllHTTPAccessLogPolicies 计算访问日志策略数量
func (this *HTTPAccessLogPolicyService) CountAllHTTPAccessLogPolicies(ctx context.Context, req *pb.CountAllHTTPAccessLogPoliciesRequest) (*pb.RPCCountResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
count, err := models.SharedHTTPAccessLogPolicyDAO.CountAllEnabledPolicies(tx)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListHTTPAccessLogPolicies 列出单页访问日志策略
func (this *HTTPAccessLogPolicyService) ListHTTPAccessLogPolicies(ctx context.Context, req *pb.ListHTTPAccessLogPoliciesRequest) (*pb.ListHTTPAccessLogPoliciesResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
policies, err := models.SharedHTTPAccessLogPolicyDAO.ListEnabledPolicies(tx, req.Offset, req.Size)
if err != nil {
return nil, err
}
var pbPolicies = []*pb.HTTPAccessLogPolicy{}
for _, policy := range policies {
pbPolicies = append(pbPolicies, &pb.HTTPAccessLogPolicy{
Id: int64(policy.Id),
Name: policy.Name,
IsOn: policy.IsOn,
Type: policy.Type,
OptionsJSON: policy.Options,
CondsJSON: policy.Conds,
IsPublic: policy.IsPublic,
FirewallOnly: policy.FirewallOnly == 1,
DisableDefaultDB: policy.DisableDefaultDB,
})
}
return &pb.ListHTTPAccessLogPoliciesResponse{HttpAccessLogPolicies: pbPolicies}, nil
}
// CreateHTTPAccessLogPolicy 创建访问日志策略
func (this *HTTPAccessLogPolicyService) CreateHTTPAccessLogPolicy(ctx context.Context, req *pb.CreateHTTPAccessLogPolicyRequest) (*pb.CreateHTTPAccessLogPolicyResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 取消别的Public
if req.IsPublic {
err = models.SharedHTTPAccessLogPolicyDAO.CancelAllPublicPolicies(tx)
if err != nil {
return nil, err
}
}
// 创建
policyId, err := models.SharedHTTPAccessLogPolicyDAO.CreatePolicy(tx, req.Name, req.Type, req.OptionsJSON, req.CondsJSON, req.IsPublic, req.FirewallOnly, req.DisableDefaultDB)
if err != nil {
return nil, err
}
return &pb.CreateHTTPAccessLogPolicyResponse{HttpAccessLogPolicyId: policyId}, nil
}
// UpdateHTTPAccessLogPolicy 修改访问日志策略
func (this *HTTPAccessLogPolicyService) UpdateHTTPAccessLogPolicy(ctx context.Context, req *pb.UpdateHTTPAccessLogPolicyRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 取消别的Public
if req.IsPublic {
err = models.SharedHTTPAccessLogPolicyDAO.CancelAllPublicPolicies(tx)
if err != nil {
return nil, err
}
}
// 保存修改
err = models.SharedHTTPAccessLogPolicyDAO.UpdatePolicy(tx, req.HttpAccessLogPolicyId, req.Name, req.OptionsJSON, req.CondsJSON, req.IsPublic, req.FirewallOnly, req.DisableDefaultDB, req.IsOn)
if err != nil {
return nil, err
}
return this.Success()
}
// FindHTTPAccessLogPolicy 查找单个访问日志策略
func (this *HTTPAccessLogPolicyService) FindHTTPAccessLogPolicy(ctx context.Context, req *pb.FindHTTPAccessLogPolicyRequest) (*pb.FindHTTPAccessLogPolicyResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
policy, err := models.SharedHTTPAccessLogPolicyDAO.FindEnabledHTTPAccessLogPolicy(tx, req.HttpAccessLogPolicyId)
if err != nil {
return nil, err
}
if policy == nil {
return &pb.FindHTTPAccessLogPolicyResponse{HttpAccessLogPolicy: nil}, nil
}
return &pb.FindHTTPAccessLogPolicyResponse{HttpAccessLogPolicy: &pb.HTTPAccessLogPolicy{
Id: int64(policy.Id),
Name: policy.Name,
IsOn: policy.IsOn,
Type: policy.Type,
OptionsJSON: policy.Options,
CondsJSON: policy.Conds,
IsPublic: policy.IsPublic,
FirewallOnly: policy.FirewallOnly == 1,
DisableDefaultDB: policy.DisableDefaultDB,
}}, nil
}
// DeleteHTTPAccessLogPolicy 删除访问日志策略
func (this *HTTPAccessLogPolicyService) DeleteHTTPAccessLogPolicy(ctx context.Context, req *pb.DeleteHTTPAccessLogPolicyRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedHTTPAccessLogPolicyDAO.DisableHTTPAccessLogPolicy(tx, req.HttpAccessLogPolicyId)
if err != nil {
return nil, err
}
return this.Success()
}
// WriteHTTPAccessLogPolicy 测试写入某个访问日志策略
func (this *HTTPAccessLogPolicyService) WriteHTTPAccessLogPolicy(ctx context.Context, req *pb.WriteHTTPAccessLogPolicyRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
success, failMessage, err := accesslogs.SharedStorageManager.WriteToPolicy(req.HttpAccessLogPolicyId, []*pb.HTTPAccessLog{req.HttpAccessLog})
if err != nil {
return nil, err
}
if !success {
return nil, errors.New("test failed: " + failMessage)
}
return this.Success()
}

View File

@@ -0,0 +1,87 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package services
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
// HTTPAuthPolicyService 服务认证策略服务
type HTTPAuthPolicyService struct {
BaseService
}
// CreateHTTPAuthPolicy 创建策略
func (this *HTTPAuthPolicyService) CreateHTTPAuthPolicy(ctx context.Context, req *pb.CreateHTTPAuthPolicyRequest) (*pb.CreateHTTPAuthPolicyResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
policyId, err := models.SharedHTTPAuthPolicyDAO.CreateHTTPAuthPolicy(tx, userId, req.Name, req.Type, req.ParamsJSON)
if err != nil {
return nil, err
}
return &pb.CreateHTTPAuthPolicyResponse{HttpAuthPolicyId: policyId}, nil
}
// UpdateHTTPAuthPolicy 修改策略
func (this *HTTPAuthPolicyService) UpdateHTTPAuthPolicy(ctx context.Context, req *pb.UpdateHTTPAuthPolicyRequest) (*pb.RPCSuccess, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查用户权限
if userId > 0 {
err = models.SharedHTTPAuthPolicyDAO.CheckUserPolicy(tx, userId, req.HttpAuthPolicyId)
if err != nil {
return nil, err
}
}
err = models.SharedHTTPAuthPolicyDAO.UpdateHTTPAuthPolicy(tx, req.HttpAuthPolicyId, req.Name, req.ParamsJSON, req.IsOn)
if err != nil {
return nil, err
}
return this.Success()
}
// FindEnabledHTTPAuthPolicy 查找策略信息
func (this *HTTPAuthPolicyService) FindEnabledHTTPAuthPolicy(ctx context.Context, req *pb.FindEnabledHTTPAuthPolicyRequest) (*pb.FindEnabledHTTPAuthPolicyResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查用户权限
if userId > 0 {
err = models.SharedHTTPAuthPolicyDAO.CheckUserPolicy(tx, userId, req.HttpAuthPolicyId)
if err != nil {
return nil, err
}
}
policy, err := models.SharedHTTPAuthPolicyDAO.FindEnabledHTTPAuthPolicy(tx, req.HttpAuthPolicyId)
if err != nil {
return nil, err
}
if policy == nil {
return &pb.FindEnabledHTTPAuthPolicyResponse{HttpAuthPolicy: nil}, nil
}
return &pb.FindEnabledHTTPAuthPolicyResponse{HttpAuthPolicy: &pb.HTTPAuthPolicy{
Id: int64(policy.Id),
IsOn: policy.IsOn,
Name: policy.Name,
Type: policy.Type,
ParamsJSON: policy.Params,
}}, nil
}

View File

@@ -0,0 +1,225 @@
package services
import (
"context"
"encoding/json"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
)
type HTTPCachePolicyService struct {
BaseService
}
// FindAllEnabledHTTPCachePolicies 获取所有可用策略
func (this *HTTPCachePolicyService) FindAllEnabledHTTPCachePolicies(ctx context.Context, req *pb.FindAllEnabledHTTPCachePoliciesRequest) (*pb.FindAllEnabledHTTPCachePoliciesResponse, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
policies, err := models.SharedHTTPCachePolicyDAO.FindAllEnabledCachePolicies(tx)
if err != nil {
return nil, err
}
result := []*pb.HTTPCachePolicy{}
for _, p := range policies {
result = append(result, &pb.HTTPCachePolicy{
Id: int64(p.Id),
Name: p.Name,
IsOn: p.IsOn,
})
}
return &pb.FindAllEnabledHTTPCachePoliciesResponse{CachePolicies: result}, nil
}
// CreateHTTPCachePolicy 创建缓存策略
func (this *HTTPCachePolicyService) CreateHTTPCachePolicy(ctx context.Context, req *pb.CreateHTTPCachePolicyRequest) (*pb.CreateHTTPCachePolicyResponse, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if req.CapacityJSON != nil {
req.CapacityJSON, err = utils.JSONDecodeConfig(req.CapacityJSON, &shared.SizeCapacity{})
if err != nil {
return nil, err
}
}
if req.MaxSizeJSON != nil {
req.MaxSizeJSON, err = utils.JSONDecodeConfig(req.MaxSizeJSON, &shared.SizeCapacity{})
if err != nil {
return nil, err
}
}
if req.FetchTimeoutJSON != nil {
req.FetchTimeoutJSON, err = utils.JSONDecodeConfig(req.FetchTimeoutJSON, &shared.TimeDuration{})
if err != nil {
return nil, err
}
}
policyId, err := models.SharedHTTPCachePolicyDAO.CreateCachePolicy(tx, req.IsOn, req.Name, req.Description, req.CapacityJSON, req.MaxSizeJSON, req.Type, req.OptionsJSON, req.SyncCompressionCache, req.FetchTimeoutJSON)
if err != nil {
return nil, err
}
return &pb.CreateHTTPCachePolicyResponse{HttpCachePolicyId: policyId}, nil
}
// UpdateHTTPCachePolicy 修改缓存策略
func (this *HTTPCachePolicyService) UpdateHTTPCachePolicy(ctx context.Context, req *pb.UpdateHTTPCachePolicyRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if req.CapacityJSON != nil {
req.CapacityJSON, err = utils.JSONDecodeConfig(req.CapacityJSON, &shared.SizeCapacity{})
if err != nil {
return nil, err
}
}
if req.MaxSizeJSON != nil {
req.MaxSizeJSON, err = utils.JSONDecodeConfig(req.MaxSizeJSON, &shared.SizeCapacity{})
if err != nil {
return nil, err
}
}
if req.FetchTimeoutJSON != nil {
req.FetchTimeoutJSON, err = utils.JSONDecodeConfig(req.FetchTimeoutJSON, &shared.TimeDuration{})
if err != nil {
return nil, err
}
}
err = models.SharedHTTPCachePolicyDAO.UpdateCachePolicy(tx, req.HttpCachePolicyId, req.IsOn, req.Name, req.Description, req.CapacityJSON, req.MaxSizeJSON, req.Type, req.OptionsJSON, req.SyncCompressionCache, req.FetchTimeoutJSON)
if err != nil {
return nil, err
}
return this.Success()
}
// DeleteHTTPCachePolicy 删除缓存策略
func (this *HTTPCachePolicyService) DeleteHTTPCachePolicy(ctx context.Context, req *pb.DeleteHTTPCachePolicyRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedHTTPCachePolicyDAO.DisableHTTPCachePolicy(tx, req.HttpCachePolicyId)
if err != nil {
return nil, err
}
return this.Success()
}
// CountAllEnabledHTTPCachePolicies 计算缓存策略数量
func (this *HTTPCachePolicyService) CountAllEnabledHTTPCachePolicies(ctx context.Context, req *pb.CountAllEnabledHTTPCachePoliciesRequest) (*pb.RPCCountResponse, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
count, err := models.SharedHTTPCachePolicyDAO.CountAllEnabledHTTPCachePolicies(tx, req.NodeClusterId, req.Keyword, req.Type)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListEnabledHTTPCachePolicies 列出单页的缓存策略
func (this *HTTPCachePolicyService) ListEnabledHTTPCachePolicies(ctx context.Context, req *pb.ListEnabledHTTPCachePoliciesRequest) (*pb.ListEnabledHTTPCachePoliciesResponse, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
cachePolicies, err := models.SharedHTTPCachePolicyDAO.ListEnabledHTTPCachePolicies(tx, req.NodeClusterId, req.Keyword, req.Type, req.Offset, req.Size)
if err != nil {
return nil, err
}
cachePoliciesJSON, err := json.Marshal(cachePolicies)
if err != nil {
return nil, err
}
return &pb.ListEnabledHTTPCachePoliciesResponse{HttpCachePoliciesJSON: cachePoliciesJSON}, nil
}
// FindEnabledHTTPCachePolicyConfig 查找单个缓存策略配置
func (this *HTTPCachePolicyService) FindEnabledHTTPCachePolicyConfig(ctx context.Context, req *pb.FindEnabledHTTPCachePolicyConfigRequest) (*pb.FindEnabledHTTPCachePolicyConfigResponse, error) {
// 校验请求
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
cachePolicy, err := models.SharedHTTPCachePolicyDAO.ComposeCachePolicy(tx, req.HttpCachePolicyId, nil)
if err != nil {
return nil, err
}
cachePolicyJSON, err := json.Marshal(cachePolicy)
return &pb.FindEnabledHTTPCachePolicyConfigResponse{HttpCachePolicyJSON: cachePolicyJSON}, nil
}
// FindEnabledHTTPCachePolicy 查找单个缓存策略信息
func (this *HTTPCachePolicyService) FindEnabledHTTPCachePolicy(ctx context.Context, req *pb.FindEnabledHTTPCachePolicyRequest) (*pb.FindEnabledHTTPCachePolicyResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
policy, err := models.SharedHTTPCachePolicyDAO.FindEnabledHTTPCachePolicy(tx, req.HttpCachePolicyId)
if err != nil {
return nil, err
}
if policy == nil {
return &pb.FindEnabledHTTPCachePolicyResponse{HttpCachePolicy: nil}, nil
}
return &pb.FindEnabledHTTPCachePolicyResponse{HttpCachePolicy: &pb.HTTPCachePolicy{
Id: int64(policy.Id),
Name: policy.Name,
IsOn: policy.IsOn,
MaxBytesJSON: policy.MaxSize,
}}, nil
}
// UpdateHTTPCachePolicyRefs 设置缓存策略的默认条件
func (this *HTTPCachePolicyService) UpdateHTTPCachePolicyRefs(ctx context.Context, req *pb.UpdateHTTPCachePolicyRefsRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedHTTPCachePolicyDAO.UpdatePolicyRefs(tx, req.HttpCachePolicyId, req.RefsJSON)
if err != nil {
return nil, err
}
return this.Success()
}

View File

@@ -0,0 +1,443 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package services
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/userconfigs"
"github.com/iwind/TeaGo/types"
timeutil "github.com/iwind/TeaGo/utils/time"
"time"
)
// HTTPCacheTaskService 缓存任务管理
type HTTPCacheTaskService struct {
BaseService
}
// CreateHTTPCacheTask 创建任务
func (this *HTTPCacheTaskService) CreateHTTPCacheTask(ctx context.Context, req *pb.CreateHTTPCacheTaskRequest) (*pb.CreateHTTPCacheTaskResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查操作类型
if len(req.Type) == 0 {
return nil, errors.New("require 'type' parameter")
}
if req.Type != models.HTTPCacheTaskTypePurge && req.Type != models.HTTPCacheTaskTypeFetch {
return nil, errors.New("'type' must be 'purge' or 'fetch'")
}
// 检查Key类型
if len(req.KeyType) == 0 {
return nil, errors.New("require 'keyType' parameter")
}
if req.KeyType != "key" && req.KeyType != "prefix" {
return nil, errors.New("'keyType' must be 'key' or 'prefix'")
}
// 预热只能是Key
if req.Type == models.HTTPCacheTaskTypeFetch && req.KeyType != "key" {
return nil, errors.New("'keyType' should be 'key' when fetching cache")
}
// 检查key是否为空
if len(req.Keys) == 0 {
return nil, errors.New("'keys' should not be empty")
}
// 检查Key数量
var clusterId int64
if userId > 0 {
// 限制单次
var maxKeysPerTask = userconfigs.MaxCacheKeysPerTask
var maxKeysPerDay = userconfigs.MaxCacheKeysPerDay
serverConfig, err := models.SharedSysSettingDAO.ReadUserServerConfig(tx)
if err != nil {
return nil, err
}
if serverConfig != nil {
switch req.Type {
case models.HTTPCacheTaskTypePurge:
if serverConfig.HTTPCacheTaskPurgeConfig != nil {
if serverConfig.HTTPCacheTaskPurgeConfig.MaxKeysPerTask > 0 {
maxKeysPerTask = serverConfig.HTTPCacheTaskPurgeConfig.MaxKeysPerTask
}
if serverConfig.HTTPCacheTaskPurgeConfig.MaxKeysPerDay > 0 {
maxKeysPerDay = serverConfig.HTTPCacheTaskPurgeConfig.MaxKeysPerDay
}
}
case models.HTTPCacheTaskTypeFetch:
if serverConfig.HTTPCacheTaskFetchConfig != nil {
if serverConfig.HTTPCacheTaskFetchConfig.MaxKeysPerTask > 0 {
maxKeysPerTask = serverConfig.HTTPCacheTaskFetchConfig.MaxKeysPerTask
}
if serverConfig.HTTPCacheTaskFetchConfig.MaxKeysPerDay > 0 {
maxKeysPerDay = serverConfig.HTTPCacheTaskFetchConfig.MaxKeysPerDay
}
}
}
}
if maxKeysPerTask > 0 && len(req.Keys) > types.Int(maxKeysPerTask) {
return nil, errors.New("too many keys in task (current:" + types.String(len(req.Keys)) + ", max:" + types.String(maxKeysPerTask) + ")")
}
if maxKeysPerDay > 0 {
countInDay, err := models.SharedHTTPCacheTaskKeyDAO.CountUserTasksInDay(tx, userId, timeutil.Format("Ymd"), req.Type)
if err != nil {
return nil, err
}
if types.Int(countInDay)+len(req.Keys) > types.Int(maxKeysPerDay) {
return nil, errors.New("too many keys in today (current:" + types.String(types.Int(countInDay)+len(req.Keys)) + ", max:" + types.String(maxKeysPerDay) + ")")
}
}
clusterId, err = models.SharedUserDAO.FindUserClusterId(tx, userId)
if err != nil {
return nil, err
}
}
// 创建任务
taskId, err := models.SharedHTTPCacheTaskDAO.CreateTask(tx, userId, req.Type, req.KeyType, "")
if err != nil {
return nil, err
}
var countKeys = 0
var domainMap = map[string]*models.Server{} // domain name => *Server
for _, key := range req.Keys {
if len(key) == 0 {
continue
}
// 获取域名
var domain = utils.ParseDomainFromKey(key)
if len(domain) == 0 {
continue
}
// 查询所在集群
server, ok := domainMap[domain]
if !ok {
server, err = models.SharedServerDAO.FindEnabledServerWithDomain(tx, userId, domain)
if err != nil {
return nil, err
}
if server == nil {
continue
}
domainMap[domain] = server
}
// 检查用户
if userId > 0 {
if int64(server.UserId) != userId {
continue
}
}
var serverClusterId = int64(server.ClusterId)
if serverClusterId == 0 {
if clusterId > 0 {
serverClusterId = clusterId
} else {
continue
}
}
_, err = models.SharedHTTPCacheTaskKeyDAO.CreateKey(tx, taskId, key, req.Type, req.KeyType, serverClusterId)
if err != nil {
return nil, err
}
countKeys++
}
if countKeys == 0 {
// 如果没有有效的Key则直接完成
err = models.SharedHTTPCacheTaskDAO.UpdateTaskStatus(tx, taskId, true, true)
} else {
err = models.SharedHTTPCacheTaskDAO.UpdateTaskReady(tx, taskId)
}
if err != nil {
return nil, err
}
return &pb.CreateHTTPCacheTaskResponse{
HttpCacheTaskId: taskId,
CountKeys: int64(countKeys),
}, nil
}
// CountHTTPCacheTasks 计算任务数量
func (this *HTTPCacheTaskService) CountHTTPCacheTasks(ctx context.Context, req *pb.CountHTTPCacheTasksRequest) (*pb.RPCCountResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
count, err := models.SharedHTTPCacheTaskDAO.CountTasks(tx, userId)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// CountDoingHTTPCacheTasks 计算正在执行的任务数量
func (this *HTTPCacheTaskService) CountDoingHTTPCacheTasks(ctx context.Context, req *pb.CountDoingHTTPCacheTasksRequest) (*pb.RPCCountResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
count, err := models.SharedHTTPCacheTaskDAO.CountDoingTasks(tx, userId)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListHTTPCacheTasks 列出单页任务
func (this *HTTPCacheTaskService) ListHTTPCacheTasks(ctx context.Context, req *pb.ListHTTPCacheTasksRequest) (*pb.ListHTTPCacheTasksResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var isFromUser = userId > 0
var tx = this.NullTx()
tasks, err := models.SharedHTTPCacheTaskDAO.ListTasks(tx, userId, req.Offset, req.Size)
if err != nil {
return nil, err
}
var pbTasks = []*pb.HTTPCacheTask{}
var cacheMap = utils.NewCacheMap()
for _, task := range tasks {
var taskId = int64(task.Id)
// 查询所属用户
var pbUser = &pb.User{}
if task.UserId > 0 {
var taskUserId = int64(task.UserId)
if taskUserId > 0 {
taskUser, err := models.SharedUserDAO.FindEnabledUser(tx, taskUserId, cacheMap)
if err != nil {
return nil, err
}
if taskUser == nil {
// 找不到用户就删除
err = models.SharedHTTPCacheTaskDAO.DisableHTTPCacheTask(tx, taskUserId)
if err != nil {
return nil, err
}
} else {
pbUser = &pb.User{
Id: int64(taskUser.Id),
Username: taskUser.Username,
Fullname: taskUser.Fullname,
}
}
}
}
// 对用户而言超过Ns自动认为已完成
const timeoutSeconds = 300
if isFromUser && !task.IsDone && time.Now().Unix()-int64(task.CreatedAt) > timeoutSeconds {
task.IsOk = true
task.IsDone = true
task.DoneAt = task.CreatedAt + timeoutSeconds
}
pbTasks = append(pbTasks, &pb.HTTPCacheTask{
Id: taskId,
UserId: int64(task.UserId),
Type: task.Type,
KeyType: task.KeyType,
CreatedAt: int64(task.CreatedAt),
DoneAt: int64(task.DoneAt),
IsDone: task.IsDone,
IsOk: task.IsOk,
Description: task.Description,
User: pbUser,
HttpCacheTaskKeys: nil,
})
}
return &pb.ListHTTPCacheTasksResponse{
HttpCacheTasks: pbTasks,
}, nil
}
// FindEnabledHTTPCacheTask 查找单个任务
func (this *HTTPCacheTaskService) FindEnabledHTTPCacheTask(ctx context.Context, req *pb.FindEnabledHTTPCacheTaskRequest) (*pb.FindEnabledHTTPCacheTaskResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var isFromUser = userId > 0
var tx = this.NullTx()
if userId > 0 {
err = models.SharedHTTPCacheTaskDAO.CheckUserTask(tx, userId, req.HttpCacheTaskId)
if err != nil {
return nil, err
}
}
task, err := models.SharedHTTPCacheTaskDAO.FindEnabledHTTPCacheTask(tx, req.HttpCacheTaskId)
if err != nil {
return nil, err
}
if task == nil {
return &pb.FindEnabledHTTPCacheTaskResponse{HttpCacheTask: nil}, nil
}
// 对用户而言超过Ns自动认为已完成
const timeoutSeconds = 300
if isFromUser && !task.IsDone && time.Now().Unix()-int64(task.CreatedAt) > timeoutSeconds {
task.IsOk = true
task.IsDone = true
task.DoneAt = task.CreatedAt + timeoutSeconds
}
// 查询所属用户
var pbUser = &pb.User{}
if task.UserId > 0 {
var taskUserId = int64(task.UserId)
if taskUserId > 0 {
taskUser, err := models.SharedUserDAO.FindEnabledUser(tx, taskUserId, nil)
if err != nil {
return nil, err
}
if taskUser == nil {
// 找不到用户就删除
err = models.SharedHTTPCacheTaskDAO.DisableHTTPCacheTask(tx, taskUserId)
if err != nil {
return nil, err
}
} else {
pbUser = &pb.User{
Id: int64(taskUser.Id),
Username: taskUser.Username,
Fullname: taskUser.Fullname,
}
}
}
}
// Keys
keys, err := models.SharedHTTPCacheTaskKeyDAO.FindAllTaskKeys(tx, req.HttpCacheTaskId)
if err != nil {
return nil, err
}
var pbKeys = []*pb.HTTPCacheTaskKey{}
for _, key := range keys {
// 对用户而言超过Ns自动认为已完成
if isFromUser && task.IsDone {
key.IsDone = true
key.Errors = nil
}
// 集群信息
var pbNodeCluster *pb.NodeCluster
if !isFromUser && key.ClusterId > 0 {
clusterName, findClusterErr := models.SharedNodeClusterDAO.FindNodeClusterName(tx, int64(key.ClusterId))
if findClusterErr != nil {
return nil, findClusterErr
}
pbNodeCluster = &pb.NodeCluster{
Id: int64(key.ClusterId),
Name: clusterName,
}
}
pbKeys = append(pbKeys, &pb.HTTPCacheTaskKey{
Id: int64(key.Id),
TaskId: int64(key.TaskId),
Key: key.Key,
KeyType: key.KeyType,
IsDone: key.IsDone,
IsDoing: !key.IsDone && len(key.DecodeNodes()) > 0,
ErrorsJSON: key.Errors,
NodeCluster: pbNodeCluster,
})
}
return &pb.FindEnabledHTTPCacheTaskResponse{
HttpCacheTask: &pb.HTTPCacheTask{
Id: int64(task.Id),
UserId: int64(task.UserId),
Type: task.Type,
KeyType: task.KeyType,
CreatedAt: int64(task.CreatedAt),
DoneAt: int64(task.DoneAt),
IsDone: task.IsDone,
IsOk: task.IsOk,
User: pbUser,
HttpCacheTaskKeys: pbKeys,
},
}, nil
}
// DeleteHTTPCacheTask 删除任务
func (this *HTTPCacheTaskService) DeleteHTTPCacheTask(ctx context.Context, req *pb.DeleteHTTPCacheTaskRequest) (*pb.RPCSuccess, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
err = models.SharedHTTPCacheTaskDAO.CheckUserTask(tx, userId, req.HttpCacheTaskId)
if err != nil {
return nil, err
}
}
err = models.SharedHTTPCacheTaskDAO.DisableHTTPCacheTask(tx, req.HttpCacheTaskId)
if err != nil {
return nil, err
}
return this.Success()
}
// ResetHTTPCacheTask 重置任务状态
// 只允许管理员重置,用于调试
func (this *HTTPCacheTaskService) ResetHTTPCacheTask(ctx context.Context, req *pb.ResetHTTPCacheTaskRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 重置任务
err = models.SharedHTTPCacheTaskDAO.ResetTask(tx, req.HttpCacheTaskId)
if err != nil {
return nil, err
}
// 重置任务下的Key
err = models.SharedHTTPCacheTaskKeyDAO.ResetCacheKeysWithTaskId(tx, req.HttpCacheTaskId)
if err != nil {
return nil, err
}
return this.Success()
}

View File

@@ -0,0 +1,201 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package services
import (
"context"
"encoding/json"
"errors"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeAPI/internal/utils/regexputils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/dbs"
)
// HTTPCacheTaskKeyService 缓存任务Key管理
type HTTPCacheTaskKeyService struct {
BaseService
}
// ValidateHTTPCacheTaskKeys 校验缓存Key
func (this *HTTPCacheTaskKeyService) ValidateHTTPCacheTaskKeys(ctx context.Context, req *pb.ValidateHTTPCacheTaskKeysRequest) (*pb.ValidateHTTPCacheTaskKeysResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx *dbs.Tx
// 检查Key数量
var clusterId int64
if userId > 0 {
clusterId, err = models.SharedUserDAO.FindUserClusterId(tx, userId)
if err != nil {
return nil, err
}
}
var pbFailResults = []*pb.ValidateHTTPCacheTaskKeysResponse_FailKey{}
var foundDomainMap = map[string]*models.Server{} // domain name => *Server
var missingDomainMap = map[string]bool{} // domain name => true
for _, key := range req.Keys {
if len(key) == 0 {
pbFailResults = append(pbFailResults, &pb.ValidateHTTPCacheTaskKeysResponse_FailKey{
Key: key,
ReasonCode: "requireKey",
})
continue
}
// 获取域名
var domain = utils.ParseDomainFromKey(key)
if len(domain) == 0 {
pbFailResults = append(pbFailResults, &pb.ValidateHTTPCacheTaskKeysResponse_FailKey{
Key: key,
ReasonCode: "requireDomain",
})
continue
}
// 是否不存在
if missingDomainMap[domain] {
pbFailResults = append(pbFailResults, &pb.ValidateHTTPCacheTaskKeysResponse_FailKey{
Key: key,
ReasonCode: "requireServer",
})
continue
}
// 查询所在集群
server, ok := foundDomainMap[domain]
if !ok {
server, err = models.SharedServerDAO.FindEnabledServerWithDomain(tx, userId, domain)
if err != nil {
return nil, err
}
if server == nil {
missingDomainMap[domain] = true
pbFailResults = append(pbFailResults, &pb.ValidateHTTPCacheTaskKeysResponse_FailKey{
Key: key,
ReasonCode: "requireServer",
})
continue
}
foundDomainMap[domain] = server
}
// 检查用户
if userId > 0 {
if int64(server.UserId) != userId {
pbFailResults = append(pbFailResults, &pb.ValidateHTTPCacheTaskKeysResponse_FailKey{
Key: key,
ReasonCode: "requireUser",
})
continue
}
}
var serverClusterId = int64(server.ClusterId)
if serverClusterId == 0 && clusterId <= 0 {
pbFailResults = append(pbFailResults, &pb.ValidateHTTPCacheTaskKeysResponse_FailKey{
Key: key,
ReasonCode: "requireClusterId",
})
continue
}
}
return &pb.ValidateHTTPCacheTaskKeysResponse{FailKeys: pbFailResults}, nil
}
// FindDoingHTTPCacheTaskKeys 查找需要执行的Key
func (this *HTTPCacheTaskKeyService) FindDoingHTTPCacheTaskKeys(ctx context.Context, req *pb.FindDoingHTTPCacheTaskKeysRequest) (*pb.FindDoingHTTPCacheTaskKeysResponse, error) {
nodeId, err := this.ValidateNode(ctx)
if err != nil {
return nil, err
}
if req.Size <= 0 {
req.Size = 100
}
var tx *dbs.Tx
keys, err := models.SharedHTTPCacheTaskKeyDAO.FindDoingTaskKeys(tx, nodeId, req.Size)
if err != nil {
return nil, err
}
var pbKeys = []*pb.HTTPCacheTaskKey{}
for _, key := range keys {
pbKeys = append(pbKeys, &pb.HTTPCacheTaskKey{
Id: int64(key.Id),
TaskId: int64(key.TaskId),
Key: key.Key,
Type: key.Type,
KeyType: key.KeyType,
NodeClusterId: int64(key.ClusterId),
})
}
return &pb.FindDoingHTTPCacheTaskKeysResponse{HttpCacheTaskKeys: pbKeys}, nil
}
// UpdateHTTPCacheTaskKeysStatus 更新一组Key状态
func (this *HTTPCacheTaskKeyService) UpdateHTTPCacheTaskKeysStatus(ctx context.Context, req *pb.UpdateHTTPCacheTaskKeysStatusRequest) (*pb.RPCSuccess, error) {
nodeId, err := this.ValidateNode(ctx)
if err != nil {
return nil, err
}
var tx *dbs.Tx
var nodesJSONMap = map[int64][]byte{} // clusterId => nodesJSON
for _, result := range req.KeyResults {
// 集群Id
var clusterId = result.NodeClusterId
nodesJSON, ok := nodesJSONMap[clusterId]
if !ok {
nodeIdsInCluster, err := models.SharedNodeDAO.FindEnabledAndOnNodeIdsWithClusterId(tx, clusterId, true)
if err != nil {
return nil, err
}
var nodeMap = map[int64]bool{}
for _, nodeIdInCluster := range nodeIdsInCluster {
nodeMap[nodeIdInCluster] = true
}
nodesJSON, err = json.Marshal(nodeMap)
if err != nil {
return nil, err
}
nodesJSONMap[clusterId] = nodesJSON
}
err = models.SharedHTTPCacheTaskKeyDAO.UpdateKeyStatus(tx, result.Id, nodeId, result.Error, nodesJSON)
if err != nil {
return nil, err
}
}
return this.Success()
}
// CountHTTPCacheTaskKeysWithDay 计算当天已经清理的Key数量
func (this *HTTPCacheTaskKeyService) CountHTTPCacheTaskKeysWithDay(ctx context.Context, req *pb.CountHTTPCacheTaskKeysWithDayRequest) (*pb.RPCCountResponse, error) {
userId, err := this.ValidateUserNode(ctx, true)
if err != nil {
return nil, err
}
if !regexputils.YYYYMMDD.MatchString(req.Day) {
return nil, errors.New("invalid format 'day'")
}
var tx = this.NullTx()
countKeys, err := models.SharedHTTPCacheTaskKeyDAO.CountUserTasksInDay(tx, userId, req.Day, req.KeyType)
if err != nil {
return nil, err
}
return this.SuccessCount(countKeys)
}

View File

@@ -0,0 +1,23 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package services
import (
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/iwind/TeaGo/assert"
"testing"
)
func TestHTTPCacheTaskService_ParseDomain(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsTrue(utils.ParseDomainFromKey("aaa") == "aaa")
a.IsTrue(utils.ParseDomainFromKey("AAA") == "aaa")
a.IsTrue(utils.ParseDomainFromKey("a.b-c.com") == "a.b-c.com")
a.IsTrue(utils.ParseDomainFromKey("a.b-c.com/hello/world") == "a.b-c.com")
a.IsTrue(utils.ParseDomainFromKey("https://a.b-c.com") == "a.b-c.com")
a.IsTrue(utils.ParseDomainFromKey("http://a.b-c.com/hello/world") == "a.b-c.com")
a.IsTrue(utils.ParseDomainFromKey("http://a.B-c.com/hello/world") == "a.b-c.com")
a.IsTrue(utils.ParseDomainFromKey("http:/aaaa.com") == "http")
a.IsTrue(utils.ParseDomainFromKey("北京") == "")
}

View File

@@ -0,0 +1,112 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package services
import (
"context"
"encoding/json"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/types"
)
// HTTPFastcgiService HTTP Fastcgi服务
type HTTPFastcgiService struct {
BaseService
}
// CreateHTTPFastcgi 创建Fastcgi
func (this *HTTPFastcgiService) CreateHTTPFastcgi(ctx context.Context, req *pb.CreateHTTPFastcgiRequest) (*pb.CreateHTTPFastcgiResponse, error) {
adminId, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
fastcgiId, err := models.SharedHTTPFastcgiDAO.CreateFastcgi(tx, adminId, userId, req.IsOn, req.Address, req.ParamsJSON, req.ReadTimeoutJSON, req.ConnTimeoutJSON, req.PoolSize, req.PathInfoPattern)
if err != nil {
return nil, err
}
return &pb.CreateHTTPFastcgiResponse{HttpFastcgiId: fastcgiId}, nil
}
// UpdateHTTPFastcgi 修改Fastcgi
func (this *HTTPFastcgiService) UpdateHTTPFastcgi(ctx context.Context, req *pb.UpdateHTTPFastcgiRequest) (*pb.RPCSuccess, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
err = models.SharedHTTPFastcgiDAO.CheckUserFastcgi(tx, userId, req.HttpFastcgiId)
if err != nil {
return nil, err
}
}
err = models.SharedHTTPFastcgiDAO.UpdateFastcgi(tx, req.HttpFastcgiId, req.IsOn, req.Address, req.ParamsJSON, req.ReadTimeoutJSON, req.ConnTimeoutJSON, req.PoolSize, req.PathInfoPattern)
if err != nil {
return nil, err
}
return this.Success()
}
// FindEnabledHTTPFastcgi 获取Fastcgi详情
func (this *HTTPFastcgiService) FindEnabledHTTPFastcgi(ctx context.Context, req *pb.FindEnabledHTTPFastcgiRequest) (*pb.FindEnabledHTTPFastcgiResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
err = models.SharedHTTPFastcgiDAO.CheckUserFastcgi(tx, userId, req.HttpFastcgiId)
if err != nil {
return nil, err
}
}
fastcgi, err := models.SharedHTTPFastcgiDAO.FindEnabledHTTPFastcgi(tx, req.HttpFastcgiId)
if err != nil {
return nil, err
}
if fastcgi == nil {
return &pb.FindEnabledHTTPFastcgiResponse{HttpFastcgi: nil}, nil
}
return &pb.FindEnabledHTTPFastcgiResponse{HttpFastcgi: &pb.HTTPFastcgi{
Id: int64(fastcgi.Id),
IsOn: fastcgi.IsOn,
Address: fastcgi.Address,
ParamsJSON: fastcgi.Params,
ReadTimeoutJSON: fastcgi.ReadTimeout,
ConnTimeoutJSON: fastcgi.ConnTimeout,
PoolSize: types.Int32(fastcgi.PoolSize),
PathInfoPattern: fastcgi.PathInfoPattern,
}}, nil
}
// FindEnabledHTTPFastcgiConfig 获取Fastcgi配置
func (this *HTTPFastcgiService) FindEnabledHTTPFastcgiConfig(ctx context.Context, req *pb.FindEnabledHTTPFastcgiConfigRequest) (*pb.FindEnabledHTTPFastcgiConfigResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
err = models.SharedHTTPFastcgiDAO.CheckUserFastcgi(tx, userId, req.HttpFastcgiId)
if err != nil {
return nil, err
}
}
config, err := models.SharedHTTPFastcgiDAO.ComposeFastcgiConfig(tx, req.HttpFastcgiId)
if err != nil {
return nil, err
}
configJSON, err := json.Marshal(config)
if err != nil {
return nil, err
}
return &pb.FindEnabledHTTPFastcgiConfigResponse{HttpFastcgiJSON: configJSON}, nil
}

View File

@@ -0,0 +1,974 @@
package services
import (
"context"
"encoding/json"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeCommon/pkg/iplibrary"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ipconfigs"
"github.com/iwind/TeaGo/lists"
"net"
)
// HTTPFirewallPolicyService HTTP防火墙WAF相关服务
type HTTPFirewallPolicyService struct {
BaseService
}
// FindAllEnabledHTTPFirewallPolicies 获取所有可用策略
func (this *HTTPFirewallPolicyService) FindAllEnabledHTTPFirewallPolicies(ctx context.Context, req *pb.FindAllEnabledHTTPFirewallPoliciesRequest) (*pb.FindAllEnabledHTTPFirewallPoliciesResponse, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
policies, err := models.SharedHTTPFirewallPolicyDAO.FindAllEnabledFirewallPolicies(tx)
if err != nil {
return nil, err
}
var result = []*pb.HTTPFirewallPolicy{}
for _, p := range policies {
result = append(result, &pb.HTTPFirewallPolicy{
Id: int64(p.Id),
Name: p.Name,
Description: p.Description,
IsOn: p.IsOn,
InboundJSON: p.Inbound,
OutboundJSON: p.Outbound,
Mode: p.Mode,
UseLocalFirewall: p.UseLocalFirewall == 1,
})
}
return &pb.FindAllEnabledHTTPFirewallPoliciesResponse{FirewallPolicies: result}, nil
}
// CreateHTTPFirewallPolicy 创建防火墙策略
func (this *HTTPFirewallPolicyService) CreateHTTPFirewallPolicy(ctx context.Context, req *pb.CreateHTTPFirewallPolicyRequest) (*pb.CreateHTTPFirewallPolicyResponse, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
policyId, err := models.SharedHTTPFirewallPolicyDAO.CreateFirewallPolicy(tx, userId, req.ServerGroupId, req.ServerId, req.IsOn, req.Name, req.Description, nil, nil)
if err != nil {
return nil, err
}
// 初始化
var inboundConfig = &firewallconfigs.HTTPFirewallInboundConfig{IsOn: true}
var outboundConfig = &firewallconfigs.HTTPFirewallOutboundConfig{IsOn: true}
var templatePolicy = firewallconfigs.HTTPFirewallTemplate()
if templatePolicy.Inbound != nil {
for _, group := range templatePolicy.Inbound.Groups {
isOn := lists.ContainsString(req.HttpFirewallGroupCodes, group.Code)
group.IsOn = isOn
groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(tx, group)
if err != nil {
return nil, err
}
inboundConfig.GroupRefs = append(inboundConfig.GroupRefs, &firewallconfigs.HTTPFirewallRuleGroupRef{
IsOn: true,
GroupId: groupId,
})
}
}
if templatePolicy.Outbound != nil {
for _, group := range templatePolicy.Outbound.Groups {
isOn := lists.ContainsString(req.HttpFirewallGroupCodes, group.Code)
group.IsOn = isOn
groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(tx, group)
if err != nil {
return nil, err
}
outboundConfig.GroupRefs = append(outboundConfig.GroupRefs, &firewallconfigs.HTTPFirewallRuleGroupRef{
IsOn: true,
GroupId: groupId,
})
}
}
inboundConfigJSON, err := json.Marshal(inboundConfig)
if err != nil {
return nil, err
}
outboundConfigJSON, err := json.Marshal(outboundConfig)
if err != nil {
return nil, err
}
err = models.SharedHTTPFirewallPolicyDAO.UpdateFirewallPolicyInboundAndOutbound(tx, policyId, userId, req.ServerId, inboundConfigJSON, outboundConfigJSON, false)
if err != nil {
return nil, err
}
return &pb.CreateHTTPFirewallPolicyResponse{HttpFirewallPolicyId: policyId}, nil
}
// CreateEmptyHTTPFirewallPolicy 创建空防火墙策略
func (this *HTTPFirewallPolicyService) CreateEmptyHTTPFirewallPolicy(ctx context.Context, req *pb.CreateEmptyHTTPFirewallPolicyRequest) (*pb.CreateEmptyHTTPFirewallPolicyResponse, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
var sourceUserId = userId
if userId > 0 {
if req.ServerId > 0 {
err = models.SharedServerDAO.CheckUserServer(tx, userId, req.ServerId)
if err != nil {
return nil, err
}
}
} else if req.ServerId > 0 {
sourceUserId, err = models.SharedServerDAO.FindServerUserId(tx, req.ServerId)
if err != nil {
return nil, err
}
}
policyId, err := models.SharedHTTPFirewallPolicyDAO.CreateFirewallPolicy(tx, sourceUserId, req.ServerGroupId, req.ServerId, req.IsOn, req.Name, req.Description, nil, nil)
if err != nil {
return nil, err
}
// 初始化
var inboundConfig = &firewallconfigs.HTTPFirewallInboundConfig{IsOn: true}
var outboundConfig = &firewallconfigs.HTTPFirewallOutboundConfig{IsOn: true}
// 准备保存
inboundConfigJSON, err := json.Marshal(inboundConfig)
if err != nil {
return nil, err
}
outboundConfigJSON, err := json.Marshal(outboundConfig)
if err != nil {
return nil, err
}
err = models.SharedHTTPFirewallPolicyDAO.UpdateFirewallPolicyInboundAndOutbound(tx, policyId, sourceUserId, req.ServerId, inboundConfigJSON, outboundConfigJSON, false)
if err != nil {
return nil, err
}
return &pb.CreateEmptyHTTPFirewallPolicyResponse{HttpFirewallPolicyId: policyId}, nil
}
// UpdateHTTPFirewallPolicy 修改防火墙策略
func (this *HTTPFirewallPolicyService) UpdateHTTPFirewallPolicy(ctx context.Context, req *pb.UpdateHTTPFirewallPolicyRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var templatePolicy = firewallconfigs.HTTPFirewallTemplate()
var tx = this.NullTx()
// 已经有的数据
firewallPolicy, err := models.SharedHTTPFirewallPolicyDAO.ComposeFirewallPolicy(tx, req.HttpFirewallPolicyId, false, nil)
if err != nil {
return nil, err
}
if firewallPolicy == nil {
return nil, errors.New("can not found firewall policy")
}
var inboundConfig = firewallPolicy.Inbound
if inboundConfig == nil {
inboundConfig = &firewallconfigs.HTTPFirewallInboundConfig{IsOn: true}
}
var outboundConfig = firewallPolicy.Outbound
if outboundConfig == nil {
outboundConfig = &firewallconfigs.HTTPFirewallOutboundConfig{IsOn: true}
}
// 更新老的
var oldCodes = []string{}
if firewallPolicy.Inbound != nil {
for _, g := range firewallPolicy.Inbound.Groups {
if len(g.Code) > 0 {
oldCodes = append(oldCodes, g.Code)
if lists.ContainsString(req.FirewallGroupCodes, g.Code) {
err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupIsOn(tx, g.Id, true)
if err != nil {
return nil, err
}
} else {
err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupIsOn(tx, g.Id, false)
if err != nil {
return nil, err
}
}
}
}
}
if firewallPolicy.Outbound != nil {
for _, g := range firewallPolicy.Outbound.Groups {
if len(g.Code) > 0 {
oldCodes = append(oldCodes, g.Code)
if lists.ContainsString(req.FirewallGroupCodes, g.Code) {
err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupIsOn(tx, g.Id, true)
if err != nil {
return nil, err
}
} else {
err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupIsOn(tx, g.Id, false)
if err != nil {
return nil, err
}
}
}
}
}
// 加入新的
if templatePolicy.Inbound != nil {
for _, group := range templatePolicy.Inbound.Groups {
if lists.ContainsString(oldCodes, group.Code) {
continue
}
isOn := lists.ContainsString(req.FirewallGroupCodes, group.Code)
group.IsOn = isOn
groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(tx, group)
if err != nil {
return nil, err
}
inboundConfig.GroupRefs = append(inboundConfig.GroupRefs, &firewallconfigs.HTTPFirewallRuleGroupRef{
IsOn: true,
GroupId: groupId,
})
}
}
if templatePolicy.Outbound != nil {
for _, group := range templatePolicy.Outbound.Groups {
if lists.ContainsString(oldCodes, group.Code) {
continue
}
isOn := lists.ContainsString(req.FirewallGroupCodes, group.Code)
group.IsOn = isOn
groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(tx, group)
if err != nil {
return nil, err
}
outboundConfig.GroupRefs = append(outboundConfig.GroupRefs, &firewallconfigs.HTTPFirewallRuleGroupRef{
IsOn: true,
GroupId: groupId,
})
}
}
inboundConfigJSON, err := json.Marshal(inboundConfig)
if err != nil {
return nil, err
}
outboundConfigJSON, err := json.Marshal(outboundConfig)
if err != nil {
return nil, err
}
var synFloodConfig = &firewallconfigs.SYNFloodConfig{}
if len(req.SynFloodJSON) > 0 {
err = json.Unmarshal(req.SynFloodJSON, synFloodConfig)
if err != nil {
return nil, err
}
}
var logConfig = &firewallconfigs.HTTPFirewallPolicyLogConfig{}
if len(req.LogJSON) > 0 {
err = json.Unmarshal(req.LogJSON, logConfig)
if err != nil {
return nil, err
}
}
// MaxRequestBodySize
if req.MaxRequestBodySize < 0 {
req.MaxRequestBodySize = 0
}
err = models.SharedHTTPFirewallPolicyDAO.UpdateFirewallPolicy(tx, req.HttpFirewallPolicyId, req.IsOn, req.Name, req.Description, inboundConfigJSON, outboundConfigJSON, req.BlockOptionsJSON, req.PageOptionsJSON, req.CaptchaOptionsJSON, req.JsCookieOptionsJSON, req.Mode, req.UseLocalFirewall, synFloodConfig, logConfig, req.MaxRequestBodySize, req.DenyCountryHTML, req.DenyProvinceHTML)
if err != nil {
return nil, err
}
return this.Success()
}
// UpdateHTTPFirewallPolicyGroups 修改分组信息
func (this *HTTPFirewallPolicyService) UpdateHTTPFirewallPolicyGroups(ctx context.Context, req *pb.UpdateHTTPFirewallPolicyGroupsRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if userId > 0 {
err = models.SharedHTTPFirewallPolicyDAO.CheckUserFirewallPolicy(nil, userId, req.HttpFirewallPolicyId)
if err != nil {
return nil, err
}
}
var tx = this.NullTx()
err = models.SharedHTTPFirewallPolicyDAO.UpdateFirewallPolicyInboundAndOutbound(tx, req.HttpFirewallPolicyId, userId, 0, req.InboundJSON, req.OutboundJSON, true)
if err != nil {
return nil, err
}
return this.Success()
}
// UpdateHTTPFirewallInboundConfig 修改inbound信息
func (this *HTTPFirewallPolicyService) UpdateHTTPFirewallInboundConfig(ctx context.Context, req *pb.UpdateHTTPFirewallInboundConfigRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
err = models.SharedHTTPFirewallPolicyDAO.CheckUserFirewallPolicy(tx, userId, req.HttpFirewallPolicyId)
if err != nil {
return nil, err
}
}
err = models.SharedHTTPFirewallPolicyDAO.UpdateFirewallPolicyInbound(tx, req.HttpFirewallPolicyId, req.InboundJSON)
if err != nil {
return nil, err
}
return this.Success()
}
// CountAllEnabledHTTPFirewallPolicies 计算可用的防火墙策略数量
func (this *HTTPFirewallPolicyService) CountAllEnabledHTTPFirewallPolicies(ctx context.Context, req *pb.CountAllEnabledHTTPFirewallPoliciesRequest) (*pb.RPCCountResponse, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
count, err := models.SharedHTTPFirewallPolicyDAO.CountAllEnabledFirewallPolicies(tx, req.NodeClusterId, req.Keyword)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListEnabledHTTPFirewallPolicies 列出单页的防火墙策略
func (this *HTTPFirewallPolicyService) ListEnabledHTTPFirewallPolicies(ctx context.Context, req *pb.ListEnabledHTTPFirewallPoliciesRequest) (*pb.ListEnabledHTTPFirewallPoliciesResponse, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
policies, err := models.SharedHTTPFirewallPolicyDAO.ListEnabledFirewallPolicies(tx, req.NodeClusterId, req.Keyword, req.Offset, req.Size)
if err != nil {
return nil, err
}
var result = []*pb.HTTPFirewallPolicy{}
for _, p := range policies {
result = append(result, &pb.HTTPFirewallPolicy{
Id: int64(p.Id),
Name: p.Name,
Description: p.Description,
IsOn: p.IsOn,
InboundJSON: p.Inbound,
OutboundJSON: p.Outbound,
Mode: p.Mode,
UseLocalFirewall: p.UseLocalFirewall == 1,
})
}
return &pb.ListEnabledHTTPFirewallPoliciesResponse{HttpFirewallPolicies: result}, nil
}
// DeleteHTTPFirewallPolicy 删除某个防火墙策略
func (this *HTTPFirewallPolicyService) DeleteHTTPFirewallPolicy(ctx context.Context, req *pb.DeleteHTTPFirewallPolicyRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedHTTPFirewallPolicyDAO.DisableHTTPFirewallPolicy(tx, req.HttpFirewallPolicyId)
if err != nil {
return nil, err
}
return this.Success()
}
// FindEnabledHTTPFirewallPolicyConfig 查找单个防火墙配置
func (this *HTTPFirewallPolicyService) FindEnabledHTTPFirewallPolicyConfig(ctx context.Context, req *pb.FindEnabledHTTPFirewallPolicyConfigRequest) (*pb.FindEnabledHTTPFirewallPolicyConfigResponse, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if userId > 0 {
// 校验权限
err = models.SharedHTTPFirewallPolicyDAO.CheckUserFirewallPolicy(nil, userId, req.HttpFirewallPolicyId)
if err != nil {
return nil, err
}
}
var tx = this.NullTx()
config, err := models.SharedHTTPFirewallPolicyDAO.ComposeFirewallPolicy(tx, req.HttpFirewallPolicyId, false, nil)
if err != nil {
return nil, err
}
if config == nil {
return &pb.FindEnabledHTTPFirewallPolicyConfigResponse{HttpFirewallPolicyJSON: nil}, nil
}
configJSON, err := json.Marshal(config)
if err != nil {
return nil, err
}
return &pb.FindEnabledHTTPFirewallPolicyConfigResponse{HttpFirewallPolicyJSON: configJSON}, nil
}
// FindEnabledHTTPFirewallPolicy 获取防火墙的基本信息
func (this *HTTPFirewallPolicyService) FindEnabledHTTPFirewallPolicy(ctx context.Context, req *pb.FindEnabledHTTPFirewallPolicyRequest) (*pb.FindEnabledHTTPFirewallPolicyResponse, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if userId > 0 {
err = models.SharedHTTPFirewallPolicyDAO.CheckUserFirewallPolicy(nil, userId, req.HttpFirewallPolicyId)
if err != nil {
return nil, err
}
}
var tx = this.NullTx()
policy, err := models.SharedHTTPFirewallPolicyDAO.FindEnabledHTTPFirewallPolicy(tx, req.HttpFirewallPolicyId)
if err != nil {
return nil, err
}
if policy == nil {
return &pb.FindEnabledHTTPFirewallPolicyResponse{HttpFirewallPolicy: nil}, nil
}
return &pb.FindEnabledHTTPFirewallPolicyResponse{
HttpFirewallPolicy: &pb.HTTPFirewallPolicy{
Id: int64(policy.Id),
ServerId: int64(policy.ServerId),
Name: policy.Name,
Description: policy.Description,
IsOn: policy.IsOn,
InboundJSON: policy.Inbound,
OutboundJSON: policy.Outbound,
Mode: policy.Mode,
SynFloodJSON: policy.SynFlood,
BlockOptionsJSON: policy.BlockOptions,
PageOptionsJSON: policy.PageOptions,
CaptchaOptionsJSON: policy.CaptchaOptions,
},
}, nil
}
// ImportHTTPFirewallPolicy 导入策略数据
func (this *HTTPFirewallPolicyService) ImportHTTPFirewallPolicy(ctx context.Context, req *pb.ImportHTTPFirewallPolicyRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
// TODO 检查权限
var tx = this.NullTx()
//stime := time.Now()
oldConfig, err := models.SharedHTTPFirewallPolicyDAO.ComposeFirewallPolicy(tx, req.HttpFirewallPolicyId, false, nil)
if err != nil {
return nil, err
}
if oldConfig == nil {
return nil, errors.New("can not find policy")
}
//fmt.Printf("调用 ComposeFirewallPolicy 方法 耗时: %v\n", time.Now().Sub(stime))
// 解析数据
newConfig := &firewallconfigs.HTTPFirewallPolicy{}
err = json.Unmarshal(req.HttpFirewallPolicyJSON, newConfig)
if err != nil {
return nil, err
}
// 入站分组
if newConfig.Inbound != nil {
for _, g := range newConfig.Inbound.Groups {
var oldGroup *firewallconfigs.HTTPFirewallRuleGroup
// 使用代号查找
if len(g.Code) > 0 {
oldGroup = oldConfig.FindRuleGroupWithCode(g.Code)
}
// 再次根据Name查找
if oldGroup == nil && len(g.Name) > 0 {
oldGroup = oldConfig.FindRuleGroupWithName(g.Name)
}
if oldGroup == nil {
// 新创建分组
groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(tx, g, true)
if err != nil {
return nil, err
}
oldConfig.Inbound.GroupRefs = append(oldConfig.Inbound.GroupRefs, &firewallconfigs.HTTPFirewallRuleGroupRef{
IsOn: true,
GroupId: groupId,
})
} else {
setRefs := []*firewallconfigs.HTTPFirewallRuleSetRef{}
for _, set := range g.Sets {
setId, err := models.SharedHTTPFirewallRuleSetDAO.CreateOrUpdateSetFromConfig(tx, set, true)
if err != nil {
return nil, err
}
setRefs = append(setRefs, &firewallconfigs.HTTPFirewallRuleSetRef{
IsOn: true,
SetId: setId,
})
}
setsJSON, err := json.Marshal(setRefs)
if err != nil {
return nil, err
}
err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroup(tx, oldGroup.Id, g.IsOn, g.Name, g.Code, g.Description)
if err != nil {
return nil, err
}
err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupSets(tx, oldGroup.Id, setsJSON)
if err != nil {
return nil, err
}
}
}
}
//fmt.Printf("调用 range newConfig.Inbound.Groups 方法 耗时: %v\n", time.Now().Sub(stime))
// 出站分组
if newConfig.Outbound != nil {
for _, g := range newConfig.Outbound.Groups {
var oldGroup *firewallconfigs.HTTPFirewallRuleGroup
// 使用代号查找
if len(g.Code) > 0 {
oldGroup = oldConfig.FindRuleGroupWithCode(g.Code)
}
// 再次根据Name查找
if oldGroup == nil && len(g.Name) > 0 {
oldGroup = oldConfig.FindRuleGroupWithName(g.Name)
}
if oldGroup == nil {
// 新创建分组
groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(tx, g, true)
if err != nil {
return nil, err
}
oldConfig.Outbound.GroupRefs = append(oldConfig.Outbound.GroupRefs, &firewallconfigs.HTTPFirewallRuleGroupRef{
IsOn: true,
GroupId: groupId,
})
} else {
setRefs := []*firewallconfigs.HTTPFirewallRuleSetRef{}
for _, set := range g.Sets {
setId, err := models.SharedHTTPFirewallRuleSetDAO.CreateOrUpdateSetFromConfig(tx, set, true)
if err != nil {
return nil, err
}
setRefs = append(setRefs, &firewallconfigs.HTTPFirewallRuleSetRef{
IsOn: true,
SetId: setId,
})
}
setsJSON, err := json.Marshal(setRefs)
if err != nil {
return nil, err
}
err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroup(tx, oldGroup.Id, g.IsOn, g.Name, g.Code, g.Description)
if err != nil {
return nil, err
}
err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupSets(tx, oldGroup.Id, setsJSON)
if err != nil {
return nil, err
}
}
}
}
//fmt.Printf("调用 range newConfig.Outbound.Groups 方法 耗时: %v\n", time.Now().Sub(stime))
// 保存Inbound和Outbound
oldConfig.Inbound.Groups = nil
oldConfig.Outbound.Groups = nil
inboundJSON, err := json.Marshal(oldConfig.Inbound)
if err != nil {
return nil, err
}
outboundJSON, err := json.Marshal(oldConfig.Outbound)
if err != nil {
return nil, err
}
err = models.SharedHTTPFirewallPolicyDAO.UpdateFirewallPolicyInboundAndOutbound(tx, req.HttpFirewallPolicyId, 0, 0, inboundJSON, outboundJSON, true)
if err != nil {
return nil, err
}
//fmt.Printf("调用 UpdateFirewallPolicyInboundAndOutbound 方法 耗时: %v\n", time.Now().Sub(stime))
return this.Success()
}
// CheckHTTPFirewallPolicyIPStatus 检查IP状态
func (this *HTTPFirewallPolicyService) CheckHTTPFirewallPolicyIPStatus(ctx context.Context, req *pb.CheckHTTPFirewallPolicyIPStatusRequest) (*pb.CheckHTTPFirewallPolicyIPStatusResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查权限
if req.HttpFirewallPolicyId <= 0 {
return nil, errors.New("invalid 'httpFirewallPolicyId'")
}
if userId > 0 {
err = models.SharedHTTPFirewallPolicyDAO.CheckUserFirewallPolicy(tx, userId, req.HttpFirewallPolicyId)
if err != nil {
return nil, err
}
}
// 校验IP
var ip = net.ParseIP(req.Ip)
if len(ip) == 0 {
return &pb.CheckHTTPFirewallPolicyIPStatusResponse{
IsOk: false,
Error: "请输入正确的IP",
}, nil
}
firewallPolicy, err := models.SharedHTTPFirewallPolicyDAO.ComposeFirewallPolicy(tx, req.HttpFirewallPolicyId, false, nil)
if err != nil {
return nil, err
}
if firewallPolicy == nil {
return &pb.CheckHTTPFirewallPolicyIPStatusResponse{
IsOk: false,
Error: "找不到策略信息",
}, nil
}
// 检查白名单
if firewallPolicy.Inbound != nil &&
firewallPolicy.Inbound.IsOn &&
firewallPolicy.Inbound.AllowListRef != nil &&
firewallPolicy.Inbound.AllowListRef.IsOn &&
firewallPolicy.Inbound.AllowListRef.ListId > 0 {
var listIds = []int64{}
if firewallPolicy.Inbound.AllowListRef.ListId > 0 {
listIds = append(listIds, firewallPolicy.Inbound.AllowListRef.ListId)
}
if len(firewallPolicy.Inbound.PublicAllowListRefs) > 0 {
for _, ref := range firewallPolicy.Inbound.PublicAllowListRefs {
if !ref.IsOn {
continue
}
listIds = append(listIds, ref.ListId)
}
}
for _, listId := range listIds {
item, err := models.SharedIPItemDAO.FindEnabledItemContainsIP(tx, listId, req.Ip)
if err != nil {
return nil, err
}
if item != nil {
listName, err := models.SharedIPListDAO.FindIPListName(tx, listId)
if err != nil {
return nil, err
}
if len(listName) == 0 {
listName = "白名单"
}
return &pb.CheckHTTPFirewallPolicyIPStatusResponse{
IsOk: true,
Error: "",
IsFound: true,
IsAllowed: true,
IpList: &pb.IPList{Name: listName, Id: listId, Type: ipconfigs.IPListTypeWhite},
IpItem: &pb.IPItem{
Id: int64(item.Id),
Value: item.ComposeValue(),
IpFrom: item.IpFrom,
IpTo: item.IpTo,
ExpiredAt: int64(item.ExpiredAt),
Reason: item.Reason,
Type: item.Type,
EventLevel: item.EventLevel,
ListType: ipconfigs.IPListTypeWhite,
},
RegionCountry: nil,
RegionProvince: nil,
}, nil
}
}
}
// 检查黑名单
if firewallPolicy.Inbound != nil &&
firewallPolicy.Inbound.IsOn &&
firewallPolicy.Inbound.DenyListRef != nil &&
firewallPolicy.Inbound.DenyListRef.IsOn &&
firewallPolicy.Inbound.DenyListRef.ListId > 0 {
var listIds = []int64{}
if firewallPolicy.Inbound.DenyListRef.ListId > 0 {
listIds = append(listIds, firewallPolicy.Inbound.DenyListRef.ListId)
}
if len(firewallPolicy.Inbound.PublicDenyListRefs) > 0 {
for _, ref := range firewallPolicy.Inbound.PublicDenyListRefs {
if !ref.IsOn {
continue
}
listIds = append(listIds, ref.ListId)
}
}
for _, listId := range listIds {
item, err := models.SharedIPItemDAO.FindEnabledItemContainsIP(tx, listId, req.Ip)
if err != nil {
return nil, err
}
if item != nil {
listName, err := models.SharedIPListDAO.FindIPListName(tx, listId)
if err != nil {
return nil, err
}
if len(listName) == 0 {
listName = "黑名单"
}
return &pb.CheckHTTPFirewallPolicyIPStatusResponse{
IsOk: true,
Error: "",
IsFound: true,
IsAllowed: false,
IpList: &pb.IPList{Name: listName, Id: listId, Type: ipconfigs.IPListTypeBlack},
IpItem: &pb.IPItem{
Id: int64(item.Id),
Value: item.ComposeValue(),
IpFrom: item.IpFrom,
IpTo: item.IpTo,
ExpiredAt: int64(item.ExpiredAt),
Reason: item.Reason,
Type: item.Type,
EventLevel: item.EventLevel,
ListType: ipconfigs.IPListTypeBlack,
},
RegionCountry: nil,
RegionProvince: nil,
}, nil
}
}
}
// 检查灰名单
if firewallPolicy.Inbound != nil &&
firewallPolicy.Inbound.IsOn &&
firewallPolicy.Inbound.GreyListRef != nil &&
firewallPolicy.Inbound.GreyListRef.IsOn &&
firewallPolicy.Inbound.GreyListRef.ListId > 0 {
var listIds = []int64{}
if firewallPolicy.Inbound.GreyListRef.ListId > 0 {
listIds = append(listIds, firewallPolicy.Inbound.GreyListRef.ListId)
}
if len(firewallPolicy.Inbound.PublicGreyListRefs) > 0 {
for _, ref := range firewallPolicy.Inbound.PublicGreyListRefs {
if !ref.IsOn {
continue
}
listIds = append(listIds, ref.ListId)
}
}
for _, listId := range listIds {
item, err := models.SharedIPItemDAO.FindEnabledItemContainsIP(tx, listId, req.Ip)
if err != nil {
return nil, err
}
if item != nil {
listName, err := models.SharedIPListDAO.FindIPListName(tx, listId)
if err != nil {
return nil, err
}
if len(listName) == 0 {
listName = "灰名单"
}
return &pb.CheckHTTPFirewallPolicyIPStatusResponse{
IsOk: true,
Error: "",
IsFound: true,
IsAllowed: true,
IpList: &pb.IPList{Name: listName, Id: listId, Type: ipconfigs.IPListTypeGrey},
IpItem: &pb.IPItem{
Id: int64(item.Id),
Value: item.ComposeValue(),
IpFrom: item.IpFrom,
IpTo: item.IpTo,
ExpiredAt: int64(item.ExpiredAt),
Reason: item.Reason,
Type: item.Type,
EventLevel: item.EventLevel,
ListType: ipconfigs.IPListTypeGrey,
},
RegionCountry: nil,
RegionProvince: nil,
}, nil
}
}
}
// 检查封禁的地区和省份
var info = iplibrary.LookupIP(req.Ip)
if info != nil && info.IsOk() {
if firewallPolicy.Inbound != nil &&
firewallPolicy.Inbound.IsOn &&
firewallPolicy.Inbound.Region != nil &&
firewallPolicy.Inbound.Region.IsOn {
// 检查封禁的地区
var countryId = info.CountryId()
if countryId > 0 && lists.ContainsInt64(firewallPolicy.Inbound.Region.DenyCountryIds, countryId) {
return &pb.CheckHTTPFirewallPolicyIPStatusResponse{
IsOk: true,
Error: "",
IsFound: true,
IsAllowed: false,
IpList: nil,
IpItem: nil,
RegionCountry: &pb.RegionCountry{
Id: countryId,
Name: info.CountryName(),
},
RegionProvince: nil,
}, nil
}
// 检查封禁的省份
if countryId > 0 {
var provinceId = info.ProvinceId()
if provinceId > 0 && lists.ContainsInt64(firewallPolicy.Inbound.Region.DenyProvinceIds, provinceId) {
return &pb.CheckHTTPFirewallPolicyIPStatusResponse{
IsOk: true,
Error: "",
IsFound: true,
IsAllowed: false,
IpList: nil,
IpItem: nil,
RegionCountry: &pb.RegionCountry{
Id: countryId,
Name: info.CountryName(),
},
RegionProvince: &pb.RegionProvince{
Id: provinceId,
Name: info.ProvinceName(),
},
}, nil
}
}
}
}
return &pb.CheckHTTPFirewallPolicyIPStatusResponse{
IsOk: true,
Error: "",
IsFound: false,
IsAllowed: false,
IpList: nil,
IpItem: nil,
RegionCountry: nil,
RegionProvince: nil,
}, nil
}
// FindServerIdWithHTTPFirewallPolicyId 获取防火墙对应的网站ID
func (this *HTTPFirewallPolicyService) FindServerIdWithHTTPFirewallPolicyId(ctx context.Context, req *pb.FindServerIdWithHTTPFirewallPolicyIdRequest) (*pb.FindServerIdWithHTTPFirewallPolicyIdResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
serverId, err := models.SharedHTTPFirewallPolicyDAO.FindServerIdWithFirewallPolicyId(tx, req.HttpFirewallPolicyId)
if err != nil {
return nil, err
}
// check user
if serverId > 0 && userId > 0 {
err = models.SharedServerDAO.CheckUserServer(tx, userId, serverId)
if err != nil {
return nil, err
}
}
return &pb.FindServerIdWithHTTPFirewallPolicyIdResponse{
ServerId: serverId,
}, nil
}

View File

@@ -0,0 +1,47 @@
package services
import (
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/logs"
"testing"
)
func TestHTTPFirewallPolicyService_CheckHTTPFirewallPolicyIPStatus(t *testing.T) {
dbs.NotifyReady()
service := &HTTPFirewallPolicyService{}
{
resp, err := service.CheckHTTPFirewallPolicyIPStatus(rpcutils.NewMockAdminNodeContext(1), &pb.CheckHTTPFirewallPolicyIPStatusRequest{
HttpFirewallPolicyId: 14,
Ip: "127.0.0.1",
})
if err != nil {
t.Fatal(err)
}
logs.PrintAsJSON(resp, t)
}
{
resp, err := service.CheckHTTPFirewallPolicyIPStatus(rpcutils.NewMockAdminNodeContext(1), &pb.CheckHTTPFirewallPolicyIPStatusRequest{
HttpFirewallPolicyId: 14,
Ip: "192.168.1.100",
})
if err != nil {
t.Fatal(err)
}
logs.PrintAsJSON(resp, t)
}
{
resp, err := service.CheckHTTPFirewallPolicyIPStatus(rpcutils.NewMockAdminNodeContext(1), &pb.CheckHTTPFirewallPolicyIPStatusRequest{
HttpFirewallPolicyId: 14,
Ip: "221.218.201.94",
})
if err != nil {
t.Fatal(err)
}
logs.PrintAsJSON(resp, t)
}
}

View File

@@ -0,0 +1,242 @@
package services
import (
"context"
"encoding/json"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
)
// HTTPFirewallRuleGroupService WAF规则分组相关服务
type HTTPFirewallRuleGroupService struct {
BaseService
}
// UpdateHTTPFirewallRuleGroupIsOn 设置是否启用分组
func (this *HTTPFirewallRuleGroupService) UpdateHTTPFirewallRuleGroupIsOn(ctx context.Context, req *pb.UpdateHTTPFirewallRuleGroupIsOnRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if userId > 0 {
// 校验权限
err = models.SharedHTTPFirewallRuleGroupDAO.CheckUserRuleGroup(nil, userId, req.FirewallRuleGroupId)
if err != nil {
return nil, err
}
}
var tx = this.NullTx()
err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupIsOn(tx, req.FirewallRuleGroupId, req.IsOn)
if err != nil {
return nil, err
}
return this.Success()
}
// CreateHTTPFirewallRuleGroup 创建分组
func (this *HTTPFirewallRuleGroupService) CreateHTTPFirewallRuleGroup(ctx context.Context, req *pb.CreateHTTPFirewallRuleGroupRequest) (*pb.CreateHTTPFirewallRuleGroupResponse, error) {
// 校验请求
_, _, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroup(tx, req.IsOn, req.Name, req.Code, req.Description)
if err != nil {
return nil, err
}
return &pb.CreateHTTPFirewallRuleGroupResponse{FirewallRuleGroupId: groupId}, nil
}
// UpdateHTTPFirewallRuleGroup 修改分组
func (this *HTTPFirewallRuleGroupService) UpdateHTTPFirewallRuleGroup(ctx context.Context, req *pb.UpdateHTTPFirewallRuleGroupRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if userId > 0 {
// 校验权限
err = models.SharedHTTPFirewallRuleGroupDAO.CheckUserRuleGroup(nil, userId, req.FirewallRuleGroupId)
if err != nil {
return nil, err
}
}
var tx = this.NullTx()
err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroup(tx, req.FirewallRuleGroupId, req.IsOn, req.Name, req.Code, req.Description)
if err != nil {
return nil, err
}
return this.Success()
}
// FindEnabledHTTPFirewallRuleGroupConfig 获取分组配置
func (this *HTTPFirewallRuleGroupService) FindEnabledHTTPFirewallRuleGroupConfig(ctx context.Context, req *pb.FindEnabledHTTPFirewallRuleGroupConfigRequest) (*pb.FindEnabledHTTPFirewallRuleGroupConfigResponse, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if userId > 0 {
// 校验权限
err = models.SharedHTTPFirewallRuleGroupDAO.CheckUserRuleGroup(nil, userId, req.FirewallRuleGroupId)
if err != nil {
return nil, err
}
}
var tx = this.NullTx()
groupConfig, err := models.SharedHTTPFirewallRuleGroupDAO.ComposeFirewallRuleGroup(tx, req.FirewallRuleGroupId, false)
if err != nil {
return nil, err
}
if groupConfig == nil {
return &pb.FindEnabledHTTPFirewallRuleGroupConfigResponse{FirewallRuleGroupJSON: nil}, nil
}
groupConfigJSON, err := json.Marshal(groupConfig)
if err != nil {
return nil, err
}
return &pb.FindEnabledHTTPFirewallRuleGroupConfigResponse{FirewallRuleGroupJSON: groupConfigJSON}, nil
}
// FindEnabledHTTPFirewallRuleGroup 获取分组信息
func (this *HTTPFirewallRuleGroupService) FindEnabledHTTPFirewallRuleGroup(ctx context.Context, req *pb.FindEnabledHTTPFirewallRuleGroupRequest) (*pb.FindEnabledHTTPFirewallRuleGroupResponse, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if userId > 0 {
// 校验权限
err = models.SharedHTTPFirewallRuleGroupDAO.CheckUserRuleGroup(nil, userId, req.FirewallRuleGroupId)
if err != nil {
return nil, err
}
}
var tx = this.NullTx()
group, err := models.SharedHTTPFirewallRuleGroupDAO.FindEnabledHTTPFirewallRuleGroup(tx, req.FirewallRuleGroupId)
if err != nil {
return nil, err
}
if group == nil {
return &pb.FindEnabledHTTPFirewallRuleGroupResponse{
FirewallRuleGroup: nil,
}, nil
}
return &pb.FindEnabledHTTPFirewallRuleGroupResponse{
FirewallRuleGroup: &pb.HTTPFirewallRuleGroup{
Id: int64(group.Id),
Name: group.Name,
IsOn: group.IsOn,
Description: group.Description,
Code: group.Code,
},
}, nil
}
// UpdateHTTPFirewallRuleGroupSets 修改分组的规则集
func (this *HTTPFirewallRuleGroupService) UpdateHTTPFirewallRuleGroupSets(ctx context.Context, req *pb.UpdateHTTPFirewallRuleGroupSetsRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if userId > 0 {
// 校验权限
err = models.SharedHTTPFirewallRuleGroupDAO.CheckUserRuleGroup(nil, userId, req.FirewallRuleGroupId)
if err != nil {
return nil, err
}
}
var tx = this.NullTx()
err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupSets(tx, req.GetFirewallRuleGroupId(), req.FirewallRuleSetsJSON)
if err != nil {
return nil, err
}
return this.Success()
}
// AddHTTPFirewallRuleGroupSet 添加规则集
func (this *HTTPFirewallRuleGroupService) AddHTTPFirewallRuleGroupSet(ctx context.Context, req *pb.AddHTTPFirewallRuleGroupSetRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if userId > 0 {
// 校验权限
err = models.SharedHTTPFirewallRuleGroupDAO.CheckUserRuleGroup(nil, userId, req.FirewallRuleGroupId)
if err != nil {
return nil, err
}
}
var tx = this.NullTx()
// 已经有的规则
config, err := models.SharedHTTPFirewallRuleGroupDAO.ComposeFirewallRuleGroup(tx, req.FirewallRuleGroupId, false)
if err != nil {
return nil, err
}
if config == nil {
return nil, errors.New("can not find group")
}
var setRefs = config.SetRefs
var set = &firewallconfigs.HTTPFirewallRuleSet{}
err = json.Unmarshal(req.FirewallRuleSetConfigJSON, set)
if err != nil {
return nil, err
}
if set.Id > 0 {
setRefs = append(setRefs, &firewallconfigs.HTTPFirewallRuleSetRef{
IsOn: true,
SetId: set.Id,
})
} else {
setId, err := models.SharedHTTPFirewallRuleSetDAO.CreateOrUpdateSetFromConfig(tx, set)
if err != nil {
return nil, err
}
setRefs = append(setRefs, &firewallconfigs.HTTPFirewallRuleSetRef{
IsOn: true,
SetId: setId,
})
}
setRefsJSON, err := json.Marshal(setRefs)
if err != nil {
return nil, err
}
err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupSets(tx, req.FirewallRuleGroupId, setRefsJSON)
if err != nil {
return nil, err
}
return this.Success()
}

View File

@@ -0,0 +1,139 @@
package services
import (
"context"
"encoding/json"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
)
// HTTPFirewallRuleSetService 规则集相关服务
type HTTPFirewallRuleSetService struct {
BaseService
}
// CreateOrUpdateHTTPFirewallRuleSetFromConfig 根据配置创建规则集
func (this *HTTPFirewallRuleSetService) CreateOrUpdateHTTPFirewallRuleSetFromConfig(ctx context.Context, req *pb.CreateOrUpdateHTTPFirewallRuleSetFromConfigRequest) (*pb.CreateOrUpdateHTTPFirewallRuleSetFromConfigResponse, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
setConfig := &firewallconfigs.HTTPFirewallRuleSet{}
err = json.Unmarshal(req.FirewallRuleSetConfigJSON, setConfig)
if err != nil {
return nil, err
}
if userId > 0 && setConfig.Id > 0 {
err = models.SharedHTTPFirewallRuleSetDAO.CheckUserRuleSet(nil, userId, setConfig.Id)
if err != nil {
return nil, err
}
}
var tx = this.NullTx()
setId, err := models.SharedHTTPFirewallRuleSetDAO.CreateOrUpdateSetFromConfig(tx, setConfig)
if err != nil {
return nil, err
}
return &pb.CreateOrUpdateHTTPFirewallRuleSetFromConfigResponse{FirewallRuleSetId: setId}, nil
}
// UpdateHTTPFirewallRuleSetIsOn 修改是否开启
func (this *HTTPFirewallRuleSetService) UpdateHTTPFirewallRuleSetIsOn(ctx context.Context, req *pb.UpdateHTTPFirewallRuleSetIsOnRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if userId > 0 {
err = models.SharedHTTPFirewallRuleSetDAO.CheckUserRuleSet(nil, userId, req.FirewallRuleSetId)
if err != nil {
return nil, err
}
}
var tx = this.NullTx()
err = models.SharedHTTPFirewallRuleSetDAO.UpdateRuleSetIsOn(tx, req.FirewallRuleSetId, req.IsOn)
if err != nil {
return nil, err
}
return this.Success()
}
// FindEnabledHTTPFirewallRuleSetConfig 查找规则集配置
func (this *HTTPFirewallRuleSetService) FindEnabledHTTPFirewallRuleSetConfig(ctx context.Context, req *pb.FindEnabledHTTPFirewallRuleSetConfigRequest) (*pb.FindEnabledHTTPFirewallRuleSetConfigResponse, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if userId > 0 {
err = models.SharedHTTPFirewallRuleSetDAO.CheckUserRuleSet(nil, userId, req.FirewallRuleSetId)
if err != nil {
return nil, err
}
}
var tx = this.NullTx()
config, err := models.SharedHTTPFirewallRuleSetDAO.ComposeFirewallRuleSet(tx, req.FirewallRuleSetId, false)
if err != nil {
return nil, err
}
if config == nil {
return &pb.FindEnabledHTTPFirewallRuleSetConfigResponse{FirewallRuleSetJSON: nil}, nil
}
configJSON, err := json.Marshal(config)
if err != nil {
return nil, err
}
return &pb.FindEnabledHTTPFirewallRuleSetConfigResponse{FirewallRuleSetJSON: configJSON}, nil
}
// FindEnabledHTTPFirewallRuleSet 查找规则集
func (this *HTTPFirewallRuleSetService) FindEnabledHTTPFirewallRuleSet(ctx context.Context, req *pb.FindEnabledHTTPFirewallRuleSetRequest) (*pb.FindEnabledHTTPFirewallRuleSetResponse, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if userId > 0 {
err = models.SharedHTTPFirewallRuleSetDAO.CheckUserRuleSet(nil, userId, req.FirewallRuleSetId)
if err != nil {
return nil, err
}
}
var tx = this.NullTx()
set, err := models.SharedHTTPFirewallRuleSetDAO.FindEnabledHTTPFirewallRuleSet(tx, req.FirewallRuleSetId)
if err != nil {
return nil, err
}
if set == nil {
return &pb.FindEnabledHTTPFirewallRuleSetResponse{
FirewallRuleSet: nil,
}, nil
}
return &pb.FindEnabledHTTPFirewallRuleSetResponse{
FirewallRuleSet: &pb.HTTPFirewallRuleSet{
Id: int64(set.Id),
Name: set.Name,
IsOn: set.IsOn,
Description: set.Description,
Code: set.Code,
},
}, nil
}

View File

@@ -0,0 +1,114 @@
package services
import (
"context"
"encoding/json"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
)
type HTTPHeaderService struct {
BaseService
}
// CreateHTTPHeader 创建Header
func (this *HTTPHeaderService) CreateHTTPHeader(ctx context.Context, req *pb.CreateHTTPHeaderRequest) (*pb.CreateHTTPHeaderResponse, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if userId > 0 {
// TODO 检查用户权限
}
var tx = this.NullTx()
// status
var newStatus = []int{}
for _, status := range req.Status {
newStatus = append(newStatus, int(status))
}
// replace values
var replaceValues = []*shared.HTTPHeaderReplaceValue{}
if len(req.ReplaceValuesJSON) > 0 {
err = json.Unmarshal(req.ReplaceValuesJSON, &replaceValues)
if err != nil {
return nil, errors.New("decode replace values failed: " + err.Error() + ", json: " + string(req.ReplaceValuesJSON))
}
}
headerId, err := models.SharedHTTPHeaderDAO.CreateHeader(tx, userId, req.Name, req.Value, newStatus, req.DisableRedirect, req.ShouldAppend, req.ShouldReplace, replaceValues, req.Methods, req.Domains)
if err != nil {
return nil, err
}
return &pb.CreateHTTPHeaderResponse{HeaderId: headerId}, nil
}
// UpdateHTTPHeader 修改Header
func (this *HTTPHeaderService) UpdateHTTPHeader(ctx context.Context, req *pb.UpdateHTTPHeaderRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if userId > 0 {
// TODO 检查用户权限
}
var tx = this.NullTx()
// status
var newStatus = []int{}
for _, status := range req.Status {
newStatus = append(newStatus, int(status))
}
// replace values
var replaceValues = []*shared.HTTPHeaderReplaceValue{}
if len(req.ReplaceValuesJSON) > 0 {
err = json.Unmarshal(req.ReplaceValuesJSON, &replaceValues)
if err != nil {
return nil, errors.New("decode replace values failed: " + err.Error())
}
}
err = models.SharedHTTPHeaderDAO.UpdateHeader(tx, req.HeaderId, req.Name, req.Value, newStatus, req.DisableRedirect, req.ShouldAppend, req.ShouldReplace, replaceValues, req.Methods, req.Domains)
if err != nil {
return nil, err
}
return this.Success()
}
// FindEnabledHTTPHeaderConfig 查找配置
func (this *HTTPHeaderService) FindEnabledHTTPHeaderConfig(ctx context.Context, req *pb.FindEnabledHTTPHeaderConfigRequest) (*pb.FindEnabledHTTPHeaderConfigResponse, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if userId > 0 {
// TODO 检查用户权限
}
var tx = this.NullTx()
config, err := models.SharedHTTPHeaderDAO.ComposeHeaderConfig(tx, req.HeaderId)
if err != nil {
return nil, err
}
configData, err := json.Marshal(config)
if err != nil {
return nil, err
}
return &pb.FindEnabledHTTPHeaderConfigResponse{HeaderJSON: configData}, nil
}

View File

@@ -0,0 +1,246 @@
package services
import (
"context"
"encoding/json"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
)
type HTTPHeaderPolicyService struct {
BaseService
}
// FindEnabledHTTPHeaderPolicyConfig 查找策略配置
func (this *HTTPHeaderPolicyService) FindEnabledHTTPHeaderPolicyConfig(ctx context.Context, req *pb.FindEnabledHTTPHeaderPolicyConfigRequest) (*pb.FindEnabledHTTPHeaderPolicyConfigResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查权限
if userId > 0 {
err = models.SharedHTTPHeaderPolicyDAO.CheckUserHeaderPolicy(tx, userId, req.HttpHeaderPolicyId)
if err != nil {
return nil, err
}
}
config, err := models.SharedHTTPHeaderPolicyDAO.ComposeHeaderPolicyConfig(tx, req.HttpHeaderPolicyId)
if err != nil {
return nil, err
}
configData, err := json.Marshal(config)
if err != nil {
return nil, err
}
return &pb.FindEnabledHTTPHeaderPolicyConfigResponse{HttpHeaderPolicyJSON: configData}, nil
}
// CreateHTTPHeaderPolicy 创建策略
func (this *HTTPHeaderPolicyService) CreateHTTPHeaderPolicy(ctx context.Context, req *pb.CreateHTTPHeaderPolicyRequest) (*pb.CreateHTTPHeaderPolicyResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
headerPolicyId, err := models.SharedHTTPHeaderPolicyDAO.CreateHeaderPolicy(tx)
if err != nil {
return nil, err
}
return &pb.CreateHTTPHeaderPolicyResponse{HttpHeaderPolicyId: headerPolicyId}, nil
}
// UpdateHTTPHeaderPolicyAddingHeaders 修改AddHeaders
func (this *HTTPHeaderPolicyService) UpdateHTTPHeaderPolicyAddingHeaders(ctx context.Context, req *pb.UpdateHTTPHeaderPolicyAddingHeadersRequest) (*pb.RPCSuccess, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查权限
if userId > 0 {
err = models.SharedHTTPHeaderPolicyDAO.CheckUserHeaderPolicy(tx, userId, req.HttpHeaderPolicyId)
if err != nil {
return nil, err
}
}
err = models.SharedHTTPHeaderPolicyDAO.UpdateAddingHeaders(tx, req.HttpHeaderPolicyId, req.HeadersJSON)
if err != nil {
return nil, err
}
return this.Success()
}
// UpdateHTTPHeaderPolicySettingHeaders 修改SetHeaders
func (this *HTTPHeaderPolicyService) UpdateHTTPHeaderPolicySettingHeaders(ctx context.Context, req *pb.UpdateHTTPHeaderPolicySettingHeadersRequest) (*pb.RPCSuccess, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查权限
if userId > 0 {
err = models.SharedHTTPHeaderPolicyDAO.CheckUserHeaderPolicy(tx, userId, req.HttpHeaderPolicyId)
if err != nil {
return nil, err
}
}
err = models.SharedHTTPHeaderPolicyDAO.UpdateSettingHeaders(tx, req.HttpHeaderPolicyId, req.HeadersJSON)
if err != nil {
return nil, err
}
return this.Success()
}
// UpdateHTTPHeaderPolicyAddingTrailers 修改AddTrailers
func (this *HTTPHeaderPolicyService) UpdateHTTPHeaderPolicyAddingTrailers(ctx context.Context, req *pb.UpdateHTTPHeaderPolicyAddingTrailersRequest) (*pb.RPCSuccess, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查权限
if userId > 0 {
err = models.SharedHTTPHeaderPolicyDAO.CheckUserHeaderPolicy(tx, userId, req.HttpHeaderPolicyId)
if err != nil {
return nil, err
}
}
err = models.SharedHTTPHeaderPolicyDAO.UpdateAddingTrailers(tx, req.HttpHeaderPolicyId, req.HeadersJSON)
if err != nil {
return nil, err
}
return this.Success()
}
// UpdateHTTPHeaderPolicyReplacingHeaders 修改ReplaceHeaders
func (this *HTTPHeaderPolicyService) UpdateHTTPHeaderPolicyReplacingHeaders(ctx context.Context, req *pb.UpdateHTTPHeaderPolicyReplacingHeadersRequest) (*pb.RPCSuccess, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查权限
if userId > 0 {
err = models.SharedHTTPHeaderPolicyDAO.CheckUserHeaderPolicy(tx, userId, req.HttpHeaderPolicyId)
if err != nil {
return nil, err
}
}
err = models.SharedHTTPHeaderPolicyDAO.UpdateReplacingHeaders(tx, req.HttpHeaderPolicyId, req.HeadersJSON)
if err != nil {
return nil, err
}
return this.Success()
}
// UpdateHTTPHeaderPolicyDeletingHeaders 修改删除的Headers
func (this *HTTPHeaderPolicyService) UpdateHTTPHeaderPolicyDeletingHeaders(ctx context.Context, req *pb.UpdateHTTPHeaderPolicyDeletingHeadersRequest) (*pb.RPCSuccess, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查权限
if userId > 0 {
err = models.SharedHTTPHeaderPolicyDAO.CheckUserHeaderPolicy(tx, userId, req.HttpHeaderPolicyId)
if err != nil {
return nil, err
}
}
err = models.SharedHTTPHeaderPolicyDAO.UpdateDeletingHeaders(tx, req.HttpHeaderPolicyId, req.HeaderNames)
if err != nil {
return nil, err
}
return this.Success()
}
// UpdateHTTPHeaderPolicyCORS 修改策略CORS设置
func (this *HTTPHeaderPolicyService) UpdateHTTPHeaderPolicyCORS(ctx context.Context, req *pb.UpdateHTTPHeaderPolicyCORSRequest) (*pb.RPCSuccess, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查权限
if userId > 0 {
err = models.SharedHTTPHeaderPolicyDAO.CheckUserHeaderPolicy(tx, userId, req.HttpHeaderPolicyId)
if err != nil {
return nil, err
}
}
var corsConfig = shared.NewHTTPCORSHeaderConfig()
err = json.Unmarshal(req.CorsJSON, corsConfig)
if err != nil {
return nil, err
}
err = corsConfig.Init()
if err != nil {
return nil, errors.New("validate CORS config failed: " + err.Error())
}
err = models.SharedHTTPHeaderPolicyDAO.UpdateHeaderPolicyCORS(tx, req.HttpHeaderPolicyId, corsConfig)
if err != nil {
return nil, err
}
return this.Success()
}
// UpdateHTTPHeaderPolicyNonStandardHeaders 修改非标的Headers
func (this *HTTPHeaderPolicyService) UpdateHTTPHeaderPolicyNonStandardHeaders(ctx context.Context, req *pb.UpdateHTTPHeaderPolicyNonStandardHeadersRequest) (*pb.RPCSuccess, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查权限
if userId > 0 {
err = models.SharedHTTPHeaderPolicyDAO.CheckUserHeaderPolicy(tx, userId, req.HttpHeaderPolicyId)
if err != nil {
return nil, err
}
}
err = models.SharedHTTPHeaderPolicyDAO.UpdateNonStandardHeaders(tx, req.HttpHeaderPolicyId, req.HeaderNames)
if err != nil {
return nil, err
}
return this.Success()
}

View File

@@ -0,0 +1,198 @@
package services
import (
"context"
"encoding/json"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
)
// HTTPLocationService 路由规则相关服务
type HTTPLocationService struct {
BaseService
}
// CreateHTTPLocation 创建路由规则
func (this *HTTPLocationService) CreateHTTPLocation(ctx context.Context, req *pb.CreateHTTPLocationRequest) (*pb.CreateHTTPLocationResponse, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
locationId, err := models.SharedHTTPLocationDAO.CreateLocation(tx, req.ParentId, req.Name, req.Pattern, req.Description, req.IsBreak, req.CondsJSON, req.Domains)
if err != nil {
return nil, err
}
return &pb.CreateHTTPLocationResponse{LocationId: locationId}, nil
}
// UpdateHTTPLocation 修改路由规则
func (this *HTTPLocationService) UpdateHTTPLocation(ctx context.Context, req *pb.UpdateHTTPLocationRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedHTTPLocationDAO.UpdateLocation(tx, req.LocationId, req.Name, req.Pattern, req.Description, req.IsOn, req.IsBreak, req.CondsJSON, req.Domains)
if err != nil {
return nil, err
}
return this.Success()
}
// FindEnabledHTTPLocationConfig 查找路由规则配置
func (this *HTTPLocationService) FindEnabledHTTPLocationConfig(ctx context.Context, req *pb.FindEnabledHTTPLocationConfigRequest) (*pb.FindEnabledHTTPLocationConfigResponse, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
config, err := models.SharedHTTPLocationDAO.ComposeLocationConfig(tx, req.LocationId, false, nil, nil)
if err != nil {
return nil, err
}
configJSON, err := json.Marshal(config)
if err != nil {
return nil, err
}
return &pb.FindEnabledHTTPLocationConfigResponse{LocationJSON: configJSON}, nil
}
// DeleteHTTPLocation 删除路由规则
func (this *HTTPLocationService) DeleteHTTPLocation(ctx context.Context, req *pb.DeleteHTTPLocationRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedHTTPLocationDAO.DisableHTTPLocation(tx, req.LocationId)
if err != nil {
return nil, err
}
return this.Success()
}
// FindAndInitHTTPLocationReverseProxyConfig 查找反向代理设置
func (this *HTTPLocationService) FindAndInitHTTPLocationReverseProxyConfig(ctx context.Context, req *pb.FindAndInitHTTPLocationReverseProxyConfigRequest) (*pb.FindAndInitHTTPLocationReverseProxyConfigResponse, error) {
// 校验请求
adminId, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
reverseProxyRef, err := models.SharedHTTPLocationDAO.FindLocationReverseProxy(tx, req.LocationId)
if err != nil {
return nil, err
}
if reverseProxyRef == nil || reverseProxyRef.ReverseProxyId <= 0 {
reverseProxyId, err := models.SharedReverseProxyDAO.CreateReverseProxy(tx, adminId, userId, nil, nil, nil)
if err != nil {
return nil, err
}
reverseProxyRef = &serverconfigs.ReverseProxyRef{
IsOn: false,
ReverseProxyId: reverseProxyId,
}
reverseProxyJSON, err := json.Marshal(reverseProxyRef)
if err != nil {
return nil, err
}
err = models.SharedHTTPLocationDAO.UpdateLocationReverseProxy(tx, req.LocationId, reverseProxyJSON)
if err != nil {
return nil, err
}
}
reverseProxyConfig, err := models.SharedReverseProxyDAO.ComposeReverseProxyConfig(tx, reverseProxyRef.ReverseProxyId, nil, nil)
if err != nil {
return nil, err
}
refJSON, err := json.Marshal(reverseProxyRef)
if err != nil {
return nil, err
}
configJSON, err := json.Marshal(reverseProxyConfig)
if err != nil {
return nil, err
}
return &pb.FindAndInitHTTPLocationReverseProxyConfigResponse{
ReverseProxyJSON: configJSON,
ReverseProxyRefJSON: refJSON,
}, nil
}
// FindAndInitHTTPLocationWebConfig 初始化Web设置
func (this *HTTPLocationService) FindAndInitHTTPLocationWebConfig(ctx context.Context, req *pb.FindAndInitHTTPLocationWebConfigRequest) (*pb.FindAndInitHTTPLocationWebConfigResponse, error) {
// 校验请求
adminId, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, rpcutils.Wrap("ValidateRequest()", err)
}
var tx = this.NullTx()
webId, err := models.SharedHTTPLocationDAO.FindLocationWebId(tx, req.LocationId)
if err != nil {
return nil, rpcutils.Wrap("FindLocationWebId()", err)
}
if webId <= 0 {
webId, err = models.SharedHTTPWebDAO.CreateWeb(tx, adminId, userId, nil)
if err != nil {
return nil, rpcutils.Wrap("CreateWeb()", err)
}
err = models.SharedHTTPLocationDAO.UpdateLocationWeb(tx, req.LocationId, webId)
if err != nil {
return nil, rpcutils.Wrap("UpdateLocationWeb()", err)
}
}
config, err := models.SharedHTTPWebDAO.ComposeWebConfig(tx, webId, true, false, nil, nil)
if err != nil {
return nil, rpcutils.Wrap("ComposeWebConfig()", err)
}
configJSON, err := json.Marshal(config)
if err != nil {
return nil, rpcutils.Wrap("json.Marshal()", err)
}
return &pb.FindAndInitHTTPLocationWebConfigResponse{
WebJSON: configJSON,
}, nil
}
// UpdateHTTPLocationReverseProxy 修改反向代理设置
func (this *HTTPLocationService) UpdateHTTPLocationReverseProxy(ctx context.Context, req *pb.UpdateHTTPLocationReverseProxyRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedHTTPLocationDAO.UpdateLocationReverseProxy(tx, req.LocationId, req.ReverseProxyJSON)
if err != nil {
return nil, err
}
return this.Success()
}

View File

@@ -0,0 +1,224 @@
package services
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/utils/regexputils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
"github.com/iwind/TeaGo/types"
)
type HTTPPageService struct {
BaseService
}
// CreateHTTPPage 创建Page
func (this *HTTPPageService) CreateHTTPPage(ctx context.Context, req *pb.CreateHTTPPageRequest) (*pb.CreateHTTPPageResponse, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// validate
const maxURLLength = 512
const maxBodyLength = 32 * 1024
switch req.BodyType {
case serverconfigs.HTTPPageBodyTypeURL:
if len(req.Url) > maxURLLength {
return nil, errors.New("'url' too long")
}
if !regexputils.HTTPProtocol.MatchString(req.Url) {
return nil, errors.New("invalid 'url' format")
}
if len(req.Body) > maxBodyLength { // we keep short body for user experience
req.Body = ""
}
case serverconfigs.HTTPPageBodyTypeRedirectURL:
if len(req.Url) > maxURLLength {
return nil, errors.New("'url' too long")
}
if !regexputils.HTTPProtocol.MatchString(req.Url) {
return nil, errors.New("invalid 'url' format")
}
if len(req.Body) > maxBodyLength { // we keep short body for user experience
req.Body = ""
}
case serverconfigs.HTTPPageBodyTypeHTML:
if len(req.Body) > maxBodyLength {
return nil, errors.New("'body' too long")
}
if len(req.Url) > maxURLLength { // we keep short url for user experience
req.Url = ""
}
default:
return nil, errors.New("invalid 'bodyType': " + req.BodyType)
}
var exceptURLPatterns = []*shared.URLPattern{}
if len(req.ExceptURLPatternsJSON) > 0 {
err = json.Unmarshal(req.ExceptURLPatternsJSON, &exceptURLPatterns)
if err != nil {
return nil, err
}
for _, pattern := range exceptURLPatterns {
err = pattern.Init()
if err != nil {
return nil, fmt.Errorf("validate url pattern '"+pattern.Pattern+"' failed: %w", err)
}
}
}
var onlyURLPatterns = []*shared.URLPattern{}
if len(req.OnlyURLPatternsJSON) > 0 {
err = json.Unmarshal(req.OnlyURLPatternsJSON, &onlyURLPatterns)
if err != nil {
return nil, err
}
for _, pattern := range onlyURLPatterns {
err = pattern.Init()
if err != nil {
return nil, fmt.Errorf("validate url pattern '"+pattern.Pattern+"' failed: %w", err)
}
}
}
pageId, err := models.SharedHTTPPageDAO.CreatePage(tx, userId, req.StatusList, req.BodyType, req.Url, req.Body, types.Int(req.NewStatus), exceptURLPatterns, onlyURLPatterns)
if err != nil {
return nil, err
}
return &pb.CreateHTTPPageResponse{HttpPageId: pageId}, nil
}
// UpdateHTTPPage 修改Page
func (this *HTTPPageService) UpdateHTTPPage(ctx context.Context, req *pb.UpdateHTTPPageRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
err = models.SharedHTTPPageDAO.CheckUserPage(tx, userId, req.HttpPageId)
if err != nil {
return nil, err
}
}
// validate
const maxURLLength = 512
const maxBodyLength = 32 * 1024
switch req.BodyType {
case serverconfigs.HTTPPageBodyTypeURL:
if len(req.Url) > maxURLLength {
return nil, errors.New("'url' too long")
}
if !regexputils.HTTPProtocol.MatchString(req.Url) {
return nil, errors.New("invalid 'url' format")
}
if len(req.Body) > maxBodyLength { // we keep short body for user experience
req.Body = ""
}
case serverconfigs.HTTPPageBodyTypeRedirectURL:
if len(req.Url) > maxURLLength {
return nil, errors.New("'url' too long")
}
if !regexputils.HTTPProtocol.MatchString(req.Url) {
return nil, errors.New("invalid 'url' format")
}
if len(req.Body) > maxBodyLength { // we keep short body for user experience
req.Body = ""
}
case serverconfigs.HTTPPageBodyTypeHTML:
if len(req.Body) > maxBodyLength {
return nil, errors.New("'body' too long")
}
if len(req.Url) > maxURLLength { // we keep short url for user experience
req.Url = ""
}
default:
return nil, errors.New("invalid 'bodyType': " + req.BodyType)
}
var exceptURLPatterns = []*shared.URLPattern{}
if len(req.ExceptURLPatternsJSON) > 0 {
err = json.Unmarshal(req.ExceptURLPatternsJSON, &exceptURLPatterns)
if err != nil {
return nil, err
}
for _, pattern := range exceptURLPatterns {
err = pattern.Init()
if err != nil {
return nil, fmt.Errorf("validate url pattern '"+pattern.Pattern+"' failed: %w", err)
}
}
}
var onlyURLPatterns = []*shared.URLPattern{}
if len(req.OnlyURLPatternsJSON) > 0 {
err = json.Unmarshal(req.OnlyURLPatternsJSON, &onlyURLPatterns)
if err != nil {
return nil, err
}
for _, pattern := range onlyURLPatterns {
err = pattern.Init()
if err != nil {
return nil, fmt.Errorf("validate url pattern '"+pattern.Pattern+"' failed: %w", err)
}
}
}
err = models.SharedHTTPPageDAO.UpdatePage(tx, req.HttpPageId, req.StatusList, req.BodyType, req.Url, req.Body, types.Int(req.NewStatus), exceptURLPatterns, onlyURLPatterns)
if err != nil {
return nil, err
}
return this.Success()
}
// FindEnabledHTTPPageConfig 查找单个Page配置
func (this *HTTPPageService) FindEnabledHTTPPageConfig(ctx context.Context, req *pb.FindEnabledHTTPPageConfigRequest) (*pb.FindEnabledHTTPPageConfigResponse, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
err = models.SharedHTTPPageDAO.CheckUserPage(tx, userId, req.HttpPageId)
if err != nil {
return nil, err
}
}
config, err := models.SharedHTTPPageDAO.ComposePageConfig(tx, req.HttpPageId, nil)
if err != nil {
return nil, err
}
configJSON, err := json.Marshal(config)
if err != nil {
return nil, err
}
return &pb.FindEnabledHTTPPageConfigResponse{
PageJSON: configJSON,
}, nil
}

View File

@@ -0,0 +1,55 @@
package services
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/types"
)
// HTTPRewriteRuleService 重写规则相关服务
type HTTPRewriteRuleService struct {
BaseService
}
// CreateHTTPRewriteRule 创建重写规则
func (this *HTTPRewriteRuleService) CreateHTTPRewriteRule(ctx context.Context, req *pb.CreateHTTPRewriteRuleRequest) (*pb.CreateHTTPRewriteRuleResponse, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
rewriteRuleId, err := models.SharedHTTPRewriteRuleDAO.CreateRewriteRule(tx, userId, req.Pattern, req.Replace, req.Mode, types.Int(req.RedirectStatus), req.IsBreak, req.ProxyHost, req.WithQuery, req.IsOn, req.CondsJSON)
if err != nil {
return nil, err
}
return &pb.CreateHTTPRewriteRuleResponse{RewriteRuleId: rewriteRuleId}, nil
}
// UpdateHTTPRewriteRule 修改重写规则
func (this *HTTPRewriteRuleService) UpdateHTTPRewriteRule(ctx context.Context, req *pb.UpdateHTTPRewriteRuleRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
err = models.SharedHTTPRewriteRuleDAO.CheckUserRewriteRule(tx, userId, req.RewriteRuleId)
if err != nil {
return nil, err
}
}
err = models.SharedHTTPRewriteRuleDAO.UpdateRewriteRule(tx, req.RewriteRuleId, req.Pattern, req.Replace, req.Mode, types.Int(req.RedirectStatus), req.IsBreak, req.ProxyHost, req.WithQuery, req.IsOn, req.CondsJSON)
if err != nil {
return nil, err
}
return this.Success()
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,44 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build !plus
// +build !plus
package services
import (
"context"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
// UpdateHTTPWebUAM 修改UAM设置
func (this *HTTPWebService) UpdateHTTPWebUAM(ctx context.Context, req *pb.UpdateHTTPWebUAMRequest) (*pb.RPCSuccess, error) {
return this.Success()
}
// FindHTTPWebUAM 查找UAM设置
func (this *HTTPWebService) FindHTTPWebUAM(ctx context.Context, req *pb.FindHTTPWebUAMRequest) (*pb.FindHTTPWebUAMResponse, error) {
return &pb.FindHTTPWebUAMResponse{UamJSON: nil}, nil
}
func (this *HTTPWebService) UpdateHTTPWebCC(ctx context.Context, req *pb.UpdateHTTPWebCCRequest) (*pb.RPCSuccess, error) {
return nil, this.NotImplementedYet()
}
// FindHTTPWebCC 查找UAM设置
func (this *HTTPWebService) FindHTTPWebCC(ctx context.Context, req *pb.FindHTTPWebCCRequest) (*pb.FindHTTPWebCCResponse, error) {
return nil, this.NotImplementedYet()
}
// UpdateHTTPWebRequestScripts 修改请求脚本
func (this *HTTPWebService) UpdateHTTPWebRequestScripts(ctx context.Context, req *pb.UpdateHTTPWebRequestScriptsRequest) (*pb.RPCSuccess, error) {
return nil, this.NotImplementedYet()
}
// UpdateHTTPWebHLS 修改HLS设置
func (this *HTTPWebService) UpdateHTTPWebHLS(ctx context.Context, req *pb.UpdateHTTPWebHLSRequest) (*pb.RPCSuccess, error) {
return nil, this.NotImplementedYet()
}
// FindHTTPWebHLS 查找HLS设置
func (this *HTTPWebService) FindHTTPWebHLS(ctx context.Context, req *pb.FindHTTPWebHLSRequest) (*pb.FindHTTPWebHLSResponse, error) {
return nil, this.NotImplementedYet()
}

View File

@@ -0,0 +1,316 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package services
import (
"context"
"encoding/json"
"errors"
"fmt"
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/types"
stringutil "github.com/iwind/TeaGo/utils/string"
)
// UpdateHTTPWebUAM 修改UAM设置
func (this *HTTPWebService) UpdateHTTPWebUAM(ctx context.Context, req *pb.UpdateHTTPWebUAMRequest) (*pb.RPCSuccess, error) {
if !teaconst.IsPlus {
return nil, this.NotImplementedYet()
}
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx *dbs.Tx
if userId > 0 {
err = models.SharedHTTPWebDAO.CheckUserWeb(tx, userId, req.HttpWebId)
if err != nil {
return nil, err
}
}
var config = &serverconfigs.UAMConfig{}
err = json.Unmarshal(req.UamJSON, config)
if err != nil {
return nil, err
}
err = config.Init()
if err != nil {
return nil, errors.New("valid uam config failed: " + err.Error())
}
err = models.SharedHTTPWebDAO.UpdateWebUAM(tx, req.HttpWebId, config)
if err != nil {
return nil, err
}
return this.Success()
}
// FindHTTPWebUAM 查找UAM设置
func (this *HTTPWebService) FindHTTPWebUAM(ctx context.Context, req *pb.FindHTTPWebUAMRequest) (*pb.FindHTTPWebUAMResponse, error) {
if !teaconst.IsPlus {
return nil, this.NotImplementedYet()
}
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx *dbs.Tx
if userId > 0 {
err = models.SharedHTTPWebDAO.CheckUserWeb(tx, userId, req.HttpWebId)
if err != nil {
return nil, err
}
}
uamJSON, err := models.SharedHTTPWebDAO.FindWebUAM(tx, req.HttpWebId)
if err != nil {
return nil, err
}
return &pb.FindHTTPWebUAMResponse{
UamJSON: uamJSON,
}, nil
}
// UpdateHTTPWebCC 修改CC设置
func (this *HTTPWebService) UpdateHTTPWebCC(ctx context.Context, req *pb.UpdateHTTPWebCCRequest) (*pb.RPCSuccess, error) {
if !teaconst.IsPlus {
return nil, this.NotImplementedYet()
}
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx *dbs.Tx
if userId > 0 {
err = models.SharedHTTPWebDAO.CheckUserWeb(tx, userId, req.HttpWebId)
if err != nil {
return nil, err
}
}
var config = serverconfigs.DefaultHTTPCCConfig()
err = json.Unmarshal(req.CcJSON, config)
if err != nil {
return nil, err
}
err = config.Init()
if err != nil {
return nil, errors.New("valid cc config failed: " + err.Error())
}
err = models.SharedHTTPWebDAO.UpdateWebCC(tx, req.HttpWebId, config)
if err != nil {
return nil, err
}
return this.Success()
}
// FindHTTPWebCC 查找CC设置
func (this *HTTPWebService) FindHTTPWebCC(ctx context.Context, req *pb.FindHTTPWebCCRequest) (*pb.FindHTTPWebCCResponse, error) {
if !teaconst.IsPlus {
return nil, this.NotImplementedYet()
}
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx *dbs.Tx
if userId > 0 {
err = models.SharedHTTPWebDAO.CheckUserWeb(tx, userId, req.HttpWebId)
if err != nil {
return nil, err
}
}
ccJSON, err := models.SharedHTTPWebDAO.FindWebCC(tx, req.HttpWebId)
if err != nil {
return nil, err
}
return &pb.FindHTTPWebCCResponse{
CcJSON: ccJSON,
}, nil
}
// UpdateHTTPWebRequestScripts 修改请求脚本
func (this *HTTPWebService) UpdateHTTPWebRequestScripts(ctx context.Context, req *pb.UpdateHTTPWebRequestScriptsRequest) (*pb.RPCSuccess, error) {
if !teaconst.IsPlus {
return nil, this.NotImplementedYet()
}
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
err = models.SharedHTTPWebDAO.CheckUserWeb(tx, userId, req.HttpWebId)
if err != nil {
return nil, err
}
}
var config = &serverconfigs.HTTPRequestScriptsConfig{}
err = json.Unmarshal(req.RequestScriptsJSON, config)
if err != nil {
return nil, err
}
err = config.Init()
if err != nil {
return nil, fmt.Errorf("validate config failed: %w", err)
}
// 代码最大长度
// TODO 需要可以在管理员系统中配置
const codeMaxLength = 8192
// 检查是否需要审核
if userId > 0 {
for _, group := range config.AllGroups() {
for _, script := range group.Scripts {
if len(script.Code) > codeMaxLength {
return nil, errors.New("code length should not more than '" + types.String(codeMaxLength) + "'")
}
var realCode = script.TrimCode()
if len(realCode) == 0 {
continue
}
// 是否已审核通过
var codeMD5 = stringutil.Md5(realCode)
isPassed, existErr := models.SharedUserScriptDAO.ExistsPassedCodeMD5(tx, codeMD5)
if existErr != nil {
return nil, existErr
}
if isPassed {
script.AuditingCodeMD5 = ""
script.AuditingCode = ""
continue
}
// 是否已存在
scriptId, findErr := models.SharedUserScriptDAO.FindUserScriptIdWithCodeMD5(tx, userId, codeMD5)
if findErr != nil {
return nil, findErr
}
if scriptId <= 0 {
scriptId, err = models.SharedUserScriptDAO.CreateUserScript(tx, userId, realCode)
if err != nil {
return nil, err
}
}
// 保存WebId用于以后的更新
err = models.SharedUserScriptDAO.AddWebIdToUserScript(tx, scriptId, req.HttpWebId)
if err != nil {
return nil, err
}
// 清空代码,等待审核
// TODO 将 script.Code 还原为老版本的代码
script.AuditingCode = script.Code // not realCode, to keep raw format
script.Code = ""
script.AuditingCodeMD5 = codeMD5
}
}
}
err = models.SharedHTTPWebDAO.UpdateWebRequestScripts(tx, req.HttpWebId, config)
if err != nil {
return nil, err
}
return this.Success()
}
// UpdateHTTPWebHLS 修改HLS设置
func (this *HTTPWebService) UpdateHTTPWebHLS(ctx context.Context, req *pb.UpdateHTTPWebHLSRequest) (*pb.RPCSuccess, error) {
if !teaconst.IsPlus {
return nil, this.NotImplementedYet()
}
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
err = models.SharedHTTPWebDAO.CheckUserWeb(tx, userId, req.HttpWebId)
if err != nil {
return nil, err
}
}
if len(req.HlsJSON) == 0 {
return nil, errors.New("require 'hlsJSON'")
}
var hlsConfig = &serverconfigs.HLSConfig{}
err = json.Unmarshal(req.HlsJSON, hlsConfig)
if err != nil {
return nil, err
}
err = hlsConfig.Init()
if err != nil {
return nil, fmt.Errorf("validate config failed: %w", err)
}
err = models.SharedHTTPWebDAO.UpdateWebHLS(tx, req.HttpWebId, hlsConfig)
if err != nil {
return nil, err
}
return this.Success()
}
// FindHTTPWebHLS 查找HLS设置
func (this *HTTPWebService) FindHTTPWebHLS(ctx context.Context, req *pb.FindHTTPWebHLSRequest) (*pb.FindHTTPWebHLSResponse, error) {
if !teaconst.IsPlus {
return nil, this.NotImplementedYet()
}
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
err = models.SharedHTTPWebDAO.CheckUserWeb(tx, userId, req.HttpWebId)
if err != nil {
return nil, err
}
}
hlsJSON, err := models.SharedHTTPWebDAO.FindWebHLS(tx, req.HttpWebId)
if err != nil {
return nil, err
}
return &pb.FindHTTPWebHLSResponse{
HlsJSON: hlsJSON,
}, nil
}

View File

@@ -0,0 +1,47 @@
package services
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
type HTTPWebsocketService struct {
BaseService
}
// 创建Websocket配置
func (this *HTTPWebsocketService) CreateHTTPWebsocket(ctx context.Context, req *pb.CreateHTTPWebsocketRequest) (*pb.CreateHTTPWebsocketResponse, error) {
// 校验请求
_, _, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
websocketId, err := models.SharedHTTPWebsocketDAO.CreateWebsocket(tx, req.HandshakeTimeoutJSON, req.AllowAllOrigins, req.AllowedOrigins, req.RequestSameOrigin, req.RequestOrigin)
if err != nil {
return nil, err
}
return &pb.CreateHTTPWebsocketResponse{WebsocketId: websocketId}, nil
}
// 修改Websocket配置
func (this *HTTPWebsocketService) UpdateHTTPWebsocket(ctx context.Context, req *pb.UpdateHTTPWebsocketRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, _, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
// TODO 用户不能修改别人的WebSocket设置
var tx = this.NullTx()
err = models.SharedHTTPWebsocketDAO.UpdateWebsocket(tx, req.WebsocketId, req.HandshakeTimeoutJSON, req.AllowAllOrigins, req.AllowedOrigins, req.RequestSameOrigin, req.RequestOrigin)
if err != nil {
return nil, err
}
return this.Success()
}

View File

@@ -0,0 +1,924 @@
package services
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/iputils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ipconfigs"
"net"
"time"
)
// IPItemService IP条目相关服务
type IPItemService struct {
BaseService
}
// CreateIPItem 创建IP
func (this *IPItemService) CreateIPItem(ctx context.Context, req *pb.CreateIPItemRequest) (*pb.CreateIPItemResponse, error) {
// 校验请求
userType, _, userId, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeUser, rpcutils.UserTypeNode, rpcutils.UserTypeDNS)
if err != nil {
return nil, err
}
if len(req.Value) > 0 {
newValue, ipFrom, ipTo, ok := models.SharedIPItemDAO.ParseIPValue(req.Value)
if !ok {
return nil, errors.New("invalid 'value' format")
}
req.Value = newValue
req.IpFrom = ipFrom
req.IpTo = ipTo
} else if req.Type != models.IPItemTypeAll {
if !iputils.IsValid(req.IpFrom) {
return nil, errors.New("invalid 'ipFrom'")
}
if len(req.IpTo) > 0 {
if !iputils.IsValid(req.IpTo) {
return nil, errors.New("invalid 'ipTo'")
}
if !iputils.IsSameVersion(req.IpFrom, req.IpTo) {
return nil, errors.New("'ipFrom' and 'ipTo' should be in same version")
}
if iputils.CompareIP(req.IpFrom, req.IpTo) > 0 {
req.IpFrom, req.IpTo = req.IpTo, req.IpFrom
}
}
}
var tx = this.NullTx()
if userType == rpcutils.UserTypeUser {
if userId <= 0 {
return nil, errors.New("invalid userId")
} else {
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId)
if err != nil {
return nil, err
}
}
}
if len(req.Type) == 0 {
req.Type = models.IPItemTypeIPv4
}
// 删除以前的
err = models.SharedIPItemDAO.DeleteOldItem(tx, req.IpListId, req.IpFrom, req.IpTo)
if err != nil {
return nil, err
}
itemId, err := models.SharedIPItemDAO.CreateIPItem(tx, req.IpListId, req.Value, req.IpFrom, req.IpTo, req.ExpiredAt, req.Reason, req.Type, req.EventLevel, req.NodeId, req.ServerId, req.SourceNodeId, req.SourceServerId, req.SourceHTTPFirewallPolicyId, req.SourceHTTPFirewallRuleGroupId, req.SourceHTTPFirewallRuleSetId, true)
if err != nil {
return nil, err
}
return &pb.CreateIPItemResponse{IpItemId: itemId}, nil
}
// CreateIPItems 创建一组IP
func (this *IPItemService) CreateIPItems(ctx context.Context, req *pb.CreateIPItemsRequest) (*pb.CreateIPItemsResponse, error) {
// 校验请求
userType, _, userId, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeUser, rpcutils.UserTypeNode, rpcutils.UserTypeDNS)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 校验
for _, item := range req.IpItems {
if len(item.Value) > 0 {
newValue, ipFrom, ipTo, ok := models.SharedIPItemDAO.ParseIPValue(item.Value)
if !ok {
return nil, errors.New("invalid 'value': " + item.Value)
}
item.Value = newValue
item.IpFrom = ipFrom
item.IpTo = ipTo
} else if item.Type != models.IPItemTypeAll {
if !iputils.IsValid(item.IpFrom) {
return nil, errors.New("invalid 'ipFrom': " + item.IpFrom)
}
if len(item.IpTo) > 0 {
if !iputils.IsValid(item.IpTo) {
return nil, errors.New("invalid 'ipTo': " + item.IpTo)
}
if !iputils.IsSameVersion(item.IpFrom, item.IpTo) {
return nil, errors.New("'ipFrom' (" + item.IpFrom + ") and 'ipTo' (" + item.IpTo + ") should be in same version")
}
if iputils.CompareIP(item.IpFrom, item.IpTo) > 0 {
item.IpFrom, item.IpTo = item.IpTo, item.IpFrom
}
}
}
if userType == rpcutils.UserTypeUser {
if userId <= 0 {
return nil, errors.New("invalid userId")
} else {
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, item.IpListId)
if err != nil {
return nil, err
}
}
}
if len(item.Type) == 0 {
item.Type = models.IPItemTypeIPv4
}
}
// 创建
var ipItemIds = []int64{}
for index, item := range req.IpItems {
var shouldNotify = index == len(req.IpItems)-1
// 删除以前的
if len(item.Value) > 0 {
err = models.SharedIPItemDAO.DeleteOldItemWithValue(tx, item.IpListId, item.Value)
} else {
err = models.SharedIPItemDAO.DeleteOldItem(tx, item.IpListId, item.IpFrom, item.IpTo)
}
if err != nil {
return nil, err
}
itemId, err := models.SharedIPItemDAO.CreateIPItem(tx, item.IpListId, item.Value, item.IpFrom, item.IpTo, item.ExpiredAt, item.Reason, item.Type, item.EventLevel, item.NodeId, item.ServerId, item.SourceNodeId, item.SourceServerId, item.SourceHTTPFirewallPolicyId, item.SourceHTTPFirewallRuleGroupId, item.SourceHTTPFirewallRuleSetId, shouldNotify)
if err != nil {
return nil, err
}
ipItemIds = append(ipItemIds, itemId)
}
return &pb.CreateIPItemsResponse{
IpItemIds: ipItemIds,
}, nil
}
// UpdateIPItem 修改IP
func (this *IPItemService) UpdateIPItem(ctx context.Context, req *pb.UpdateIPItemRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// validate ip
if len(req.Value) > 0 {
newValue, ipFrom, ipTo, ok := models.SharedIPItemDAO.ParseIPValue(req.Value)
if !ok {
return nil, errors.New("invalid 'value' format")
}
req.Value = newValue
req.IpFrom = ipFrom
req.IpTo = ipTo
} else if req.Type != models.IPItemTypeAll {
if !iputils.IsValid(req.IpFrom) {
return nil, errors.New("invalid 'ipFrom'")
}
if len(req.IpTo) > 0 {
if !iputils.IsValid(req.IpTo) {
return nil, errors.New("invalid 'ipTo'")
}
if !iputils.IsSameVersion(req.IpFrom, req.IpTo) {
return nil, errors.New("'ipFrom' and 'ipTo' should be in same version")
}
if iputils.CompareIP(req.IpFrom, req.IpTo) > 0 {
req.IpFrom, req.IpTo = req.IpTo, req.IpFrom
}
}
}
if userId > 0 {
listId, err := models.SharedIPItemDAO.FindItemListId(tx, req.IpItemId)
if err != nil {
return nil, err
}
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, listId)
if err != nil {
return nil, err
}
}
if len(req.Type) == 0 {
req.Type = models.IPItemTypeIPv4
}
err = models.SharedIPItemDAO.UpdateIPItem(tx, req.IpItemId, req.Value, req.IpFrom, req.IpTo, req.ExpiredAt, req.Reason, req.Type, req.EventLevel)
if err != nil {
return nil, err
}
return this.Success()
}
// DeleteIPItem 删除IP
func (this *IPItemService) DeleteIPItem(ctx context.Context, req *pb.DeleteIPItemRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if req.IpItemId <= 0 && len(req.Value) == 0 && len(req.IpFrom) == 0 {
return nil, errors.New("one of 'ipItemId', 'value' or 'ipFrom' params required")
}
// 如果是使用IPItemId删除
if req.IpItemId > 0 {
err = models.SharedIPItemDAO.DisableIPItem(tx, req.IpItemId, userId)
if err != nil {
return nil, err
}
return this.Success()
}
// 使用value删除
if len(req.Value) > 0 {
// 检查IP列表
if req.IpListId > 0 && userId > 0 && !firewallconfigs.IsGlobalListId(req.IpListId) {
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId)
if err != nil {
return nil, err
}
}
err = models.SharedIPItemDAO.DisableIPItemsWithIPValue(tx, req.Value, userId, req.IpListId)
if err != nil {
return nil, err
}
return this.Success()
}
// 如果是使用ipFrom+ipTo删除
if len(req.IpFrom) > 0 {
// 检查IP列表
if req.IpListId > 0 && userId > 0 && !firewallconfigs.IsGlobalListId(req.IpListId) {
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId)
if err != nil {
return nil, err
}
}
err = models.SharedIPItemDAO.DisableIPItemsWithIP(tx, req.IpFrom, req.IpTo, userId, req.IpListId)
if err != nil {
return nil, err
}
return this.Success()
}
return this.Success()
}
// DeleteIPItems 批量删除IP
func (this *IPItemService) DeleteIPItems(ctx context.Context, req *pb.DeleteIPItemsRequest) (*pb.RPCSuccess, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
for _, itemId := range req.IpItemIds {
err = models.SharedIPItemDAO.DisableIPItem(tx, itemId, userId)
if err != nil {
return nil, err
}
}
return this.Success()
}
// CountIPItemsWithListId 计算IP数量
func (this *IPItemService) CountIPItemsWithListId(ctx context.Context, req *pb.CountIPItemsWithListIdRequest) (*pb.RPCCountResponse, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
// 检查用户所属名单
if !firewallconfigs.IsGlobalListId(req.IpListId) {
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId)
if err != nil {
return nil, err
}
}
}
count, err := models.SharedIPItemDAO.CountIPItemsWithListId(tx, req.IpListId, userId, req.Keyword, req.IpFrom, req.IpTo, req.EventLevel)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListIPItemsWithListId 列出单页的IP
func (this *IPItemService) ListIPItemsWithListId(ctx context.Context, req *pb.ListIPItemsWithListIdRequest) (*pb.ListIPItemsWithListIdResponse, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
// 检查用户所属名单
if !firewallconfigs.IsGlobalListId(req.IpListId) {
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId)
if err != nil {
return nil, err
}
}
}
items, err := models.SharedIPItemDAO.ListIPItemsWithListId(tx, req.IpListId, userId, req.Keyword, req.IpFrom, req.IpTo, req.EventLevel, req.Offset, req.Size)
if err != nil {
return nil, err
}
var result = []*pb.IPItem{}
for _, item := range items {
if len(item.Type) == 0 {
item.Type = models.IPItemTypeIPv4
}
// server
var pbSourceServer *pb.Server
if item.SourceServerId > 0 {
serverName, err := models.SharedServerDAO.FindEnabledServerName(tx, int64(item.SourceServerId))
if err != nil {
return nil, err
}
pbSourceServer = &pb.Server{
Id: int64(item.SourceServerId),
Name: serverName,
}
}
// WAF策略
var pbSourcePolicy *pb.HTTPFirewallPolicy
if item.SourceHTTPFirewallPolicyId > 0 {
policy, err := models.SharedHTTPFirewallPolicyDAO.FindEnabledHTTPFirewallPolicyBasic(tx, int64(item.SourceHTTPFirewallPolicyId))
if err != nil {
return nil, err
}
if policy != nil {
pbSourcePolicy = &pb.HTTPFirewallPolicy{
Id: int64(item.SourceHTTPFirewallPolicyId),
Name: policy.Name,
ServerId: int64(policy.ServerId),
}
}
}
// WAF分组
var pbSourceGroup *pb.HTTPFirewallRuleGroup
if item.SourceHTTPFirewallRuleGroupId > 0 {
groupName, err := models.SharedHTTPFirewallRuleGroupDAO.FindHTTPFirewallRuleGroupName(tx, int64(item.SourceHTTPFirewallRuleGroupId))
if err != nil {
return nil, err
}
pbSourceGroup = &pb.HTTPFirewallRuleGroup{
Id: int64(item.SourceHTTPFirewallRuleGroupId),
Name: groupName,
}
}
// WAF规则集
var pbSourceSet *pb.HTTPFirewallRuleSet
if item.SourceHTTPFirewallRuleSetId > 0 {
setName, err := models.SharedHTTPFirewallRuleSetDAO.FindHTTPFirewallRuleSetName(tx, int64(item.SourceHTTPFirewallRuleSetId))
if err != nil {
return nil, err
}
pbSourceSet = &pb.HTTPFirewallRuleSet{
Id: int64(item.SourceHTTPFirewallRuleSetId),
Name: setName,
}
}
result = append(result, &pb.IPItem{
Id: int64(item.Id),
Value: item.ComposeValue(),
IpFrom: item.IpFrom,
IpTo: item.IpTo,
Version: int64(item.Version),
CreatedAt: int64(item.CreatedAt),
ExpiredAt: int64(item.ExpiredAt),
Reason: item.Reason,
Type: item.Type,
EventLevel: item.EventLevel,
NodeId: int64(item.NodeId),
ServerId: int64(item.ServerId),
SourceNodeId: int64(item.SourceNodeId),
SourceServerId: int64(item.SourceServerId),
SourceHTTPFirewallPolicyId: int64(item.SourceHTTPFirewallPolicyId),
SourceHTTPFirewallRuleGroupId: int64(item.SourceHTTPFirewallRuleGroupId),
SourceHTTPFirewallRuleSetId: int64(item.SourceHTTPFirewallRuleSetId),
SourceServer: pbSourceServer,
SourceHTTPFirewallPolicy: pbSourcePolicy,
SourceHTTPFirewallRuleGroup: pbSourceGroup,
SourceHTTPFirewallRuleSet: pbSourceSet,
IsRead: item.IsRead,
})
}
return &pb.ListIPItemsWithListIdResponse{IpItems: result}, nil
}
// FindEnabledIPItem 查找单个IP
func (this *IPItemService) FindEnabledIPItem(ctx context.Context, req *pb.FindEnabledIPItemRequest) (*pb.FindEnabledIPItemResponse, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
item, err := models.SharedIPItemDAO.FindEnabledIPItem(tx, req.IpItemId)
if err != nil {
return nil, err
}
if item == nil {
return &pb.FindEnabledIPItemResponse{IpItem: nil}, nil
}
if userId > 0 {
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, int64(item.ListId))
if err != nil {
return nil, err
}
}
if len(item.Type) == 0 {
item.Type = models.IPItemTypeIPv4
}
return &pb.FindEnabledIPItemResponse{IpItem: &pb.IPItem{
Id: int64(item.Id),
Value: item.ComposeValue(),
IpFrom: item.IpFrom,
IpTo: item.IpTo,
Version: int64(item.Version),
CreatedAt: int64(item.CreatedAt),
ExpiredAt: int64(item.ExpiredAt),
Reason: item.Reason,
Type: item.Type,
EventLevel: item.EventLevel,
NodeId: int64(item.NodeId),
ServerId: int64(item.ServerId),
}}, nil
}
// ListIPItemsAfterVersion 根据版本列出一组IP
func (this *IPItemService) ListIPItemsAfterVersion(ctx context.Context, req *pb.ListIPItemsAfterVersionRequest) (*pb.ListIPItemsAfterVersionResponse, error) {
// 校验请求
_, _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeNode)
if err != nil {
return nil, err
}
var tx = this.NullTx()
var result = []*pb.IPItem{}
items, err := models.SharedIPItemDAO.ListIPItemsAfterVersion(tx, req.Version, req.Size)
if err != nil {
return nil, err
}
var latestVersion = req.Version
for _, item := range items {
latestVersion = int64(item.Version)
// 是否已过期
if item.ExpiredAt > 0 && int64(item.ExpiredAt) <= time.Now().Unix() {
item.State = models.IPItemStateDisabled
}
if len(item.Type) == 0 {
item.Type = models.IPItemTypeIPv4
}
// List类型
list, err := models.SharedIPListDAO.FindIPListCacheable(tx, int64(item.ListId))
if err != nil {
return nil, err
}
if list == nil {
continue
}
// 跳过灰名单
if list.Type == ipconfigs.IPListTypeGrey {
continue
}
// 如果已经删除
if list.State != models.IPListStateEnabled {
item.State = models.IPItemStateDisabled
}
result = append(result, &pb.IPItem{
Id: int64(item.Id),
Value: item.ComposeValue(),
IpFrom: item.IpFrom,
IpTo: item.IpTo,
Version: int64(item.Version),
CreatedAt: int64(item.CreatedAt),
ExpiredAt: int64(item.ExpiredAt),
Reason: "", // 这里我们不需要这个数据
ListId: int64(item.ListId),
IsDeleted: item.State == 0,
Type: item.Type,
EventLevel: item.EventLevel,
ListType: list.Type,
IsGlobal: list.IsPublic && list.IsGlobal,
NodeId: int64(item.NodeId),
ServerId: int64(item.ServerId),
})
}
return &pb.ListIPItemsAfterVersionResponse{
IpItems: result,
Version: latestVersion,
}, nil
}
// CheckIPItemStatus 检查IP状态
func (this *IPItemService) CheckIPItemStatus(ctx context.Context, req *pb.CheckIPItemStatusRequest) (*pb.CheckIPItemStatusResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
// 校验IP
var ip = net.ParseIP(req.Ip)
if len(ip) == 0 {
return &pb.CheckIPItemStatusResponse{
IsOk: false,
Error: "请输入正确的IP",
}, nil
}
var tx = this.NullTx()
// 名单类型
list, err := models.SharedIPListDAO.FindEnabledIPList(tx, req.IpListId, nil)
if err != nil {
return nil, err
}
if list == nil {
return &pb.CheckIPItemStatusResponse{
IsOk: false,
Error: "IP名单不存在",
}, nil
}
var isAllowed = list.Type == ipconfigs.IPListTypeWhite || list.Type == ipconfigs.IPListTypeGrey
// 检查IP名单
item, err := models.SharedIPItemDAO.FindEnabledItemContainsIP(tx, req.IpListId, req.Ip)
if err != nil {
return nil, err
}
if item != nil {
return &pb.CheckIPItemStatusResponse{
IsOk: true,
Error: "",
IsFound: true,
IsAllowed: isAllowed,
IpItem: &pb.IPItem{
Id: int64(item.Id),
Value: item.ComposeValue(),
IpFrom: item.IpFrom,
IpTo: item.IpTo,
CreatedAt: int64(item.CreatedAt),
ExpiredAt: int64(item.ExpiredAt),
Reason: item.Reason,
Type: item.Type,
EventLevel: item.EventLevel,
ListType: list.Type,
},
}, nil
}
return &pb.CheckIPItemStatusResponse{
IsOk: true,
Error: "",
IsFound: false,
IsAllowed: false,
IpItem: nil,
}, nil
}
// ExistsEnabledIPItem 检查IP是否存在
func (this *IPItemService) ExistsEnabledIPItem(ctx context.Context, req *pb.ExistsEnabledIPItemRequest) (*pb.ExistsEnabledIPItemResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
b, err := models.SharedIPItemDAO.ExistsEnabledItem(tx, req.IpItemId)
if err != nil {
return nil, err
}
return &pb.ExistsEnabledIPItemResponse{Exists: b}, nil
}
// CountAllEnabledIPItems 计算所有IP数量
func (this *IPItemService) CountAllEnabledIPItems(ctx context.Context, req *pb.CountAllEnabledIPItemsRequest) (*pb.RPCCountResponse, error) {
adminId, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if adminId > 0 {
userId = req.UserId
}
var tx = this.NullTx()
count, err := models.SharedIPItemDAO.CountAllEnabledIPItems(tx, userId, req.Keyword, req.Ip, 0, req.Unread, req.EventLevel, req.ListType, req.GlobalOnly)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListAllEnabledIPItems 搜索IP
func (this *IPItemService) ListAllEnabledIPItems(ctx context.Context, req *pb.ListAllEnabledIPItemsRequest) (*pb.ListAllEnabledIPItemsResponse, error) {
adminId, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if adminId > 0 {
userId = req.UserId
}
var results = []*pb.ListAllEnabledIPItemsResponse_Result{}
var tx = this.NullTx()
items, err := models.SharedIPItemDAO.ListAllEnabledIPItems(tx, userId, req.Keyword, req.Ip, 0, req.Unread, req.EventLevel, req.ListType, req.GlobalOnly, req.Offset, req.Size)
if err != nil {
return nil, err
}
var cacheMap = utils.NewCacheMap()
for _, item := range items {
// server
var pbSourceServer *pb.Server
if item.SourceServerId > 0 {
serverName, err := models.SharedServerDAO.FindEnabledServerName(tx, int64(item.SourceServerId))
if err != nil {
return nil, err
}
pbSourceServer = &pb.Server{
Id: int64(item.SourceServerId),
Name: serverName,
}
}
// WAF策略
var pbSourcePolicy *pb.HTTPFirewallPolicy
if item.SourceHTTPFirewallPolicyId > 0 {
policy, err := models.SharedHTTPFirewallPolicyDAO.FindEnabledHTTPFirewallPolicyBasic(tx, int64(item.SourceHTTPFirewallPolicyId))
if err != nil {
return nil, err
}
if policy != nil {
pbSourcePolicy = &pb.HTTPFirewallPolicy{
Id: int64(item.SourceHTTPFirewallPolicyId),
Name: policy.Name,
ServerId: int64(policy.ServerId),
}
}
}
// WAF分组
var pbSourceGroup *pb.HTTPFirewallRuleGroup
if item.SourceHTTPFirewallRuleGroupId > 0 {
groupName, err := models.SharedHTTPFirewallRuleGroupDAO.FindHTTPFirewallRuleGroupName(tx, int64(item.SourceHTTPFirewallRuleGroupId))
if err != nil {
return nil, err
}
pbSourceGroup = &pb.HTTPFirewallRuleGroup{
Id: int64(item.SourceHTTPFirewallRuleGroupId),
Name: groupName,
}
}
// WAF规则集
var pbSourceSet *pb.HTTPFirewallRuleSet
if item.SourceHTTPFirewallRuleSetId > 0 {
setName, err := models.SharedHTTPFirewallRuleSetDAO.FindHTTPFirewallRuleSetName(tx, int64(item.SourceHTTPFirewallRuleSetId))
if err != nil {
return nil, err
}
pbSourceSet = &pb.HTTPFirewallRuleSet{
Id: int64(item.SourceHTTPFirewallRuleSetId),
Name: setName,
}
}
// 节点
var pbSourceNode *pb.Node
if item.SourceNodeId > 0 {
node, err := models.SharedNodeDAO.FindEnabledBasicNode(tx, int64(item.SourceNodeId))
if err != nil {
return nil, err
}
if node != nil {
pbSourceNode = &pb.Node{
Id: int64(node.Id),
Name: node.Name,
NodeCluster: &pb.NodeCluster{Id: int64(node.ClusterId)},
}
}
}
var pbItem = &pb.IPItem{
Id: int64(item.Id),
Value: item.ComposeValue(),
IpFrom: item.IpFrom,
IpTo: item.IpTo,
Version: int64(item.Version),
CreatedAt: int64(item.CreatedAt),
ExpiredAt: int64(item.ExpiredAt),
Reason: item.Reason,
Type: item.Type,
EventLevel: item.EventLevel,
NodeId: int64(item.NodeId),
ServerId: int64(item.ServerId),
SourceNodeId: int64(item.SourceNodeId),
SourceServerId: int64(item.SourceServerId),
SourceHTTPFirewallPolicyId: int64(item.SourceHTTPFirewallPolicyId),
SourceHTTPFirewallRuleGroupId: int64(item.SourceHTTPFirewallRuleGroupId),
SourceHTTPFirewallRuleSetId: int64(item.SourceHTTPFirewallRuleSetId),
SourceServer: pbSourceServer,
SourceHTTPFirewallPolicy: pbSourcePolicy,
SourceHTTPFirewallRuleGroup: pbSourceGroup,
SourceHTTPFirewallRuleSet: pbSourceSet,
SourceNode: pbSourceNode,
IsRead: item.IsRead,
}
// 所属名单
list, err := models.SharedIPListDAO.FindEnabledIPList(tx, int64(item.ListId), cacheMap)
if err != nil {
return nil, err
}
if list == nil {
err = models.SharedIPItemDAO.DisableIPItem(tx, int64(item.Id), 0)
if err != nil {
return nil, err
}
continue
}
var pbList = &pb.IPList{
Id: int64(list.Id),
Name: list.Name,
Type: list.Type,
IsPublic: list.IsPublic,
IsGlobal: list.IsGlobal,
}
// 所属服务注意与SourceServer不同
var pbFirewallServer *pb.Server
// 所属策略注意与SourceHTTPFirewallPolicy不同
var pbFirewallPolicy *pb.HTTPFirewallPolicy
if !list.IsPublic {
policy, err := models.SharedHTTPFirewallPolicyDAO.FindEnabledFirewallPolicyWithIPListId(tx, int64(list.Id))
if err != nil {
return nil, err
}
if policy == nil {
err = models.SharedIPItemDAO.DisableIPItem(tx, int64(item.Id), 0)
if err != nil {
return nil, err
}
continue
}
pbFirewallPolicy = &pb.HTTPFirewallPolicy{
Id: int64(policy.Id),
Name: policy.Name,
}
if policy.ServerId > 0 {
serverName, err := models.SharedServerDAO.FindEnabledServerName(tx, int64(policy.ServerId))
if err != nil {
return nil, err
}
if len(serverName) == 0 {
serverName = "[已删除]"
}
pbFirewallServer = &pb.Server{
Id: int64(policy.ServerId),
Name: serverName,
}
}
}
results = append(results, &pb.ListAllEnabledIPItemsResponse_Result{
IpList: pbList,
IpItem: pbItem,
Server: pbFirewallServer,
HttpFirewallPolicy: pbFirewallPolicy,
})
}
return &pb.ListAllEnabledIPItemsResponse{Results: results}, nil
}
// ListAllIPItemIds 列出所有名单中的IP ID
func (this *IPItemService) ListAllIPItemIds(ctx context.Context, req *pb.ListAllIPItemIdsRequest) (*pb.ListAllIPItemIdsResponse, error) {
adminId, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if adminId > 0 {
userId = req.UserId
}
var tx = this.NullTx()
itemIds, err := models.SharedIPItemDAO.ListAllIPItemIds(tx, userId, req.Keyword, req.Ip, 0, req.Unread, req.EventLevel, req.ListType, req.Offset, req.Size)
if err != nil {
return nil, err
}
return &pb.ListAllIPItemIdsResponse{IpItemIds: itemIds}, nil
}
// UpdateIPItemsRead 设置所有为已读
func (this *IPItemService) UpdateIPItemsRead(ctx context.Context, req *pb.UpdateIPItemsReadRequest) (*pb.RPCSuccess, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedIPItemDAO.UpdateItemsRead(tx, userId)
if err != nil {
return nil, err
}
return this.Success()
}
// FindServerIdWithIPItemId 查找IP对应的名单所属网站ID
func (this *IPItemService) FindServerIdWithIPItemId(ctx context.Context, req *pb.FindServerIdWithIPItemIdRequest) (*pb.FindServerIdWithIPItemIdResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
listId, err := models.SharedIPItemDAO.FindItemListId(tx, req.IpItemId)
if err != nil {
return nil, err
}
if listId > 0 {
var serverId int64
serverId, err = models.SharedIPListDAO.FindServerIdWithListId(tx, listId)
if err != nil {
return nil, err
}
if serverId > 0 {
// check user
if userId > 0 {
err = models.SharedServerDAO.CheckUserServer(tx, userId, serverId)
if err != nil {
return nil, err
}
}
return &pb.FindServerIdWithIPItemIdResponse{ServerId: serverId}, nil
}
}
return &pb.FindServerIdWithIPItemIdResponse{ServerId: 0}, nil
}

View File

@@ -0,0 +1,411 @@
package services
import (
"context"
"errors"
"fmt"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/remotelogs"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/iplibrary"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/Tea"
"os"
"path/filepath"
"strings"
)
// IPLibraryService IP库服务
type IPLibraryService struct {
BaseService
}
// CreateIPLibrary 创建IP库
func (this *IPLibraryService) CreateIPLibrary(ctx context.Context, req *pb.CreateIPLibraryRequest) (*pb.CreateIPLibraryResponse, error) {
// 校验请求
_, _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
if err != nil {
return nil, err
}
var tx = this.NullTx()
ipLibraryId, err := models.SharedIPLibraryDAO.CreateIPLibrary(tx, req.Type, req.FileId)
if err != nil {
return nil, err
}
return &pb.CreateIPLibraryResponse{
IpLibraryId: ipLibraryId,
}, nil
}
// FindEnabledIPLibrary 查找单个IP库
func (this *IPLibraryService) FindEnabledIPLibrary(ctx context.Context, req *pb.FindEnabledIPLibraryRequest) (*pb.FindEnabledIPLibraryResponse, error) {
// 校验请求
_, _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
if err != nil {
return nil, err
}
var tx = this.NullTx()
ipLibrary, err := models.SharedIPLibraryDAO.FindEnabledIPLibrary(tx, req.IpLibraryId)
if err != nil {
return nil, err
}
if ipLibrary == nil {
return &pb.FindEnabledIPLibraryResponse{IpLibrary: nil}, nil
}
// 文件相关
var pbFile *pb.File = nil
file, err := models.SharedFileDAO.FindEnabledFile(tx, int64(ipLibrary.FileId))
if err != nil {
return nil, err
}
if file != nil {
pbFile = &pb.File{
Id: int64(file.Id),
Filename: file.Filename,
Size: int64(file.Size),
}
}
return &pb.FindEnabledIPLibraryResponse{
IpLibrary: &pb.IPLibrary{
Id: int64(ipLibrary.Id),
Type: ipLibrary.Type,
File: pbFile,
CreatedAt: int64(ipLibrary.CreatedAt),
},
}, nil
}
// FindLatestIPLibraryWithType 查找最新的IP库
func (this *IPLibraryService) FindLatestIPLibraryWithType(ctx context.Context, req *pb.FindLatestIPLibraryWithTypeRequest) (*pb.FindLatestIPLibraryWithTypeResponse, error) {
// 校验请求
_, _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeNode)
if err != nil {
return nil, err
}
var tx = this.NullTx()
ipLibrary, err := models.SharedIPLibraryDAO.FindLatestIPLibraryWithType(tx, req.Type)
if err != nil {
return nil, err
}
if ipLibrary == nil {
return &pb.FindLatestIPLibraryWithTypeResponse{IpLibrary: nil}, nil
}
// 文件相关
var pbFile *pb.File = nil
file, err := models.SharedFileDAO.FindEnabledFile(tx, int64(ipLibrary.FileId))
if err != nil {
return nil, err
}
if file != nil {
pbFile = &pb.File{
Id: int64(file.Id),
Filename: file.Filename,
Size: int64(file.Size),
}
}
return &pb.FindLatestIPLibraryWithTypeResponse{
IpLibrary: &pb.IPLibrary{
Id: int64(ipLibrary.Id),
Type: ipLibrary.Type,
File: pbFile,
CreatedAt: int64(ipLibrary.CreatedAt),
},
}, nil
}
// FindAllEnabledIPLibrariesWithType 列出某个类型的所有IP库
func (this *IPLibraryService) FindAllEnabledIPLibrariesWithType(ctx context.Context, req *pb.FindAllEnabledIPLibrariesWithTypeRequest) (*pb.FindAllEnabledIPLibrariesWithTypeResponse, error) {
// 校验请求
_, _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
if err != nil {
return nil, err
}
var tx = this.NullTx()
ipLibraries, err := models.SharedIPLibraryDAO.FindAllEnabledIPLibrariesWithType(tx, req.Type)
if err != nil {
return nil, err
}
result := []*pb.IPLibrary{}
for _, library := range ipLibraries {
// 文件相关
var pbFile *pb.File = nil
file, err := models.SharedFileDAO.FindEnabledFile(tx, int64(library.FileId))
if err != nil {
return nil, err
}
if file != nil {
pbFile = &pb.File{
Id: int64(file.Id),
Filename: file.Filename,
Size: int64(file.Size),
}
}
result = append(result, &pb.IPLibrary{
Id: int64(library.Id),
Type: library.Type,
File: pbFile,
CreatedAt: int64(library.CreatedAt),
})
}
return &pb.FindAllEnabledIPLibrariesWithTypeResponse{IpLibraries: result}, nil
}
// DeleteIPLibrary 删除IP库
func (this *IPLibraryService) DeleteIPLibrary(ctx context.Context, req *pb.DeleteIPLibraryRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedIPLibraryDAO.DisableIPLibrary(tx, req.IpLibraryId)
if err != nil {
return nil, err
}
return this.Success()
}
// LookupIPRegion 查询某个IP信息
func (this *IPLibraryService) LookupIPRegion(ctx context.Context, req *pb.LookupIPRegionRequest) (*pb.LookupIPRegionResponse, error) {
// 校验请求
_, _, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var result = iplibrary.LookupIP(req.Ip)
if result == nil || !result.IsOk() {
return &pb.LookupIPRegionResponse{IpRegion: nil}, nil
}
return &pb.LookupIPRegionResponse{IpRegion: &pb.IPRegion{
Country: result.CountryName(),
Region: "",
Province: result.ProvinceName(),
City: result.CityName(),
Isp: result.ProviderName(),
CountryId: result.CountryId(),
ProvinceId: result.ProvinceId(),
CityId: result.CityId(),
TownId: result.TownId(),
ProviderId: result.ProviderId(),
Summary: result.Summary(),
}}, nil
}
// LookupIPRegions 查询一组IP信息
func (this *IPLibraryService) LookupIPRegions(ctx context.Context, req *pb.LookupIPRegionsRequest) (*pb.LookupIPRegionsResponse, error) {
// 校验请求
_, _, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var result = map[string]*pb.IPRegion{}
if len(req.IpList) > 0 {
for _, ip := range req.IpList {
var info = iplibrary.LookupIP(ip)
if info != nil && info.IsOk() {
result[ip] = &pb.IPRegion{
Country: info.CountryName(),
Region: "",
Province: info.ProvinceName(),
City: info.CityName(),
Isp: info.ProviderName(),
CountryId: info.CountryId(),
ProvinceId: info.ProvinceId(),
CityId: info.CityId(),
TownId: info.TownId(),
ProviderId: info.ProviderId(),
Summary: info.Summary(),
}
}
}
}
return &pb.LookupIPRegionsResponse{IpRegionMap: result}, nil
}
// ReloadIPLibrary 重新加载IP库
func (this *IPLibraryService) ReloadIPLibrary(ctx context.Context, req *pb.ReloadIPLibraryRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
// 重新加载IP库
err = iplibrary.InitDefault()
if err != nil {
return nil, err
}
return this.Success()
}
// UploadMaxMindFile 上传MaxMind文件到EdgeAPI
func (this *IPLibraryService) UploadMaxMindFile(ctx context.Context, req *pb.UploadMaxMindFileRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
if len(req.Filename) == 0 || len(req.Data) == 0 {
return nil, errors.New("filename and data are required")
}
// 检查文件名
filename := strings.ToLower(req.Filename)
if !strings.HasSuffix(filename, ".mmdb") {
return nil, errors.New("only MaxMind format files (.mmdb) are supported")
}
// 确定目标路径
iplibDir := Tea.Root + "/data/iplibrary"
err = os.MkdirAll(iplibDir, 0755)
if err != nil {
return nil, fmt.Errorf("create IP library directory failed: %w", err)
}
var targetPath string
if strings.Contains(filename, "city") {
targetPath = filepath.Join(iplibDir, "maxmind-city.mmdb")
} else if strings.Contains(filename, "asn") {
targetPath = filepath.Join(iplibDir, "maxmind-asn.mmdb")
} else {
return nil, errors.New("MaxMind filename must contain 'city' or 'asn'")
}
// 保存文件(使用临时文件原子替换)
tmpPath := targetPath + ".tmp"
err = os.WriteFile(tmpPath, req.Data, 0644)
if err != nil {
return nil, fmt.Errorf("save IP library file failed: %w", err)
}
// 原子替换
err = os.Rename(tmpPath, targetPath)
if err != nil {
os.Remove(tmpPath)
return nil, fmt.Errorf("replace IP library file failed: %w", err)
}
// 重新加载IP库
err = iplibrary.InitDefault()
if err != nil {
return nil, fmt.Errorf("reload IP library failed: %w", err)
}
return this.Success()
}
// FindMaxMindFileStatus 查询MaxMind文件状态
func (this *IPLibraryService) FindMaxMindFileStatus(ctx context.Context, req *pb.FindMaxMindFileStatusRequest) (*pb.FindMaxMindFileStatusResponse, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
// 检查EdgeAPI的data/iplibrary/目录
// 使用与 UploadMaxMindFile 相同的路径逻辑
iplibDir := Tea.Root + "/data/iplibrary"
cityDBPath := filepath.Join(iplibDir, "maxmind-city.mmdb")
asnDBPath := filepath.Join(iplibDir, "maxmind-asn.mmdb")
cityExists := false
asnExists := false
// 检查文件是否存在
if stat, err := os.Stat(cityDBPath); err == nil && stat != nil && !stat.IsDir() {
cityExists = true
}
// 文件不存在是正常情况(会使用嵌入的库),不需要记录错误
if stat, err := os.Stat(asnDBPath); err == nil && stat != nil && !stat.IsDir() {
asnExists = true
}
// 检查是否使用了MaxMind库通过测试查询来判断
testIP := "8.8.8.8"
testResult := iplibrary.LookupIP(testIP)
usingMaxMind := false
if testResult != nil && testResult.IsOk() {
// MaxMind库的特征CountryId 和 ProvinceId 通常为 0因为MaxMind不使用ID系统
// 同时有国家名称,说明查询成功
if testResult.CountryId() == 0 && len(testResult.CountryName()) > 0 {
usingMaxMind = true
}
}
usingEmbeddedMaxMind := usingMaxMind && !cityExists
return &pb.FindMaxMindFileStatusResponse{
CityExists: cityExists,
AsnExists: asnExists,
UsingMaxMind: usingMaxMind,
UsingEmbeddedMaxMind: usingEmbeddedMaxMind,
}, nil
}
// DeleteMaxMindFile 删除MaxMind文件
func (this *IPLibraryService) DeleteMaxMindFile(ctx context.Context, req *pb.DeleteMaxMindFileRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
iplibDir := Tea.Root + "/data/iplibrary"
cityDBPath := filepath.Join(iplibDir, "maxmind-city.mmdb")
asnDBPath := filepath.Join(iplibDir, "maxmind-asn.mmdb")
// 根据文件名删除对应的文件,如果为空则删除所有
filename := strings.ToLower(req.Filename)
if len(filename) == 0 {
// 删除所有文件
if err := os.Remove(cityDBPath); err != nil && !os.IsNotExist(err) {
remotelogs.Error("IP_LIBRARY", "delete city file failed: "+err.Error())
}
if err := os.Remove(asnDBPath); err != nil && !os.IsNotExist(err) {
remotelogs.Error("IP_LIBRARY", "delete ASN file failed: "+err.Error())
}
} else if strings.Contains(filename, "city") {
// 只删除 City 文件
if err := os.Remove(cityDBPath); err != nil && !os.IsNotExist(err) {
return nil, fmt.Errorf("delete city file failed: %w", err)
}
} else if strings.Contains(filename, "asn") {
// 只删除 ASN 文件
if err := os.Remove(asnDBPath); err != nil && !os.IsNotExist(err) {
return nil, fmt.Errorf("delete ASN file failed: %w", err)
}
} else {
return nil, errors.New("filename must contain 'city' or 'asn', or be empty to delete all")
}
// 重新加载IP库使用嵌入的默认库
err = iplibrary.InitDefault()
if err != nil {
remotelogs.Error("IP_LIBRARY", "reload IP library after deletion failed: "+err.Error())
// 不返回错误,因为文件已经删除成功
}
return this.Success()
}

View File

@@ -0,0 +1,181 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package services
import (
"context"
"encoding/json"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/iplibrary"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
// IPLibraryArtifactService IP库制品
type IPLibraryArtifactService struct {
BaseService
}
// CreateIPLibraryArtifact 创建制品
func (this *IPLibraryArtifactService) CreateIPLibraryArtifact(ctx context.Context, req *pb.CreateIPLibraryArtifactRequest) (*pb.CreateIPLibraryArtifactResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
var meta = &iplibrary.Meta{}
err = json.Unmarshal(req.MetaJSON, meta)
if err != nil {
return nil, errors.New("decode meta failed: " + err.Error())
}
// TODO 更新数据库中的省市县等信息?
artifactId, err := models.SharedIPLibraryArtifactDAO.CreateArtifact(tx, req.Name, req.FileId, 0, meta)
if err != nil {
return nil, err
}
return &pb.CreateIPLibraryArtifactResponse{IpLibraryArtifactId: artifactId}, nil
}
// UpdateIPLibraryArtifactIsPublic 使用/取消使用制品
func (this *IPLibraryArtifactService) UpdateIPLibraryArtifactIsPublic(ctx context.Context, req *pb.UpdateIPLibraryArtifactIsPublicRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedIPLibraryArtifactDAO.UpdateArtifactPublic(tx, req.IpLibraryArtifactId, req.IsPublic)
if err != nil {
return nil, err
}
return this.Success()
}
// FindAllIPLibraryArtifacts 查询所有制品
func (this *IPLibraryArtifactService) FindAllIPLibraryArtifacts(ctx context.Context, req *pb.FindAllIPLibraryArtifactsRequest) (*pb.FindAllIPLibraryArtifactsResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
artifacts, err := models.SharedIPLibraryArtifactDAO.FindAllArtifacts(tx)
if err != nil {
return nil, err
}
var pbArtifacts = []*pb.IPLibraryArtifact{}
for _, artifact := range artifacts {
var pbFile *pb.File
if artifact.FileId > 0 {
fileInfo, err := models.SharedFileDAO.FindEnabledFile(tx, int64(artifact.FileId))
if err != nil {
return nil, err
}
if fileInfo != nil {
pbFile = &pb.File{
Id: int64(fileInfo.Id),
Size: int64(fileInfo.Size),
}
}
}
pbArtifacts = append(pbArtifacts, &pb.IPLibraryArtifact{
Id: int64(artifact.Id),
Name: artifact.Name,
FileId: int64(artifact.FileId),
CreatedAt: int64(artifact.CreatedAt),
MetaJSON: artifact.Meta,
IsPublic: artifact.IsPublic,
Code: artifact.Code,
File: pbFile,
})
}
return &pb.FindAllIPLibraryArtifactsResponse{
IpLibraryArtifacts: pbArtifacts,
}, nil
}
// FindIPLibraryArtifact 查找当前正在使用的制品
func (this *IPLibraryArtifactService) FindIPLibraryArtifact(ctx context.Context, req *pb.FindIPLibraryArtifactRequest) (*pb.FindIPLibraryArtifactResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
artifact, err := models.SharedIPLibraryArtifactDAO.FindEnabledIPLibraryArtifact(tx, req.IpLibraryArtifactId)
if err != nil {
return nil, err
}
if artifact == nil {
return &pb.FindIPLibraryArtifactResponse{
IpLibraryArtifact: nil,
}, nil
}
return &pb.FindIPLibraryArtifactResponse{
IpLibraryArtifact: &pb.IPLibraryArtifact{
Id: int64(artifact.Id),
FileId: int64(artifact.FileId),
CreatedAt: int64(artifact.CreatedAt),
MetaJSON: artifact.Meta,
IsPublic: artifact.IsPublic,
Code: artifact.Code,
},
}, nil
}
// FindPublicIPLibraryArtifact 查找当前正在使用的制品
func (this *IPLibraryArtifactService) FindPublicIPLibraryArtifact(ctx context.Context, req *pb.FindPublicIPLibraryArtifactRequest) (*pb.FindPublicIPLibraryArtifactResponse, error) {
_, _, err := this.ValidateNodeId(ctx, rpcutils.UserTypeNode, rpcutils.UserTypeDNS)
if err != nil {
return nil, err
}
var tx = this.NullTx()
artifact, err := models.SharedIPLibraryArtifactDAO.FindPublicArtifact(tx)
if err != nil {
return nil, err
}
if artifact == nil {
return &pb.FindPublicIPLibraryArtifactResponse{
IpLibraryArtifact: nil,
}, nil
}
return &pb.FindPublicIPLibraryArtifactResponse{
IpLibraryArtifact: &pb.IPLibraryArtifact{
Id: int64(artifact.Id),
FileId: int64(artifact.FileId),
CreatedAt: int64(artifact.CreatedAt),
MetaJSON: artifact.Meta,
IsPublic: artifact.IsPublic,
Code: artifact.Code,
},
}, nil
}
// DeleteIPLibraryArtifact 删除制品
func (this *IPLibraryArtifactService) DeleteIPLibraryArtifact(ctx context.Context, req *pb.DeleteIPLibraryArtifactRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedIPLibraryArtifactDAO.DisableIPLibraryArtifact(tx, req.IpLibraryArtifactId)
if err != nil {
return nil, err
}
return this.Success()
}

View File

@@ -0,0 +1,716 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package services
import (
"context"
"encoding/json"
"errors"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/regions"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/types"
)
// IPLibraryFileService IP库文件管理
type IPLibraryFileService struct {
BaseService
}
// FindAllFinishedIPLibraryFiles 查找所有已完成的IP库文件
func (this *IPLibraryFileService) FindAllFinishedIPLibraryFiles(ctx context.Context, req *pb.FindAllFinishedIPLibraryFilesRequest) (*pb.FindAllFinishedIPLibraryFilesResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
libraryFiles, err := models.SharedIPLibraryFileDAO.FindAllFinishedLibraryFiles(tx)
if err != nil {
return nil, err
}
var pbLibraryFiles = []*pb.IPLibraryFile{}
for _, libraryFile := range libraryFiles {
var pbCountryNames = libraryFile.DecodeCountries()
var pbProviderNames = libraryFile.DecodeProviders()
var pbProvinces = []*pb.IPLibraryFile_Province{}
for _, province := range libraryFile.DecodeProvinces() {
pbProvinces = append(pbProvinces, &pb.IPLibraryFile_Province{
CountryName: province[0],
ProvinceName: province[1],
})
}
var pbCities = []*pb.IPLibraryFile_City{}
for _, city := range libraryFile.DecodeCities() {
pbCities = append(pbCities, &pb.IPLibraryFile_City{
CountryName: city[0],
ProvinceName: city[1],
CityName: city[2],
})
}
var pbTowns = []*pb.IPLibraryFile_Town{}
for _, town := range libraryFile.DecodeTowns() {
pbTowns = append(pbTowns, &pb.IPLibraryFile_Town{
CountryName: town[0],
ProvinceName: town[1],
CityName: town[2],
TownName: town[3],
})
}
pbLibraryFiles = append(pbLibraryFiles, &pb.IPLibraryFile{
Id: int64(libraryFile.Id),
Name: libraryFile.Name,
FileId: int64(libraryFile.FileId),
IsFinished: libraryFile.IsFinished,
CreatedAt: int64(libraryFile.CreatedAt),
GeneratedFileId: int64(libraryFile.GeneratedFileId),
GeneratedAt: int64(libraryFile.GeneratedAt),
Password: libraryFile.Password,
CountryNames: pbCountryNames,
Provinces: pbProvinces,
Cities: pbCities,
Towns: pbTowns,
ProviderNames: pbProviderNames,
})
}
return &pb.FindAllFinishedIPLibraryFilesResponse{
IpLibraryFiles: pbLibraryFiles,
}, nil
}
// FindAllUnfinishedIPLibraryFiles 查找所有未完成的IP库文件
func (this *IPLibraryFileService) FindAllUnfinishedIPLibraryFiles(ctx context.Context, req *pb.FindAllUnfinishedIPLibraryFilesRequest) (*pb.FindAllUnfinishedIPLibraryFilesResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
libraryFiles, err := models.SharedIPLibraryFileDAO.FindAllUnfinishedLibraryFiles(tx)
if err != nil {
return nil, err
}
var pbLibraryFiles = []*pb.IPLibraryFile{}
for _, libraryFile := range libraryFiles {
var pbCountryNames = libraryFile.DecodeCountries()
var pbProviderNames = libraryFile.DecodeProviders()
var pbProvinces = []*pb.IPLibraryFile_Province{}
for _, province := range libraryFile.DecodeProvinces() {
pbProvinces = append(pbProvinces, &pb.IPLibraryFile_Province{
CountryName: province[0],
ProvinceName: province[1],
})
}
var pbCities = []*pb.IPLibraryFile_City{}
for _, city := range libraryFile.DecodeCities() {
pbCities = append(pbCities, &pb.IPLibraryFile_City{
CountryName: city[0],
ProvinceName: city[1],
CityName: city[2],
})
}
var pbTowns = []*pb.IPLibraryFile_Town{}
for _, town := range libraryFile.DecodeTowns() {
pbTowns = append(pbTowns, &pb.IPLibraryFile_Town{
CountryName: town[0],
ProvinceName: town[1],
CityName: town[2],
TownName: town[3],
})
}
pbLibraryFiles = append(pbLibraryFiles, &pb.IPLibraryFile{
Id: int64(libraryFile.Id),
Name: libraryFile.Name,
FileId: int64(libraryFile.FileId),
IsFinished: libraryFile.IsFinished,
CreatedAt: int64(libraryFile.CreatedAt),
Password: libraryFile.Password,
CountryNames: pbCountryNames,
Provinces: pbProvinces,
Cities: pbCities,
Towns: pbTowns,
ProviderNames: pbProviderNames,
})
}
return &pb.FindAllUnfinishedIPLibraryFilesResponse{
IpLibraryFiles: pbLibraryFiles,
}, nil
}
// FindIPLibraryFile 查找单个IP库文件
func (this *IPLibraryFileService) FindIPLibraryFile(ctx context.Context, req *pb.FindIPLibraryFileRequest) (*pb.FindIPLibraryFileResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
libraryFile, err := models.SharedIPLibraryFileDAO.FindEnabledIPLibraryFile(tx, req.IpLibraryFileId)
if err != nil {
return nil, err
}
if libraryFile == nil {
return &pb.FindIPLibraryFileResponse{
IpLibraryFile: nil,
}, nil
}
var pbCountryNames = libraryFile.DecodeCountries()
var pbProviderNames = libraryFile.DecodeProviders()
var pbProvinces = []*pb.IPLibraryFile_Province{}
for _, province := range libraryFile.DecodeProvinces() {
pbProvinces = append(pbProvinces, &pb.IPLibraryFile_Province{
CountryName: province[0],
ProvinceName: province[1],
})
}
var pbCities = []*pb.IPLibraryFile_City{}
for _, city := range libraryFile.DecodeCities() {
pbCities = append(pbCities, &pb.IPLibraryFile_City{
CountryName: city[0],
ProvinceName: city[1],
CityName: city[2],
})
}
var pbTowns = []*pb.IPLibraryFile_Town{}
for _, town := range libraryFile.DecodeTowns() {
pbTowns = append(pbTowns, &pb.IPLibraryFile_Town{
CountryName: town[0],
ProvinceName: town[1],
CityName: town[2],
TownName: town[3],
})
}
return &pb.FindIPLibraryFileResponse{
IpLibraryFile: &pb.IPLibraryFile{
Id: int64(libraryFile.Id),
Name: libraryFile.Name,
Template: libraryFile.Template,
EmptyValues: libraryFile.DecodeEmptyValues(),
FileId: int64(libraryFile.FileId),
IsFinished: libraryFile.IsFinished,
CreatedAt: int64(libraryFile.CreatedAt),
GeneratedFileId: int64(libraryFile.GeneratedFileId),
Password: libraryFile.Password,
CountryNames: pbCountryNames,
Provinces: pbProvinces,
Cities: pbCities,
Towns: pbTowns,
ProviderNames: pbProviderNames,
},
}, nil
}
// CreateIPLibraryFile 创建IP库文件
func (this *IPLibraryFileService) CreateIPLibraryFile(ctx context.Context, req *pb.CreateIPLibraryFileRequest) (*pb.CreateIPLibraryFileResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var countries = []string{}
var provinces = [][2]string{}
var cities = [][3]string{}
var towns = [][4]string{}
var providers = []string{}
err = json.Unmarshal(req.CountriesJSON, &countries)
if err != nil {
return nil, errors.New("decode countries failed: " + err.Error())
}
err = json.Unmarshal(req.ProvincesJSON, &provinces)
if err != nil {
return nil, errors.New("decode provinces failed: " + err.Error())
}
err = json.Unmarshal(req.CitiesJSON, &cities)
if err != nil {
return nil, errors.New("decode cities failed: " + err.Error())
}
err = json.Unmarshal(req.TownsJSON, &towns)
if err != nil {
return nil, errors.New("decode towns failed: " + err.Error())
}
err = json.Unmarshal(req.ProvidersJSON, &providers)
if err != nil {
return nil, errors.New("decode providers failed: " + err.Error())
}
var tx = this.NullTx()
libraryFileId, err := models.SharedIPLibraryFileDAO.CreateLibraryFile(tx, req.Name, req.Template, req.EmptyValues, req.Password, req.FileId, countries, provinces, cities, towns, providers)
if err != nil {
return nil, err
}
return &pb.CreateIPLibraryFileResponse{
IpLibraryFileId: libraryFileId,
}, nil
}
// CheckCountriesWithIPLibraryFileId 检查国家/地区
func (this *IPLibraryFileService) CheckCountriesWithIPLibraryFileId(ctx context.Context, req *pb.CheckCountriesWithIPLibraryFileIdRequest) (*pb.CheckCountriesWithIPLibraryFileIdResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
allCountries, err := regions.SharedRegionCountryDAO.FindAllCountries(tx)
if err != nil {
return nil, err
}
countryNames, err := models.SharedIPLibraryFileDAO.FindLibraryFileCountries(tx, req.IpLibraryFileId)
if err != nil {
return nil, err
}
var pbMissingCountries = []*pb.CheckCountriesWithIPLibraryFileIdResponse_MissingCountry{}
for _, countryName := range countryNames {
if len(countryName) == 0 {
continue
}
// 检查是否存在
countryId, err := regions.SharedRegionCountryDAO.FindCountryIdWithName(tx, countryName)
if err != nil {
return nil, err
}
if countryId > 0 {
continue
}
var pbMissingCountry = &pb.CheckCountriesWithIPLibraryFileIdResponse_MissingCountry{
CountryName: countryName,
SimilarCountries: nil,
}
// 查找相似
var similarCountries = regions.SharedRegionCountryDAO.FindSimilarCountries(allCountries, countryName, 5)
if err != nil {
return nil, err
}
for _, similarCountry := range similarCountries {
pbMissingCountry.SimilarCountries = append(pbMissingCountry.SimilarCountries, &pb.RegionCountry{
Id: int64(similarCountry.ValueId),
Name: similarCountry.Name,
DisplayName: similarCountry.DisplayName(),
})
}
pbMissingCountries = append(pbMissingCountries, pbMissingCountry)
}
return &pb.CheckCountriesWithIPLibraryFileIdResponse{
MissingCountries: pbMissingCountries,
}, nil
}
// CheckProvincesWithIPLibraryFileId 检查省份/州
func (this *IPLibraryFileService) CheckProvincesWithIPLibraryFileId(ctx context.Context, req *pb.CheckProvincesWithIPLibraryFileIdRequest) (*pb.CheckProvincesWithIPLibraryFileIdResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
provinces, err := models.SharedIPLibraryFileDAO.FindLibraryFileProvinces(tx, req.IpLibraryFileId)
if err != nil {
return nil, err
}
var countryMap = map[string]int64{} // countryName => countryId
var provinceNamesMap = map[int64][][2]string{} // countryId => [][2]{countryName, provinceName}
var countryIds = []int64{}
for _, province := range provinces {
var countryName = province[0]
var provinceName = province[1]
countryId, ok := countryMap[countryName]
if ok {
provinceNamesMap[countryId] = append(provinceNamesMap[countryId], [2]string{countryName, provinceName})
continue
}
countryId, err := regions.SharedRegionCountryDAO.FindCountryIdWithName(tx, countryName)
if err != nil {
return nil, err
}
countryMap[countryName] = countryId
provinceNamesMap[countryId] = append(provinceNamesMap[countryId], [2]string{countryName, provinceName})
if countryId > 0 && !lists.ContainsInt64(countryIds, countryId) {
countryIds = append(countryIds, countryId)
}
}
var pbMissingProvinces = []*pb.CheckProvincesWithIPLibraryFileIdResponse_MissingProvince{}
for _, countryId := range countryIds {
allProvinces, err := regions.SharedRegionProvinceDAO.FindAllEnabledProvincesWithCountryId(tx, countryId)
if err != nil {
return nil, err
}
for _, province := range provinceNamesMap[countryId] {
var countryName = province[0]
var provinceName = province[1]
provinceId, err := regions.SharedRegionProvinceDAO.FindProvinceIdWithName(tx, countryId, provinceName)
if err != nil {
return nil, err
}
if provinceId > 0 {
continue
}
var similarProvinces = regions.SharedRegionProvinceDAO.FindSimilarProvinces(allProvinces, provinceName, 5)
if err != nil {
return nil, err
}
var pbMissingProvince = &pb.CheckProvincesWithIPLibraryFileIdResponse_MissingProvince{}
pbMissingProvince.CountryName = countryName
pbMissingProvince.ProvinceName = provinceName
for _, similarProvince := range similarProvinces {
pbMissingProvince.SimilarProvinces = append(pbMissingProvince.SimilarProvinces, &pb.RegionProvince{
Id: int64(similarProvince.ValueId),
Name: similarProvince.Name,
DisplayName: similarProvince.DisplayName(),
})
}
pbMissingProvinces = append(pbMissingProvinces, pbMissingProvince)
}
}
return &pb.CheckProvincesWithIPLibraryFileIdResponse{MissingProvinces: pbMissingProvinces}, nil
}
// CheckCitiesWithIPLibraryFileId 检查城市/市
func (this *IPLibraryFileService) CheckCitiesWithIPLibraryFileId(ctx context.Context, req *pb.CheckCitiesWithIPLibraryFileIdRequest) (*pb.CheckCitiesWithIPLibraryFileIdResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
cities, err := models.SharedIPLibraryFileDAO.FindLibraryFileCities(tx, req.IpLibraryFileId)
if err != nil {
return nil, err
}
var countryMap = map[string]int64{} // countryName => countryId
var provinceMap = map[string]int64{} // countryId_provinceName => provinceId
var cityNamesMap = map[int64][][3]string{} // provinceId => [][3]{countryName, provinceName, cityName}
var provinceIds = []int64{}
for _, city := range cities {
var countryName = city[0]
var provinceName = city[1]
var cityName = city[2]
countryId, ok := countryMap[countryName]
if !ok {
countryId, err = regions.SharedRegionCountryDAO.FindCountryIdWithName(tx, countryName)
if err != nil {
return nil, err
}
}
countryMap[countryName] = countryId
var key = types.String(countryId) + "_" + provinceName
provinceId, ok := provinceMap[key]
if ok {
cityNamesMap[provinceId] = append(cityNamesMap[provinceId], [3]string{countryName, provinceName, cityName})
} else {
provinceId, err := regions.SharedRegionProvinceDAO.FindProvinceIdWithName(tx, countryId, provinceName)
if err != nil {
return nil, err
}
provinceMap[key] = provinceId
cityNamesMap[provinceId] = append(cityNamesMap[provinceId], [3]string{countryName, provinceName, cityName})
if provinceId > 0 {
provinceIds = append(provinceIds, provinceId)
}
}
}
var pbMissingCities = []*pb.CheckCitiesWithIPLibraryFileIdResponse_MissingCity{}
for _, provinceId := range provinceIds {
allCities, err := regions.SharedRegionCityDAO.FindAllEnabledCitiesWithProvinceId(tx, provinceId)
if err != nil {
return nil, err
}
for _, city := range cityNamesMap[provinceId] {
var countryName = city[0]
var provinceName = city[1]
var cityName = city[2]
cityId, err := regions.SharedRegionCityDAO.FindCityIdWithName(tx, provinceId, cityName)
if err != nil {
return nil, err
}
if cityId > 0 {
continue
}
var similarCities = regions.SharedRegionCityDAO.FindSimilarCities(allCities, cityName, 5)
if err != nil {
return nil, err
}
var pbMissingCity = &pb.CheckCitiesWithIPLibraryFileIdResponse_MissingCity{}
pbMissingCity.CountryName = countryName
pbMissingCity.ProvinceName = provinceName
pbMissingCity.CityName = cityName
for _, similarCity := range similarCities {
pbMissingCity.SimilarCities = append(pbMissingCity.SimilarCities, &pb.RegionCity{
Id: int64(similarCity.ValueId),
Name: similarCity.Name,
DisplayName: similarCity.DisplayName(),
})
}
pbMissingCities = append(pbMissingCities, pbMissingCity)
}
}
return &pb.CheckCitiesWithIPLibraryFileIdResponse{MissingCities: pbMissingCities}, nil
}
// CheckTownsWithIPLibraryFileId 检查区县
func (this *IPLibraryFileService) CheckTownsWithIPLibraryFileId(ctx context.Context, req *pb.CheckTownsWithIPLibraryFileIdRequest) (*pb.CheckTownsWithIPLibraryFileIdResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
towns, err := models.SharedIPLibraryFileDAO.FindLibraryFileTowns(tx, req.IpLibraryFileId)
if err != nil {
return nil, err
}
var countryMap = map[string]int64{} // countryName => countryId
var provinceMap = map[string]int64{} // countryId_provinceName => provinceId
var cityMap = map[string]int64{} // province_cityName => cityId
var townNamesMap = map[int64][][4]string{} // cityId => [][4]{countryName, provinceName, cityName, townName}
var cityIds = []int64{}
for _, town := range towns {
var countryName = town[0]
var provinceName = town[1]
var cityName = town[2]
var townName = town[3]
// country
countryId, ok := countryMap[countryName]
if !ok {
countryId, err = regions.SharedRegionCountryDAO.FindCountryIdWithName(tx, countryName)
if err != nil {
return nil, err
}
}
countryMap[countryName] = countryId
// province
var provinceKey = types.String(countryId) + "_" + provinceName
provinceId, ok := provinceMap[provinceKey]
if !ok {
if countryId > 0 {
provinceId, err = regions.SharedRegionProvinceDAO.FindProvinceIdWithName(tx, countryId, provinceName)
if err != nil {
return nil, err
}
}
provinceMap[provinceKey] = provinceId
}
// city
var cityKey = types.String(provinceId) + "_" + cityName
cityId, ok := cityMap[cityKey]
if !ok {
if provinceId > 0 {
cityId, err = regions.SharedRegionCityDAO.FindCityIdWithName(tx, provinceId, cityName)
if err != nil {
return nil, err
}
}
cityMap[cityKey] = cityId
if cityId > 0 {
cityIds = append(cityIds, cityId)
}
}
// town
townNamesMap[cityId] = append(townNamesMap[cityId], [4]string{countryName, provinceName, cityName, townName})
}
var pbMissingTowns = []*pb.CheckTownsWithIPLibraryFileIdResponse_MissingTown{}
for _, cityId := range cityIds {
allTowns, err := regions.SharedRegionTownDAO.FindAllRegionTownsWithCityId(tx, cityId)
if err != nil {
return nil, err
}
for _, town := range townNamesMap[cityId] {
var countryName = town[0]
var provinceName = town[1]
var cityName = town[2]
var townName = town[3]
townId, err := regions.SharedRegionTownDAO.FindTownIdWithName(tx, cityId, townName)
if err != nil {
return nil, err
}
if townId > 0 {
// 已存在,则跳过
continue
}
var similarTowns = regions.SharedRegionTownDAO.FindSimilarTowns(allTowns, townName, 5)
if err != nil {
return nil, err
}
var pbMissingTown = &pb.CheckTownsWithIPLibraryFileIdResponse_MissingTown{}
pbMissingTown.CountryName = countryName
pbMissingTown.ProvinceName = provinceName
pbMissingTown.CityName = cityName
pbMissingTown.TownName = townName
for _, similarTown := range similarTowns {
pbMissingTown.SimilarTowns = append(pbMissingTown.SimilarTowns, &pb.RegionTown{
Id: int64(similarTown.ValueId),
Name: similarTown.Name,
DisplayName: similarTown.DisplayName(),
})
}
pbMissingTowns = append(pbMissingTowns, pbMissingTown)
}
}
return &pb.CheckTownsWithIPLibraryFileIdResponse{MissingTowns: pbMissingTowns}, nil
}
// CheckProvidersWithIPLibraryFileId 检查ISP运营商
func (this *IPLibraryFileService) CheckProvidersWithIPLibraryFileId(ctx context.Context, req *pb.CheckProvidersWithIPLibraryFileIdRequest) (*pb.CheckProvidersWithIPLibraryFileIdResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
allProviders, err := regions.SharedRegionProviderDAO.FindAllEnabledProviders(tx)
if err != nil {
return nil, err
}
providerNames, err := models.SharedIPLibraryFileDAO.FindLibraryFileProviders(tx, req.IpLibraryFileId)
if err != nil {
return nil, err
}
var pbMissingProviders = []*pb.CheckProvidersWithIPLibraryFileIdResponse_MissingProvider{}
for _, providerName := range providerNames {
if len(providerName) == 0 {
continue
}
// 检查是否存在
providerId, err := regions.SharedRegionProviderDAO.FindProviderIdWithName(tx, providerName)
if err != nil {
return nil, err
}
if providerId > 0 {
continue
}
var pbMissingProvider = &pb.CheckProvidersWithIPLibraryFileIdResponse_MissingProvider{
ProviderName: providerName,
SimilarProviders: nil,
}
// 查找相似
var similarProviders = regions.SharedRegionProviderDAO.FindSimilarProviders(allProviders, providerName, 5)
if err != nil {
return nil, err
}
for _, similarProvider := range similarProviders {
pbMissingProvider.SimilarProviders = append(pbMissingProvider.SimilarProviders, &pb.RegionProvider{
Id: int64(similarProvider.ValueId),
Name: similarProvider.Name,
DisplayName: similarProvider.DisplayName(),
})
}
pbMissingProviders = append(pbMissingProviders, pbMissingProvider)
}
return &pb.CheckProvidersWithIPLibraryFileIdResponse{
MissingProviders: pbMissingProviders,
}, nil
}
// GenerateIPLibraryFile 生成IP库文件
func (this *IPLibraryFileService) GenerateIPLibraryFile(ctx context.Context, req *pb.GenerateIPLibraryFileRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedIPLibraryFileDAO.GenerateIPLibrary(tx, req.IpLibraryFileId)
if err != nil {
return nil, err
}
return this.Success()
}
// UpdateIPLibraryFileFinished 设置某个IP库为已完成
func (this *IPLibraryFileService) UpdateIPLibraryFileFinished(ctx context.Context, req *pb.UpdateIPLibraryFileFinishedRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedIPLibraryFileDAO.UpdateLibraryFileIsFinished(tx, req.IpLibraryFileId)
if err != nil {
return nil, err
}
return this.Success()
}
// DeleteIPLibraryFile 删除IP库文件
func (this *IPLibraryFileService) DeleteIPLibraryFile(ctx context.Context, req *pb.DeleteIPLibraryFileRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedIPLibraryFileDAO.DisableIPLibraryFile(tx, req.IpLibraryFileId)
if err != nil {
return nil, err
}
return this.Success()
}

View File

@@ -0,0 +1,334 @@
package services
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/rands"
)
// IPListService IP名单相关服务
type IPListService struct {
BaseService
}
// CreateIPList 创建IP列表
func (this *IPListService) CreateIPList(ctx context.Context, req *pb.CreateIPListRequest) (*pb.CreateIPListResponse, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 修正默认的代号
if req.Code == "white" || req.Code == "black" || req.Code == "grey" {
req.Code = req.Code + "-" + rands.HexString(8)
}
// 检查用户相关信息
var sourceUserId = userId
if userId > 0 {
// 检查网站ID
if req.ServerId > 0 {
err = models.SharedServerDAO.CheckUserServer(tx, userId, req.ServerId)
if err != nil {
return nil, err
}
}
} else if req.ServerId > 0 {
sourceUserId, err = models.SharedServerDAO.FindServerUserId(tx, req.ServerId)
if err != nil {
return nil, err
}
}
// 检查代号
if len(req.Code) > 0 {
if len(req.Code) > 100 {
return nil, errors.New("too long 'code', should be short than 100 characters")
}
if !models.SharedIPListDAO.ValidateIPListCode(req.Code) {
return nil, errors.New("invalid 'code' format")
}
oldListId, findErr := models.SharedIPListDAO.FindIPListIdWithCode(tx, req.Code)
if findErr != nil {
return nil, findErr
}
if oldListId > 0 {
return nil, errors.New("the code '" + req.Code + "' has been used")
}
}
listId, err := models.SharedIPListDAO.CreateIPList(tx, sourceUserId, req.ServerId, req.Type, req.Name, req.Code, req.TimeoutJSON, req.Description, req.IsPublic, req.IsGlobal)
if err != nil {
return nil, err
}
return &pb.CreateIPListResponse{IpListId: listId}, nil
}
// UpdateIPList 修改IP列表
func (this *IPListService) UpdateIPList(ctx context.Context, req *pb.UpdateIPListRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 检查代号
if len(req.Code) > 0 {
if len(req.Code) > 100 {
return nil, errors.New("too long 'code', should be short than 100 characters")
}
if !models.SharedIPListDAO.ValidateIPListCode(req.Code) {
return nil, errors.New("invalid 'code' format")
}
oldListId, findErr := models.SharedIPListDAO.FindIPListIdWithCode(tx, req.Code)
if findErr != nil {
return nil, findErr
}
if oldListId > 0 && oldListId != req.IpListId {
return nil, errors.New("the code '" + req.Code + "' has been used")
}
}
err = models.SharedIPListDAO.UpdateIPList(tx, req.IpListId, req.Name, req.Code, req.TimeoutJSON, req.Description)
if err != nil {
return nil, err
}
return this.Success()
}
// FindEnabledIPList 查找IP列表
func (this *IPListService) FindEnabledIPList(ctx context.Context, req *pb.FindEnabledIPListRequest) (*pb.FindEnabledIPListResponse, error) {
// 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
if userId > 0 {
// 检查用户所属名单
if !firewallconfigs.IsGlobalListId(req.IpListId) {
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId)
if err != nil {
return nil, err
}
}
}
list, err := models.SharedIPListDAO.FindEnabledIPList(tx, req.IpListId, nil)
if err != nil {
return nil, err
}
if list == nil {
return &pb.FindEnabledIPListResponse{IpList: nil}, nil
}
return &pb.FindEnabledIPListResponse{IpList: &pb.IPList{
Id: int64(list.Id),
IsOn: list.IsOn,
Type: list.Type,
Name: list.Name,
Code: list.Code,
TimeoutJSON: list.Timeout,
Description: list.Description,
IsGlobal: list.IsGlobal,
}}, nil
}
// CountAllEnabledIPLists 计算名单数量
func (this *IPListService) CountAllEnabledIPLists(ctx context.Context, req *pb.CountAllEnabledIPListsRequest) (*pb.RPCCountResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
count, err := models.SharedIPListDAO.CountAllEnabledIPLists(tx, req.Type, req.IsPublic, req.Keyword)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListEnabledIPLists 列出单页名单
func (this *IPListService) ListEnabledIPLists(ctx context.Context, req *pb.ListEnabledIPListsRequest) (*pb.ListEnabledIPListsResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
ipLists, err := models.SharedIPListDAO.ListEnabledIPLists(tx, req.Type, req.IsPublic, req.Keyword, req.Offset, req.Size)
if err != nil {
return nil, err
}
var pbLists []*pb.IPList
for _, list := range ipLists {
pbLists = append(pbLists, &pb.IPList{
Id: int64(list.Id),
IsOn: list.IsOn,
Type: list.Type,
Name: list.Name,
Code: list.Code,
TimeoutJSON: list.Timeout,
IsPublic: list.IsPublic,
Description: list.Description,
IsGlobal: list.IsGlobal,
})
}
return &pb.ListEnabledIPListsResponse{IpLists: pbLists}, nil
}
// DeleteIPList 删除IP名单
func (this *IPListService) DeleteIPList(ctx context.Context, req *pb.DeleteIPListRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedIPListDAO.DisableIPList(tx, req.IpListId)
if err != nil {
return nil, err
}
// 删除所有IP
err = models.SharedIPItemDAO.DisableIPItemsWithListId(tx, req.IpListId)
if err != nil {
return nil, err
}
return this.Success()
}
// ExistsEnabledIPList 检查IPList是否存在
func (this *IPListService) ExistsEnabledIPList(ctx context.Context, req *pb.ExistsEnabledIPListRequest) (*pb.ExistsEnabledIPListResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
b, err := models.SharedIPListDAO.ExistsEnabledIPList(tx, req.IpListId)
if err != nil {
return nil, err
}
return &pb.ExistsEnabledIPListResponse{Exists: b}, nil
}
// FindEnabledIPListContainsIP 根据IP来搜索IP名单
func (this *IPListService) FindEnabledIPListContainsIP(ctx context.Context, req *pb.FindEnabledIPListContainsIPRequest) (*pb.FindEnabledIPListContainsIPResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
items, err := models.SharedIPItemDAO.FindEnabledItemsWithIP(tx, req.Ip)
if err != nil {
return nil, err
}
var pbLists = []*pb.IPList{}
var listIds = []int64{}
var cacheMap = utils.NewCacheMap()
for _, item := range items {
if lists.ContainsInt64(listIds, int64(item.ListId)) {
continue
}
list, err := models.SharedIPListDAO.FindEnabledIPList(tx, int64(item.ListId), cacheMap)
if err != nil {
return nil, err
}
if list == nil {
continue
}
if !list.IsPublic {
continue
}
pbLists = append(pbLists, &pb.IPList{
Id: int64(list.Id),
IsOn: list.IsOn,
Type: list.Type,
Name: list.Name,
Code: list.Code,
IsPublic: list.IsPublic,
IsGlobal: list.IsGlobal,
Description: "",
})
listIds = append(listIds, int64(item.ListId))
}
return &pb.FindEnabledIPListContainsIPResponse{IpLists: pbLists}, nil
}
// FindServerIdWithIPListId 查找IP名单对应的网站ID
func (this *IPListService) FindServerIdWithIPListId(ctx context.Context, req *pb.FindServerIdWithIPListIdRequest) (*pb.FindServerIdWithIPListIdResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
serverId, err := models.SharedIPListDAO.FindServerIdWithListId(tx, req.IpListId)
if err != nil {
return nil, err
}
// check user
if serverId > 0 && userId > 0 {
err = models.SharedServerDAO.CheckUserServer(tx, userId, serverId)
if err != nil {
return nil, err
}
}
return &pb.FindServerIdWithIPListIdResponse{
ServerId: serverId,
}, nil
}
// FindIPListIdWithCode 根据IP名单代号获取IP名单ID
func (this *IPListService) FindIPListIdWithCode(ctx context.Context, req *pb.FindIPListIdWithCodeRequest) (*pb.FindIPListIdWithCodeResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if len(req.Code) == 0 {
return nil, errors.New("require 'code'")
}
var tx = this.NullTx()
listId, err := models.SharedIPListDAO.FindIPListIdWithCode(tx, req.Code)
if err != nil {
return nil, err
}
if listId > 0 {
if userId > 0 {
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, listId)
if err != nil {
return nil, err
}
}
}
return &pb.FindIPListIdWithCodeResponse{
IpListId: listId,
}, nil
}

View File

@@ -0,0 +1,28 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package services
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
// LatestItemService 最近使用的条目服务
type LatestItemService struct {
BaseService
}
// IncreaseLatestItem 记录最近使用的条目
func (this *LatestItemService) IncreaseLatestItem(ctx context.Context, req *pb.IncreaseLatestItemRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedLatestItemDAO.IncreaseItemCount(tx, req.ItemType, req.ItemId)
if err != nil {
return nil, err
}
return this.Success()
}

View File

@@ -0,0 +1,190 @@
package services
import (
"context"
"encoding/json"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/langs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
// LogService 管理员、用户或者其他系统用户日志
type LogService struct {
BaseService
}
// CreateLog 创建日志
func (this *LogService) CreateLog(ctx context.Context, req *pb.CreateLogRequest) (*pb.CreateLogResponse, error) {
// 校验请求
userType, _, userId, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeUser, rpcutils.UserTypeProvider)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// i18n
var langMessageArgs = []any{}
if len(req.LangMessageArgsJSON) > 0 {
err = json.Unmarshal(req.LangMessageArgsJSON, &langMessageArgs)
if err != nil {
return nil, err
}
}
err = models.SharedLogDAO.CreateLog(tx, userType, userId, req.Level, req.Description, req.Action, req.Ip, langs.MessageCode(req.LangMessageCode), langMessageArgs)
if err != nil {
return nil, err
}
return &pb.CreateLogResponse{}, nil
}
// CountLogs 计算日志数量
func (this *LogService) CountLogs(ctx context.Context, req *pb.CountLogRequest) (*pb.RPCCountResponse, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
count, err := models.SharedLogDAO.CountLogs(tx, req.DayFrom, req.DayTo, req.Keyword, req.UserType, req.Level)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListLogs 列出单页日志
func (this *LogService) ListLogs(ctx context.Context, req *pb.ListLogsRequest) (*pb.ListLogsResponse, error) {
// 校验请求
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
logs, err := models.SharedLogDAO.ListLogs(tx, req.Offset, req.Size, req.DayFrom, req.DayTo, req.Keyword, req.UserType, req.Level)
if err != nil {
return nil, err
}
result := []*pb.Log{}
for _, log := range logs {
userName := ""
if log.AdminId > 0 {
userName, err = models.SharedAdminDAO.FindAdminFullname(tx, int64(log.AdminId))
} else if log.UserId > 0 {
userName, err = models.SharedUserDAO.FindUserFullname(tx, int64(log.UserId))
} else if log.ProviderId > 0 {
userName, err = models.SharedProviderDAO.FindProviderName(tx, int64(log.ProviderId))
}
if err != nil {
return nil, err
}
result = append(result, &pb.Log{
Id: int64(log.Id),
Level: log.Level,
Action: log.Action,
AdminId: int64(log.AdminId),
UserId: int64(log.UserId),
ProviderId: int64(log.ProviderId),
CreatedAt: int64(log.CreatedAt),
Type: log.Type,
Ip: log.Ip,
UserName: userName,
Description: log.Description,
})
}
return &pb.ListLogsResponse{Logs: result}, nil
}
// DeleteLogPermanently 删除单条
func (this *LogService) DeleteLogPermanently(ctx context.Context, req *pb.DeleteLogPermanentlyRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
// TODO 校验权限
var tx = this.NullTx()
// 执行物理删除
err = models.SharedLogDAO.DeleteLogPermanently(tx, req.LogId)
if err != nil {
return nil, err
}
return this.Success()
}
// DeleteLogsPermanently 批量删除
func (this *LogService) DeleteLogsPermanently(ctx context.Context, req *pb.DeleteLogsPermanentlyRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
// TODO 校验权限
var tx = this.NullTx()
// 执行物理删除
for _, logId := range req.LogIds {
err = models.SharedLogDAO.DeleteLogPermanently(tx, logId)
if err != nil {
return nil, err
}
}
return this.Success()
}
// CleanLogsPermanently 清理日志
func (this *LogService) CleanLogsPermanently(ctx context.Context, req *pb.CleanLogsPermanentlyRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
// TODO 校验权限
var tx = this.NullTx()
if req.ClearAll {
err = models.SharedLogDAO.DeleteAllLogsPermanently(tx)
if err != nil {
return nil, err
}
} else if req.Days > 0 {
err = models.SharedLogDAO.DeleteLogsPermanentlyBeforeDays(tx, int(req.Days))
if err != nil {
return nil, err
}
}
return this.Success()
}
// SumLogsSize 计算日志容量大小
func (this *LogService) SumLogsSize(ctx context.Context, req *pb.SumLogsSizeRequest) (*pb.SumLogsResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
// TODO 校验权限
size, err := models.SharedLogDAO.SumLogsSize()
if err != nil {
return nil, err
}
return &pb.SumLogsResponse{SizeBytes: size}, nil
}

View File

@@ -0,0 +1,83 @@
package services
import (
"context"
"encoding/json"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/maps"
)
// LoginService 管理员认证相关服务
type LoginService struct {
BaseService
}
// FindEnabledLogin 查找认证
func (this *LoginService) FindEnabledLogin(ctx context.Context, req *pb.FindEnabledLoginRequest) (*pb.FindEnabledLoginResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
if userId > 0 {
req.UserId = userId
}
var tx = this.NullTx()
login, err := models.SharedLoginDAO.FindEnabledLoginWithType(tx, req.AdminId, req.UserId, req.Type)
if err != nil {
return nil, err
}
if login == nil {
return &pb.FindEnabledLoginResponse{Login: nil}, nil
}
return &pb.FindEnabledLoginResponse{Login: &pb.Login{
Id: int64(login.Id),
Type: login.Type,
ParamsJSON: login.Params,
IsOn: login.IsOn,
AdminId: int64(login.AdminId),
UserId: int64(login.UserId),
}}, nil
}
// UpdateLogin 修改认证
func (this *LoginService) UpdateLogin(ctx context.Context, req *pb.UpdateLoginRequest) (*pb.RPCSuccess, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
if req.Login == nil {
return nil, errors.New("'login' should not be nil")
}
var tx = this.NullTx()
if userId > 0 {
req.Login.UserId = userId
}
if req.Login.IsOn {
var params = maps.Map{}
if len(req.Login.ParamsJSON) > 0 {
err = json.Unmarshal(req.Login.ParamsJSON, &params)
if err != nil {
return nil, err
}
}
err = models.SharedLoginDAO.UpdateLogin(tx, req.Login.AdminId, req.Login.UserId, req.Login.Type, params, req.Login.IsOn)
if err != nil {
return nil, err
}
} else {
err = models.SharedLoginDAO.DisableLoginWithType(tx, req.Login.AdminId, req.Login.UserId, req.Login.Type)
if err != nil {
return nil, err
}
}
return this.Success()
}

View File

@@ -0,0 +1,114 @@
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package services
import (
"context"
"errors"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
// LoginSessionService 登录SESSION服务
type LoginSessionService struct {
BaseService
}
// WriteLoginSessionValue 写入SESSION数据
func (this *LoginSessionService) WriteLoginSessionValue(ctx context.Context, req *pb.WriteLoginSessionValueRequest) (*pb.RPCSuccess, error) {
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedLoginSessionDAO.WriteSessionValue(tx, req.Sid, req.Key, req.Value)
if err != nil {
return nil, err
}
return this.Success()
}
// DeleteLoginSession 删除SESSION
func (this *LoginSessionService) DeleteLoginSession(ctx context.Context, req *pb.DeleteLoginSessionRequest) (*pb.RPCSuccess, error) {
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
if len(req.Sid) == 0 {
return nil, errors.New("'sid' should not be empty")
}
var tx = this.NullTx()
err = models.SharedLoginSessionDAO.DeleteSession(tx, req.Sid)
if err != nil {
return nil, err
}
return this.Success()
}
// FindLoginSession 查找SESSION
func (this *LoginSessionService) FindLoginSession(ctx context.Context, req *pb.FindLoginSessionRequest) (*pb.FindLoginSessionResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
if len(req.Sid) == 0 {
return nil, errors.New("'token' should not be empty")
}
var tx = this.NullTx()
session, err := models.SharedLoginSessionDAO.FindSession(tx, req.Sid)
if err != nil {
return nil, err
}
if session == nil || !session.IsAvailable() {
return &pb.FindLoginSessionResponse{
LoginSession: nil,
}, nil
}
return &pb.FindLoginSessionResponse{
LoginSession: &pb.LoginSession{
Id: int64(session.Id),
Sid: session.Sid,
AdminId: int64(session.AdminId),
UserId: int64(session.UserId),
Ip: session.Ip,
CreatedAt: int64(session.CreatedAt),
ExpiresAt: int64(session.ExpiresAt),
ValuesJSON: session.Values,
},
}, nil
}
// ClearOldLoginSessions 清理老的SESSION
func (this *LoginSessionService) ClearOldLoginSessions(ctx context.Context, req *pb.ClearOldLoginSessionsRequest) (*pb.RPCSuccess, error) {
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
if len(req.Sid) == 0 {
return nil, errors.New("'token' should not be empty")
}
var tx = this.NullTx()
session, err := models.SharedLoginSessionDAO.FindSession(tx, req.Sid)
if err != nil {
return nil, err
}
if session == nil || !session.IsAvailable() {
return nil, errors.New("invalid sid")
}
err = models.SharedLoginSessionDAO.ClearOldSessions(tx, int64(session.AdminId), int64(session.UserId), req.Sid, req.Ip)
if err != nil {
return nil, err
}
return this.Success()
}

View File

@@ -0,0 +1,71 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package services
import (
"context"
"errors"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeCommon/pkg/iputils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
// LoginTicketService 登录票据相关服务
type LoginTicketService struct {
BaseService
}
// CreateLoginTicket 创建票据
func (this *LoginTicketService) CreateLoginTicket(ctx context.Context, req *pb.CreateLoginTicketRequest) (*pb.CreateLoginTicketResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
if req.AdminId <= 0 && req.UserId <= 0 {
return nil, errors.New("either 'adminId' or 'userId' must be greater than 0")
}
if len(req.Ip) > 0 && !iputils.IsValid(req.Ip) {
return nil, errors.New("invalid ip: '" + req.Ip + "'")
}
var tx = this.NullTx()
value, err := models.SharedLoginTicketDAO.CreateLoginTicket(tx, req.AdminId, req.UserId, req.Ip)
if err != nil {
return nil, err
}
return &pb.CreateLoginTicketResponse{Value: value}, nil
}
// FindLoginTicketWithValue 查找票据
// 查找成功后,会自动删除票据信息,所以票据信息只能查询一次
func (this *LoginTicketService) FindLoginTicketWithValue(ctx context.Context, req *pb.FindLoginTicketWithValueRequest) (*pb.FindLoginTicketWithValueResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, false)
if err != nil {
return nil, err
}
var tx = this.NullTx()
ticket, err := models.SharedLoginTicketDAO.FindLoginTicketWithValue(tx, req.Value)
if err != nil {
return nil, err
}
if ticket == nil {
return &pb.FindLoginTicketWithValueResponse{
LoginTicket: nil,
}, nil
}
return &pb.FindLoginTicketWithValueResponse{
LoginTicket: &pb.LoginTicket{
Id: int64(ticket.Id),
ExpiresAt: int64(ticket.ExpiresAt),
Value: ticket.Value,
AdminId: int64(ticket.AdminId),
UserId: int64(ticket.UserId),
Ip: ticket.Ip,
},
}, nil
}

View File

@@ -0,0 +1,192 @@
package services
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
// MessageService 消息相关服务
type MessageService struct {
BaseService
}
// CountUnreadMessages 计算未读消息数
func (this *MessageService) CountUnreadMessages(ctx context.Context, req *pb.CountUnreadMessagesRequest) (*pb.RPCCountResponse, error) {
// 校验请求
adminId, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
count, err := models.SharedMessageDAO.CountUnreadMessages(tx, adminId, userId)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListUnreadMessages 列出单页未读消息
func (this *MessageService) ListUnreadMessages(ctx context.Context, req *pb.ListUnreadMessagesRequest) (*pb.ListUnreadMessagesResponse, error) {
// 校验请求
adminId, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
messages, err := models.SharedMessageDAO.ListUnreadMessages(tx, adminId, userId, req.Offset, req.Size)
if err != nil {
return nil, err
}
result := []*pb.Message{}
for _, message := range messages {
var pbCluster *pb.NodeCluster = nil
var pbNode *pb.Node = nil
if message.ClusterId > 0 {
switch message.Role {
case nodeconfigs.NodeRoleNode:
cluster, err := models.SharedNodeClusterDAO.FindEnabledNodeCluster(tx, int64(message.ClusterId))
if err != nil {
return nil, err
}
if cluster != nil {
pbCluster = &pb.NodeCluster{
Id: int64(cluster.Id),
Name: cluster.Name,
}
}
case nodeconfigs.NodeRoleDNS:
cluster, err := models.SharedNSClusterDAO.FindEnabledNSCluster(tx, int64(message.ClusterId))
if err != nil {
return nil, err
}
if cluster != nil {
pbCluster = &pb.NodeCluster{
Id: int64(cluster.Id),
Name: cluster.Name,
}
}
}
}
if message.NodeId > 0 {
switch message.Role {
case nodeconfigs.NodeRoleNode:
node, err := models.SharedNodeDAO.FindEnabledNode(tx, int64(message.NodeId))
if err != nil {
return nil, err
}
if node != nil {
pbNode = &pb.Node{
Id: int64(node.Id),
Name: node.Name,
}
}
case nodeconfigs.NodeRoleDNS:
node, err := models.SharedNSNodeDAO.FindEnabledNSNode(tx, int64(message.NodeId))
if err != nil {
return nil, err
}
if node != nil {
pbNode = &pb.Node{
Id: int64(node.Id),
Name: node.Name,
}
}
}
}
result = append(result, &pb.Message{
Id: int64(message.Id),
Role: message.Role,
Type: message.Type,
Body: message.Body,
Level: message.Level,
ParamsJSON: message.Params,
IsRead: message.IsRead,
CreatedAt: int64(message.CreatedAt),
NodeCluster: pbCluster,
Node: pbNode,
})
}
return &pb.ListUnreadMessagesResponse{Messages: result}, nil
}
// UpdateMessageRead 设置消息已读状态
func (this *MessageService) UpdateMessageRead(ctx context.Context, req *pb.UpdateMessageReadRequest) (*pb.RPCSuccess, error) {
// 校验请求
adminId, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 校验权限
exists, err := models.SharedMessageDAO.CheckMessageUser(tx, req.MessageId, adminId, userId)
if err != nil {
return nil, err
}
if !exists {
return nil, this.PermissionError()
}
err = models.SharedMessageDAO.UpdateMessageRead(tx, req.MessageId, req.IsRead)
if err != nil {
return nil, err
}
return this.Success()
}
// UpdateMessagesRead 设置一组消息已读状态
func (this *MessageService) UpdateMessagesRead(ctx context.Context, req *pb.UpdateMessagesReadRequest) (*pb.RPCSuccess, error) {
// 校验请求
adminId, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
// 校验权限
for _, messageId := range req.MessageIds {
exists, err := models.SharedMessageDAO.CheckMessageUser(tx, messageId, adminId, userId)
if err != nil {
return nil, err
}
if !exists {
return nil, this.PermissionError()
}
err = models.SharedMessageDAO.UpdateMessageRead(tx, messageId, req.IsRead)
if err != nil {
return nil, err
}
}
return this.Success()
}
// UpdateAllMessagesRead 设置所有消息为已读
func (this *MessageService) UpdateAllMessagesRead(ctx context.Context, req *pb.UpdateAllMessagesReadRequest) (*pb.RPCSuccess, error) {
// 校验请求
// 校验请求
adminId, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedMessageDAO.UpdateAllMessagesRead(tx, adminId, userId)
if err != nil {
return nil, err
}
return this.Success()
}

View File

@@ -0,0 +1,190 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package services
import (
"context"
"encoding/json"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
)
// MessageMediaInstanceService 消息媒介实例服务
type MessageMediaInstanceService struct {
BaseService
}
// CreateMessageMediaInstance 创建消息媒介实例
func (this *MessageMediaInstanceService) CreateMessageMediaInstance(ctx context.Context, req *pb.CreateMessageMediaInstanceRequest) (*pb.CreateMessageMediaInstanceResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
params := maps.Map{}
if len(req.ParamsJSON) > 0 {
err = json.Unmarshal(req.ParamsJSON, &params)
if err != nil {
return nil, err
}
}
instanceId, err := models.SharedMessageMediaInstanceDAO.CreateMediaInstance(tx, req.Name, req.MediaType, params, req.Description, req.RateJSON, req.HashLife)
if err != nil {
return nil, err
}
return &pb.CreateMessageMediaInstanceResponse{MessageMediaInstanceId: instanceId}, nil
}
// UpdateMessageMediaInstance 修改消息实例
func (this *MessageMediaInstanceService) UpdateMessageMediaInstance(ctx context.Context, req *pb.UpdateMessageMediaInstanceRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
params := maps.Map{}
if len(req.ParamsJSON) > 0 {
err = json.Unmarshal(req.ParamsJSON, &params)
if err != nil {
return nil, err
}
}
var tx = this.NullTx()
err = models.SharedMessageMediaInstanceDAO.UpdateMediaInstance(tx, req.MessageMediaInstanceId, req.Name, req.MediaType, params, req.Description, req.RateJSON, req.HashLife, req.IsOn)
if err != nil {
return nil, err
}
return this.Success()
}
// DeleteMessageMediaInstance 删除媒介实例
func (this *MessageMediaInstanceService) DeleteMessageMediaInstance(ctx context.Context, req *pb.DeleteMessageMediaInstanceRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedMessageMediaInstanceDAO.DisableMessageMediaInstance(tx, req.MessageMediaInstanceId)
if err != nil {
return nil, err
}
return this.Success()
}
// CountAllEnabledMessageMediaInstances 计算媒介实例数量
func (this *MessageMediaInstanceService) CountAllEnabledMessageMediaInstances(ctx context.Context, req *pb.CountAllEnabledMessageMediaInstancesRequest) (*pb.RPCCountResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
count, err := models.SharedMessageMediaInstanceDAO.CountAllEnabledMediaInstances(tx, req.MediaType, req.Keyword)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListEnabledMessageMediaInstances 列出单页媒介实例
func (this *MessageMediaInstanceService) ListEnabledMessageMediaInstances(ctx context.Context, req *pb.ListEnabledMessageMediaInstancesRequest) (*pb.ListEnabledMessageMediaInstancesResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
instances, err := models.SharedMessageMediaInstanceDAO.ListAllEnabledMediaInstances(tx, req.MediaType, req.Keyword, req.Offset, req.Size)
if err != nil {
return nil, err
}
pbInstances := []*pb.MessageMediaInstance{}
for _, instance := range instances {
// 媒介
media, err := models.SharedMessageMediaDAO.FindEnabledMediaWithType(tx, instance.MediaType)
if err != nil {
return nil, err
}
if media == nil {
continue
}
pbMedia := &pb.MessageMedia{
Id: int64(media.Id),
Type: media.Type,
Name: media.Name,
Description: media.Description,
UserDescription: media.UserDescription,
IsOn: media.IsOn,
}
pbInstances = append(pbInstances, &pb.MessageMediaInstance{
Id: int64(instance.Id),
Name: instance.Name,
IsOn: instance.IsOn,
MessageMedia: pbMedia,
ParamsJSON: instance.Params,
Description: instance.Description,
RateJSON: instance.Rate,
})
}
return &pb.ListEnabledMessageMediaInstancesResponse{MessageMediaInstances: pbInstances}, nil
}
// FindEnabledMessageMediaInstance 查找单个媒介实例信息
func (this *MessageMediaInstanceService) FindEnabledMessageMediaInstance(ctx context.Context, req *pb.FindEnabledMessageMediaInstanceRequest) (*pb.FindEnabledMessageMediaInstanceResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
var cacheMap = utils.NewCacheMap()
instance, err := models.SharedMessageMediaInstanceDAO.FindEnabledMessageMediaInstance(tx, req.MessageMediaInstanceId, cacheMap)
if err != nil {
return nil, err
}
if instance == nil {
return &pb.FindEnabledMessageMediaInstanceResponse{MessageMediaInstance: nil}, nil
}
// 媒介
media, err := models.SharedMessageMediaDAO.FindEnabledMediaWithType(tx, instance.MediaType)
if err != nil {
return nil, err
}
if media == nil {
return &pb.FindEnabledMessageMediaInstanceResponse{MessageMediaInstance: nil}, nil
}
pbMedia := &pb.MessageMedia{
Id: int64(media.Id),
Type: media.Type,
Name: media.Name,
Description: media.Description,
UserDescription: media.UserDescription,
IsOn: media.IsOn,
}
return &pb.FindEnabledMessageMediaInstanceResponse{MessageMediaInstance: &pb.MessageMediaInstance{
Id: int64(instance.Id),
Name: instance.Name,
IsOn: instance.IsOn,
MessageMedia: pbMedia,
ParamsJSON: instance.Params,
Description: instance.Description,
RateJSON: instance.Rate,
HashLife: types.Int32(instance.HashLife),
}}, nil
}

View File

@@ -0,0 +1,84 @@
//go:build plus
package services
import (
"context"
"errors"
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/senders/mediasenders"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
timeutil "github.com/iwind/TeaGo/utils/time"
)
// MessageMediaService 消息媒介服务
type MessageMediaService struct {
BaseService
}
// FindAllMessageMedias 获取所有支持的媒介
func (this *MessageMediaService) FindAllMessageMedias(ctx context.Context, req *pb.FindAllMessageMediasRequest) (*pb.FindAllMessageMediasResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
medias, err := models.SharedMessageMediaDAO.FindAllEnabledMessageMedias(tx)
if err != nil {
return nil, err
}
pbMedias := []*pb.MessageMedia{}
for _, media := range medias {
pbMedias = append(pbMedias, &pb.MessageMedia{
Id: int64(media.Id),
Type: media.Type,
Name: media.Name,
Description: media.Description,
UserDescription: media.UserDescription,
IsOn: media.IsOn,
})
}
return &pb.FindAllMessageMediasResponse{MessageMedias: pbMedias}, nil
}
// SendMediaMessage 发送媒介信息
func (this *MessageMediaService) SendMediaMessage(ctx context.Context, req *pb.SendMediaMessageRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
if len(req.MediaType) == 0 {
return nil, errors.New("'mediaType' should not be empty")
}
if len(req.OptionsJSON) == 0 {
return nil, errors.New("invalid 'optionsJSON'")
}
media, err := mediasenders.NewMediaInstance(req.MediaType, req.OptionsJSON)
if err != nil {
return nil, err
}
if media == nil {
return nil, errors.New("can not find media with mediaType '" + req.MediaType + "'")
}
// 产品名称
var tx = this.NullTx()
productName, err := models.SharedSysSettingDAO.ReadProductName(tx)
if err != nil {
return nil, err
}
if len(productName) == 0 {
productName = teaconst.GlobalProductName
}
_, err = media.Send(req.User, req.Subject, req.Body, productName, timeutil.Format("Y-m-d H:i:s"))
if err != nil {
return nil, err
}
return this.Success()
}

View File

@@ -0,0 +1,229 @@
//go:build plus
package services
import (
"context"
"encoding/json"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/maps"
)
// MessageReceiverService 消息对象接收人
type MessageReceiverService struct {
BaseService
}
// UpdateMessageReceivers 创建接收者
func (this *MessageReceiverService) UpdateMessageReceivers(ctx context.Context, req *pb.UpdateMessageReceiversRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
if len(req.Role) == 0 {
req.Role = nodeconfigs.NodeRoleNode
}
params := maps.Map{}
if len(req.ParamsJSON) > 0 {
err = json.Unmarshal(req.ParamsJSON, &params)
if err != nil {
return nil, err
}
}
err = this.RunTx(func(tx *dbs.Tx) error {
err = models.SharedMessageReceiverDAO.DisableReceivers(tx, req.NodeClusterId, req.NodeId, req.ServerId)
if err != nil {
return err
}
for messageType, options := range req.RecipientOptions {
for _, option := range options.RecipientOptions {
_, err := models.SharedMessageReceiverDAO.CreateReceiver(tx, req.Role, req.NodeClusterId, req.NodeId, req.ServerId, messageType, params, option.MessageRecipientId, option.MessageRecipientGroupId)
if err != nil {
return err
}
}
}
return nil
})
if err != nil {
return nil, err
}
return this.Success()
}
// FindAllEnabledMessageReceivers 查找接收者
func (this *MessageReceiverService) FindAllEnabledMessageReceivers(ctx context.Context, req *pb.FindAllEnabledMessageReceiversRequest) (*pb.FindAllEnabledMessageReceiversResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
if len(req.Role) == 0 {
req.Role = nodeconfigs.NodeRoleNode
}
var tx = this.NullTx()
var cacheMap = utils.NewCacheMap()
receivers, err := models.SharedMessageReceiverDAO.FindAllEnabledReceivers(tx, req.Role, req.NodeClusterId, req.NodeId, req.ServerId, "")
if err != nil {
return nil, err
}
var pbReceivers = []*pb.MessageReceiver{}
for _, receiver := range receivers {
var pbRecipient *pb.MessageRecipient = nil
// 接收人
if receiver.RecipientId > 0 {
recipient, err := models.SharedMessageRecipientDAO.FindEnabledMessageRecipient(tx, int64(receiver.RecipientId), cacheMap)
if err != nil {
return nil, err
}
if recipient == nil {
continue
}
// 管理员
admin, err := models.SharedAdminDAO.FindEnabledAdmin(tx, int64(recipient.AdminId))
if err != nil {
return nil, err
}
if admin == nil {
continue
}
// 接收人
instance, err := models.SharedMessageMediaInstanceDAO.FindEnabledMessageMediaInstance(tx, int64(recipient.InstanceId), cacheMap)
if err != nil {
return nil, err
}
if instance == nil {
continue
}
pbRecipient = &pb.MessageRecipient{
Id: int64(recipient.Id),
Admin: &pb.Admin{
Id: int64(admin.Id),
Fullname: admin.Fullname,
Username: admin.Username,
IsOn: admin.IsOn,
},
MessageMediaInstance: &pb.MessageMediaInstance{
Id: int64(instance.Id),
Name: instance.Name,
IsOn: instance.IsOn,
},
IsOn: recipient.IsOn,
MessageRecipientGroups: nil,
Description: "",
User: "",
}
}
// 接收人分组
var pbRecipientGroup *pb.MessageRecipientGroup = nil
if receiver.RecipientGroupId > 0 {
group, err := models.SharedMessageRecipientGroupDAO.FindEnabledMessageRecipientGroup(tx, int64(receiver.RecipientGroupId))
if err != nil {
return nil, err
}
if group == nil {
continue
}
pbRecipientGroup = &pb.MessageRecipientGroup{
Id: int64(group.Id),
Name: group.Name,
IsOn: group.IsOn,
}
}
pbReceivers = append(pbReceivers, &pb.MessageReceiver{
Id: int64(receiver.Id),
ClusterId: int64(receiver.ClusterId),
NodeId: int64(receiver.NodeId),
ServerId: int64(receiver.ServerId),
Type: receiver.Type,
ParamsJSON: receiver.Params,
MessageRecipient: pbRecipient,
MessageRecipientGroup: pbRecipientGroup,
Role: receiver.Role,
})
}
return &pb.FindAllEnabledMessageReceiversResponse{MessageReceivers: pbReceivers}, nil
}
// FindAllEnabledMessageReceiversWithMessageRecipientId 根据接收人查找关联的接收者
func (this *MessageReceiverService) FindAllEnabledMessageReceiversWithMessageRecipientId(ctx context.Context, req *pb.FindAllEnabledMessageReceiversWithMessageRecipientIdRequest) (*pb.FindAllEnabledMessageReceiversWithMessageRecipientIdResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
receivers, err := models.SharedMessageReceiverDAO.FindAllEnabledReceiversWithRecipientId(tx, req.MessageRecipientId)
if err != nil {
return nil, err
}
var pbReceivers = []*pb.MessageReceiver{}
for _, receiver := range receivers {
pbReceivers = append(pbReceivers, &pb.MessageReceiver{
Id: int64(receiver.Id),
ClusterId: int64(receiver.ClusterId),
NodeId: int64(receiver.NodeId),
ServerId: int64(receiver.ServerId),
Type: receiver.Type,
ParamsJSON: nil,
MessageRecipient: nil,
MessageRecipientGroup: nil,
Role: receiver.Role,
})
}
return &pb.FindAllEnabledMessageReceiversWithMessageRecipientIdResponse{
MessageReceivers: pbReceivers,
}, nil
}
// DeleteMessageReceiver 删除接收者
func (this *MessageReceiverService) DeleteMessageReceiver(ctx context.Context, req *pb.DeleteMessageReceiverRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedMessageReceiverDAO.DisableMessageReceiver(tx, req.MessageReceiverId)
if err != nil {
return nil, err
}
return this.Success()
}
// CountAllEnabledMessageReceivers 计算接收者数量
func (this *MessageReceiverService) CountAllEnabledMessageReceivers(ctx context.Context, req *pb.CountAllEnabledMessageReceiversRequest) (*pb.RPCCountResponse, error) {
_, err := this.ValidateAdmin(ctx)
if err != nil {
return nil, err
}
if len(req.Role) == 0 {
req.Role = nodeconfigs.NodeRoleNode
}
var tx = this.NullTx()
count, err := models.SharedMessageReceiverDAO.CountAllEnabledReceivers(tx, req.Role, req.NodeClusterId, req.NodeId, req.ServerId, "")
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}

Some files were not shown because too many files have changed in this diff Show More