520 lines
13 KiB
Go
520 lines
13 KiB
Go
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||
//go:build plus
|
||
|
||
package uam
|
||
|
||
import (
|
||
"encoding/base64"
|
||
"errors"
|
||
"fmt"
|
||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||
"github.com/TeaOSLab/EdgeNode/internal/compressions"
|
||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||
"github.com/TeaOSLab/EdgeNode/internal/utils/counters"
|
||
"github.com/TeaOSLab/EdgeNode/internal/utils/encrypt"
|
||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||
"github.com/TeaOSLab/EdgeNode/internal/utils/ttlcache"
|
||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||
"github.com/iwind/TeaGo/Tea"
|
||
"github.com/iwind/TeaGo/rands"
|
||
"github.com/iwind/TeaGo/types"
|
||
"google.golang.org/protobuf/proto"
|
||
"io"
|
||
"net"
|
||
"net/http"
|
||
"net/url"
|
||
"strings"
|
||
"time"
|
||
)
|
||
|
||
const (
|
||
CookiePrevKey = "ge_ua_p"
|
||
CookieKey = "ge_ua_key"
|
||
PrevKeyLife int = 5
|
||
DefaultKeyLife int = 3600
|
||
|
||
StepPrev = "prev"
|
||
|
||
DefaultMaxFails = 30 // 单IP攻击最大次数
|
||
DefaultBlockSeconds = 1800 // 攻击封锁时间
|
||
)
|
||
|
||
// Manager UAM管理器
|
||
type Manager struct {
|
||
encryptMethod encrypt.MethodInterface
|
||
}
|
||
|
||
// NewManager 获取新的UAM管理器
|
||
func NewManager(key string, secret string) (*Manager, error) {
|
||
method, err := encrypt.NewMethodInstance("aes-256-cfb", key, secret)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("init encrypt method failed: %w", err)
|
||
}
|
||
|
||
return &Manager{
|
||
encryptMethod: method,
|
||
}, nil
|
||
}
|
||
|
||
// CheckKey 检查是否已经通过验证
|
||
func (this *Manager) CheckKey(policy *nodeconfigs.UAMPolicy, req *http.Request, writer http.ResponseWriter, remoteAddr string, serverId int64, keyLife int) (isOk bool, isAttack bool, err error) {
|
||
// 对攻击行为进行惩罚
|
||
defer func() {
|
||
if !isOk {
|
||
this.IncreaseFails(policy, remoteAddr, serverId)
|
||
}
|
||
}()
|
||
|
||
// 读取Cookie
|
||
keyCookie, err := req.Cookie(CookieKey)
|
||
if err != nil {
|
||
return false, false, fmt.Errorf("read cookie failed: %w", err)
|
||
}
|
||
|
||
if keyCookie == nil {
|
||
return false, false, errors.New("cookie not found")
|
||
}
|
||
|
||
var cookieValue, _ = url.QueryUnescape(keyCookie.Value)
|
||
if len(cookieValue) == 0 {
|
||
return false, false, errors.New("unable to read cookie value")
|
||
}
|
||
|
||
keyData, err := base64.StdEncoding.DecodeString(cookieValue)
|
||
if err != nil {
|
||
return false, true, fmt.Errorf("decode key failed: %w", err)
|
||
}
|
||
keyJSON, err := this.encryptMethod.Decrypt(keyData)
|
||
if err != nil {
|
||
return false, true, fmt.Errorf("decrypt key failed: %w", err)
|
||
}
|
||
|
||
var key = &Key{}
|
||
err = proto.Unmarshal(keyJSON, key)
|
||
if err != nil {
|
||
return false, true, fmt.Errorf("unmarshal key failed: %w", err)
|
||
}
|
||
|
||
if keyLife <= 0 {
|
||
keyLife = DefaultKeyLife
|
||
|
||
if policy != nil && policy.KeyLife > 0 {
|
||
keyLife = policy.KeyLife
|
||
}
|
||
}
|
||
|
||
var unixTime = time.Now().Unix()
|
||
if key.Timestamp >= unixTime-int64(PrevKeyLife)+1 /** 离生成时间过近 **/ || key.Timestamp < unixTime-int64(keyLife) /** 过了有效期 **/ {
|
||
return false, true, errors.New("verify key failed")
|
||
}
|
||
|
||
if key.Version >= Version1 {
|
||
if !key.IsSame(remoteAddr, req.UserAgent()) {
|
||
return false, true, errors.New("verify key hash failed")
|
||
}
|
||
|
||
if policy.IncludeSubdomains {
|
||
if this.ParseTopDomain(key.Host) != this.ParseTopDomain(req.Host) {
|
||
return false, true, errors.New("verify key domain failed")
|
||
}
|
||
} else {
|
||
if key.Host != req.Host {
|
||
return false, true, errors.New("verify key domain failed")
|
||
}
|
||
}
|
||
}
|
||
|
||
return true, false, nil
|
||
}
|
||
|
||
// LoadPage 显示加载页面
|
||
func (this *Manager) LoadPage(policy *nodeconfigs.UAMPolicy, req *http.Request, formatter func(s string) string, remoteAddr string, writer http.ResponseWriter) error {
|
||
var prevKey = this.ComposeKey(req, remoteAddr)
|
||
prevKeyString, err := this.EncodeKey(prevKey)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
http.SetCookie(writer, &http.Cookie{
|
||
Name: CookiePrevKey,
|
||
Value: url.QueryEscape(prevKeyString),
|
||
Expires: time.Now().Add(time.Duration(PrevKeyLife) * time.Second),
|
||
MaxAge: PrevKeyLife,
|
||
Path: "/",
|
||
})
|
||
|
||
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||
writer.Header().Set("Cache-Control", "no-cache")
|
||
|
||
var jsCode = JSMinifyCode
|
||
if Tea.IsTesting() {
|
||
jsCode = JSCode
|
||
}
|
||
|
||
// 压缩
|
||
var ioWriter io.Writer = writer
|
||
|
||
var compressionWriter = this.prepareCompression(req, writer)
|
||
if compressionWriter != nil {
|
||
ioWriter = compressionWriter
|
||
defer func() {
|
||
_ = compressionWriter.Close()
|
||
}()
|
||
}
|
||
|
||
writer.WriteHeader(http.StatusOK)
|
||
|
||
var productName = teaconst.GlobalProductName
|
||
if len(productName) == 0 {
|
||
productName = teaconst.ProductName
|
||
}
|
||
|
||
var messages = map[string]string{}
|
||
switch this.lang(req) {
|
||
case "zh-CN":
|
||
messages = I18NForZH_CN()
|
||
case "zh-TW":
|
||
messages = I18NForZH_TW()
|
||
default:
|
||
messages = I18NForEN()
|
||
}
|
||
|
||
var title = policy.UITitle
|
||
if len(title) > 0 {
|
||
title = formatter(title)
|
||
}
|
||
|
||
var body = policy.UIBody
|
||
if len(body) > 0 {
|
||
body = formatter(body)
|
||
} else {
|
||
body = `<div class="ui-uam-box">
|
||
<h1>` + fmt.Sprintf(messages["checking"], req.Host) + `</h1>
|
||
<p>` + messages["waiting"] + `</p>
|
||
<p> </p>
|
||
<p>` + fmt.Sprintf(messages["by"], productName) + `</p>
|
||
</div>`
|
||
}
|
||
|
||
_, _ = ioWriter.Write([]byte(`<!DOCTYPE html>
|
||
<html>
|
||
<head>
|
||
<title>` + title + `</title>
|
||
<meta charset="UTF-8"/>
|
||
<meta name="viewport" content="width=device-width, initial-scale=1, user-scalable=0">
|
||
<style type="text/css">
|
||
.ui-uam-box {
|
||
text-align: center;
|
||
font-family: font-family: Roboto,"Helvetica Neue Light","Helvetica Neue",Helvetica,Arial,"Lucida Grande",sans-serif;
|
||
font-size: 16px;
|
||
}
|
||
|
||
.ui-uam-box .ui-counter {
|
||
font-weight: bold;
|
||
}
|
||
|
||
</style>
|
||
<script type="text/javascript">
|
||
var cpk = "` + CookiePrevKey + `"
|
||
var step = "` + StepPrev + `";
|
||
var nonce = ` + types.String(rands.Int(1000, 9999)) + `;
|
||
` + jsCode + `
|
||
</script>
|
||
</head>
|
||
<body>` + body +
|
||
`</body>
|
||
</html>`))
|
||
|
||
return nil
|
||
}
|
||
|
||
// CheckPrevKey 检查第一个Key
|
||
func (this *Manager) CheckPrevKey(policy *nodeconfigs.UAMPolicy, config *serverconfigs.UAMConfig, req *http.Request, remoteAddr string, writer http.ResponseWriter) bool {
|
||
// 检查Cookie
|
||
cookie, err := req.Cookie(CookiePrevKey)
|
||
if err != nil {
|
||
return false
|
||
}
|
||
var escapedValue = cookie.Value
|
||
value, _ := url.QueryUnescape(escapedValue)
|
||
valueData, err := base64.StdEncoding.DecodeString(value)
|
||
if err != nil {
|
||
return false
|
||
}
|
||
valueJSON, err := this.encryptMethod.Decrypt(valueData)
|
||
if err != nil {
|
||
return false
|
||
}
|
||
var key = &Key{}
|
||
err = proto.Unmarshal(valueJSON, key)
|
||
if err != nil {
|
||
return false
|
||
}
|
||
if key.Version >= Version1 {
|
||
if !key.IsSame(remoteAddr, req.UserAgent()) {
|
||
return false
|
||
}
|
||
|
||
// 检查域名
|
||
if policy.IncludeSubdomains {
|
||
if this.ParseTopDomain(key.Host) != this.ParseTopDomain(req.Host) {
|
||
return false
|
||
}
|
||
} else {
|
||
if key.Host != req.Host {
|
||
return false
|
||
}
|
||
}
|
||
}
|
||
|
||
if !strings.HasSuffix(req.Referer(), req.URL.String()) {
|
||
return false
|
||
}
|
||
|
||
if key.Timestamp < time.Now().Unix()-int64(PrevKeyLife) {
|
||
return false
|
||
}
|
||
|
||
// 检查连接端口
|
||
// 因为每个浏览器实现机制不同,可能会导致ajax新开端口,所以这里暂时不判断
|
||
/**if key.Port > 0 {
|
||
var portIndex = strings.LastIndex(req.RemoteAddr, ":")
|
||
if portIndex > 0 && key.Port != types.Int32(req.RemoteAddr[portIndex+1:]) {
|
||
return false
|
||
}
|
||
}**/
|
||
|
||
body, err := io.ReadAll(io.LimitReader(req.Body, 64))
|
||
if err != nil {
|
||
return false
|
||
}
|
||
if len(body) == 0 {
|
||
return false
|
||
}
|
||
var bodyString = string(body)
|
||
var sum int64 = 0
|
||
var nonce int64 = 0
|
||
for _, param := range strings.Split(bodyString, "&") {
|
||
var eqIndex = strings.Index(param, "=")
|
||
if eqIndex > 0 {
|
||
if param[:eqIndex] == "sum" {
|
||
sum = types.Int64(param[eqIndex+1:])
|
||
} else if param[:eqIndex] == "nonce" {
|
||
nonce = types.Int64(param[eqIndex+1:])
|
||
}
|
||
}
|
||
}
|
||
|
||
if sum != this.sumKey(escapedValue, nonce) {
|
||
return false
|
||
}
|
||
|
||
// 设置新的Cookie
|
||
newKey, err := this.EncodeKey(this.ComposeKey(req, remoteAddr))
|
||
if err != nil {
|
||
return false
|
||
}
|
||
|
||
var keyLife = config.KeyLife
|
||
if keyLife <= 0 {
|
||
keyLife = DefaultKeyLife
|
||
|
||
if policy != nil && policy.KeyLife > 0 {
|
||
keyLife = policy.KeyLife
|
||
}
|
||
}
|
||
|
||
var keyCookie = &http.Cookie{
|
||
Name: CookieKey,
|
||
Value: url.QueryEscape(newKey),
|
||
Expires: time.Now().Add(time.Duration(keyLife) * time.Second),
|
||
MaxAge: keyLife,
|
||
Path: "/",
|
||
}
|
||
|
||
if policy != nil && policy.IncludeSubdomains {
|
||
keyCookie.Domain = this.ParseTopDomain(req.Host)
|
||
}
|
||
|
||
http.SetCookie(writer, keyCookie)
|
||
writer.WriteHeader(http.StatusOK)
|
||
|
||
// 记录到IP白名单
|
||
if config.AddToWhiteList {
|
||
ttlcache.SharedInt64Cache.Write("UAM:WHITE:"+remoteAddr, 1, fasttime.Now().Unix()+int64(keyLife))
|
||
}
|
||
|
||
// 清理
|
||
this.resetFails(remoteAddr)
|
||
|
||
return true
|
||
}
|
||
|
||
// ComposeKey 组合Key
|
||
func (this *Manager) ComposeKey(req *http.Request, remoteAddr string) *Key {
|
||
var key = &Key{
|
||
Version: Version1,
|
||
Timestamp: fasttime.Now().Unix(),
|
||
Host: req.Host,
|
||
}
|
||
key.Put(remoteAddr, req.UserAgent())
|
||
|
||
return key
|
||
}
|
||
|
||
// EncodeKey 对Key进行编码
|
||
func (this *Manager) EncodeKey(key *Key) (string, error) {
|
||
keyPB, err := key.AsPB()
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
dst, err := this.encryptMethod.Encrypt(keyPB)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
return base64.StdEncoding.EncodeToString(dst), nil
|
||
}
|
||
|
||
// ExistsActivePreKey 检查是否已经生成PrevKey
|
||
func (this *Manager) ExistsActivePreKey(req *http.Request) bool {
|
||
// 检查Cookie
|
||
cookie, err := req.Cookie(CookiePrevKey)
|
||
if err != nil {
|
||
return false
|
||
}
|
||
var escapedValue = cookie.Value
|
||
value, _ := url.QueryUnescape(escapedValue)
|
||
valueData, err := base64.StdEncoding.DecodeString(value)
|
||
if err != nil {
|
||
return false
|
||
}
|
||
valueJSON, err := this.encryptMethod.Decrypt(valueData)
|
||
if err != nil {
|
||
return false
|
||
}
|
||
var key = &Key{}
|
||
err = proto.Unmarshal(valueJSON, key)
|
||
if err != nil {
|
||
return false
|
||
}
|
||
return time.Now().Unix()-key.Timestamp <= 1
|
||
}
|
||
|
||
// IncreaseFails 增加失败次数
|
||
// 以便于对客户端进行处罚
|
||
func (this *Manager) IncreaseFails(policy *nodeconfigs.UAMPolicy, remoteAddr string, serverId int64) (isAttack bool) {
|
||
const statSeconds = 60
|
||
|
||
var maxFails = DefaultMaxFails
|
||
if policy != nil && policy.MaxFails > 0 {
|
||
maxFails = policy.MaxFails
|
||
}
|
||
|
||
var blockSeconds = DefaultBlockSeconds
|
||
if policy != nil && policy.BlockSeconds > 0 {
|
||
blockSeconds = policy.BlockSeconds
|
||
}
|
||
|
||
var count = counters.SharedCounter.IncreaseKey("UAM:CheckKey:"+remoteAddr, statSeconds)
|
||
if count >= types.Uint32(maxFails) {
|
||
waf.SharedIPBlackList.RecordIP(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, serverId, remoteAddr, fasttime.Now().Unix()+int64(blockSeconds), 0, policy != nil && policy.Firewall.Scope == firewallconfigs.FirewallScopeGlobal, 0, 0, "5秒盾认证不通过("+types.String(statSeconds)+"秒内尝试"+types.String(maxFails)+"次)")
|
||
return true
|
||
}
|
||
return
|
||
}
|
||
|
||
// 清除失败次数
|
||
func (this *Manager) resetFails(remoteAddr string) {
|
||
counters.SharedCounter.ResetKey("UAM:CheckKey:" + remoteAddr)
|
||
}
|
||
|
||
// 准备压缩
|
||
func (this *Manager) prepareCompression(req *http.Request, writer http.ResponseWriter) compressions.Writer {
|
||
var acceptEncodings = req.Header.Get("Accept-Encoding")
|
||
var encodings = strings.Split(acceptEncodings, ",")
|
||
var compressionWriter compressions.Writer
|
||
for _, piece := range encodings {
|
||
var qualityIndex = strings.Index(piece, ";")
|
||
if qualityIndex >= 0 {
|
||
piece = piece[:qualityIndex]
|
||
}
|
||
|
||
if piece == "br" {
|
||
compressionWriter, _ = compressions.NewBrotliWriter(writer, 6)
|
||
if compressionWriter != nil {
|
||
writer.Header().Set("Content-Encoding", "br")
|
||
writer.Header().Del("Content-Length")
|
||
}
|
||
break
|
||
} else if piece == "gzip" {
|
||
compressionWriter, _ = compressions.NewGzipWriter(writer, 6)
|
||
if compressionWriter != nil {
|
||
writer.Header().Set("Content-Encoding", "gzip")
|
||
writer.Header().Del("Content-Length")
|
||
}
|
||
break
|
||
}
|
||
}
|
||
|
||
return compressionWriter
|
||
}
|
||
|
||
func (this *Manager) sumKey(key string, nonce int64) int64 {
|
||
var result int64 = 0
|
||
for i, r := range key {
|
||
if (r >= '0' && r <= '9') || (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') {
|
||
result += int64(r) * (nonce + int64(i))
|
||
}
|
||
}
|
||
|
||
return result
|
||
}
|
||
|
||
func (this *Manager) lang(req *http.Request) (lang string) {
|
||
var acceptLanguage = req.Header.Get("Accept-Language")
|
||
if len(acceptLanguage) > 0 {
|
||
langIndex := strings.Index(acceptLanguage, ",")
|
||
if langIndex > 0 {
|
||
lang = acceptLanguage[:langIndex]
|
||
}
|
||
}
|
||
if len(lang) == 0 {
|
||
lang = "en-US"
|
||
}
|
||
return
|
||
}
|
||
|
||
func (this *Manager) ParseTopDomain(domain string) string {
|
||
if strings.Contains(domain, ":") {
|
||
newHost, _, splitErr := net.SplitHostPort(domain)
|
||
if splitErr == nil && len(newHost) > 0 {
|
||
domain = newHost
|
||
}
|
||
}
|
||
|
||
if len(net.ParseIP(domain)) > 0 {
|
||
return domain
|
||
}
|
||
|
||
var pieces = strings.Split(domain, ".")
|
||
var l = len(pieces)
|
||
if l <= 2 {
|
||
return domain
|
||
}
|
||
|
||
var topDomain string
|
||
if pieces[l-2] == "net" || pieces[l-2] == "com" || pieces[l-2] == "org" {
|
||
if l == 3 { // *.[net|com|org].abc
|
||
return domain
|
||
}
|
||
|
||
// a.b.c.abc.[net|com|org].abc
|
||
topDomain = strings.Join(pieces[len(pieces)-3:], ".")
|
||
} else { // a.b.c.abc.[com|...]
|
||
topDomain = strings.Join(pieces[len(pieces)-2:], ".")
|
||
}
|
||
return topDomain
|
||
}
|