Files
waf-platform/EdgeNode/internal/uam/manager.go

520 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus
package uam
import (
"encoding/base64"
"errors"
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/compressions"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/utils/counters"
"github.com/TeaOSLab/EdgeNode/internal/utils/encrypt"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
"github.com/TeaOSLab/EdgeNode/internal/utils/ttlcache"
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types"
"google.golang.org/protobuf/proto"
"io"
"net"
"net/http"
"net/url"
"strings"
"time"
)
const (
CookiePrevKey = "ge_ua_p"
CookieKey = "ge_ua_key"
PrevKeyLife int = 5
DefaultKeyLife int = 3600
StepPrev = "prev"
DefaultMaxFails = 30 // 单IP攻击最大次数
DefaultBlockSeconds = 1800 // 攻击封锁时间
)
// Manager UAM管理器
type Manager struct {
encryptMethod encrypt.MethodInterface
}
// NewManager 获取新的UAM管理器
func NewManager(key string, secret string) (*Manager, error) {
method, err := encrypt.NewMethodInstance("aes-256-cfb", key, secret)
if err != nil {
return nil, fmt.Errorf("init encrypt method failed: %w", err)
}
return &Manager{
encryptMethod: method,
}, nil
}
// CheckKey 检查是否已经通过验证
func (this *Manager) CheckKey(policy *nodeconfigs.UAMPolicy, req *http.Request, writer http.ResponseWriter, remoteAddr string, serverId int64, keyLife int) (isOk bool, isAttack bool, err error) {
// 对攻击行为进行惩罚
defer func() {
if !isOk {
this.IncreaseFails(policy, remoteAddr, serverId)
}
}()
// 读取Cookie
keyCookie, err := req.Cookie(CookieKey)
if err != nil {
return false, false, fmt.Errorf("read cookie failed: %w", err)
}
if keyCookie == nil {
return false, false, errors.New("cookie not found")
}
var cookieValue, _ = url.QueryUnescape(keyCookie.Value)
if len(cookieValue) == 0 {
return false, false, errors.New("unable to read cookie value")
}
keyData, err := base64.StdEncoding.DecodeString(cookieValue)
if err != nil {
return false, true, fmt.Errorf("decode key failed: %w", err)
}
keyJSON, err := this.encryptMethod.Decrypt(keyData)
if err != nil {
return false, true, fmt.Errorf("decrypt key failed: %w", err)
}
var key = &Key{}
err = proto.Unmarshal(keyJSON, key)
if err != nil {
return false, true, fmt.Errorf("unmarshal key failed: %w", err)
}
if keyLife <= 0 {
keyLife = DefaultKeyLife
if policy != nil && policy.KeyLife > 0 {
keyLife = policy.KeyLife
}
}
var unixTime = time.Now().Unix()
if key.Timestamp >= unixTime-int64(PrevKeyLife)+1 /** 离生成时间过近 **/ || key.Timestamp < unixTime-int64(keyLife) /** 过了有效期 **/ {
return false, true, errors.New("verify key failed")
}
if key.Version >= Version1 {
if !key.IsSame(remoteAddr, req.UserAgent()) {
return false, true, errors.New("verify key hash failed")
}
if policy.IncludeSubdomains {
if this.ParseTopDomain(key.Host) != this.ParseTopDomain(req.Host) {
return false, true, errors.New("verify key domain failed")
}
} else {
if key.Host != req.Host {
return false, true, errors.New("verify key domain failed")
}
}
}
return true, false, nil
}
// LoadPage 显示加载页面
func (this *Manager) LoadPage(policy *nodeconfigs.UAMPolicy, req *http.Request, formatter func(s string) string, remoteAddr string, writer http.ResponseWriter) error {
var prevKey = this.ComposeKey(req, remoteAddr)
prevKeyString, err := this.EncodeKey(prevKey)
if err != nil {
return err
}
http.SetCookie(writer, &http.Cookie{
Name: CookiePrevKey,
Value: url.QueryEscape(prevKeyString),
Expires: time.Now().Add(time.Duration(PrevKeyLife) * time.Second),
MaxAge: PrevKeyLife,
Path: "/",
})
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.Header().Set("Cache-Control", "no-cache")
var jsCode = JSMinifyCode
if Tea.IsTesting() {
jsCode = JSCode
}
// 压缩
var ioWriter io.Writer = writer
var compressionWriter = this.prepareCompression(req, writer)
if compressionWriter != nil {
ioWriter = compressionWriter
defer func() {
_ = compressionWriter.Close()
}()
}
writer.WriteHeader(http.StatusOK)
var productName = teaconst.GlobalProductName
if len(productName) == 0 {
productName = teaconst.ProductName
}
var messages = map[string]string{}
switch this.lang(req) {
case "zh-CN":
messages = I18NForZH_CN()
case "zh-TW":
messages = I18NForZH_TW()
default:
messages = I18NForEN()
}
var title = policy.UITitle
if len(title) > 0 {
title = formatter(title)
}
var body = policy.UIBody
if len(body) > 0 {
body = formatter(body)
} else {
body = `<div class="ui-uam-box">
<h1>` + fmt.Sprintf(messages["checking"], req.Host) + `</h1>
<p>` + messages["waiting"] + `</p>
<p>&nbsp;</p>
<p>` + fmt.Sprintf(messages["by"], productName) + `</p>
</div>`
}
_, _ = ioWriter.Write([]byte(`<!DOCTYPE html>
<html>
<head>
<title>` + title + `</title>
<meta charset="UTF-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1, user-scalable=0">
<style type="text/css">
.ui-uam-box {
text-align: center;
font-family: font-family: Roboto,"Helvetica Neue Light","Helvetica Neue",Helvetica,Arial,"Lucida Grande",sans-serif;
font-size: 16px;
}
.ui-uam-box .ui-counter {
font-weight: bold;
}
</style>
<script type="text/javascript">
var cpk = "` + CookiePrevKey + `"
var step = "` + StepPrev + `";
var nonce = ` + types.String(rands.Int(1000, 9999)) + `;
` + jsCode + `
</script>
</head>
<body>` + body +
`</body>
</html>`))
return nil
}
// CheckPrevKey 检查第一个Key
func (this *Manager) CheckPrevKey(policy *nodeconfigs.UAMPolicy, config *serverconfigs.UAMConfig, req *http.Request, remoteAddr string, writer http.ResponseWriter) bool {
// 检查Cookie
cookie, err := req.Cookie(CookiePrevKey)
if err != nil {
return false
}
var escapedValue = cookie.Value
value, _ := url.QueryUnescape(escapedValue)
valueData, err := base64.StdEncoding.DecodeString(value)
if err != nil {
return false
}
valueJSON, err := this.encryptMethod.Decrypt(valueData)
if err != nil {
return false
}
var key = &Key{}
err = proto.Unmarshal(valueJSON, key)
if err != nil {
return false
}
if key.Version >= Version1 {
if !key.IsSame(remoteAddr, req.UserAgent()) {
return false
}
// 检查域名
if policy.IncludeSubdomains {
if this.ParseTopDomain(key.Host) != this.ParseTopDomain(req.Host) {
return false
}
} else {
if key.Host != req.Host {
return false
}
}
}
if !strings.HasSuffix(req.Referer(), req.URL.String()) {
return false
}
if key.Timestamp < time.Now().Unix()-int64(PrevKeyLife) {
return false
}
// 检查连接端口
// 因为每个浏览器实现机制不同可能会导致ajax新开端口所以这里暂时不判断
/**if key.Port > 0 {
var portIndex = strings.LastIndex(req.RemoteAddr, ":")
if portIndex > 0 && key.Port != types.Int32(req.RemoteAddr[portIndex+1:]) {
return false
}
}**/
body, err := io.ReadAll(io.LimitReader(req.Body, 64))
if err != nil {
return false
}
if len(body) == 0 {
return false
}
var bodyString = string(body)
var sum int64 = 0
var nonce int64 = 0
for _, param := range strings.Split(bodyString, "&") {
var eqIndex = strings.Index(param, "=")
if eqIndex > 0 {
if param[:eqIndex] == "sum" {
sum = types.Int64(param[eqIndex+1:])
} else if param[:eqIndex] == "nonce" {
nonce = types.Int64(param[eqIndex+1:])
}
}
}
if sum != this.sumKey(escapedValue, nonce) {
return false
}
// 设置新的Cookie
newKey, err := this.EncodeKey(this.ComposeKey(req, remoteAddr))
if err != nil {
return false
}
var keyLife = config.KeyLife
if keyLife <= 0 {
keyLife = DefaultKeyLife
if policy != nil && policy.KeyLife > 0 {
keyLife = policy.KeyLife
}
}
var keyCookie = &http.Cookie{
Name: CookieKey,
Value: url.QueryEscape(newKey),
Expires: time.Now().Add(time.Duration(keyLife) * time.Second),
MaxAge: keyLife,
Path: "/",
}
if policy != nil && policy.IncludeSubdomains {
keyCookie.Domain = this.ParseTopDomain(req.Host)
}
http.SetCookie(writer, keyCookie)
writer.WriteHeader(http.StatusOK)
// 记录到IP白名单
if config.AddToWhiteList {
ttlcache.SharedInt64Cache.Write("UAM:WHITE:"+remoteAddr, 1, fasttime.Now().Unix()+int64(keyLife))
}
// 清理
this.resetFails(remoteAddr)
return true
}
// ComposeKey 组合Key
func (this *Manager) ComposeKey(req *http.Request, remoteAddr string) *Key {
var key = &Key{
Version: Version1,
Timestamp: fasttime.Now().Unix(),
Host: req.Host,
}
key.Put(remoteAddr, req.UserAgent())
return key
}
// EncodeKey 对Key进行编码
func (this *Manager) EncodeKey(key *Key) (string, error) {
keyPB, err := key.AsPB()
if err != nil {
return "", err
}
dst, err := this.encryptMethod.Encrypt(keyPB)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(dst), nil
}
// ExistsActivePreKey 检查是否已经生成PrevKey
func (this *Manager) ExistsActivePreKey(req *http.Request) bool {
// 检查Cookie
cookie, err := req.Cookie(CookiePrevKey)
if err != nil {
return false
}
var escapedValue = cookie.Value
value, _ := url.QueryUnescape(escapedValue)
valueData, err := base64.StdEncoding.DecodeString(value)
if err != nil {
return false
}
valueJSON, err := this.encryptMethod.Decrypt(valueData)
if err != nil {
return false
}
var key = &Key{}
err = proto.Unmarshal(valueJSON, key)
if err != nil {
return false
}
return time.Now().Unix()-key.Timestamp <= 1
}
// IncreaseFails 增加失败次数
// 以便于对客户端进行处罚
func (this *Manager) IncreaseFails(policy *nodeconfigs.UAMPolicy, remoteAddr string, serverId int64) (isAttack bool) {
const statSeconds = 60
var maxFails = DefaultMaxFails
if policy != nil && policy.MaxFails > 0 {
maxFails = policy.MaxFails
}
var blockSeconds = DefaultBlockSeconds
if policy != nil && policy.BlockSeconds > 0 {
blockSeconds = policy.BlockSeconds
}
var count = counters.SharedCounter.IncreaseKey("UAM:CheckKey:"+remoteAddr, statSeconds)
if count >= types.Uint32(maxFails) {
waf.SharedIPBlackList.RecordIP(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, serverId, remoteAddr, fasttime.Now().Unix()+int64(blockSeconds), 0, policy != nil && policy.Firewall.Scope == firewallconfigs.FirewallScopeGlobal, 0, 0, "5秒盾认证不通过"+types.String(statSeconds)+"秒内尝试"+types.String(maxFails)+"次)")
return true
}
return
}
// 清除失败次数
func (this *Manager) resetFails(remoteAddr string) {
counters.SharedCounter.ResetKey("UAM:CheckKey:" + remoteAddr)
}
// 准备压缩
func (this *Manager) prepareCompression(req *http.Request, writer http.ResponseWriter) compressions.Writer {
var acceptEncodings = req.Header.Get("Accept-Encoding")
var encodings = strings.Split(acceptEncodings, ",")
var compressionWriter compressions.Writer
for _, piece := range encodings {
var qualityIndex = strings.Index(piece, ";")
if qualityIndex >= 0 {
piece = piece[:qualityIndex]
}
if piece == "br" {
compressionWriter, _ = compressions.NewBrotliWriter(writer, 6)
if compressionWriter != nil {
writer.Header().Set("Content-Encoding", "br")
writer.Header().Del("Content-Length")
}
break
} else if piece == "gzip" {
compressionWriter, _ = compressions.NewGzipWriter(writer, 6)
if compressionWriter != nil {
writer.Header().Set("Content-Encoding", "gzip")
writer.Header().Del("Content-Length")
}
break
}
}
return compressionWriter
}
func (this *Manager) sumKey(key string, nonce int64) int64 {
var result int64 = 0
for i, r := range key {
if (r >= '0' && r <= '9') || (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') {
result += int64(r) * (nonce + int64(i))
}
}
return result
}
func (this *Manager) lang(req *http.Request) (lang string) {
var acceptLanguage = req.Header.Get("Accept-Language")
if len(acceptLanguage) > 0 {
langIndex := strings.Index(acceptLanguage, ",")
if langIndex > 0 {
lang = acceptLanguage[:langIndex]
}
}
if len(lang) == 0 {
lang = "en-US"
}
return
}
func (this *Manager) ParseTopDomain(domain string) string {
if strings.Contains(domain, ":") {
newHost, _, splitErr := net.SplitHostPort(domain)
if splitErr == nil && len(newHost) > 0 {
domain = newHost
}
}
if len(net.ParseIP(domain)) > 0 {
return domain
}
var pieces = strings.Split(domain, ".")
var l = len(pieces)
if l <= 2 {
return domain
}
var topDomain string
if pieces[l-2] == "net" || pieces[l-2] == "com" || pieces[l-2] == "org" {
if l == 3 { // *.[net|com|org].abc
return domain
}
// a.b.c.abc.[net|com|org].abc
topDomain = strings.Join(pieces[len(pieces)-3:], ".")
} else { // a.b.c.abc.[com|...]
topDomain = strings.Join(pieces[len(pieces)-2:], ".")
}
return topDomain
}