// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus
package uam
import (
"encoding/base64"
"errors"
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/compressions"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/utils/counters"
"github.com/TeaOSLab/EdgeNode/internal/utils/encrypt"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
"github.com/TeaOSLab/EdgeNode/internal/utils/ttlcache"
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types"
"google.golang.org/protobuf/proto"
"io"
"net"
"net/http"
"net/url"
"strings"
"time"
)
const (
CookiePrevKey = "ge_ua_p"
CookieKey = "ge_ua_key"
PrevKeyLife int = 5
DefaultKeyLife int = 3600
StepPrev = "prev"
DefaultMaxFails = 30 // 单IP攻击最大次数
DefaultBlockSeconds = 1800 // 攻击封锁时间
)
// Manager UAM管理器
type Manager struct {
encryptMethod encrypt.MethodInterface
}
// NewManager 获取新的UAM管理器
func NewManager(key string, secret string) (*Manager, error) {
method, err := encrypt.NewMethodInstance("aes-256-cfb", key, secret)
if err != nil {
return nil, fmt.Errorf("init encrypt method failed: %w", err)
}
return &Manager{
encryptMethod: method,
}, nil
}
// CheckKey 检查是否已经通过验证
func (this *Manager) CheckKey(policy *nodeconfigs.UAMPolicy, req *http.Request, writer http.ResponseWriter, remoteAddr string, serverId int64, keyLife int) (isOk bool, isAttack bool, err error) {
// 对攻击行为进行惩罚
defer func() {
if !isOk {
this.IncreaseFails(policy, remoteAddr, serverId)
}
}()
// 读取Cookie
keyCookie, err := req.Cookie(CookieKey)
if err != nil {
return false, false, fmt.Errorf("read cookie failed: %w", err)
}
if keyCookie == nil {
return false, false, errors.New("cookie not found")
}
var cookieValue, _ = url.QueryUnescape(keyCookie.Value)
if len(cookieValue) == 0 {
return false, false, errors.New("unable to read cookie value")
}
keyData, err := base64.StdEncoding.DecodeString(cookieValue)
if err != nil {
return false, true, fmt.Errorf("decode key failed: %w", err)
}
keyJSON, err := this.encryptMethod.Decrypt(keyData)
if err != nil {
return false, true, fmt.Errorf("decrypt key failed: %w", err)
}
var key = &Key{}
err = proto.Unmarshal(keyJSON, key)
if err != nil {
return false, true, fmt.Errorf("unmarshal key failed: %w", err)
}
if keyLife <= 0 {
keyLife = DefaultKeyLife
if policy != nil && policy.KeyLife > 0 {
keyLife = policy.KeyLife
}
}
var unixTime = time.Now().Unix()
if key.Timestamp >= unixTime-int64(PrevKeyLife)+1 /** 离生成时间过近 **/ || key.Timestamp < unixTime-int64(keyLife) /** 过了有效期 **/ {
return false, true, errors.New("verify key failed")
}
if key.Version >= Version1 {
if !key.IsSame(remoteAddr, req.UserAgent()) {
return false, true, errors.New("verify key hash failed")
}
if policy.IncludeSubdomains {
if this.ParseTopDomain(key.Host) != this.ParseTopDomain(req.Host) {
return false, true, errors.New("verify key domain failed")
}
} else {
if key.Host != req.Host {
return false, true, errors.New("verify key domain failed")
}
}
}
return true, false, nil
}
// LoadPage 显示加载页面
func (this *Manager) LoadPage(policy *nodeconfigs.UAMPolicy, req *http.Request, formatter func(s string) string, remoteAddr string, writer http.ResponseWriter) error {
var prevKey = this.ComposeKey(req, remoteAddr)
prevKeyString, err := this.EncodeKey(prevKey)
if err != nil {
return err
}
http.SetCookie(writer, &http.Cookie{
Name: CookiePrevKey,
Value: url.QueryEscape(prevKeyString),
Expires: time.Now().Add(time.Duration(PrevKeyLife) * time.Second),
MaxAge: PrevKeyLife,
Path: "/",
})
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.Header().Set("Cache-Control", "no-cache")
var jsCode = JSMinifyCode
if Tea.IsTesting() {
jsCode = JSCode
}
// 压缩
var ioWriter io.Writer = writer
var compressionWriter = this.prepareCompression(req, writer)
if compressionWriter != nil {
ioWriter = compressionWriter
defer func() {
_ = compressionWriter.Close()
}()
}
writer.WriteHeader(http.StatusOK)
var productName = teaconst.GlobalProductName
if len(productName) == 0 {
productName = teaconst.ProductName
}
var messages = map[string]string{}
switch this.lang(req) {
case "zh-CN":
messages = I18NForZH_CN()
case "zh-TW":
messages = I18NForZH_TW()
default:
messages = I18NForEN()
}
var title = policy.UITitle
if len(title) > 0 {
title = formatter(title)
}
var body = policy.UIBody
if len(body) > 0 {
body = formatter(body)
} else {
body = `
` + fmt.Sprintf(messages["checking"], req.Host) + `
` + messages["waiting"] + `
` + fmt.Sprintf(messages["by"], productName) + `
`
}
_, _ = ioWriter.Write([]byte(`
` + title + `
` + body +
`
`))
return nil
}
// CheckPrevKey 检查第一个Key
func (this *Manager) CheckPrevKey(policy *nodeconfigs.UAMPolicy, config *serverconfigs.UAMConfig, req *http.Request, remoteAddr string, writer http.ResponseWriter) bool {
// 检查Cookie
cookie, err := req.Cookie(CookiePrevKey)
if err != nil {
return false
}
var escapedValue = cookie.Value
value, _ := url.QueryUnescape(escapedValue)
valueData, err := base64.StdEncoding.DecodeString(value)
if err != nil {
return false
}
valueJSON, err := this.encryptMethod.Decrypt(valueData)
if err != nil {
return false
}
var key = &Key{}
err = proto.Unmarshal(valueJSON, key)
if err != nil {
return false
}
if key.Version >= Version1 {
if !key.IsSame(remoteAddr, req.UserAgent()) {
return false
}
// 检查域名
if policy.IncludeSubdomains {
if this.ParseTopDomain(key.Host) != this.ParseTopDomain(req.Host) {
return false
}
} else {
if key.Host != req.Host {
return false
}
}
}
if !strings.HasSuffix(req.Referer(), req.URL.String()) {
return false
}
if key.Timestamp < time.Now().Unix()-int64(PrevKeyLife) {
return false
}
// 检查连接端口
// 因为每个浏览器实现机制不同,可能会导致ajax新开端口,所以这里暂时不判断
/**if key.Port > 0 {
var portIndex = strings.LastIndex(req.RemoteAddr, ":")
if portIndex > 0 && key.Port != types.Int32(req.RemoteAddr[portIndex+1:]) {
return false
}
}**/
body, err := io.ReadAll(io.LimitReader(req.Body, 64))
if err != nil {
return false
}
if len(body) == 0 {
return false
}
var bodyString = string(body)
var sum int64 = 0
var nonce int64 = 0
for _, param := range strings.Split(bodyString, "&") {
var eqIndex = strings.Index(param, "=")
if eqIndex > 0 {
if param[:eqIndex] == "sum" {
sum = types.Int64(param[eqIndex+1:])
} else if param[:eqIndex] == "nonce" {
nonce = types.Int64(param[eqIndex+1:])
}
}
}
if sum != this.sumKey(escapedValue, nonce) {
return false
}
// 设置新的Cookie
newKey, err := this.EncodeKey(this.ComposeKey(req, remoteAddr))
if err != nil {
return false
}
var keyLife = config.KeyLife
if keyLife <= 0 {
keyLife = DefaultKeyLife
if policy != nil && policy.KeyLife > 0 {
keyLife = policy.KeyLife
}
}
var keyCookie = &http.Cookie{
Name: CookieKey,
Value: url.QueryEscape(newKey),
Expires: time.Now().Add(time.Duration(keyLife) * time.Second),
MaxAge: keyLife,
Path: "/",
}
if policy != nil && policy.IncludeSubdomains {
keyCookie.Domain = this.ParseTopDomain(req.Host)
}
http.SetCookie(writer, keyCookie)
writer.WriteHeader(http.StatusOK)
// 记录到IP白名单
if config.AddToWhiteList {
ttlcache.SharedInt64Cache.Write("UAM:WHITE:"+remoteAddr, 1, fasttime.Now().Unix()+int64(keyLife))
}
// 清理
this.resetFails(remoteAddr)
return true
}
// ComposeKey 组合Key
func (this *Manager) ComposeKey(req *http.Request, remoteAddr string) *Key {
var key = &Key{
Version: Version1,
Timestamp: fasttime.Now().Unix(),
Host: req.Host,
}
key.Put(remoteAddr, req.UserAgent())
return key
}
// EncodeKey 对Key进行编码
func (this *Manager) EncodeKey(key *Key) (string, error) {
keyPB, err := key.AsPB()
if err != nil {
return "", err
}
dst, err := this.encryptMethod.Encrypt(keyPB)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(dst), nil
}
// ExistsActivePreKey 检查是否已经生成PrevKey
func (this *Manager) ExistsActivePreKey(req *http.Request) bool {
// 检查Cookie
cookie, err := req.Cookie(CookiePrevKey)
if err != nil {
return false
}
var escapedValue = cookie.Value
value, _ := url.QueryUnescape(escapedValue)
valueData, err := base64.StdEncoding.DecodeString(value)
if err != nil {
return false
}
valueJSON, err := this.encryptMethod.Decrypt(valueData)
if err != nil {
return false
}
var key = &Key{}
err = proto.Unmarshal(valueJSON, key)
if err != nil {
return false
}
return time.Now().Unix()-key.Timestamp <= 1
}
// IncreaseFails 增加失败次数
// 以便于对客户端进行处罚
func (this *Manager) IncreaseFails(policy *nodeconfigs.UAMPolicy, remoteAddr string, serverId int64) (isAttack bool) {
const statSeconds = 60
var maxFails = DefaultMaxFails
if policy != nil && policy.MaxFails > 0 {
maxFails = policy.MaxFails
}
var blockSeconds = DefaultBlockSeconds
if policy != nil && policy.BlockSeconds > 0 {
blockSeconds = policy.BlockSeconds
}
var count = counters.SharedCounter.IncreaseKey("UAM:CheckKey:"+remoteAddr, statSeconds)
if count >= types.Uint32(maxFails) {
waf.SharedIPBlackList.RecordIP(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, serverId, remoteAddr, fasttime.Now().Unix()+int64(blockSeconds), 0, policy != nil && policy.Firewall.Scope == firewallconfigs.FirewallScopeGlobal, 0, 0, "5秒盾认证不通过("+types.String(statSeconds)+"秒内尝试"+types.String(maxFails)+"次)")
return true
}
return
}
// 清除失败次数
func (this *Manager) resetFails(remoteAddr string) {
counters.SharedCounter.ResetKey("UAM:CheckKey:" + remoteAddr)
}
// 准备压缩
func (this *Manager) prepareCompression(req *http.Request, writer http.ResponseWriter) compressions.Writer {
var acceptEncodings = req.Header.Get("Accept-Encoding")
var encodings = strings.Split(acceptEncodings, ",")
var compressionWriter compressions.Writer
for _, piece := range encodings {
var qualityIndex = strings.Index(piece, ";")
if qualityIndex >= 0 {
piece = piece[:qualityIndex]
}
if piece == "br" {
compressionWriter, _ = compressions.NewBrotliWriter(writer, 6)
if compressionWriter != nil {
writer.Header().Set("Content-Encoding", "br")
writer.Header().Del("Content-Length")
}
break
} else if piece == "gzip" {
compressionWriter, _ = compressions.NewGzipWriter(writer, 6)
if compressionWriter != nil {
writer.Header().Set("Content-Encoding", "gzip")
writer.Header().Del("Content-Length")
}
break
}
}
return compressionWriter
}
func (this *Manager) sumKey(key string, nonce int64) int64 {
var result int64 = 0
for i, r := range key {
if (r >= '0' && r <= '9') || (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') {
result += int64(r) * (nonce + int64(i))
}
}
return result
}
func (this *Manager) lang(req *http.Request) (lang string) {
var acceptLanguage = req.Header.Get("Accept-Language")
if len(acceptLanguage) > 0 {
langIndex := strings.Index(acceptLanguage, ",")
if langIndex > 0 {
lang = acceptLanguage[:langIndex]
}
}
if len(lang) == 0 {
lang = "en-US"
}
return
}
func (this *Manager) ParseTopDomain(domain string) string {
if strings.Contains(domain, ":") {
newHost, _, splitErr := net.SplitHostPort(domain)
if splitErr == nil && len(newHost) > 0 {
domain = newHost
}
}
if len(net.ParseIP(domain)) > 0 {
return domain
}
var pieces = strings.Split(domain, ".")
var l = len(pieces)
if l <= 2 {
return domain
}
var topDomain string
if pieces[l-2] == "net" || pieces[l-2] == "com" || pieces[l-2] == "org" {
if l == 3 { // *.[net|com|org].abc
return domain
}
// a.b.c.abc.[net|com|org].abc
topDomain = strings.Join(pieces[len(pieces)-3:], ".")
} else { // a.b.c.abc.[com|...]
topDomain = strings.Join(pieces[len(pieces)-2:], ".")
}
return topDomain
}