// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. //go:build plus package uam import ( "encoding/base64" "errors" "fmt" "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" "github.com/TeaOSLab/EdgeNode/internal/compressions" teaconst "github.com/TeaOSLab/EdgeNode/internal/const" "github.com/TeaOSLab/EdgeNode/internal/utils/counters" "github.com/TeaOSLab/EdgeNode/internal/utils/encrypt" "github.com/TeaOSLab/EdgeNode/internal/utils/fasttime" "github.com/TeaOSLab/EdgeNode/internal/utils/ttlcache" "github.com/TeaOSLab/EdgeNode/internal/waf" "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/rands" "github.com/iwind/TeaGo/types" "google.golang.org/protobuf/proto" "io" "net" "net/http" "net/url" "strings" "time" ) const ( CookiePrevKey = "ge_ua_p" CookieKey = "ge_ua_key" PrevKeyLife int = 5 DefaultKeyLife int = 3600 StepPrev = "prev" DefaultMaxFails = 30 // 单IP攻击最大次数 DefaultBlockSeconds = 1800 // 攻击封锁时间 ) // Manager UAM管理器 type Manager struct { encryptMethod encrypt.MethodInterface } // NewManager 获取新的UAM管理器 func NewManager(key string, secret string) (*Manager, error) { method, err := encrypt.NewMethodInstance("aes-256-cfb", key, secret) if err != nil { return nil, fmt.Errorf("init encrypt method failed: %w", err) } return &Manager{ encryptMethod: method, }, nil } // CheckKey 检查是否已经通过验证 func (this *Manager) CheckKey(policy *nodeconfigs.UAMPolicy, req *http.Request, writer http.ResponseWriter, remoteAddr string, serverId int64, keyLife int) (isOk bool, isAttack bool, err error) { // 对攻击行为进行惩罚 defer func() { if !isOk { this.IncreaseFails(policy, remoteAddr, serverId) } }() // 读取Cookie keyCookie, err := req.Cookie(CookieKey) if err != nil { return false, false, fmt.Errorf("read cookie failed: %w", err) } if keyCookie == nil { return false, false, errors.New("cookie not found") } var cookieValue, _ = url.QueryUnescape(keyCookie.Value) if len(cookieValue) == 0 { return false, false, errors.New("unable to read cookie value") } keyData, err := base64.StdEncoding.DecodeString(cookieValue) if err != nil { return false, true, fmt.Errorf("decode key failed: %w", err) } keyJSON, err := this.encryptMethod.Decrypt(keyData) if err != nil { return false, true, fmt.Errorf("decrypt key failed: %w", err) } var key = &Key{} err = proto.Unmarshal(keyJSON, key) if err != nil { return false, true, fmt.Errorf("unmarshal key failed: %w", err) } if keyLife <= 0 { keyLife = DefaultKeyLife if policy != nil && policy.KeyLife > 0 { keyLife = policy.KeyLife } } var unixTime = time.Now().Unix() if key.Timestamp >= unixTime-int64(PrevKeyLife)+1 /** 离生成时间过近 **/ || key.Timestamp < unixTime-int64(keyLife) /** 过了有效期 **/ { return false, true, errors.New("verify key failed") } if key.Version >= Version1 { if !key.IsSame(remoteAddr, req.UserAgent()) { return false, true, errors.New("verify key hash failed") } if policy.IncludeSubdomains { if this.ParseTopDomain(key.Host) != this.ParseTopDomain(req.Host) { return false, true, errors.New("verify key domain failed") } } else { if key.Host != req.Host { return false, true, errors.New("verify key domain failed") } } } return true, false, nil } // LoadPage 显示加载页面 func (this *Manager) LoadPage(policy *nodeconfigs.UAMPolicy, req *http.Request, formatter func(s string) string, remoteAddr string, writer http.ResponseWriter) error { var prevKey = this.ComposeKey(req, remoteAddr) prevKeyString, err := this.EncodeKey(prevKey) if err != nil { return err } http.SetCookie(writer, &http.Cookie{ Name: CookiePrevKey, Value: url.QueryEscape(prevKeyString), Expires: time.Now().Add(time.Duration(PrevKeyLife) * time.Second), MaxAge: PrevKeyLife, Path: "/", }) writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.Header().Set("Cache-Control", "no-cache") var jsCode = JSMinifyCode if Tea.IsTesting() { jsCode = JSCode } // 压缩 var ioWriter io.Writer = writer var compressionWriter = this.prepareCompression(req, writer) if compressionWriter != nil { ioWriter = compressionWriter defer func() { _ = compressionWriter.Close() }() } writer.WriteHeader(http.StatusOK) var productName = teaconst.GlobalProductName if len(productName) == 0 { productName = teaconst.ProductName } var messages = map[string]string{} switch this.lang(req) { case "zh-CN": messages = I18NForZH_CN() case "zh-TW": messages = I18NForZH_TW() default: messages = I18NForEN() } var title = policy.UITitle if len(title) > 0 { title = formatter(title) } var body = policy.UIBody if len(body) > 0 { body = formatter(body) } else { body = `

