Files
waf-platform/EdgeAPI/internal/db/models/http_web_dao_plus.go
2026-02-04 20:27:13 +08:00

441 lines
9.0 KiB
Go

// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package models
import (
"encoding/json"
"errors"
"fmt"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/iwind/TeaGo/dbs"
)
// CopyWebConfigs 拷贝Web配置
func (this *HTTPWebDAO) CopyWebConfigs(tx *dbs.Tx, fromWebId int64, toWebIds []int64, configField string) error {
if fromWebId <= 0 || len(toWebIds) == 0 {
return nil
}
configJSON, err := this.Query(tx).
Pk(fromWebId).
Result(configField).
FindJSONCol()
if err != nil {
return err
}
// 暂时不处理
if len(configJSON) == 0 {
return nil
}
// 拷贝配置
for _, toWebId := range toWebIds {
if toWebId == fromWebId {
continue
}
err = this.Query(tx).
Pk(toWebId).
Set(configField, configJSON).
UpdateQuickly()
if err != nil {
return err
}
}
return nil
}
// CopyWebsocketConfigs 复制Websocket配置
func (this *HTTPWebDAO) CopyWebsocketConfigs(tx *dbs.Tx, fromWebId int64, toWebIds []int64) error {
if fromWebId <= 0 || len(toWebIds) == 0 {
return nil
}
websocketRefJSON, err := this.Query(tx).
Pk(fromWebId).
Result("websocket").
FindJSONCol()
if err != nil {
return err
}
if len(websocketRefJSON) == 0 {
return nil
}
var websocketRef = &serverconfigs.HTTPWebsocketRef{}
err = json.Unmarshal(websocketRefJSON, &websocketRef)
if err != nil {
return err
}
if websocketRef.WebsocketId <= 0 {
for _, toWebId := range toWebIds {
if toWebId == fromWebId {
continue
}
err = this.Query(tx).
Pk(toWebId).
Set("websocket", websocketRefJSON).
UpdateQuickly()
if err != nil {
return err
}
}
return nil
}
var oldWebsocketId = websocketRef.WebsocketId
for _, toWebId := range toWebIds {
if toWebId == fromWebId {
continue
}
newWebsocketId, err := SharedHTTPWebsocketDAO.CloneWebsocket(tx, oldWebsocketId)
if err != nil {
return err
}
if newWebsocketId <= 0 {
continue
}
websocketRef.WebsocketId = newWebsocketId
newWebsocketRefJSON, err := json.Marshal(websocketRef)
if err != nil {
return err
}
err = this.Query(tx).
Pk(toWebId).
Set("websocket", newWebsocketRefJSON).
UpdateQuickly()
if err != nil {
return err
}
}
return nil
}
// CopyPageConfigs 复制Page配置
func (this *HTTPWebDAO) CopyPageConfigs(tx *dbs.Tx, fromWebId int64, toWebIds []int64) error {
if fromWebId <= 0 || len(toWebIds) == 0 {
return nil
}
pagesJSON, err := this.Query(tx).
Pk(fromWebId).
Result("pages").
FindJSONCol()
if err != nil {
return err
}
if len(pagesJSON) == 0 {
return nil
}
var pages = []*serverconfigs.HTTPPageConfig{}
err = json.Unmarshal(pagesJSON, &pages)
if err != nil {
return err
}
for _, toWebId := range toWebIds {
if toWebId == fromWebId {
continue
}
var newPages = []*serverconfigs.HTTPPageConfig{}
for _, page := range pages {
newPageId, err := SharedHTTPPageDAO.ClonePage(tx, page.Id)
if err != nil {
return err
}
if newPageId <= 0 {
continue
}
newPages = append(newPages, &serverconfigs.HTTPPageConfig{
Id: newPageId,
})
}
newPagesJSON, err := json.Marshal(newPages)
if err != nil {
return err
}
err = this.Query(tx).
Pk(toWebId).
Set("pages", newPagesJSON).
UpdateQuickly()
if err != nil {
return err
}
}
return nil
}
// CopyAuthConfigs 复制访问鉴权
func (this *HTTPWebDAO) CopyAuthConfigs(tx *dbs.Tx, fromWebId int64, toWebIds []int64) error {
if fromWebId <= 0 || len(toWebIds) == 0 {
return nil
}
fromAuthJSON, err := this.Query(tx).
Pk(fromWebId).
Result("auth").
FindJSONCol()
if err != nil {
return err
}
if len(fromAuthJSON) == 0 {
return nil
}
var authConfig = &serverconfigs.HTTPAuthConfig{}
err = json.Unmarshal(fromAuthJSON, authConfig)
if err != nil {
return err
}
oldRefs, err := utils.JSONClone[[]*serverconfigs.HTTPAuthPolicyRef](authConfig.PolicyRefs)
if err != nil {
return err
}
for _, toWebId := range toWebIds {
if toWebId == fromWebId {
continue
}
var newRefs = []*serverconfigs.HTTPAuthPolicyRef{}
for _, ref := range oldRefs {
if ref.AuthPolicyId > 0 {
newAuthPolicyId, err := SharedHTTPAuthPolicyDAO.CloneAuthPolicy(tx, ref.AuthPolicyId)
if err != nil {
return err
}
if newAuthPolicyId <= 0 {
continue
}
newRef, err := utils.JSONClone[*serverconfigs.HTTPAuthPolicyRef](ref)
if err != nil {
return err
}
newRef.AuthPolicyId = newAuthPolicyId
newRefs = append(newRefs, newRef)
}
}
authConfig.PolicyRefs = newRefs
authConfigJSON, err := json.Marshal(authConfig)
if err != nil {
return err
}
err = this.Query(tx).
Pk(toWebId).
Set("auth", authConfigJSON).
UpdateQuickly()
if err != nil {
return err
}
}
return nil
}
// CopyFirewallConfigs 复制WAF
func (this *HTTPWebDAO) CopyFirewallConfigs(tx *dbs.Tx, fromWebId int64, toWebIds []int64, copyRegions bool) error {
if fromWebId <= 0 || len(toWebIds) == 0 {
return nil
}
fromFirewallJSON, err := this.Query(tx).
Pk(fromWebId).
Result("firewall").
FindJSONCol()
if err != nil {
return err
}
var fromRef = &firewallconfigs.HTTPFirewallRef{}
err = json.Unmarshal(fromFirewallJSON, fromRef)
if err != nil {
return err
}
var fromFirewallPolicy *firewallconfigs.HTTPFirewallPolicy
if fromRef.FirewallPolicyId > 0 {
fromFirewallPolicy, err = SharedHTTPFirewallPolicyDAO.ComposeFirewallPolicy(tx, fromRef.FirewallPolicyId, false, nil)
if err != nil {
return err
}
}
for _, toWebId := range toWebIds {
toFirewallJSON, err := this.Query(tx).
Pk(toWebId).
Result("firewall").
FindJSONCol()
if err != nil {
return err
}
var toRef = &firewallconfigs.HTTPFirewallRef{}
if IsNotNull(toFirewallJSON) {
err = json.Unmarshal(toFirewallJSON, toRef)
if err != nil {
return fmt.Errorf("decode 'toFirewallJSON' failed: %w", err)
}
}
toRef.IsOn = fromRef.IsOn
toRef.IsPrior = fromRef.IsPrior
toRef.IgnoreGlobalRules = fromRef.IgnoreGlobalRules
// waf policy
if toRef.FirewallPolicyId == 0 && copyRegions {
serverId, err := SharedServerDAO.FindEnabledServerIdWithWebId(tx, toWebId)
if err != nil {
return err
}
if serverId <= 0 {
continue
}
toRef.FirewallPolicyId, err = SharedHTTPFirewallPolicyDAO.CreateFirewallPolicy(tx, 0, 0, serverId, true, "", "", nil, nil)
if err != nil {
return err
}
}
// ref
toRefJSON, err := json.Marshal(toRef)
if err != nil {
return err
}
err = this.Query(tx).
Pk(toWebId).
Set("firewall", toRefJSON).
UpdateQuickly()
if err != nil {
return err
}
// WAF策略
if fromFirewallPolicy != nil {
if fromFirewallPolicy.Inbound != nil {
if copyRegions && toRef.FirewallPolicyId > 0 {
err = SharedHTTPFirewallPolicyDAO.UpdateFirewallPolicyInboundRegion(tx, toRef.FirewallPolicyId, fromFirewallPolicy.Inbound.Region)
if err != nil {
return err
}
}
}
}
}
return nil
}
func (this *HTTPWebDAO) CopyCacheConfigs(tx *dbs.Tx, fromWebId int64, toWebIds []int64) error {
if fromWebId <= 0 || len(toWebIds) == 0 {
return nil
}
const configField = HTTPWebField_Cache
configJSON, err := this.Query(tx).
Pk(fromWebId).
Result(configField).
FindJSONCol()
if err != nil {
return err
}
var cacheConfig = &serverconfigs.HTTPCacheConfig{}
err = json.Unmarshal(configJSON, cacheConfig)
if err != nil {
return err
}
cacheConfig.Key = nil // 不拷贝 key 配置
// 拷贝配置
for _, toWebId := range toWebIds {
if toWebId == fromWebId {
continue
}
oldConfigJSON, err := this.Query(tx).
Pk(toWebId).
Result(configField).
FindJSONCol()
if err != nil {
return err
}
if len(oldConfigJSON) > 0 {
var oldCacheConfig = &serverconfigs.HTTPCacheConfig{}
err = json.Unmarshal(oldConfigJSON, oldCacheConfig)
if err != nil {
return err
}
cacheConfig.Key = oldCacheConfig.Key // 保留 key
}
newConfigJSON, err := json.Marshal(cacheConfig)
if err != nil {
return err
}
err = this.Query(tx).
Pk(toWebId).
Set(configField, newConfigJSON).
UpdateQuickly()
if err != nil {
return err
}
}
return nil
}
// UpdateWebHLS 修改HLS设置
func (this *HTTPWebDAO) UpdateWebHLS(tx *dbs.Tx, webId int64, hlsConfig *serverconfigs.HLSConfig) error {
if webId <= 0 {
return errors.New("require 'webId'")
}
if hlsConfig == nil {
return errors.New("'hlsConfig' must not be nil")
}
hlsJSON, err := json.Marshal(hlsConfig)
if err != nil {
return err
}
err = this.Query(tx).
Pk(webId).
Set(HTTPWebField_Hls, hlsJSON).
UpdateQuickly()
if err != nil {
return err
}
return this.NotifyUpdate(tx, webId)
}
// FindWebHLS 读取HLS设置
func (this *HTTPWebDAO) FindWebHLS(tx *dbs.Tx, webId int64) ([]byte, error) {
if webId <= 0 {
return nil, errors.New("require 'webId'")
}
return this.Query(tx).
Pk(webId).
Result(HTTPWebField_Hls).
FindJSONCol()
}