` + fmt.Sprintf(messages["checking"], req.Host) + `

` + messages["waiting"] + `

 

` + fmt.Sprintf(messages["by"], productName) + `

` } _, _ = ioWriter.Write([]byte(` ` + title + ` ` + body + ` `)) return nil } // CheckPrevKey 检查第一个Key func (this *Manager) CheckPrevKey(policy *nodeconfigs.UAMPolicy, config *serverconfigs.UAMConfig, req *http.Request, remoteAddr string, writer http.ResponseWriter) bool { // 检查Cookie cookie, err := req.Cookie(CookiePrevKey) if err != nil { return false } var escapedValue = cookie.Value value, _ := url.QueryUnescape(escapedValue) valueData, err := base64.StdEncoding.DecodeString(value) if err != nil { return false } valueJSON, err := this.encryptMethod.Decrypt(valueData) if err != nil { return false } var key = &Key{} err = proto.Unmarshal(valueJSON, key) if err != nil { return false } if key.Version >= Version1 { if !key.IsSame(remoteAddr, req.UserAgent()) { return false } // 检查域名 if policy.IncludeSubdomains { if this.ParseTopDomain(key.Host) != this.ParseTopDomain(req.Host) { return false } } else { if key.Host != req.Host { return false } } } if !strings.HasSuffix(req.Referer(), req.URL.String()) { return false } if key.Timestamp < time.Now().Unix()-int64(PrevKeyLife) { return false } // 检查连接端口 // 因为每个浏览器实现机制不同,可能会导致ajax新开端口,所以这里暂时不判断 /**if key.Port > 0 { var portIndex = strings.LastIndex(req.RemoteAddr, ":") if portIndex > 0 && key.Port != types.Int32(req.RemoteAddr[portIndex+1:]) { return false } }**/ body, err := io.ReadAll(io.LimitReader(req.Body, 64)) if err != nil { return false } if len(body) == 0 { return false } var bodyString = string(body) var sum int64 = 0 var nonce int64 = 0 for _, param := range strings.Split(bodyString, "&") { var eqIndex = strings.Index(param, "=") if eqIndex > 0 { if param[:eqIndex] == "sum" { sum = types.Int64(param[eqIndex+1:]) } else if param[:eqIndex] == "nonce" { nonce = types.Int64(param[eqIndex+1:]) } } } if sum != this.sumKey(escapedValue, nonce) { return false } // 设置新的Cookie newKey, err := this.EncodeKey(this.ComposeKey(req, remoteAddr)) if err != nil { return false } var keyLife = config.KeyLife if keyLife <= 0 { keyLife = DefaultKeyLife if policy != nil && policy.KeyLife > 0 { keyLife = policy.KeyLife } } var keyCookie = &http.Cookie{ Name: CookieKey, Value: url.QueryEscape(newKey), Expires: time.Now().Add(time.Duration(keyLife) * time.Second), MaxAge: keyLife, Path: "/", } if policy != nil && policy.IncludeSubdomains { keyCookie.Domain = this.ParseTopDomain(req.Host) } http.SetCookie(writer, keyCookie) writer.WriteHeader(http.StatusOK) // 记录到IP白名单 if config.AddToWhiteList { ttlcache.SharedInt64Cache.Write("UAM:WHITE:"+remoteAddr, 1, fasttime.Now().Unix()+int64(keyLife)) } // 清理 this.resetFails(remoteAddr) return true } // ComposeKey 组合Key func (this *Manager) ComposeKey(req *http.Request, remoteAddr string) *Key { var key = &Key{ Version: Version1, Timestamp: fasttime.Now().Unix(), Host: req.Host, } key.Put(remoteAddr, req.UserAgent()) return key } // EncodeKey 对Key进行编码 func (this *Manager) EncodeKey(key *Key) (string, error) { keyPB, err := key.AsPB() if err != nil { return "", err } dst, err := this.encryptMethod.Encrypt(keyPB) if err != nil { return "", err } return base64.StdEncoding.EncodeToString(dst), nil } // ExistsActivePreKey 检查是否已经生成PrevKey func (this *Manager) ExistsActivePreKey(req *http.Request) bool { // 检查Cookie cookie, err := req.Cookie(CookiePrevKey) if err != nil { return false } var escapedValue = cookie.Value value, _ := url.QueryUnescape(escapedValue) valueData, err := base64.StdEncoding.DecodeString(value) if err != nil { return false } valueJSON, err := this.encryptMethod.Decrypt(valueData) if err != nil { return false } var key = &Key{} err = proto.Unmarshal(valueJSON, key) if err != nil { return false } return time.Now().Unix()-key.Timestamp <= 1 } // IncreaseFails 增加失败次数 // 以便于对客户端进行处罚 func (this *Manager) IncreaseFails(policy *nodeconfigs.UAMPolicy, remoteAddr string, serverId int64) (isAttack bool) { const statSeconds = 60 var maxFails = DefaultMaxFails if policy != nil && policy.MaxFails > 0 { maxFails = policy.MaxFails } var blockSeconds = DefaultBlockSeconds if policy != nil && policy.BlockSeconds > 0 { blockSeconds = policy.BlockSeconds } var count = counters.SharedCounter.IncreaseKey("UAM:CheckKey:"+remoteAddr, statSeconds) if count >= types.Uint32(maxFails) { waf.SharedIPBlackList.RecordIP(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, serverId, remoteAddr, fasttime.Now().Unix()+int64(blockSeconds), 0, policy != nil && policy.Firewall.Scope == firewallconfigs.FirewallScopeGlobal, 0, 0, "5秒盾认证不通过("+types.String(statSeconds)+"秒内尝试"+types.String(maxFails)+"次)") return true } return } // 清除失败次数 func (this *Manager) resetFails(remoteAddr string) { counters.SharedCounter.ResetKey("UAM:CheckKey:" + remoteAddr) } // 准备压缩 func (this *Manager) prepareCompression(req *http.Request, writer http.ResponseWriter) compressions.Writer { var acceptEncodings = req.Header.Get("Accept-Encoding") var encodings = strings.Split(acceptEncodings, ",") var compressionWriter compressions.Writer for _, piece := range encodings { var qualityIndex = strings.Index(piece, ";") if qualityIndex >= 0 { piece = piece[:qualityIndex] } if piece == "br" { compressionWriter, _ = compressions.NewBrotliWriter(writer, 6) if compressionWriter != nil { writer.Header().Set("Content-Encoding", "br") writer.Header().Del("Content-Length") } break } else if piece == "gzip" { compressionWriter, _ = compressions.NewGzipWriter(writer, 6) if compressionWriter != nil { writer.Header().Set("Content-Encoding", "gzip") writer.Header().Del("Content-Length") } break } } return compressionWriter } func (this *Manager) sumKey(key string, nonce int64) int64 { var result int64 = 0 for i, r := range key { if (r >= '0' && r <= '9') || (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') { result += int64(r) * (nonce + int64(i)) } } return result } func (this *Manager) lang(req *http.Request) (lang string) { var acceptLanguage = req.Header.Get("Accept-Language") if len(acceptLanguage) > 0 { langIndex := strings.Index(acceptLanguage, ",") if langIndex > 0 { lang = acceptLanguage[:langIndex] } } if len(lang) == 0 { lang = "en-US" } return } func (this *Manager) ParseTopDomain(domain string) string { if strings.Contains(domain, ":") { newHost, _, splitErr := net.SplitHostPort(domain) if splitErr == nil && len(newHost) > 0 { domain = newHost } } if len(net.ParseIP(domain)) > 0 { return domain } var pieces = strings.Split(domain, ".") var l = len(pieces) if l <= 2 { return domain } var topDomain string if pieces[l-2] == "net" || pieces[l-2] == "com" || pieces[l-2] == "org" { if l == 3 { // *.[net|com|org].abc return domain } // a.b.c.abc.[net|com|org].abc topDomain = strings.Join(pieces[len(pieces)-3:], ".") } else { // a.b.c.abc.[com|...] topDomain = strings.Join(pieces[len(pieces)-2:], ".") } return topDomain }