This commit is contained in:
unknown
2026-02-04 20:27:13 +08:00
commit 3b042d1dad
9410 changed files with 1488147 additions and 0 deletions

BIN
EdgeDNS/.DS_Store vendored Normal file

Binary file not shown.

73
EdgeDNS/.golangci.yaml Normal file
View File

@@ -0,0 +1,73 @@
# https://golangci-lint.run/usage/configuration/
linters:
enable-all: true
disable:
- ifshort
- golint
- nosnakecase
- scopelint
- varcheck
- structcheck
- interfacer
- exhaustivestruct
- maligned
- deadcode
- dogsled
- wrapcheck
- wastedassign
- varnamelen
- testpackage
- thelper
- nilerr
- sqlclosecheck
- paralleltest
- nonamedreturns
- nlreturn
- nakedret
- ireturn
- interfacebloat
- gosmopolitan
- gomnd
- goerr113
- gochecknoglobals
- exhaustruct
- errorlint
- depguard
- exhaustive
- containedctx
- wsl
- cyclop
- dupword
- errchkjson
- contextcheck
- tagalign
- dupl
- forbidigo
- funlen
- goconst
- godox
- gosec
- lll
- nestif
- revive
- unparam
- stylecheck
- gocritic
- gofumpt
- gomoddirectives
- godot
- gofmt
- gocognit
- mirror
- gocyclo
- gochecknoinits
- gci
- maintidx
- prealloc
- goimports
- errname
- musttag
- forcetypeassert
- whitespace
- noctx

View File

@@ -0,0 +1,9 @@
#!/usr/bin/env bash
./build.sh linux amd64
#./build.sh linux 386
./build.sh linux arm64
#./build.sh linux mips64
#./build.sh linux mips64le
#./build.sh darwin amd64
#./build.sh darwin arm64

110
EdgeDNS/build/build.sh Normal file
View File

@@ -0,0 +1,110 @@
#!/usr/bin/env bash
function build() {
ROOT=$(dirname "$0")
NAME="edge-dns"
VERSION=$(lookup-version "$ROOT"/../internal/const/const.go)
DIST=$ROOT/"../dist/${NAME}"
OS=${1}
ARCH=${2}
if [ -z "$OS" ]; then
echo "usage: build.sh OS ARCH"
exit
fi
if [ -z "$ARCH" ]; then
echo "usage: build.sh OS ARCH"
exit
fi
echo "checking ..."
ZIP_PATH=$(which zip)
if [ -z "$ZIP_PATH" ]; then
echo "we need 'zip' command to compress files"
exit
fi
echo "building v${VERSION}/${OS}/${ARCH} ..."
ZIP="${NAME}-${OS}-${ARCH}-v${VERSION}.zip"
echo "copying ..."
if [ ! -d "$DIST" ]; then
mkdir "$DIST"
mkdir "$DIST"/bin
mkdir "$DIST"/configs
mkdir "$DIST"/logs
mkdir "$DIST"/data
fi
cp "$ROOT"/configs/api_dns.template.yaml "$DIST"/configs
echo "building ..."
MUSL_DIR="/usr/local/opt/musl-cross/bin"
CC_PATH=""
CXX_PATH=""
if [[ $(uname -a) == *"Darwin"* && "${OS}" == "linux" ]]; then
# /usr/local/opt/musl-cross/bin/
if [ "${ARCH}" == "amd64" ]; then
CC_PATH="x86_64-linux-musl-gcc"
CXX_PATH="x86_64-linux-musl-g++"
fi
if [ "${ARCH}" == "386" ]; then
CC_PATH="i486-linux-musl-gcc"
CXX_PATH="i486-linux-musl-g++"
fi
if [ "${ARCH}" == "arm64" ]; then
CC_PATH="aarch64-linux-musl-gcc"
CXX_PATH="aarch64-linux-musl-g++"
fi
if [ "${ARCH}" == "mips64" ]; then
CC_PATH="mips64-linux-musl-gcc"
CXX_PATH="mips64-linux-musl-g++"
fi
if [ "${ARCH}" == "mips64le" ]; then
CC_PATH="mips64el-linux-musl-gcc"
CXX_PATH="mips64el-linux-musl-g++"
fi
fi
if [ ! -z $CC_PATH ]; then
env CC=$MUSL_DIR/$CC_PATH CXX=$MUSL_DIR/$CXX_PATH GOOS="${OS}" GOARCH="${ARCH}" CGO_ENABLED=1 go build -trimpath -tags="plus" -o "$DIST"/bin/${NAME} -ldflags "-linkmode external -extldflags -static -s -w" "$ROOT"/../cmd/edge-dns/main.go
else
env GOOS="${OS}" GOARCH="${ARCH}" CGO_ENABLED=1 go build -trimpath -tags="plus" -o "$DIST"/bin/${NAME} -ldflags="-s -w" "$ROOT"/../cmd/edge-dns/main.go
fi
# check build result
RESULT=$?
if [ "${RESULT}" != "0" ]; then
exit
fi
# delete hidden files
find "$DIST" -name ".DS_Store" -delete
find "$DIST" -name ".gitignore" -delete
echo "zip files"
cd "${DIST}/../" || exit
if [ -f "${ZIP}" ]; then
rm -f "${ZIP}"
fi
zip -r -X -q "${ZIP}" ${NAME}/
rm -rf ${NAME}
cd - || exit
echo "OK"
}
function lookup-version() {
FILE=$1
VERSION_DATA=$(cat "$FILE")
re="Version[ ]+=[ ]+\"([0-9.]+)\""
if [[ $VERSION_DATA =~ $re ]]; then
VERSION=${BASH_REMATCH[1]}
echo "$VERSION"
else
echo "could not match version"
exit
fi
}
build "$1" "$2"

3
EdgeDNS/build/configs/.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
api.yaml
api_dns.yaml
*.cache

View File

@@ -0,0 +1,3 @@
rpc.endpoints: [ "http://127.0.0.1:8003" ]
nodeId: ""
secret: ""

4
EdgeDNS/build/data/.gitignore vendored Normal file
View File

@@ -0,0 +1,4 @@
*.db
*.db-shm
*.db-wal
*.lock

View File

@@ -0,0 +1 @@
这个目录下我们列举了所有需要公开声明的第三方License如果有遗漏烦请告知 iwind.liu@gmail.com。再次感谢这些开源软件项目和贡献人员

View File

@@ -0,0 +1,30 @@
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
As this is fork of the official Go code the same license applies.
Extensions of the original work are copyright (c) 2011 Miek Gieben

1
EdgeDNS/build/logs/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
*.log

View File

@@ -0,0 +1,142 @@
//go:build plus
package main
import (
"flag"
"fmt"
"github.com/TeaOSLab/EdgeDNS/internal/apps"
"github.com/TeaOSLab/EdgeDNS/internal/configs"
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
"github.com/TeaOSLab/EdgeDNS/internal/nodes"
"github.com/TeaOSLab/EdgeDNS/internal/utils"
"github.com/iwind/TeaGo/Tea"
_ "github.com/iwind/TeaGo/bootstrap"
"github.com/iwind/TeaGo/logs"
"github.com/iwind/gosock/pkg/gosock"
"net/http"
_ "net/http/pprof"
"os"
"path/filepath"
"time"
)
func main() {
var app = apps.NewAppCmd().
Version(teaconst.Version).
Product(teaconst.ProductName).
Usage(teaconst.ProcessName + " [-v|start|stop|restart|service|daemon|pprof|gc|uninstall]")
app.On("start:before", func() {
// validate config
_, err := configs.LoadAPIConfig()
if err != nil {
fmt.Println("[ERROR]start failed: load api config from '" + Tea.ConfigFile(configs.ConfigFileName) + "' failed: " + err.Error())
os.Exit(0)
}
})
app.On("test", func() {
err := nodes.NewDNSNode().Test()
if err != nil {
_, _ = os.Stderr.WriteString(err.Error())
}
})
app.On("daemon", func() {
nodes.NewDNSNode().Daemon()
})
app.On("service", func() {
err := nodes.NewDNSNode().InstallSystemService()
if err != nil {
fmt.Println("[ERROR]install failed: " + err.Error())
return
}
fmt.Println("done")
})
app.On("pprof", func() {
var flagSet = flag.NewFlagSet("pprof", flag.ExitOnError)
var addr string
flagSet.StringVar(&addr, "addr", "", "")
_ = flagSet.Parse(os.Args[2:])
if len(addr) == 0 {
addr = "127.0.0.1:6060"
}
logs.Println("starting with pprof '" + addr + "'...")
go func() {
err := http.ListenAndServe(addr, nil)
if err != nil {
logs.Println("[ERROR]" + err.Error())
}
}()
var node = nodes.NewDNSNode()
node.Start()
})
app.On("gc", func() {
var sock = gosock.NewTmpSock(teaconst.ProcessName)
_, err := sock.Send(&gosock.Command{Code: "gc"})
if err != nil {
fmt.Println("[ERROR]" + err.Error())
} else {
fmt.Println("ok")
}
})
app.On("uninstall", func() {
// service
fmt.Println("Uninstall service ...")
var manager = utils.NewServiceManager(teaconst.ProcessName, teaconst.ProductName)
go func() {
_ = manager.Uninstall()
}()
// stop
fmt.Println("Stopping ...")
_, _ = gosock.NewTmpSock(teaconst.ProcessName).SendTimeout(&gosock.Command{Code: "stop"}, 1*time.Second)
// delete files
var exe, _ = os.Executable()
if len(exe) == 0 {
return
}
var dir = filepath.Dir(filepath.Dir(exe)) // ROOT / bin / exe
// verify dir
{
fmt.Println("Checking '" + dir + "' ...")
for _, subDir := range []string{"bin/" + filepath.Base(exe), "configs", "logs"} {
_, err := os.Stat(dir + "/" + subDir)
if err != nil {
fmt.Println("[ERROR]program directory structure has been broken, please remove it manually.")
return
}
}
fmt.Println("Removing '" + dir + "' ...")
err := os.RemoveAll(dir)
if err != nil {
fmt.Println("[ERROR]remove failed: " + err.Error())
}
}
// delete symbolic links
fmt.Println("Removing symbolic links ...")
_ = os.Remove("/usr/bin/" + teaconst.ProcessName)
_ = os.Remove("/var/log/" + teaconst.ProcessName)
// delete configs
// nothing to delete for EdgeDNS
// delete sock
fmt.Println("Removing temporary files ...")
var tempDir = os.TempDir()
_ = os.Remove(tempDir + "/" + teaconst.ProcessName + ".sock")
// done
fmt.Println("[DONE]")
})
app.Run(func() {
var node = nodes.NewDNSNode()
node.Start()
})
}

2
EdgeDNS/dist/.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
*.zip
edge-dns

39
EdgeDNS/go.mod Normal file
View File

@@ -0,0 +1,39 @@
module github.com/TeaOSLab/EdgeDNS
go 1.25
replace github.com/TeaOSLab/EdgeCommon => ../EdgeCommon
require (
github.com/TeaOSLab/EdgeCommon v0.0.0-00010101000000-000000000000
github.com/google/nftables v0.2.0
github.com/iwind/TeaGo v0.0.0-20240128112714-6bcd0529d0ea
github.com/iwind/gosock v0.0.0-20220505115348-f88412125a62
github.com/mattn/go-sqlite3 v1.14.22
github.com/mdlayher/netlink v1.7.2
github.com/miekg/dns v1.1.58
github.com/shirou/gopsutil/v3 v3.24.2
golang.org/x/sys v0.38.0
google.golang.org/grpc v1.78.0
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/mdlayher/socket v0.5.0 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/shoenig/go-m1cpu v0.1.6 // indirect
github.com/tklauser/go-sysconf v0.3.13 // indirect
github.com/tklauser/numcpus v0.7.0 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
golang.org/x/mod v0.29.0 // indirect
golang.org/x/net v0.47.0 // indirect
golang.org/x/sync v0.18.0 // indirect
golang.org/x/text v0.31.0 // indirect
golang.org/x/tools v0.38.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20251029180050-ab9386a59fda // indirect
google.golang.org/protobuf v1.36.10 // indirect
)

99
EdgeDNS/go.sum Normal file
View File

@@ -0,0 +1,99 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/nftables v0.2.0 h1:PbJwaBmbVLzpeldoeUKGkE2RjstrjPKMl6oLrfEJ6/8=
github.com/google/nftables v0.2.0/go.mod h1:Beg6V6zZ3oEn0JuiUQ4wqwuyqqzasOltcoXPtgLbFp4=
github.com/iwind/TeaGo v0.0.0-20240128112714-6bcd0529d0ea h1:o0QCF6tMJ9E6OgU1c0L+rYDshKsTu7mEk+7KCGLbnpI=
github.com/iwind/TeaGo v0.0.0-20240128112714-6bcd0529d0ea/go.mod h1:Ng3xWekHSVy0E/6/jYqJ7Htydm/H+mWIl0AS+Eg3H2M=
github.com/iwind/gosock v0.0.0-20220505115348-f88412125a62 h1:HJH6RDheAY156DnIfJSD/bEvqyXzsZuE2gzs8PuUjoo=
github.com/iwind/gosock v0.0.0-20220505115348-f88412125a62/go.mod h1:H5Q7SXwbx3a97ecJkaS2sD77gspzE7HFUafBO0peEyA=
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
github.com/mdlayher/socket v0.5.0 h1:ilICZmJcQz70vrWVes1MFera4jGiWNocSkykwwoy3XI=
github.com/mdlayher/socket v0.5.0/go.mod h1:WkcBFfvyG8QENs5+hfQPl1X6Jpd2yeLIYgrGFmJiJxI=
github.com/miekg/dns v1.1.58 h1:ca2Hdkz+cDg/7eNF6V56jjzuZ4aCAE+DbVkILdQWG/4=
github.com/miekg/dns v1.1.58/go.mod h1:Ypv+3b/KadlvW9vJfXOTf300O4UqaHFzFCuHz+rPkBY=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/shirou/gopsutil/v3 v3.24.2 h1:kcR0erMbLg5/3LcInpw0X/rrPSqq4CDPyI6A6ZRC18Y=
github.com/shirou/gopsutil/v3 v3.24.2/go.mod h1:tSg/594BcA+8UdQU2XcW803GWYgdtauFFPgJCJKZlVk=
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU=
github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/go-sysconf v0.3.13 h1:GBUpcahXSpR2xN01jhkNAbTLRk2Yzgggk8IM08lq3r4=
github.com/tklauser/go-sysconf v0.3.13/go.mod h1:zwleP4Q4OehZHGn4CYZDipCgg9usW5IJePewFCGVEa0=
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
github.com/tklauser/numcpus v0.7.0 h1:yjuerZP127QG9m5Zh/mSO4wqurYil27tHrqwRoRjpr4=
github.com/tklauser/numcpus v0.7.0/go.mod h1:bb6dMVcj8A42tSE7i32fsIUCbQNllK5iDguyOZRUzAY=
github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc h1:R83G5ikgLMxrBvLh22JhdfI8K6YXEPHx5P03Uu3DRs4=
github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc/go.mod h1:ZjcWmFBXmLKZu9Nxj3WKYEafiSqer2rnvPr0en9UNpI=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic=
golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w=
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw=
golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240311173647-c811ad7063a7 h1:8EeVk1VKMD+GD/neyEHGmz7pFblqPjHoi+PGQIlLx2s=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240311173647-c811ad7063a7/go.mod h1:WtryC6hu0hhx87FDGxWCDptyssuo68sk10vYjF+T9fY=
google.golang.org/genproto/googleapis/rpc v0.0.0-20251029180050-ab9386a59fda/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk=
google.golang.org/grpc v1.63.2 h1:MUeiw1B2maTVZthpU5xvASfTh3LDbxHd6IJ6QQVU+xM=
google.golang.org/grpc v1.63.2/go.mod h1:WAX/8DgncnokcFUldAxq7GeB5DXHDbMF+lLvDomNkRA=
google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -0,0 +1,36 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package agents
import (
"regexp"
"strings"
)
type Agent struct {
Code string
suffixes []string
reg *regexp.Regexp
}
func NewAgent(code string, suffixes []string, reg *regexp.Regexp) *Agent {
return &Agent{
Code: code,
suffixes: suffixes,
reg: reg,
}
}
func (this *Agent) Match(ptr string) bool {
if len(this.suffixes) > 0 {
for _, suffix := range this.suffixes {
if strings.HasSuffix(ptr, suffix) {
return true
}
}
}
if this.reg != nil {
return this.reg.MatchString(ptr)
}
return false
}

View File

@@ -0,0 +1,17 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package agents
var AllAgents = []*Agent{
NewAgent("baidu", []string{".baidu.com."}, nil),
NewAgent("google", []string{".googlebot.com."}, nil),
NewAgent("bing", []string{".search.msn.com."}, nil),
NewAgent("sogou", []string{".sogou.com."}, nil),
NewAgent("youdao", []string{".163.com."}, nil),
NewAgent("yahoo", []string{".yahoo.com."}, nil),
NewAgent("bytedance", []string{".bytedance.com."}, nil),
NewAgent("sm", []string{".sm.cn."}, nil),
NewAgent("yandex", []string{".yandex.com.", ".yndx.net."}, nil),
NewAgent("semrush", []string{".semrush.com."}, nil),
NewAgent("facebook", []string{"facebook-waw.1-ix.net.", "facebook.b-ix.net."}, nil),
}

View File

@@ -0,0 +1,54 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package agents
import (
"github.com/TeaOSLab/EdgeDNS/internal/utils/zero"
"sync"
)
type IPCacheMap struct {
m map[string]zero.Zero
list []string
locker sync.RWMutex
maxLen int
}
func NewIPCacheMap(maxLen int) *IPCacheMap {
if maxLen <= 0 {
maxLen = 65535
}
return &IPCacheMap{
m: map[string]zero.Zero{},
maxLen: maxLen,
}
}
func (this *IPCacheMap) Add(ip string) {
this.locker.Lock()
defer this.locker.Unlock()
// 是否已经存在
_, ok := this.m[ip]
if ok {
return
}
// 超出长度删除第一个
if len(this.list) >= this.maxLen {
delete(this.m, this.list[0])
this.list = this.list[1:]
}
// 加入新数据
this.m[ip] = zero.Zero{}
this.list = append(this.list, ip)
}
func (this *IPCacheMap) Contains(ip string) bool {
this.locker.RLock()
defer this.locker.RUnlock()
_, ok := this.m[ip]
return ok
}

View File

@@ -0,0 +1,34 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package agents
import (
"github.com/iwind/TeaGo/logs"
"testing"
)
func TestNewIPCacheMap(t *testing.T) {
var cacheMap = NewIPCacheMap(3)
t.Log("====")
cacheMap.Add("1")
cacheMap.Add("2")
logs.PrintAsJSON(cacheMap.m, t)
logs.PrintAsJSON(cacheMap.list, t)
t.Log("====")
cacheMap.Add("3")
logs.PrintAsJSON(cacheMap.m, t)
logs.PrintAsJSON(cacheMap.list, t)
t.Log("====")
cacheMap.Add("4")
logs.PrintAsJSON(cacheMap.m, t)
logs.PrintAsJSON(cacheMap.list, t)
t.Log("====")
cacheMap.Add("3")
logs.PrintAsJSON(cacheMap.m, t)
logs.PrintAsJSON(cacheMap.list, t)
}

View File

@@ -0,0 +1,173 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package agents
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeDNS/internal/dbs"
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
"github.com/iwind/TeaGo/Tea"
"sync"
"time"
)
// SharedManager 此值在外面调用时指定
var SharedManager *Manager
// Manager Agent管理器
type Manager struct {
ipMap map[string]string // ip => agentCode
locker sync.RWMutex
db *dbs.DB
lastId int64
}
func NewManager(db *dbs.DB) *Manager {
return &Manager{
ipMap: map[string]string{},
db: db,
}
}
func (this *Manager) Start() {
remotelogs.Println("AGENT_MANAGER", "starting ...")
// 从本地数据库中加载
err := this.Load()
if err != nil {
remotelogs.Error("AGENT_MANAGER", "load failed: "+err.Error())
}
// 先从API获取
err = this.LoopAll()
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("AGENT_MANAGER", "retrieve latest agent ip failed: "+err.Error())
} else {
remotelogs.Error("AGENT_MANAGER", "retrieve latest agent ip failed: "+err.Error())
}
}
// 定时获取
var duration = 30 * time.Minute
if Tea.IsTesting() {
duration = 30 * time.Second
}
var ticker = time.NewTicker(duration)
for range ticker.C {
err = this.LoopAll()
if err != nil {
remotelogs.Error("AGENT_MANAGER", "retrieve latest agent ip failed: "+err.Error())
}
}
}
func (this *Manager) Load() error {
var offset int64 = 0
var size int64 = 10000
for {
agentIPs, err := this.db.ListAgentIPs(offset, size)
if err != nil {
return err
}
if len(agentIPs) == 0 {
break
}
for _, agentIP := range agentIPs {
this.locker.Lock()
this.ipMap[agentIP.IP] = agentIP.AgentCode
this.locker.Unlock()
if agentIP.Id > this.lastId {
this.lastId = agentIP.Id
}
}
offset += size
}
return nil
}
func (this *Manager) LoopAll() error {
for {
hasNext, err := this.Loop()
if err != nil {
return err
}
if !hasNext {
break
}
}
return nil
}
// Loop 单次循环获取数据
func (this *Manager) Loop() (hasNext bool, err error) {
rpcClient, err := rpc.SharedRPC()
if err != nil {
return false, err
}
ipsResp, err := rpcClient.ClientAgentIPRPC.ListClientAgentIPsAfterId(rpcClient.Context(), &pb.ListClientAgentIPsAfterIdRequest{
Id: this.lastId,
Size: 10000,
})
if err != nil {
return false, err
}
if len(ipsResp.ClientAgentIPs) == 0 {
return false, nil
}
for _, agentIP := range ipsResp.ClientAgentIPs {
if agentIP.ClientAgent == nil {
// 设置ID
if agentIP.Id > this.lastId {
this.lastId = agentIP.Id
}
continue
}
// 写入到数据库
err = this.db.InsertAgentIP(agentIP.Id, agentIP.Ip, agentIP.ClientAgent.Code)
if err != nil {
return false, err
}
// 写入Map
this.locker.Lock()
this.ipMap[agentIP.Ip] = agentIP.ClientAgent.Code
this.locker.Unlock()
// 设置ID
if agentIP.Id > this.lastId {
this.lastId = agentIP.Id
}
}
return true, nil
}
// AddIP 添加记录
func (this *Manager) AddIP(ip string, agentCode string) {
this.locker.Lock()
this.ipMap[ip] = agentCode
this.locker.Unlock()
}
// LookupIP 查询IP所属Agent
func (this *Manager) LookupIP(ip string) (agentCode string) {
this.locker.RLock()
defer this.locker.RUnlock()
return this.ipMap[ip]
}
// ContainsIP 检查是否有IP相关数据
func (this *Manager) ContainsIP(ip string) bool {
this.locker.RLock()
defer this.locker.RUnlock()
_, ok := this.ipMap[ip]
return ok
}

View File

@@ -0,0 +1,33 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package agents_test
import (
"github.com/TeaOSLab/EdgeDNS/internal/agents"
"github.com/TeaOSLab/EdgeDNS/internal/dbs"
"github.com/iwind/TeaGo/Tea"
_ "github.com/iwind/TeaGo/bootstrap"
"testing"
)
func TestNewManager(t *testing.T) {
var db = dbs.NewDB(Tea.Root + "/data/data.db")
err := db.Init()
if err != nil {
t.Fatal(err)
}
var manager = agents.NewManager(db)
err = manager.Load()
if err != nil {
t.Fatal(err)
}
_, err = manager.Loop()
if err != nil {
t.Fatal(err)
}
t.Log(manager.LookupIP("192.168.3.100"))
}

View File

@@ -0,0 +1,138 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package agents
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
"github.com/TeaOSLab/EdgeDNS/internal/events"
"github.com/TeaOSLab/EdgeDNS/internal/goman"
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
"github.com/iwind/TeaGo/Tea"
"net"
)
func init() {
if !teaconst.IsMain {
return
}
events.On(events.EventLoaded, func() {
goman.New(func() {
SharedQueue.Start()
})
})
}
var SharedQueue = NewQueue()
type Queue struct {
c chan string // chan ip
cacheMap *IPCacheMap
}
func NewQueue() *Queue {
return &Queue{
c: make(chan string, 128),
cacheMap: NewIPCacheMap(65535),
}
}
func (this *Queue) Start() {
for ip := range this.c {
err := this.Process(ip)
if err != nil {
// 不需要上报错误
if Tea.IsTesting() {
remotelogs.Debug("SharedParseQueue", err.Error())
}
continue
}
}
}
// Push 将IP加入到处理队列
func (this *Queue) Push(ip string) {
// 是否在处理中
if this.cacheMap.Contains(ip) {
return
}
this.cacheMap.Add(ip)
// 加入到队列
select {
case this.c <- ip:
default:
}
}
// Process 处理IP
func (this *Queue) Process(ip string) error {
// 是否已经在库中
if SharedManager.ContainsIP(ip) {
return nil
}
ptr, err := this.ParseIP(ip)
if err != nil {
return err
}
if len(ptr) == 0 || ptr == "." {
return nil
}
//remotelogs.Debug("AGENT", ip+" => "+ptr)
var agentCode = this.ParsePtr(ptr)
if len(agentCode) == 0 {
return nil
}
// 加入到本地
SharedManager.AddIP(ip, agentCode)
var pbAgentIP = &pb.CreateClientAgentIPsRequest_AgentIPInfo{
AgentCode: agentCode,
Ip: ip,
Ptr: ptr,
}
rpcClient, err := rpc.SharedRPC()
if err != nil {
return err
}
_, err = rpcClient.ClientAgentIPRPC.CreateClientAgentIPs(rpcClient.Context(), &pb.CreateClientAgentIPsRequest{AgentIPs: []*pb.CreateClientAgentIPsRequest_AgentIPInfo{pbAgentIP}})
if err != nil {
return err
}
return nil
}
// ParseIP 分析IP的PTR值
func (this *Queue) ParseIP(ip string) (ptr string, err error) {
if len(ip) == 0 {
return "", nil
}
names, err := net.LookupAddr(ip)
if err != nil {
return "", err
}
if len(names) == 0 {
return "", nil
}
return names[0], nil
}
// ParsePtr 分析PTR对应的Agent
func (this *Queue) ParsePtr(ptr string) (agentCode string) {
for _, agent := range AllAgents {
if agent.Match(ptr) {
return agent.Code
}
}
return ""
}

View File

@@ -0,0 +1,77 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package agents_test
import (
"github.com/TeaOSLab/EdgeDNS/internal/agents"
"github.com/iwind/TeaGo/assert"
_ "github.com/iwind/TeaGo/bootstrap"
"testing"
"time"
)
func TestParseQueue_Process(t *testing.T) {
var queue = agents.NewQueue()
go queue.Start()
time.Sleep(1 * time.Second)
queue.Push("220.181.13.100")
time.Sleep(1 * time.Second)
}
func TestParseQueue_ParseIP(t *testing.T) {
var queue = agents.NewQueue()
for _, ip := range []string{
"192.168.1.100",
"42.120.160.1",
"42.236.10.98",
"124.115.0.100",
} {
ptr, err := queue.ParseIP(ip)
if err != nil {
t.Log(ip, "=>", err)
continue
}
t.Log(ip, "=>", ptr)
}
}
func TestParseQueue_ParsePtr(t *testing.T) {
var a = assert.NewAssertion(t)
var queue = agents.NewQueue()
for _, s := range [][]string{
{"baiduspider-220-181-108-101.crawl.baidu.com.", "baidu"},
{"crawl-66-249-71-219.googlebot.com.", "google"},
{"msnbot-40-77-167-31.search.msn.com.", "bing"},
{"sogouspider-49-7-20-129.crawl.sogou.com.", "sogou"},
{"m13102.mail.163.com.", "youdao"},
{"yeurosport.pat1.tc2.yahoo.com.", "yahoo"},
{"shenmaspider-42-120-160-1.crawl.sm.cn.", "sm"},
{"93-158-161-39.spider.yandex.com.", "yandex"},
{"25.bl.bot.semrush.com.", "semrush"},
} {
a.IsTrue(queue.ParsePtr(s[0]) == s[1])
}
}
func BenchmarkQueue_ParsePtr(b *testing.B) {
var queue = agents.NewQueue()
for i := 0; i < b.N; i++ {
for _, s := range [][]string{
{"baiduspider-220-181-108-101.crawl.baidu.com.", "baidu"},
{"crawl-66-249-71-219.googlebot.com.", "google"},
{"msnbot-40-77-167-31.search.msn.com.", "bing"},
{"sogouspider-49-7-20-129.crawl.sogou.com.", "sogou"},
{"m13102.mail.163.com.", "youdao"},
{"yeurosport.pat1.tc2.yahoo.com.", "yahoo"},
{"shenmaspider-42-120-160-1.crawl.sm.cn.", "sm"},
{"93-158-161-39.spider.yandex.com.", "yandex"},
{"93.158.164.218-red.dhcp.yndx.net.", "yandex"},
{"25.bl.bot.semrush.com.", "semrush"},
} {
queue.ParsePtr(s[0])
}
}
}

View File

@@ -0,0 +1,319 @@
package apps
import (
"errors"
"fmt"
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
"github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
"github.com/iwind/gosock/pkg/gosock"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"
"syscall"
"time"
)
// AppCmd App命令帮助
type AppCmd struct {
product string
version string
usage string
options []*CommandHelpOption
appendStrings []string
directives []*Directive
sock *gosock.Sock
}
func NewAppCmd() *AppCmd {
return &AppCmd{
sock: gosock.NewTmpSock(teaconst.ProcessName),
}
}
type CommandHelpOption struct {
Code string
Description string
}
// Product 产品
func (this *AppCmd) Product(product string) *AppCmd {
this.product = product
return this
}
// Version 版本
func (this *AppCmd) Version(version string) *AppCmd {
this.version = version
return this
}
// Usage 使用方法
func (this *AppCmd) Usage(usage string) *AppCmd {
this.usage = usage
return this
}
// Option 选项
func (this *AppCmd) Option(code string, description string) *AppCmd {
this.options = append(this.options, &CommandHelpOption{
Code: code,
Description: description,
})
return this
}
// Append 附加内容
func (this *AppCmd) Append(appendString string) *AppCmd {
this.appendStrings = append(this.appendStrings, appendString)
return this
}
// Print 打印
func (this *AppCmd) Print() {
fmt.Println(this.product + " v" + this.version)
usage := this.usage
fmt.Println("Usage:", "\n "+usage)
if len(this.options) > 0 {
fmt.Println("")
fmt.Println("Options:")
spaces := 20
max := 40
for _, option := range this.options {
l := len(option.Code)
if l < max && l > spaces {
spaces = l + 4
}
}
for _, option := range this.options {
if len(option.Code) > max {
fmt.Println("")
fmt.Println(" " + option.Code)
option.Code = ""
}
fmt.Printf(" %-"+strconv.Itoa(spaces)+"s%s\n", option.Code, ": "+option.Description)
}
}
if len(this.appendStrings) > 0 {
fmt.Println("")
for _, s := range this.appendStrings {
fmt.Println(s)
}
}
}
// On 添加指令
func (this *AppCmd) On(arg string, callback func()) {
this.directives = append(this.directives, &Directive{
Arg: arg,
Callback: callback,
})
}
// Run 运行
func (this *AppCmd) Run(main func()) {
// 获取参数
var args = os.Args[1:]
if len(args) > 0 {
var mainArg = args[0]
this.callDirective(mainArg + ":before")
switch mainArg {
case "-v", "version", "-version", "--version":
this.runVersion()
return
case "?", "help", "-help", "h", "-h":
this.runHelp()
return
case "start":
this.runStart()
return
case "stop":
this.runStop()
return
case "restart":
this.runRestart()
return
case "status":
this.runStatus()
return
}
// 查找指令
for _, directive := range this.directives {
if directive.Arg == mainArg {
directive.Callback()
return
}
}
fmt.Println("unknown command '" + mainArg + "'")
return
}
// 日志
writer := new(LogWriter)
writer.Init()
logs.SetWriter(writer)
// 运行主函数
main()
}
// 版本号
func (this *AppCmd) runVersion() {
fmt.Println(this.product+" v"+this.version, "(build: "+runtime.Version(), runtime.GOOS, runtime.GOARCH+")")
}
// 帮助
func (this *AppCmd) runHelp() {
this.Print()
}
// 启动
func (this *AppCmd) runStart() {
var pid = this.getPID()
if pid > 0 {
fmt.Println(this.product+" already started, pid:", pid)
return
}
var cmd = exec.Command(this.exe())
cmd.SysProcAttr = &syscall.SysProcAttr{
Foreground: false,
Setsid: true,
}
err := cmd.Start()
if err != nil {
fmt.Println(this.product+" start failed:", err.Error())
return
}
// create symbolic links
_ = this.createSymLinks()
fmt.Println(this.product+" started ok, pid:", cmd.Process.Pid)
}
// 停止
func (this *AppCmd) runStop() {
var pid = this.getPID()
if pid == 0 {
fmt.Println(this.product + " not started yet")
return
}
_, _ = this.sock.Send(&gosock.Command{Code: "stop"})
fmt.Println(this.product+" stopped ok, pid:", types.String(pid))
}
// 重启
func (this *AppCmd) runRestart() {
this.runStop()
time.Sleep(1 * time.Second)
this.runStart()
}
// 状态
func (this *AppCmd) runStatus() {
var pid = this.getPID()
if pid == 0 {
fmt.Println(this.product + " not started yet")
return
}
fmt.Println(this.product + " is running, pid: " + types.String(pid))
}
// 获取当前的PID
func (this *AppCmd) getPID() int {
if !this.sock.IsListening() {
return 0
}
reply, err := this.sock.Send(&gosock.Command{Code: "pid"})
if err != nil {
return 0
}
return maps.NewMap(reply.Params).GetInt("pid")
}
func (this *AppCmd) exe() string {
var exe, _ = os.Executable()
if len(exe) == 0 {
exe = os.Args[0]
}
return exe
}
// 创建软链接
func (this *AppCmd) createSymLinks() error {
if runtime.GOOS != "linux" {
return nil
}
var exe, _ = os.Executable()
if len(exe) == 0 {
return nil
}
var errorList = []string{}
// bin
{
var target = "/usr/bin/" + teaconst.ProcessName
old, _ := filepath.EvalSymlinks(target)
if old != exe {
_ = os.Remove(target)
err := os.Symlink(exe, target)
if err != nil {
errorList = append(errorList, err.Error())
}
}
}
// log
{
var realPath = filepath.Dir(filepath.Dir(exe)) + "/logs/run.log"
var target = "/var/log/" + teaconst.ProcessName + ".log"
old, _ := filepath.EvalSymlinks(target)
if old != realPath {
_ = os.Remove(target)
err := os.Symlink(realPath, target)
if err != nil {
errorList = append(errorList, err.Error())
}
}
}
if len(errorList) > 0 {
return errors.New(strings.Join(errorList, "\n"))
}
return nil
}
func (this *AppCmd) callDirective(code string) {
for _, directive := range this.directives {
if directive.Arg == code {
if directive.Callback != nil {
directive.Callback()
}
return
}
}
}

View File

@@ -0,0 +1,6 @@
package apps
type Directive struct {
Arg string
Callback func()
}

View File

@@ -0,0 +1,111 @@
package apps
import (
"github.com/TeaOSLab/EdgeDNS/internal/goman"
"github.com/TeaOSLab/EdgeDNS/internal/utils/sizes"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/files"
timeutil "github.com/iwind/TeaGo/utils/time"
"log"
"os"
"runtime"
"strconv"
"strings"
)
type LogWriter struct {
fp *os.File
c chan string
}
func (this *LogWriter) Init() {
// 创建目录
var dir = files.NewFile(Tea.LogDir())
if !dir.Exists() {
err := dir.Mkdir()
if err != nil {
log.Println("[LOG]create log dir failed: " + err.Error())
}
}
// 打开要写入的日志文件
var logPath = Tea.LogFile("run.log")
fp, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
if err != nil {
log.Println("[LOG]open log file failed: " + err.Error())
} else {
this.fp = fp
}
this.c = make(chan string, 1024)
// 异步写入文件
var maxFileSize = 128 * sizes.M // 文件最大尺寸,超出此尺寸则清空
if fp != nil {
goman.New(func() {
var totalSize int64 = 0
stat, err := fp.Stat()
if err == nil {
totalSize = stat.Size()
}
for message := range this.c {
totalSize += int64(len(message))
_, err := fp.WriteString(timeutil.Format("Y/m/d H:i:s ") + message + "\n")
if err != nil {
log.Println("[LOG]write log failed: " + err.Error())
} else {
// 如果太大则Truncate
if totalSize > maxFileSize {
_ = fp.Truncate(0)
totalSize = 0
}
}
}
})
}
}
func (this *LogWriter) Write(message string) {
backgroundEnv, _ := os.LookupEnv("EdgeBackground")
if backgroundEnv != "on" {
// 文件和行号
var file string
var line int
if Tea.IsTesting() {
var callDepth = 3
var ok bool
_, file, line, ok = runtime.Caller(callDepth)
if ok {
file = this.packagePath(file)
}
}
if len(file) > 0 {
log.Println(message + " (" + file + ":" + strconv.Itoa(line) + ")")
} else {
log.Println(message)
}
}
select {
case this.c <- message:
default:
}
}
func (this *LogWriter) Close() {
if this.fp != nil {
_ = this.fp.Close()
}
close(this.c)
}
func (this *LogWriter) packagePath(path string) string {
var pieces = strings.Split(path, "/")
if len(pieces) >= 2 {
return strings.Join(pieces[len(pieces)-2:], "/")
}
return path
}

View File

@@ -0,0 +1,107 @@
package configs
import (
"errors"
"github.com/iwind/TeaGo/Tea"
"gopkg.in/yaml.v3"
"os"
)
const ConfigFileName = "api_dns.yaml"
const oldConfigFileName = "api.yaml"
var sharedAPIConfig *APIConfig
// APIConfig API配置
type APIConfig struct {
OldRPC struct {
Endpoints []string `yaml:"endpoints"`
DisableUpdate bool `yaml:"disableUpdate"`
} `yaml:"rpc,omitempty"`
RPCEndpoints []string `yaml:"rpc.endpoints,flow" json:"rpc.endpoints"`
RPCDisableUpdate bool `yaml:"rpc.disableUpdate" json:"rpc.disableUpdate"`
NodeId string `yaml:"nodeId"`
Secret string `yaml:"secret"`
NumberId int64 `yaml:"numberId"`
}
// SharedAPIConfig 加载API配置
func SharedAPIConfig() (*APIConfig, error) {
if sharedAPIConfig != nil {
return sharedAPIConfig, nil
}
config, err := LoadAPIConfig()
if err != nil {
return nil, err
}
sharedAPIConfig = config
return config, nil
}
// LoadAPIConfig 加载API配置
func LoadAPIConfig() (*APIConfig, error) {
for _, filename := range []string{ConfigFileName, oldConfigFileName} {
data, err := os.ReadFile(Tea.ConfigFile(filename))
if err != nil {
if os.IsNotExist(err) {
continue
}
return nil, err
}
var config = &APIConfig{}
err = yaml.Unmarshal(data, config)
if err != nil {
return nil, err
}
err = config.Init()
if err != nil {
return nil, errors.New("init error: " + err.Error())
}
// 自动生成新的配置文件
if filename == oldConfigFileName {
config.OldRPC.Endpoints = nil
_ = config.WriteFile(Tea.ConfigFile(ConfigFileName))
}
return config, nil
}
return nil, errors.New("no config file '" + ConfigFileName + "' found")
}
func (this *APIConfig) Init() error {
// compatible with old
if len(this.RPCEndpoints) == 0 && len(this.OldRPC.Endpoints) > 0 {
this.RPCEndpoints = this.OldRPC.Endpoints
this.RPCDisableUpdate = this.OldRPC.DisableUpdate
}
if len(this.RPCEndpoints) == 0 {
return errors.New("no valid 'rpc.endpoints'")
}
if len(this.NodeId) == 0 {
return errors.New("'nodeId' required")
}
if len(this.Secret) == 0 {
return errors.New("'secret' required")
}
return nil
}
// WriteFile 写入API配置
func (this *APIConfig) WriteFile(path string) error {
data, err := yaml.Marshal(this)
if err != nil {
return err
}
return os.WriteFile(path, data, 0666)
}

View File

@@ -0,0 +1,21 @@
package configs
import (
_ "github.com/iwind/TeaGo/bootstrap"
"gopkg.in/yaml.v3"
"testing"
)
func TestLoadAPIConfig(t *testing.T) {
config, err := LoadAPIConfig()
if err != nil {
t.Fatal(err)
}
t.Logf("%+v", config)
configData, err := yaml.Marshal(config)
if err != nil {
t.Fatal(err)
}
t.Log(string(configData))
}

View File

@@ -0,0 +1,8 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package configs
import "github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
var SharedNodeConfig *dnsconfigs.NSNodeConfig

View File

@@ -0,0 +1,15 @@
package teaconst
const (
Version = "1.4.5.1" //1.3.8.2
ProductName = "Edge DNS"
ProcessName = "edge-dns"
ProductNameZH = "Edge"
Role = "dns"
EncryptMethod = "aes-256-cfb"
SystemdServiceName = "edge-dns"
)

View File

@@ -0,0 +1,29 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package teaconst
import (
"os"
"strings"
)
var (
IsDaemon = false
IsPlus = true
IsMain = checkMain()
IsQuiting = false // 是否正在退出
EnableDBStat = false // 是否开启本地数据库统计
)
// 检查是否为主程序
func checkMain() bool {
if len(os.Args) == 1 ||
(len(os.Args) >= 2 && os.Args[1] == "pprof") {
return true
}
exe, _ := os.Executable()
return strings.HasSuffix(exe, ".test") ||
strings.HasSuffix(exe, ".test.exe") ||
strings.Contains(exe, "___")
}

783
EdgeDNS/internal/dbs/db.go Normal file
View File

@@ -0,0 +1,783 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus
package dbs
import (
"encoding/json"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/TeaOSLab/EdgeDNS/internal/models"
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
dbutils "github.com/TeaOSLab/EdgeDNS/internal/utils/dbs"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/types"
_ "github.com/mattn/go-sqlite3"
"log"
"os"
"path/filepath"
"strings"
)
const (
tableDomains = "domains_v2"
tableRecords = "records_v2"
tableRoutes = "routes_v2"
tableKeys = "keys"
tableAgentIPs = "agentIPs"
)
type DB struct {
db *dbutils.DB
path string
insertDomainStmt *dbutils.Stmt
updateDomainStmt *dbutils.Stmt
deleteDomainStmt *dbutils.Stmt
existsDomainStmt *dbutils.Stmt
listDomainsStmt *dbutils.Stmt
insertRecordStmt *dbutils.Stmt
updateRecordStmt *dbutils.Stmt
existsRecordStmt *dbutils.Stmt
deleteRecordStmt *dbutils.Stmt
listRecordsStmt *dbutils.Stmt
insertRouteStmt *dbutils.Stmt
updateRouteStmt *dbutils.Stmt
deleteRouteStmt *dbutils.Stmt
listRoutesStmt *dbutils.Stmt
existsRouteStmt *dbutils.Stmt
insertKeyStmt *dbutils.Stmt
updateKeyStmt *dbutils.Stmt
deleteKeyStmt *dbutils.Stmt
listKeysStmt *dbutils.Stmt
existsKeyStmt *dbutils.Stmt
insertAgentIPStmt *dbutils.Stmt
listAgentIPsStmt *dbutils.Stmt
}
func NewDB(path string) *DB {
return &DB{path: path}
}
func (this *DB) Init() error {
// 检查目录是否存在
var dir = filepath.Dir(this.path)
_, err := os.Stat(dir)
if err != nil {
err = os.MkdirAll(dir, 0777)
if err != nil {
return err
}
remotelogs.Println("DB", "create database dir '"+dir+"'")
}
// TODO 思考 data.db 的数据安全性
db, err := dbutils.OpenWriter("file:" + this.path + "?cache=shared&mode=rwc&_journal_mode=WAL&_locking_mode=EXCLUSIVE")
if err != nil {
return err
}
db.SetMaxOpenConns(1)
/**_, err = db.Exec("VACUUM")
if err != nil {
return err
}**/
// 创建数据表
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + tableDomains + `" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"clusterId" integer DEFAULT 0,
"userId" integer DEFAULT 0,
"name" varchar(255),
"version" integer DEFAULT 0,
"tsig" text
);
CREATE INDEX IF NOT EXISTS "clusterId"
ON "` + tableDomains + `" (
"clusterId"
);
`)
if err != nil {
return err
}
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + tableRecords + `" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"domainId" integer DEFAULT 0,
"name" varchar(255),
"type" varchar(32),
"value" varchar(4096),
"mxPriority" integer DEFAULT 10,
"srvPriority" integer DEFAULT 10,
"srvWeight" integer DEFAULT 10,
"srvPort" integer DEFAULT 0,
"caaFlag" integer DEFAULT 0,
"caaTag" varchar(16),
"ttl" integer DEFAULT 0,
"weight" integer DEFAULT 0,
"routeIds" varchar(512),
"version" integer DEFAULT 0
);
`)
if err != nil {
// 忽略可以预期的错误
if strings.Contains(err.Error(), "duplicate column name") {
err = nil
}
if err != nil {
return err
}
}
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + tableRoutes + `" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"ranges" text,
"order" integer DEFAULT 0,
"priority" integer DEFAULT 0,
"userId" integer DEFAULT 0,
"version" integer DEFAULT 0
);
`)
if err != nil {
// 忽略可以预期的错误
if strings.Contains(err.Error(), "duplicate column name") {
err = nil
}
if err != nil {
return err
}
}
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + tableKeys + `" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"domainId" integer DEFAULT 0,
"zoneId" integer DEFAULT 0,
"algo" varchar(128),
"secret" varchar(4096),
"secretType" varchar(32),
"version" integer DEFAULT 0
);`)
if err != nil {
return err
}
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + tableAgentIPs + `" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"ip" varchar(64),
"agentCode" varchar(128)
);`)
if err != nil {
return err
}
// 预编译语句
// domain statements
this.insertDomainStmt, err = db.Prepare(`INSERT INTO "` + tableDomains + `" ("id", "clusterId", "userId", "name", "tsig", "version") VALUES (?, ?, ?, ?, ?, ?)`)
if err != nil {
return err
}
this.updateDomainStmt, err = db.Prepare(`UPDATE "` + tableDomains + `" SET "clusterId"=?, "userId"=?, "name"=?, "tsig"=?, "version"=? WHERE "id"=?`)
if err != nil {
return err
}
this.deleteDomainStmt, err = db.Prepare(`DELETE FROM "` + tableDomains + `" WHERE id=?`)
if err != nil {
return err
}
this.existsDomainStmt, err = db.Prepare(`SELECT "id" FROM "` + tableDomains + `" WHERE "id"=? LIMIT 1`)
if err != nil {
return err
}
this.listDomainsStmt, err = db.Prepare(`SELECT "id", "clusterId", "userId", "name", "tsig", "version" FROM "` + tableDomains + `" WHERE "clusterId"=? ORDER BY "id" ASC LIMIT ? OFFSET ?`)
if err != nil {
return err
}
// record statements
this.insertRecordStmt, err = db.Prepare(`INSERT INTO "` + tableRecords + `" ("id", "domainId", "name", "type", "value", "mxPriority", "srvPriority", "srvWeight", "srvPort", "caaFlag", "caaTag", "ttl", "weight", "routeIds", "version") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
if err != nil {
return err
}
this.updateRecordStmt, err = db.Prepare(`UPDATE "` + tableRecords + `" SET "domainId"=?, "name"=?, "type"=?, "value"=?, "mxPriority"=?, "srvPriority"=?, "srvWeight"=?, "srvPort"=?, "caaFlag"=?, "caaTag"=?, "ttl"=?, "weight"=?, "routeIds"=?, "version"=? WHERE "id"=?`)
if err != nil {
return err
}
this.existsRecordStmt, err = db.Prepare(`SELECT "id" FROM "` + tableRecords + `" WHERE "id"=? LIMIT 1`)
if err != nil {
return err
}
this.deleteRecordStmt, err = db.Prepare(`DELETE FROM "` + tableRecords + `" WHERE id=?`)
if err != nil {
return err
}
this.listRecordsStmt, err = db.Prepare(`SELECT "id", "domainId", "name", "type", "value", "mxPriority", "srvPriority", "srvWeight", "srvPort", "caaFlag", "caaTag", "ttl", "weight", "routeIds", "version" FROM "` + tableRecords + `" ORDER BY "id" ASC LIMIT ? OFFSET ?`)
if err != nil {
return err
}
// route statements
this.insertRouteStmt, err = db.Prepare(`INSERT INTO "` + tableRoutes + `" ("id", "userId", "ranges", "order", "priority", "version") VALUES (?, ?, ?, ?, ?, ?)`)
if err != nil {
return err
}
this.updateRouteStmt, err = db.Prepare(`UPDATE "` + tableRoutes + `" SET "userId"=?, "ranges"=?, "order"=?, "priority"=?, "version"=? WHERE "id"=?`)
if err != nil {
return err
}
this.deleteRouteStmt, err = db.Prepare(`DELETE FROM "` + tableRoutes + `" WHERE "id"=?`)
if err != nil {
return err
}
this.listRoutesStmt, err = db.Prepare(`SELECT "id", "userId", "ranges", "priority", "order", "version" FROM "` + tableRoutes + `" ORDER BY "id" ASC LIMIT ? OFFSET ?`)
if err != nil {
return err
}
this.existsRouteStmt, err = db.Prepare(`SELECT "id" FROM "` + tableRoutes + `" WHERE "id"=? LIMIT 1`)
if err != nil {
return err
}
// key statements
this.insertKeyStmt, err = db.Prepare(`INSERT INTO "` + tableKeys + `" ("id", "domainId", "zoneId", "algo", "secret", "secretType", "version") VALUES (?, ?, ?, ?, ?, ?, ?)`)
if err != nil {
return err
}
this.updateKeyStmt, err = db.Prepare(`UPDATE "` + tableKeys + `" SET "domainId"=?, "zoneId"=?, "algo"=?, "secret"=?, "secretType"=?, "version"=? WHERE "id"=?`)
if err != nil {
return err
}
this.deleteKeyStmt, err = db.Prepare(`DELETE FROM "` + tableKeys + `" WHERE "id"=?`)
if err != nil {
return err
}
this.listKeysStmt, err = db.Prepare(`SELECT "id", "domainId", "zoneId", "algo", "secret", "secretType", "version" FROM "` + tableKeys + `" ORDER BY "id" ASC LIMIT ? OFFSET ?`)
if err != nil {
return err
}
this.existsKeyStmt, err = db.Prepare(`SELECT "id" FROM "` + tableKeys + `" WHERE "id"=? LIMIT 1`)
if err != nil {
return err
}
// agent ip record statements
this.insertAgentIPStmt, err = db.Prepare(`INSERT INTO "` + tableAgentIPs + `" ("id", "ip", "agentCode") VALUES (?, ?, ?)`)
if err != nil {
return err
}
this.listAgentIPsStmt, err = db.Prepare(`SELECT "id", "ip", "agentCode" FROM "` + tableAgentIPs + `" ORDER BY "id" ASC LIMIT ? OFFSET ?`)
if err != nil {
return err
}
this.db = db
return nil
}
func (this *DB) InsertDomain(domainId int64, clusterId int64, userId int64, name string, tsigJSON []byte, version int64) error {
if this.db == nil {
return errors.New("db should not be nil")
}
this.log("InsertDomain", "domain:", domainId, "user:", userId, "name:", name)
_, err := this.insertDomainStmt.Exec(domainId, clusterId, userId, name, string(tsigJSON), version)
if err != nil {
return err
}
return nil
}
func (this *DB) UpdateDomain(domainId int64, clusterId int64, userId int64, name string, tsigJSON []byte, version int64) error {
if this.db == nil {
return errors.New("db should not be nil")
}
this.log("UpdateDomain", "domain:", domainId, "user:", userId, "name:", name)
_, err := this.updateDomainStmt.Exec(clusterId, userId, name, string(tsigJSON), version, domainId)
if err != nil {
return err
}
return nil
}
func (this *DB) DeleteDomain(domainId int64) error {
if this.db == nil {
return errors.New("db should not be nil")
}
this.log("DeleteDomain", "domain:", domainId)
_, err := this.deleteDomainStmt.Exec(domainId)
if err != nil {
return err
}
return nil
}
func (this *DB) ExistsDomain(domainId int64) (bool, error) {
if this.db == nil {
return false, errors.New("db should not be nil")
}
rows, err := this.existsDomainStmt.Query(domainId)
if err != nil {
return false, err
}
if rows.Err() != nil {
return false, rows.Err()
}
defer func() {
_ = rows.Close()
}()
if rows.Next() {
return true, nil
}
return false, nil
}
func (this *DB) ListDomains(clusterId int64, offset int, size int) (domains []*models.NSDomain, err error) {
if this.db == nil {
return nil, errors.New("db should not be nil")
}
rows, err := this.listDomainsStmt.Query(clusterId, size, offset)
if err != nil {
return nil, err
}
if rows.Err() != nil {
return nil, rows.Err()
}
defer func() {
_ = rows.Close()
}()
for rows.Next() {
var domain = &models.NSDomain{}
var tsigString string
err = rows.Scan(&domain.Id, &domain.ClusterId, &domain.UserId, &domain.Name, &tsigString, &domain.Version)
if err != nil {
return nil, err
}
if len(tsigString) > 0 {
var tsigConfig = &dnsconfigs.NSTSIGConfig{}
err = json.Unmarshal([]byte(tsigString), tsigConfig)
if err != nil {
remotelogs.Error("decode tsig string failed: "+err.Error()+", domain:"+domain.Name, ", domainId: "+types.String(domain.Id))
} else {
domain.TSIG = tsigConfig
}
}
domains = append(domains, domain)
}
return
}
func (this *DB) InsertRecord(recordId int64, domainId int64, name string, recordType dnsconfigs.RecordType, value string, mxPriority int32, srvPriority int32, srvWeight int32, srvPort int32, caaFlag int32, caaTag string, ttl int32, weight int32, routeIds []string, version int64) error {
if this.db == nil {
return errors.New("db should not be nil")
}
this.log("InsertRecord", "domain:", domainId, "name:", name)
_, err := this.insertRecordStmt.Exec(recordId, domainId, name, recordType, value, mxPriority, srvPriority, srvWeight, srvPort, caaFlag, caaTag, ttl, weight, strings.Join(routeIds, ","), version)
if err != nil {
return err
}
return nil
}
func (this *DB) UpdateRecord(recordId int64, domainId int64, name string, recordType dnsconfigs.RecordType, value string, mxPriority int32, srvPriority int32, srvWeight int32, srvPort int32, caaFlag int32, caaTag string, ttl int32, weight int32, routeIds []string, version int64) error {
if this.db == nil {
return errors.New("db should not be nil")
}
this.log("UpdateRecord", "domain:", domainId, "name:", name)
_, err := this.updateRecordStmt.Exec(domainId, name, recordType, value, mxPriority, srvPriority, srvWeight, srvPort, caaFlag, caaTag, ttl, weight, strings.Join(routeIds, ","), version, recordId)
if err != nil {
return err
}
return nil
}
func (this *DB) DeleteRecord(recordId int64) error {
if this.db == nil {
return errors.New("db should not be nil")
}
this.log("DeleteRecord", "record:", recordId)
_, err := this.deleteRecordStmt.Exec(recordId)
if err != nil {
return err
}
return nil
}
func (this *DB) ExistsRecord(recordId int64) (bool, error) {
if this.db == nil {
return false, errors.New("db should not be nil")
}
rows, err := this.existsRecordStmt.Query(recordId)
if err != nil {
return false, err
}
if rows.Err() != nil {
return false, rows.Err()
}
defer func() {
_ = rows.Close()
}()
if rows.Next() {
return true, nil
}
return false, nil
}
// ListRecords 列出一组记录
// TODO 将来只加载本集群上的记录
func (this *DB) ListRecords(offset int, size int) (records []*models.NSRecord, err error) {
if this.db == nil {
return nil, errors.New("db should not be nil")
}
rows, err := this.listRecordsStmt.Query(size, offset)
if err != nil {
return nil, err
}
if rows.Err() != nil {
return nil, rows.Err()
}
defer func() {
_ = rows.Close()
}()
for rows.Next() {
var record = &models.NSRecord{}
var routeIds = ""
err = rows.Scan(&record.Id, &record.DomainId, &record.Name, &record.Type, &record.Value, &record.MXPriority, &record.SRVPriority, &record.SRVWeight, &record.SRVPort, &record.CAAFlag, &record.CAATag, &record.Ttl, &record.Weight, &routeIds, &record.Version)
if err != nil {
return nil, err
}
if len(routeIds) > 0 {
record.RouteIds = strings.Split(routeIds, ",")
}
records = append(records, record)
}
return
}
// InsertRoute 创建线路
func (this *DB) InsertRoute(routeId int64, userId int64, rangesJSON []byte, order int32, priority int32, version int64) error {
if this.db == nil {
return errors.New("db should not be nil")
}
this.log("InsertRoute", "route:", routeId, "user:", userId)
_, err := this.insertRouteStmt.Exec(routeId, userId, string(rangesJSON), order, priority, version)
if err != nil {
return err
}
return nil
}
// UpdateRoute 修改线路
func (this *DB) UpdateRoute(routeId int64, userId int64, rangesJSON []byte, order int32, priority int32, version int64) error {
if this.db == nil {
return errors.New("db should not be nil")
}
this.log("UpdateRoute", "route:", routeId, "user:", userId)
_, err := this.updateRouteStmt.Exec(userId, string(rangesJSON), order, priority, version, routeId)
if err != nil {
return err
}
return nil
}
// DeleteRoute 删除线路
func (this *DB) DeleteRoute(routeId int64) error {
if this.db == nil {
return errors.New("db should not be nil")
}
this.log("DeleteRoute", "route:", routeId)
_, err := this.deleteRouteStmt.Exec(routeId)
if err != nil {
return err
}
return nil
}
// ExistsRoute 检查是否存在线路
func (this *DB) ExistsRoute(routeId int64) (bool, error) {
if this.db == nil {
return false, errors.New("db should not be nil")
}
rows, err := this.existsRouteStmt.Query(routeId)
if err != nil {
return false, err
}
if rows.Err() != nil {
return false, rows.Err()
}
defer func() {
_ = rows.Close()
}()
if rows.Next() {
return true, nil
}
return false, nil
}
// ListRoutes 查找所有线路
func (this *DB) ListRoutes(offset int64, size int64) (routes []*models.NSRoute, err error) {
if this.db == nil {
return nil, errors.New("db should not be nil")
}
rows, err := this.listRoutesStmt.Query(size, offset)
if err != nil {
return nil, err
}
if rows.Err() != nil {
return nil, rows.Err()
}
defer func() {
_ = rows.Close()
}()
for rows.Next() {
var route = &models.NSRoute{}
var rangesString = ""
err = rows.Scan(&route.Id, &route.UserId, &rangesString, &route.Priority, &route.Order, &route.Version)
if err != nil {
return nil, err
}
route.Ranges, err = models.InitRangesFromJSON([]byte(rangesString))
if err != nil {
return nil, err
}
routes = append(routes, route)
}
return
}
func (this *DB) InsertKey(keyId int64, domainId int64, zoneId int64, algo string, secret string, secretType string, version int64) error {
if this.db == nil {
return errors.New("db should not be nil")
}
this.log("InsertKey", "key:", keyId, "domain:", domainId, "zone:", zoneId)
_, err := this.insertKeyStmt.Exec(keyId, domainId, zoneId, algo, secret, secretType, version)
if err != nil {
return err
}
return nil
}
func (this *DB) UpdateKey(keyId int64, domainId int64, zoneId int64, algo string, secret string, secretType string, version int64) error {
if this.db == nil {
return errors.New("db should not be nil")
}
this.log("UpdateKey", "key:", keyId, "domain:", domainId, "zone:", zoneId)
_, err := this.updateKeyStmt.Exec(domainId, zoneId, algo, secret, secretType, version, keyId)
if err != nil {
return err
}
return nil
}
func (this *DB) ListKeys(offset int, size int) (keys []*models.NSKey, err error) {
if this.db == nil {
return nil, errors.New("db should not be nil")
}
rows, err := this.listKeysStmt.Query(size, offset)
if err != nil {
return nil, err
}
if rows.Err() != nil {
return nil, rows.Err()
}
defer func() {
_ = rows.Close()
}()
for rows.Next() {
var key = &models.NSKey{}
err = rows.Scan(&key.Id, &key.DomainId, &key.ZoneId, &key.Algo, &key.Secret, &key.SecretType, &key.Version)
if err != nil {
return nil, err
}
keys = append(keys, key)
}
return
}
func (this *DB) DeleteKey(keyId int64) error {
if this.db == nil {
return errors.New("db should not be nil")
}
this.log("DeleteKey", "key:", keyId)
_, err := this.deleteKeyStmt.Exec(keyId)
if err != nil {
return err
}
return nil
}
func (this *DB) ExistsKey(keyId int64) (bool, error) {
if this.db == nil {
return false, errors.New("db should not be nil")
}
rows, err := this.existsKeyStmt.Query(keyId)
if err != nil {
return false, err
}
if rows.Err() != nil {
return false, rows.Err()
}
defer func() {
_ = rows.Close()
}()
if rows.Next() {
return true, nil
}
return false, nil
}
func (this *DB) InsertAgentIP(ipId int64, ip string, agentCode string) error {
if this.db == nil {
return errors.New("db should not be nil")
}
this.log("InsertAgentIP", "id:", ipId, "ip:", ip, "agent:", agentCode)
_, err := this.insertAgentIPStmt.Exec(ipId, ip, agentCode)
if err != nil {
return err
}
return nil
}
func (this *DB) ListAgentIPs(offset int64, size int64) (agentIPs []*models.AgentIP, err error) {
if this.db == nil {
return nil, errors.New("db should not be nil")
}
rows, err := this.listAgentIPsStmt.Query(size, offset)
if err != nil {
return nil, err
}
if rows.Err() != nil {
return nil, rows.Err()
}
defer func() {
_ = rows.Close()
}()
for rows.Next() {
var agentIP = &models.AgentIP{}
err = rows.Scan(&agentIP.Id, &agentIP.IP, &agentIP.AgentCode)
if err != nil {
return nil, err
}
agentIPs = append(agentIPs, agentIP)
}
return
}
func (this *DB) Close() error {
if this.db == nil {
return nil
}
for _, stmt := range []*dbutils.Stmt{
this.insertDomainStmt,
this.updateDomainStmt,
this.deleteDomainStmt,
this.existsDomainStmt,
this.listDomainsStmt,
this.insertRecordStmt,
this.updateRecordStmt,
this.existsRecordStmt,
this.deleteRecordStmt,
this.listRecordsStmt,
this.insertRouteStmt,
this.updateRouteStmt,
this.deleteRouteStmt,
this.listRoutesStmt,
this.existsRouteStmt,
this.insertKeyStmt,
this.updateKeyStmt,
this.deleteKeyStmt,
this.listKeysStmt,
this.existsKeyStmt,
this.insertAgentIPStmt,
this.listAgentIPsStmt,
} {
if stmt != nil {
_ = stmt.Close()
}
}
err := this.db.Close()
if err != nil {
return err
}
return nil
}
// 打印日志
func (this *DB) log(args ...any) {
if !Tea.IsTesting() {
return
}
if len(args) == 0 {
return
}
args[0] = "[" + types.String(args[0]) + "]"
log.Println(args...)
}

View File

@@ -0,0 +1,228 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus
package dbs
import (
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/iwind/TeaGo/Tea"
_ "github.com/iwind/TeaGo/bootstrap"
"testing"
)
func TestDB_Init(t *testing.T) {
db := NewDB(Tea.Root + "/data/data.db")
err := db.Init()
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestDB_InsertDomain(t *testing.T) {
db := NewDB(Tea.Root + "/data/data.db")
err := db.Init()
if err != nil {
t.Fatal(err)
}
err = db.InsertDomain(1, 1, 1, "examples.com", nil, 1)
if err != nil {
t.Fatal(err)
}
err = db.InsertDomain(2, 2, 1, "examples2.com", nil, 1)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestDB_UpdateDomain(t *testing.T) {
db := NewDB(Tea.Root + "/data/data.db")
err := db.Init()
if err != nil {
t.Fatal(err)
}
err = db.UpdateDomain(1, 1, 1, "examples2.com", nil, 1)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestDB_DeleteDomain(t *testing.T) {
db := NewDB(Tea.Root + "/data/data.db")
err := db.Init()
if err != nil {
t.Fatal(err)
}
err = db.DeleteDomain(1)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestDB_ExistsDomain(t *testing.T) {
db := NewDB(Tea.Root + "/data/data.db")
err := db.Init()
if err != nil {
t.Fatal(err)
}
b, err := db.ExistsDomain(1)
if err != nil {
t.Fatal(err)
}
t.Log("exists:", b)
}
func TestDB_FindAllDomains(t *testing.T) {
db := NewDB(Tea.Root + "/data/data.db")
err := db.Init()
if err != nil {
t.Fatal(err)
}
limit := 2
offset := 0
for {
t.Log("===", offset, "===")
domains, err := db.ListDomains(2, offset, limit)
if err != nil {
t.Fatal(err)
}
if len(domains) == 0 {
break
}
for _, domain := range domains {
t.Log(domain.Name)
}
offset += limit
}
}
func TestDB_InsertRecord(t *testing.T) {
db := NewDB(Tea.Root + "/data/data.db")
err := db.Init()
if err != nil {
t.Fatal(err)
}
err = db.InsertRecord(1, 1, "a", dnsconfigs.RecordTypeA, "192.168.1.100", 0, 3600, 10, 8080, 1, "", 3600, 10, []string{"id:100", "id:1"}, 1)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestDB_UpdateRecord(t *testing.T) {
db := NewDB(Tea.Root + "/data/data.db")
err := db.Init()
if err != nil {
t.Fatal(err)
}
err = db.UpdateRecord(1, 1, "a1", dnsconfigs.RecordTypeA, "192.168.1.101", 0, 3600, 10, 8080, 1, "", 3600, 10, []string{"id:100", "id:1"}, 1)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestDB_InsertRoute(t *testing.T) {
db := NewDB(Tea.Root + "/data/data.db")
err := db.Init()
if err != nil {
t.Fatal(err)
}
err = db.InsertRoute(1, 1, []byte("[]"), 1, 0, 1)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestDB_UpdateRoute(t *testing.T) {
db := NewDB(Tea.Root + "/data/data.db")
err := db.Init()
if err != nil {
t.Fatal(err)
}
err = db.UpdateRoute(1, 1, []byte("[{}, {}]"), 2, 0, 1)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestDB_InsertKey(t *testing.T) {
db := NewDB(Tea.Root + "/data/data.db")
err := db.Init()
if err != nil {
t.Fatal(err)
}
err = db.InsertKey(1, 2, 3, "md5", "secret123", dnsconfigs.NSKeySecretTypeClear, 4)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestDB_UpdateKey(t *testing.T) {
db := NewDB(Tea.Root + "/data/data.db")
err := db.Init()
if err != nil {
t.Fatal(err)
}
err = db.UpdateKey(1, 22, 33, "sha1", "secret456", dnsconfigs.NSKeySecretTypeBase64, 5)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestDB_DeleteKey(t *testing.T) {
db := NewDB(Tea.Root + "/data/data.db")
err := db.Init()
if err != nil {
t.Fatal(err)
}
err = db.DeleteKey(1)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestDB_ExistsKey(t *testing.T) {
db := NewDB(Tea.Root + "/data/data.db")
err := db.Init()
if err != nil {
t.Fatal(err)
}
for _, id := range []int64{1, 2, 3} {
b, err := db.ExistsKey(id)
if err != nil {
t.Fatal(err)
}
t.Log(id, b)
}
}
func TestDB_ListKeys(t *testing.T) {
db := NewDB(Tea.Root + "/data/data.db")
err := db.Init()
if err != nil {
t.Fatal(err)
}
keys, err := db.ListKeys(0, 10)
if err != nil {
t.Fatal(err)
}
for _, key := range keys {
t.Log(key)
}
}

View File

@@ -0,0 +1,41 @@
package encrypt
import (
"github.com/iwind/TeaGo/logs"
)
const (
MagicKey = "f1c8eafb543f03023e97b7be864a4e9b"
)
// 加密特殊信息
func MagicKeyEncode(data []byte) []byte {
method, err := NewMethodInstance("aes-256-cfb", MagicKey, MagicKey[:16])
if err != nil {
logs.Println("[MagicKeyEncode]" + err.Error())
return data
}
dst, err := method.Encrypt(data)
if err != nil {
logs.Println("[MagicKeyEncode]" + err.Error())
return data
}
return dst
}
// 解密特殊信息
func MagicKeyDecode(data []byte) []byte {
method, err := NewMethodInstance("aes-256-cfb", MagicKey, MagicKey[:16])
if err != nil {
logs.Println("[MagicKeyEncode]" + err.Error())
return data
}
src, err := method.Decrypt(data)
if err != nil {
logs.Println("[MagicKeyEncode]" + err.Error())
return data
}
return src
}

View File

@@ -0,0 +1,11 @@
package encrypt
import "testing"
func TestMagicKeyEncode(t *testing.T) {
dst := MagicKeyEncode([]byte("Hello,World"))
t.Log("dst:", string(dst))
src := MagicKeyDecode(dst)
t.Log("src:", string(src))
}

View File

@@ -0,0 +1,12 @@
package encrypt
type MethodInterface interface {
// 初始化
Init(key []byte, iv []byte) error
// 加密
Encrypt(src []byte) (dst []byte, err error)
// 解密
Decrypt(dst []byte) (src []byte, err error)
}

View File

@@ -0,0 +1,73 @@
package encrypt
import (
"bytes"
"crypto/aes"
"crypto/cipher"
)
type AES128CFBMethod struct {
iv []byte
block cipher.Block
}
func (this *AES128CFBMethod) Init(key, iv []byte) error {
// 判断key是否为32长度
l := len(key)
if l > 16 {
key = key[:16]
} else if l < 16 {
key = append(key, bytes.Repeat([]byte{' '}, 16-l)...)
}
// 判断iv长度
l2 := len(iv)
if l2 > aes.BlockSize {
iv = iv[:aes.BlockSize]
} else if l2 < aes.BlockSize {
iv = append(iv, bytes.Repeat([]byte{' '}, aes.BlockSize-l2)...)
}
this.iv = iv
// block
block, err := aes.NewCipher(key)
if err != nil {
return err
}
this.block = block
return nil
}
func (this *AES128CFBMethod) Encrypt(src []byte) (dst []byte, err error) {
if len(src) == 0 {
return
}
defer func() {
err = RecoverMethodPanic(recover())
}()
dst = make([]byte, len(src))
encrypter := cipher.NewCFBEncrypter(this.block, this.iv)
encrypter.XORKeyStream(dst, src)
return
}
func (this *AES128CFBMethod) Decrypt(dst []byte) (src []byte, err error) {
if len(dst) == 0 {
return
}
defer func() {
err = RecoverMethodPanic(recover())
}()
src = make([]byte, len(dst))
encrypter := cipher.NewCFBDecrypter(this.block, this.iv)
encrypter.XORKeyStream(src, dst)
return
}

View File

@@ -0,0 +1,90 @@
package encrypt
import (
"runtime"
"strings"
"testing"
)
func TestAES128CFBMethod_Encrypt(t *testing.T) {
method, err := NewMethodInstance("aes-128-cfb", "abc", "123")
if err != nil {
t.Fatal(err)
}
src := []byte("Hello, World")
dst, err := method.Encrypt(src)
if err != nil {
t.Fatal(err)
}
dst = dst[:len(src)]
t.Log("dst:", string(dst))
src, err = method.Decrypt(dst)
if err != nil {
t.Fatal(err)
}
t.Log("src:", string(src))
}
func TestAES128CFBMethod_Encrypt2(t *testing.T) {
method, err := NewMethodInstance("aes-128-cfb", "abc", "123")
if err != nil {
t.Fatal(err)
}
sources := [][]byte{}
{
a := []byte{1}
_, err = method.Encrypt(a)
if err != nil {
t.Fatal(err)
}
}
for i := 0; i < 10; i++ {
src := []byte(strings.Repeat("Hello", 1))
dst, err := method.Encrypt(src)
if err != nil {
t.Fatal(err)
}
sources = append(sources, dst)
}
{
a := []byte{1}
_, err = method.Decrypt(a)
if err != nil {
t.Fatal(err)
}
}
for _, dst := range sources {
dst2 := append([]byte{}, dst...)
src2, err := method.Decrypt(dst2)
if err != nil {
t.Fatal(err)
}
t.Log(string(src2))
}
}
func BenchmarkAES128CFBMethod_Encrypt(b *testing.B) {
runtime.GOMAXPROCS(1)
method, err := NewMethodInstance("aes-128-cfb", "abc", "123")
if err != nil {
b.Fatal(err)
}
src := []byte(strings.Repeat("Hello", 1024))
for i := 0; i < b.N; i++ {
dst, err := method.Encrypt(src)
if err != nil {
b.Fatal(err)
}
_ = dst
}
}

View File

@@ -0,0 +1,74 @@
package encrypt
import (
"bytes"
"crypto/aes"
"crypto/cipher"
)
type AES192CFBMethod struct {
block cipher.Block
iv []byte
}
func (this *AES192CFBMethod) Init(key, iv []byte) error {
// 判断key是否为24长度
l := len(key)
if l > 24 {
key = key[:24]
} else if l < 24 {
key = append(key, bytes.Repeat([]byte{' '}, 24-l)...)
}
block, err := aes.NewCipher(key)
if err != nil {
return err
}
this.block = block
// 判断iv长度
l2 := len(iv)
if l2 > aes.BlockSize {
iv = iv[:aes.BlockSize]
} else if l2 < aes.BlockSize {
iv = append(iv, bytes.Repeat([]byte{' '}, aes.BlockSize-l2)...)
}
this.iv = iv
return nil
}
func (this *AES192CFBMethod) Encrypt(src []byte) (dst []byte, err error) {
if len(src) == 0 {
return
}
defer func() {
err = RecoverMethodPanic(recover())
}()
dst = make([]byte, len(src))
encrypter := cipher.NewCFBEncrypter(this.block, this.iv)
encrypter.XORKeyStream(dst, src)
return
}
func (this *AES192CFBMethod) Decrypt(dst []byte) (src []byte, err error) {
if len(dst) == 0 {
return
}
defer func() {
err = RecoverMethodPanic(recover())
}()
src = make([]byte, len(dst))
decrypter := cipher.NewCFBDecrypter(this.block, this.iv)
decrypter.XORKeyStream(src, dst)
return
}

View File

@@ -0,0 +1,45 @@
package encrypt
import (
"runtime"
"strings"
"testing"
)
func TestAES192CFBMethod_Encrypt(t *testing.T) {
method, err := NewMethodInstance("aes-192-cfb", "abc", "123")
if err != nil {
t.Fatal(err)
}
src := []byte("Hello, World")
dst, err := method.Encrypt(src)
if err != nil {
t.Fatal(err)
}
dst = dst[:len(src)]
t.Log("dst:", string(dst))
src, err = method.Decrypt(dst)
if err != nil {
t.Fatal(err)
}
t.Log("src:", string(src))
}
func BenchmarkAES192CFBMethod_Encrypt(b *testing.B) {
runtime.GOMAXPROCS(1)
method, err := NewMethodInstance("aes-192-cfb", "abc", "123")
if err != nil {
b.Fatal(err)
}
src := []byte(strings.Repeat("Hello", 1024))
for i := 0; i < b.N; i++ {
dst, err := method.Encrypt(src)
if err != nil {
b.Fatal(err)
}
_ = dst
}
}

View File

@@ -0,0 +1,72 @@
package encrypt
import (
"bytes"
"crypto/aes"
"crypto/cipher"
)
type AES256CFBMethod struct {
block cipher.Block
iv []byte
}
func (this *AES256CFBMethod) Init(key, iv []byte) error {
// 判断key是否为32长度
l := len(key)
if l > 32 {
key = key[:32]
} else if l < 32 {
key = append(key, bytes.Repeat([]byte{' '}, 32-l)...)
}
block, err := aes.NewCipher(key)
if err != nil {
return err
}
this.block = block
// 判断iv长度
l2 := len(iv)
if l2 > aes.BlockSize {
iv = iv[:aes.BlockSize]
} else if l2 < aes.BlockSize {
iv = append(iv, bytes.Repeat([]byte{' '}, aes.BlockSize-l2)...)
}
this.iv = iv
return nil
}
func (this *AES256CFBMethod) Encrypt(src []byte) (dst []byte, err error) {
if len(src) == 0 {
return
}
defer func() {
err = RecoverMethodPanic(recover())
}()
dst = make([]byte, len(src))
encrypter := cipher.NewCFBEncrypter(this.block, this.iv)
encrypter.XORKeyStream(dst, src)
return
}
func (this *AES256CFBMethod) Decrypt(dst []byte) (src []byte, err error) {
if len(dst) == 0 {
return
}
defer func() {
err = RecoverMethodPanic(recover())
}()
src = make([]byte, len(dst))
decrypter := cipher.NewCFBDecrypter(this.block, this.iv)
decrypter.XORKeyStream(src, dst)
return
}

View File

@@ -0,0 +1,42 @@
package encrypt
import "testing"
func TestAES256CFBMethod_Encrypt(t *testing.T) {
method, err := NewMethodInstance("aes-256-cfb", "abc", "123")
if err != nil {
t.Fatal(err)
}
src := []byte("Hello, World")
dst, err := method.Encrypt(src)
if err != nil {
t.Fatal(err)
}
dst = dst[:len(src)]
t.Log("dst:", string(dst))
src, err = method.Decrypt(dst)
if err != nil {
t.Fatal(err)
}
t.Log("src:", string(src))
}
func TestAES256CFBMethod_Encrypt2(t *testing.T) {
method, err := NewMethodInstance("aes-256-cfb", "abc", "123")
if err != nil {
t.Fatal(err)
}
src := []byte("Hello, World")
dst, err := method.Encrypt(src)
if err != nil {
t.Fatal(err)
}
t.Log("dst:", string(dst))
src, err = method.Decrypt(dst)
if err != nil {
t.Fatal(err)
}
t.Log("src:", string(src))
}

View File

@@ -0,0 +1,26 @@
package encrypt
type RawMethod struct {
}
func (this *RawMethod) Init(key, iv []byte) error {
return nil
}
func (this *RawMethod) Encrypt(src []byte) (dst []byte, err error) {
if len(src) == 0 {
return
}
dst = make([]byte, len(src))
copy(dst, src)
return
}
func (this *RawMethod) Decrypt(dst []byte) (src []byte, err error) {
if len(dst) == 0 {
return
}
src = make([]byte, len(dst))
copy(src, dst)
return
}

View File

@@ -0,0 +1,23 @@
package encrypt
import "testing"
func TestRawMethod_Encrypt(t *testing.T) {
method, err := NewMethodInstance("raw", "abc", "123")
if err != nil {
t.Fatal(err)
}
src := []byte("Hello, World")
dst, err := method.Encrypt(src)
if err != nil {
t.Fatal(err)
}
dst = dst[:len(src)]
t.Log("dst:", string(dst))
src, err = method.Decrypt(dst)
if err != nil {
t.Fatal(err)
}
t.Log("src:", string(src))
}

View File

@@ -0,0 +1,43 @@
package encrypt
import (
"errors"
"reflect"
)
var methods = map[string]reflect.Type{
"raw": reflect.TypeOf(new(RawMethod)).Elem(),
"aes-128-cfb": reflect.TypeOf(new(AES128CFBMethod)).Elem(),
"aes-192-cfb": reflect.TypeOf(new(AES192CFBMethod)).Elem(),
"aes-256-cfb": reflect.TypeOf(new(AES256CFBMethod)).Elem(),
}
func NewMethodInstance(method string, key string, iv string) (MethodInterface, error) {
valueType, ok := methods[method]
if !ok {
return nil, errors.New("method '" + method + "' not found")
}
instance, ok := reflect.New(valueType).Interface().(MethodInterface)
if !ok {
return nil, errors.New("method '" + method + "' must implement MethodInterface")
}
err := instance.Init([]byte(key), []byte(iv))
return instance, err
}
func RecoverMethodPanic(err interface{}) error {
if err != nil {
s, ok := err.(string)
if ok {
return errors.New(s)
}
e, ok := err.(error)
if ok {
return e
}
return errors.New("unknown error")
}
return nil
}

View File

@@ -0,0 +1,8 @@
package encrypt
import "testing"
func TestFindMethodInstance(t *testing.T) {
t.Log(NewMethodInstance("a", "b", ""))
t.Log(NewMethodInstance("aes-256-cfb", "123456", ""))
}

View File

@@ -0,0 +1,12 @@
package events
type Event = string
const (
EventStart Event = "start" // start loading
EventLoaded Event = "load" // loaded
EventQuit Event = "quit" // quit node gracefully
EventTerminated Event = "terminated" // process terminated
EventReload Event = "reload" // reload config
EventNFTablesReady Event = "nftablesReady" // nftables ready
)

View File

@@ -0,0 +1,27 @@
package events
import "sync"
var eventsMap = map[string][]func(){} // event => []callbacks
var locker = sync.Mutex{}
// On 增加事件回调
func On(event string, callback func()) {
locker.Lock()
defer locker.Unlock()
var callbacks = eventsMap[event]
callbacks = append(callbacks, callback)
eventsMap[event] = callbacks
}
// Notify 通知事件
func Notify(event string) {
locker.Lock()
var callbacks = eventsMap[event]
locker.Unlock()
for _, callback := range callbacks {
callback()
}
}

View File

@@ -0,0 +1,16 @@
package events
import "testing"
func TestOn(t *testing.T) {
On("hello", func() {
t.Log("world")
})
On("hello", func() {
t.Log("world2")
})
On("hello2", func() {
t.Log("world2")
})
Notify("hello")
}

View File

@@ -0,0 +1,570 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
package firewalls
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
"github.com/TeaOSLab/EdgeDNS/internal/configs"
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
"github.com/TeaOSLab/EdgeDNS/internal/events"
"github.com/TeaOSLab/EdgeDNS/internal/firewalls/nftables"
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
"github.com/TeaOSLab/EdgeDNS/internal/utils"
executils "github.com/TeaOSLab/EdgeDNS/internal/utils/exec"
"github.com/TeaOSLab/EdgeDNS/internal/utils/zero"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/types"
stringutil "github.com/iwind/TeaGo/utils/string"
"net"
"os/exec"
"strings"
)
var SharedDDoSProtectionManager = NewDDoSProtectionManager()
func init() {
if !teaconst.IsMain {
return
}
events.On(events.EventReload, func() {
if nftablesInstance == nil {
return
}
var nodeConfig = configs.SharedNodeConfig
if nodeConfig != nil {
err := SharedDDoSProtectionManager.Apply(nodeConfig.DDoSProtection)
if err != nil {
remotelogs.Error("FIREWALL", "apply DDoS protection failed: "+err.Error())
}
}
})
events.On(events.EventNFTablesReady, func() {
var nodeConfig = configs.SharedNodeConfig
if nodeConfig != nil {
err := SharedDDoSProtectionManager.Apply(nodeConfig.DDoSProtection)
if err != nil {
remotelogs.Error("FIREWALL", "apply DDoS protection failed: "+err.Error())
}
}
})
}
// DDoSProtectionManager DDoS防护
type DDoSProtectionManager struct {
nftPath string
lastAllowIPList []string
lastConfig []byte
}
// NewDDoSProtectionManager 获取新对象
func NewDDoSProtectionManager() *DDoSProtectionManager {
nftPath, _ := executils.LookPath("nft")
return &DDoSProtectionManager{
nftPath: nftPath,
}
}
// Apply 应用配置
func (this *DDoSProtectionManager) Apply(config *ddosconfigs.ProtectionConfig) error {
// 同集群节点IP白名单
var allowIPListChanged = false
var nodeConfig = configs.SharedNodeConfig
if nodeConfig != nil {
var allowIPList = nodeConfig.AllowedIPs
if !utils.EqualStrings(allowIPList, this.lastAllowIPList) {
allowIPListChanged = true
this.lastAllowIPList = allowIPList
}
}
// 对比配置
configJSON, err := json.Marshal(config)
if err != nil {
return fmt.Errorf("encode config to json failed: %w", err)
}
if config == nil {
configJSON = nil
}
if !allowIPListChanged && bytes.Equal(this.lastConfig, configJSON) {
return nil
}
remotelogs.Println("FIREWALL", "change DDoS protection config")
if len(this.nftPath) == 0 {
return errors.New("can not find nft command")
}
if nftablesInstance == nil {
return errors.New("nftables instance should not be nil")
}
if config == nil {
// TCP
err := this.removeTCPRules()
if err != nil {
return err
}
// TODO other protocols
return nil
}
// TCP
if config.TCP == nil {
err := this.removeTCPRules()
if err != nil {
return err
}
} else {
// allow ip list
var allowIPList = []string{}
for _, ipConfig := range config.TCP.AllowIPList {
allowIPList = append(allowIPList, ipConfig.IP)
}
for _, ip := range this.lastAllowIPList {
if !lists.ContainsString(allowIPList, ip) {
allowIPList = append(allowIPList, ip)
}
}
err = this.updateAllowIPList(allowIPList)
if err != nil {
return err
}
// tcp
if config.TCP.IsOn {
err := this.addTCPRules(config.TCP)
if err != nil {
return err
}
} else {
err := this.removeTCPRules()
if err != nil {
return err
}
}
}
this.lastConfig = configJSON
return nil
}
// 添加TCP规则
func (this *DDoSProtectionManager) addTCPRules(tcpConfig *ddosconfigs.TCPConfig) error {
// 检查nft版本不能小于0.9
if len(nftablesInstance.version) > 0 && stringutil.VersionCompare("0.9", nftablesInstance.version) > 0 {
return nil
}
var ports = []int32{}
for _, portConfig := range tcpConfig.Ports {
if !lists.ContainsInt32(ports, portConfig.Port) {
ports = append(ports, portConfig.Port)
}
}
if len(ports) == 0 {
ports = []int32{53}
}
for _, filter := range nftablesFilters {
chain, oldRules, err := this.getRules(filter)
if err != nil {
return fmt.Errorf("get old rules failed: %w", err)
}
var protocol = filter.protocol()
// max connections
var maxConnections = tcpConfig.MaxConnections
if maxConnections <= 0 {
maxConnections = dnsconfigs.DefaultTCPMaxConnections
if maxConnections <= 0 {
maxConnections = 100000
}
}
// max connections per ip
var maxConnectionsPerIP = tcpConfig.MaxConnectionsPerIP
if maxConnectionsPerIP <= 0 {
maxConnectionsPerIP = dnsconfigs.DefaultTCPMaxConnectionsPerIP
if maxConnectionsPerIP <= 0 {
maxConnectionsPerIP = 100000
}
}
// new connections rate (minutely)
var newConnectionsMinutelyRate = tcpConfig.NewConnectionsMinutelyRate
if newConnectionsMinutelyRate <= 0 {
newConnectionsMinutelyRate = nodeconfigs.DefaultTCPNewConnectionsMinutelyRate
if newConnectionsMinutelyRate <= 0 {
newConnectionsMinutelyRate = 100000
}
}
var newConnectionsMinutelyRateBlockTimeout = tcpConfig.NewConnectionsMinutelyRateBlockTimeout
if newConnectionsMinutelyRateBlockTimeout < 0 {
newConnectionsMinutelyRateBlockTimeout = 0
}
// new connections rate (secondly)
var newConnectionsSecondlyRate = tcpConfig.NewConnectionsSecondlyRate
if newConnectionsSecondlyRate <= 0 {
newConnectionsSecondlyRate = nodeconfigs.DefaultTCPNewConnectionsSecondlyRate
if newConnectionsSecondlyRate <= 0 {
newConnectionsSecondlyRate = 10000
}
}
var newConnectionsSecondlyRateBlockTimeout = tcpConfig.NewConnectionsSecondlyRateBlockTimeout
if newConnectionsSecondlyRateBlockTimeout < 0 {
newConnectionsSecondlyRateBlockTimeout = 0
}
// 检查是否有变化
var hasChanges = false
for _, port := range ports {
if !this.existsRule(oldRules, []string{"tcp", types.String(port), "maxConnections", types.String(maxConnections)}) {
hasChanges = true
break
}
if !this.existsRule(oldRules, []string{"tcp", types.String(port), "maxConnectionsPerIP", types.String(maxConnectionsPerIP)}) {
hasChanges = true
break
}
if !this.existsRule(oldRules, []string{"tcp", types.String(port), "newConnectionsRate", types.String(newConnectionsMinutelyRate), types.String(newConnectionsMinutelyRateBlockTimeout)}) {
hasChanges = true
break
}
if !this.existsRule(oldRules, []string{"tcp", types.String(port), "newConnectionsSecondlyRate", types.String(newConnectionsSecondlyRate), types.String(newConnectionsSecondlyRateBlockTimeout)}) {
hasChanges = true
break
}
}
if !hasChanges {
// 检查是否有多余的端口
var oldPorts = this.getTCPPorts(oldRules)
if !this.eqPorts(ports, oldPorts) {
hasChanges = true
}
}
if !hasChanges {
return nil
}
// 先清空所有相关规则
err = this.removeOldTCPRules(chain, oldRules)
if err != nil {
return fmt.Errorf("delete old rules failed: %w", err)
}
// 添加新规则
for _, port := range ports {
// TODO 让用户选择是drop还是reject
if maxConnections > 0 {
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "count", "over", types.String(maxConnections), "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "maxConnections", types.String(maxConnections)}))
var stderr = &bytes.Buffer{}
cmd.Stderr = stderr
err := cmd.Run()
if err != nil {
return fmt.Errorf("add nftables rule '%s' failed: %w (%s)", cmd.String(), err, stderr.String())
}
}
// TODO 让用户选择是drop还是reject
if maxConnectionsPerIP > 0 {
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "meter", "meter-"+protocol+"-"+types.String(port)+"-max-connections", "{ "+protocol+" saddr ct count over "+types.String(maxConnectionsPerIP)+" }", "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "maxConnectionsPerIP", types.String(maxConnectionsPerIP)}))
var stderr = &bytes.Buffer{}
cmd.Stderr = stderr
err := cmd.Run()
if err != nil {
return fmt.Errorf("add nftables rule '%s' failed: %w (%s)", cmd.String(), err, stderr.String())
}
}
// 超过一定速率就drop或者加入黑名单分钟
// TODO 让用户选择是drop还是reject
if newConnectionsMinutelyRate > 0 {
if newConnectionsMinutelyRateBlockTimeout > 0 {
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsMinutelyRate)+"/minute burst "+types.String(newConnectionsMinutelyRate+3)+" packets }", "add", "@deny_set", "{"+protocol+" saddr timeout "+types.String(newConnectionsMinutelyRateBlockTimeout)+"s}", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsRate", types.String(newConnectionsMinutelyRate), types.String(newConnectionsMinutelyRateBlockTimeout)}))
var stderr = &bytes.Buffer{}
cmd.Stderr = stderr
err := cmd.Run()
if err != nil {
return fmt.Errorf("add nftables rule '%s' failed: %w (%s)", cmd.String(), err, stderr.String())
}
} else {
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsMinutelyRate)+"/minute burst "+types.String(newConnectionsMinutelyRate+3)+" packets }" /**"add", "@deny_set", "{"+protocol+" saddr}",**/, "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsRate", "0"}))
var stderr = &bytes.Buffer{}
cmd.Stderr = stderr
err := cmd.Run()
if err != nil {
return fmt.Errorf("add nftables rule '%s' failed: %w (%s)", cmd.String(), err, stderr.String())
}
}
}
// 超过一定速率就drop或者加入黑名单
// TODO 让用户选择是drop还是reject
if newConnectionsSecondlyRate > 0 {
if newConnectionsSecondlyRateBlockTimeout > 0 {
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-secondly-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsSecondlyRate)+"/second burst "+types.String(newConnectionsSecondlyRate+3)+" packets }", "add", "@deny_set", "{"+protocol+" saddr timeout "+types.String(newConnectionsSecondlyRateBlockTimeout)+"s}", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsSecondlyRate", types.String(newConnectionsSecondlyRate), types.String(newConnectionsSecondlyRateBlockTimeout)}))
var stderr = &bytes.Buffer{}
cmd.Stderr = stderr
err := cmd.Run()
if err != nil {
return fmt.Errorf("add nftables rule '%s' failed: %w (%s)", cmd.String(), err, stderr.String())
}
} else {
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-secondly-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsSecondlyRate)+"/second burst "+types.String(newConnectionsSecondlyRate+3)+" packets }" /**"add", "@deny_set", "{"+protocol+" saddr}",**/, "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsSecondlyRate", "0"}))
var stderr = &bytes.Buffer{}
cmd.Stderr = stderr
err := cmd.Run()
if err != nil {
return fmt.Errorf("add nftables rule '%s' failed: %w (%s)", cmd.String(), err, stderr.String())
}
}
}
}
}
return nil
}
// 删除TCP规则
func (this *DDoSProtectionManager) removeTCPRules() error {
for _, filter := range nftablesFilters {
chain, rules, err := this.getRules(filter)
// TCP
err = this.removeOldTCPRules(chain, rules)
if err != nil {
return err
}
}
return nil
}
// 组合user data
// 数据中不能包含字母、数字、下划线以外的数据
func (this *DDoSProtectionManager) encodeUserData(attrs []string) string {
if attrs == nil {
return ""
}
return "ZZ" + strings.Join(attrs, "_") + "ZZ"
}
// 解码user data
func (this *DDoSProtectionManager) decodeUserData(data []byte) []string {
if len(data) == 0 {
return nil
}
var dataCopy = make([]byte, len(data))
copy(dataCopy, data)
var separatorLen = 2
var index1 = bytes.Index(dataCopy, []byte{'Z', 'Z'})
if index1 < 0 {
return nil
}
dataCopy = dataCopy[index1+separatorLen:]
var index2 = bytes.LastIndex(dataCopy, []byte{'Z', 'Z'})
if index2 < 0 {
return nil
}
var s = string(dataCopy[:index2])
var pieces = strings.Split(s, "_")
for index, piece := range pieces {
pieces[index] = strings.TrimSpace(piece)
}
return pieces
}
// 清除规则
func (this *DDoSProtectionManager) removeOldTCPRules(chain *nftables.Chain, rules []*nftables.Rule) error {
for _, rule := range rules {
var pieces = this.decodeUserData(rule.UserData())
if len(pieces) < 4 {
continue
}
if pieces[0] != "tcp" {
continue
}
switch pieces[2] {
case "maxConnections", "maxConnectionsPerIP", "newConnectionsRate", "newConnectionsSecondlyRate":
err := chain.DeleteRule(rule)
if err != nil {
return err
}
}
}
return nil
}
// 根据参数检查规则是否存在
func (this *DDoSProtectionManager) existsRule(rules []*nftables.Rule, attrs []string) (exists bool) {
if len(attrs) == 0 {
return false
}
for _, oldRule := range rules {
var pieces = this.decodeUserData(oldRule.UserData())
if len(attrs) != len(pieces) {
continue
}
var isSame = true
for index, piece := range pieces {
if strings.TrimSpace(piece) != attrs[index] {
isSame = false
break
}
}
if isSame {
return true
}
}
return false
}
// 获取规则中的端口号
func (this *DDoSProtectionManager) getTCPPorts(rules []*nftables.Rule) []int32 {
var ports = []int32{}
for _, rule := range rules {
var pieces = this.decodeUserData(rule.UserData())
if len(pieces) != 4 {
continue
}
if pieces[0] != "tcp" {
continue
}
var port = types.Int32(pieces[1])
if port > 0 && !lists.ContainsInt32(ports, port) {
ports = append(ports, port)
}
}
return ports
}
// 检查端口是否一样
func (this *DDoSProtectionManager) eqPorts(ports1 []int32, ports2 []int32) bool {
if len(ports1) != len(ports2) {
return false
}
var portMap = map[int32]bool{}
for _, port := range ports2 {
portMap[port] = true
}
for _, port := range ports1 {
_, ok := portMap[port]
if !ok {
return false
}
}
return true
}
// 查找Table
func (this *DDoSProtectionManager) getTable(filter *nftablesTableDefinition) (*nftables.Table, error) {
var family nftables.TableFamily
if filter.IsIPv4 {
family = nftables.TableFamilyIPv4
} else if filter.IsIPv6 {
family = nftables.TableFamilyIPv6
} else {
return nil, errors.New("table '" + filter.Name + "' should be IPv4 or IPv6")
}
return nftablesInstance.conn.GetTable(filter.Name, family)
}
// 查找所有规则
func (this *DDoSProtectionManager) getRules(filter *nftablesTableDefinition) (*nftables.Chain, []*nftables.Rule, error) {
table, err := this.getTable(filter)
if err != nil {
return nil, nil, fmt.Errorf("get table failed: %w", err)
}
chain, err := table.GetChain(nftablesChainName)
if err != nil {
return nil, nil, fmt.Errorf("get chain failed: %w", err)
}
rules, err := chain.GetRules()
return chain, rules, err
}
// 更新白名单
func (this *DDoSProtectionManager) updateAllowIPList(allIPList []string) error {
if nftablesInstance == nil {
return nil
}
var allMap = map[string]zero.Zero{}
for _, ip := range allIPList {
allMap[ip] = zero.New()
}
for _, set := range []*nftables.Set{nftablesInstance.allowIPv4Set, nftablesInstance.allowIPv6Set} {
var isIPv4 = set == nftablesInstance.allowIPv4Set
var isIPv6 = !isIPv4
// 现有的
oldList, err := set.GetIPElements()
if err != nil {
return err
}
var oldMap = map[string]zero.Zero{} // ip=> zero
for _, ip := range oldList {
oldMap[ip] = zero.New()
if (utils.IsIPv4(ip) && isIPv4) || (utils.IsIPv6(ip) && isIPv6) {
_, ok := allMap[ip]
if !ok {
// 不存在则删除
err = set.DeleteIPElement(ip)
if err != nil {
return fmt.Errorf("delete ip element '%s' failed: %w", ip, err)
}
}
}
}
// 新增的
for _, ip := range allIPList {
var ipObj = net.ParseIP(ip)
if ipObj == nil {
continue
}
if (utils.IsIPv4(ip) && isIPv4) || (utils.IsIPv6(ip) && isIPv6) {
_, ok := oldMap[ip]
if !ok {
// 不存在则添加
err = set.AddIPElement(ip, nil)
if err != nil {
return fmt.Errorf("add ip '%s' failed: %w", ip, err)
}
}
}
}
}
return nil
}

View File

@@ -0,0 +1,22 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build !linux
// +build !linux
package firewalls
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
)
var SharedDDoSProtectionManager = NewDDoSProtectionManager()
type DDoSProtectionManager struct {
}
func NewDDoSProtectionManager() *DDoSProtectionManager {
return &DDoSProtectionManager{}
}
func (this *DDoSProtectionManager) Apply(config *ddosconfigs.ProtectionConfig) error {
return nil
}

View File

@@ -0,0 +1,66 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package firewalls
import (
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
"github.com/TeaOSLab/EdgeDNS/internal/events"
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
"runtime"
"sync"
)
var currentFirewall FirewallInterface
var firewallLocker = &sync.Mutex{}
// 初始化
func init() {
if !teaconst.IsMain {
return
}
events.On(events.EventLoaded, func() {
var firewall = Firewall()
if firewall.Name() != "mock" {
remotelogs.Println("FIREWALL", "found local firewall '"+firewall.Name()+"'")
}
})
}
// Firewall 查找当前系统中最适合的防火墙
func Firewall() FirewallInterface {
firewallLocker.Lock()
defer firewallLocker.Unlock()
if currentFirewall != nil {
return currentFirewall
}
// nftables
if runtime.GOOS == "linux" {
nftables, err := NewNFTablesFirewall()
if err != nil {
remotelogs.Warn("FIREWALL", "'nftables' should be installed on the system to enhance security (init failed: "+err.Error()+")")
} else {
if nftables.IsReady() {
currentFirewall = nftables
events.Notify(events.EventNFTablesReady)
return nftables
} else {
remotelogs.Warn("FIREWALL", "'nftables' should be enabled on the system to enhance security")
}
}
}
// firewalld
if runtime.GOOS == "linux" {
var firewalld = NewFirewalld()
if firewalld.IsReady() {
currentFirewall = firewalld
return currentFirewall
}
}
// 至少返回一个
currentFirewall = NewMockFirewall()
return currentFirewall
}

View File

@@ -0,0 +1,186 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package firewalls
import (
"github.com/TeaOSLab/EdgeDNS/internal/goman"
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
executils "github.com/TeaOSLab/EdgeDNS/internal/utils/exec"
"github.com/iwind/TeaGo/types"
"os/exec"
"strings"
)
type Firewalld struct {
isReady bool
exe string
cmdQueue chan *exec.Cmd
}
func NewFirewalld() *Firewalld {
var firewalld = &Firewalld{
cmdQueue: make(chan *exec.Cmd, 4096),
}
path, err := executils.LookPath("firewall-cmd")
if err == nil && len(path) > 0 {
var cmd = exec.Command(path, "--state")
err := cmd.Run()
if err == nil {
firewalld.exe = path
// TODO check firewalld status with 'firewall-cmd --state' (running or not running),
// but we should recover the state when firewalld state changes, maybe check it every minutes
firewalld.isReady = true
firewalld.init()
}
}
return firewalld
}
func (this *Firewalld) init() {
goman.New(func() {
for cmd := range this.cmdQueue {
err := cmd.Run()
if err != nil {
if strings.HasPrefix(err.Error(), "Warning:") {
continue
}
remotelogs.Warn("FIREWALL", "run command failed '"+cmd.String()+"': "+err.Error())
}
}
})
}
// Name 名称
func (this *Firewalld) Name() string {
return "firewalld"
}
func (this *Firewalld) IsReady() bool {
return this.isReady
}
// IsMock 是否为模拟
func (this *Firewalld) IsMock() bool {
return false
}
func (this *Firewalld) AllowPort(port int, protocol string) error {
if !this.isReady {
return nil
}
var cmd = exec.Command(this.exe, "--add-port="+types.String(port)+"/"+protocol)
this.pushCmd(cmd)
return nil
}
func (this *Firewalld) AllowPortRangesPermanently(portRanges [][2]int, protocol string) error {
for _, portRange := range portRanges {
var port = this.PortRangeString(portRange, protocol)
{
var cmd = exec.Command(this.exe, "--add-port="+port, "--permanent")
this.pushCmd(cmd)
}
{
var cmd = exec.Command(this.exe, "--add-port="+port)
this.pushCmd(cmd)
}
}
return nil
}
func (this *Firewalld) RemovePort(port int, protocol string) error {
if !this.isReady {
return nil
}
var cmd = exec.Command(this.exe, "--remove-port="+types.String(port)+"/"+protocol)
this.pushCmd(cmd)
return nil
}
func (this *Firewalld) RemovePortRangePermanently(portRange [2]int, protocol string) error {
var port = this.PortRangeString(portRange, protocol)
{
var cmd = exec.Command(this.exe, "--remove-port="+port, "--permanent")
this.pushCmd(cmd)
}
{
var cmd = exec.Command(this.exe, "--remove-port="+port)
this.pushCmd(cmd)
}
return nil
}
func (this *Firewalld) PortRangeString(portRange [2]int, protocol string) string {
if portRange[0] == portRange[1] {
return types.String(portRange[0]) + "/" + protocol
} else {
return types.String(portRange[0]) + "-" + types.String(portRange[1]) + "/" + protocol
}
}
func (this *Firewalld) RejectSourceIP(ip string, timeoutSeconds int) error {
if !this.isReady {
return nil
}
var family = "ipv4"
if strings.Contains(ip, ":") {
family = "ipv6"
}
var args = []string{"--add-rich-rule=rule family='" + family + "' source address='" + ip + "' reject"}
if timeoutSeconds > 0 {
args = append(args, "--timeout="+types.String(timeoutSeconds)+"s")
}
var cmd = exec.Command(this.exe, args...)
this.pushCmd(cmd)
return nil
}
func (this *Firewalld) DropSourceIP(ip string, timeoutSeconds int) error {
if !this.isReady {
return nil
}
var family = "ipv4"
if strings.Contains(ip, ":") {
family = "ipv6"
}
var args = []string{"--add-rich-rule=rule family='" + family + "' source address='" + ip + "' drop"}
if timeoutSeconds > 0 {
args = append(args, "--timeout="+types.String(timeoutSeconds)+"s")
}
var cmd = exec.Command(this.exe, args...)
this.pushCmd(cmd)
return nil
}
func (this *Firewalld) RemoveSourceIP(ip string) error {
if !this.isReady {
return nil
}
var family = "ipv4"
if strings.Contains(ip, ":") {
family = "ipv6"
}
for _, action := range []string{"reject", "drop"} {
var args = []string{"--remove-rich-rule=rule family='" + family + "' source address='" + ip + "' " + action}
var cmd = exec.Command(this.exe, args...)
this.pushCmd(cmd)
}
return nil
}
func (this *Firewalld) pushCmd(cmd *exec.Cmd) {
select {
case this.cmdQueue <- cmd:
default:
// we discard the command
}
}

View File

@@ -0,0 +1,30 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package firewalls
// FirewallInterface 防火墙接口
type FirewallInterface interface {
// Name 名称
Name() string
// IsReady 是否已准备被调用
IsReady() bool
// IsMock 是否为模拟
IsMock() bool
// AllowPort 允许端口
AllowPort(port int, protocol string) error
// RemovePort 删除端口
RemovePort(port int, protocol string) error
// RejectSourceIP 拒绝某个源IP连接
RejectSourceIP(ip string, timeoutSeconds int) error
// DropSourceIP 丢弃某个源IP数据
DropSourceIP(ip string, timeoutSeconds int) error
// RemoveSourceIP 删除某个源IP
RemoveSourceIP(ip string) error
}

View File

@@ -0,0 +1,60 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package firewalls
// MockFirewall 模拟防火墙
type MockFirewall struct {
}
func NewMockFirewall() *MockFirewall {
return &MockFirewall{}
}
// Name 名称
func (this *MockFirewall) Name() string {
return "mock"
}
// IsReady 是否已准备被调用
func (this *MockFirewall) IsReady() bool {
return true
}
// IsMock 是否为模拟
func (this *MockFirewall) IsMock() bool {
return true
}
// AllowPort 允许端口
func (this *MockFirewall) AllowPort(port int, protocol string) error {
_ = port
_ = protocol
return nil
}
// RemovePort 删除端口
func (this *MockFirewall) RemovePort(port int, protocol string) error {
_ = port
_ = protocol
return nil
}
// RejectSourceIP 拒绝某个源IP连接
func (this *MockFirewall) RejectSourceIP(ip string, timeoutSeconds int) error {
_ = ip
_ = timeoutSeconds
return nil
}
// DropSourceIP 丢弃某个源IP数据
func (this *MockFirewall) DropSourceIP(ip string, timeoutSeconds int) error {
_ = ip
_ = timeoutSeconds
return nil
}
// RemoveSourceIP 删除某个源IP
func (this *MockFirewall) RemoveSourceIP(ip string) error {
_ = ip
return nil
}

View File

@@ -0,0 +1,417 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
package firewalls
import (
"bytes"
"errors"
"fmt"
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
"github.com/TeaOSLab/EdgeDNS/internal/events"
"github.com/TeaOSLab/EdgeDNS/internal/firewalls/nftables"
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
executils "github.com/TeaOSLab/EdgeDNS/internal/utils/exec"
"github.com/google/nftables/expr"
"github.com/iwind/TeaGo/types"
"net"
"os/exec"
"regexp"
"runtime"
"strings"
"time"
)
// check nft status, if being enabled we load it automatically
func init() {
if !teaconst.IsMain {
return
}
if teaconst.IsDaemon {
return
}
if runtime.GOOS == "linux" {
var ticker = time.NewTicker(3 * time.Minute)
go func() {
for range ticker.C {
// if already ready, we break
if nftablesIsReady {
ticker.Stop()
break
}
_, err := executils.LookPath("nft")
if err == nil {
nftablesFirewall, err := NewNFTablesFirewall()
if err != nil {
continue
}
currentFirewall = nftablesFirewall
remotelogs.Println("FIREWALL", "nftables is ready")
// fire event
if nftablesFirewall.IsReady() {
events.Notify(events.EventNFTablesReady)
}
ticker.Stop()
break
}
}
}()
}
}
var nftablesInstance *NFTablesFirewall
var nftablesIsReady = false
var nftablesFilters = []*nftablesTableDefinition{
// we shorten the name for table name length restriction
{Name: "edge_dns_v4", IsIPv4: true},
{Name: "edge_dns_v6", IsIPv6: true},
}
var nftablesChainName = "input"
type nftablesTableDefinition struct {
Name string
IsIPv4 bool
IsIPv6 bool
}
func (this *nftablesTableDefinition) protocol() string {
if this.IsIPv6 {
return "ip6"
}
return "ip"
}
func NewNFTablesFirewall() (*NFTablesFirewall, error) {
var firewall = &NFTablesFirewall{
conn: nftables.NewConn(),
}
err := firewall.init()
if err != nil {
return nil, err
}
return firewall, nil
}
type NFTablesFirewall struct {
conn *nftables.Conn
isReady bool
version string
allowIPv4Set *nftables.Set
allowIPv6Set *nftables.Set
denyIPv4Set *nftables.Set
denyIPv6Set *nftables.Set
firewalld *Firewalld
}
func (this *NFTablesFirewall) init() error {
// check nft
nftPath, err := executils.LookPath("nft")
if err != nil {
return errors.New("nft not found")
}
this.version = this.readVersion(nftPath)
// table
for _, tableDef := range nftablesFilters {
var family nftables.TableFamily
if tableDef.IsIPv4 {
family = nftables.TableFamilyIPv4
} else if tableDef.IsIPv6 {
family = nftables.TableFamilyIPv6
} else {
return errors.New("invalid table family: " + types.String(tableDef))
}
table, err := this.conn.GetTable(tableDef.Name, family)
if err != nil {
if nftables.IsNotFound(err) {
if tableDef.IsIPv4 {
table, err = this.conn.AddIPv4Table(tableDef.Name)
} else if tableDef.IsIPv6 {
table, err = this.conn.AddIPv6Table(tableDef.Name)
}
if err != nil {
return fmt.Errorf("create table '%s' failed: %w", tableDef.Name, err)
}
} else {
return fmt.Errorf("get table '%s' failed: %w", tableDef.Name, err)
}
}
if table == nil {
return errors.New("can not create table '" + tableDef.Name + "'")
}
// chain
var chainName = nftablesChainName
chain, err := table.GetChain(chainName)
if err != nil {
if nftables.IsNotFound(err) {
chain, err = table.AddAcceptChain(chainName)
if err != nil {
return fmt.Errorf("create chain '%s' failed: %w", chainName, err)
}
} else {
return fmt.Errorf("get chain '%s' failed: %w", chainName, err)
}
}
if chain == nil {
return errors.New("can not create chain '" + chainName + "'")
}
// allow lo
var loRuleName = []byte("lo")
_, err = chain.GetRuleWithUserData(loRuleName)
if err != nil {
if nftables.IsNotFound(err) {
_, err = chain.AddAcceptInterfaceRule("lo", loRuleName)
}
if err != nil {
return fmt.Errorf("add 'lo' rule failed: %w", err)
}
}
// allow set
// "allow" should be always first
for _, setAction := range []string{"allow", "deny"} {
var setName = setAction + "_set"
set, err := table.GetSet(setName)
if err != nil {
if nftables.IsNotFound(err) {
var keyType nftables.SetDataType
if tableDef.IsIPv4 {
keyType = nftables.TypeIPAddr
} else if tableDef.IsIPv6 {
keyType = nftables.TypeIP6Addr
}
set, err = table.AddSet(setName, &nftables.SetOptions{
KeyType: keyType,
HasTimeout: true,
})
if err != nil {
return fmt.Errorf("create set '%s' failed: %w", setName, err)
}
} else {
return fmt.Errorf("get set '%s' failed: %w", setName, err)
}
}
if set == nil {
return errors.New("can not create set '" + setName + "'")
}
if tableDef.IsIPv4 {
if setAction == "allow" {
this.allowIPv4Set = set
} else {
this.denyIPv4Set = set
}
} else if tableDef.IsIPv6 {
if setAction == "allow" {
this.allowIPv6Set = set
} else {
this.denyIPv6Set = set
}
}
// rule
var ruleName = []byte(setAction)
rule, err := chain.GetRuleWithUserData(ruleName)
// 将以前的drop规则删掉替换成后面的reject
if err == nil && setAction != "allow" && rule != nil && rule.VerDict() == expr.VerdictDrop {
deleteErr := chain.DeleteRule(rule)
if deleteErr == nil {
err = nftables.ErrRuleNotFound
rule = nil
}
}
if err != nil {
if nftables.IsNotFound(err) {
if tableDef.IsIPv4 {
if setAction == "allow" {
rule, err = chain.AddAcceptIPv4SetRule(setName, ruleName)
} else {
rule, err = chain.AddRejectIPv4SetRule(setName, ruleName)
}
} else if tableDef.IsIPv6 {
if setAction == "allow" {
rule, err = chain.AddAcceptIPv6SetRule(setName, ruleName)
} else {
rule, err = chain.AddRejectIPv6SetRule(setName, ruleName)
}
}
if err != nil {
return fmt.Errorf("add rule failed: %w", err)
}
} else {
return fmt.Errorf("get rule failed: %w", err)
}
}
if rule == nil {
return errors.New("can not create rule '" + string(ruleName) + "'")
}
}
}
this.isReady = true
nftablesIsReady = true
nftablesInstance = this
// load firewalld
var firewalld = NewFirewalld()
if firewalld.IsReady() {
this.firewalld = firewalld
}
return nil
}
// Name 名称
func (this *NFTablesFirewall) Name() string {
return "nftables"
}
// IsReady 是否已准备被调用
func (this *NFTablesFirewall) IsReady() bool {
return this.isReady
}
// IsMock 是否为模拟
func (this *NFTablesFirewall) IsMock() bool {
return false
}
// AllowPort 允许端口
func (this *NFTablesFirewall) AllowPort(port int, protocol string) error {
if this.firewalld != nil {
return this.firewalld.AllowPort(port, protocol)
}
return nil
}
// RemovePort 删除端口
func (this *NFTablesFirewall) RemovePort(port int, protocol string) error {
if this.firewalld != nil {
return this.firewalld.RemovePort(port, protocol)
}
return nil
}
// AllowSourceIP Allow把IP加入白名单
func (this *NFTablesFirewall) AllowSourceIP(ip string) error {
var data = net.ParseIP(ip)
if data == nil {
return errors.New("invalid ip '" + ip + "'")
}
if strings.Contains(ip, ":") { // ipv6
if this.allowIPv6Set == nil {
return errors.New("ipv6 ip set is nil")
}
return this.allowIPv6Set.AddElement(data.To16(), nil)
}
// ipv4
if this.allowIPv4Set == nil {
return errors.New("ipv4 ip set is nil")
}
return this.allowIPv4Set.AddElement(data.To4(), nil)
}
// RejectSourceIP 拒绝某个源IP连接
// we did not create set for drop ip, so we reuse DropSourceIP() method here
func (this *NFTablesFirewall) RejectSourceIP(ip string, timeoutSeconds int) error {
return this.DropSourceIP(ip, timeoutSeconds)
}
// DropSourceIP 丢弃某个源IP数据
func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int) error {
var data = net.ParseIP(ip)
if data == nil {
return errors.New("invalid ip '" + ip + "'")
}
if strings.Contains(ip, ":") { // ipv6
if this.denyIPv6Set == nil {
return errors.New("ipv6 ip set is nil")
}
return this.denyIPv6Set.AddElement(data.To16(), &nftables.ElementOptions{
Timeout: time.Duration(timeoutSeconds) * time.Second,
})
}
// ipv4
if this.denyIPv4Set == nil {
return errors.New("ipv4 ip set is nil")
}
return this.denyIPv4Set.AddElement(data.To4(), &nftables.ElementOptions{
Timeout: time.Duration(timeoutSeconds) * time.Second,
})
}
// RemoveSourceIP 删除某个源IP
func (this *NFTablesFirewall) RemoveSourceIP(ip string) error {
var data = net.ParseIP(ip)
if data == nil {
return errors.New("invalid ip '" + ip + "'")
}
if strings.Contains(ip, ":") { // ipv6
if this.denyIPv6Set != nil {
err := this.denyIPv6Set.DeleteElement(data.To16())
if err != nil {
return err
}
}
if this.allowIPv6Set != nil {
err := this.allowIPv6Set.DeleteElement(data.To16())
if err != nil {
return err
}
}
return nil
}
// ipv4
if this.denyIPv4Set != nil {
err := this.denyIPv4Set.DeleteElement(data.To4())
if err != nil {
return err
}
}
if this.allowIPv4Set != nil {
err := this.allowIPv4Set.DeleteElement(data.To4())
if err != nil {
return err
}
}
return nil
}
// 读取版本号
func (this *NFTablesFirewall) readVersion(nftPath string) string {
var cmd = exec.Command(nftPath, "--version")
var output = &bytes.Buffer{}
cmd.Stdout = output
err := cmd.Run()
if err != nil {
return ""
}
var outputString = output.String()
var versionMatches = regexp.MustCompile(`nftables v([\d.]+)`).FindStringSubmatch(outputString)
if len(versionMatches) <= 1 {
return ""
}
return versionMatches[1]
}

View File

@@ -0,0 +1,60 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build !linux
package firewalls
import (
"errors"
)
func NewNFTablesFirewall() (*NFTablesFirewall, error) {
return nil, errors.New("not implemented")
}
type NFTablesFirewall struct {
}
// Name 名称
func (this *NFTablesFirewall) Name() string {
return "nftables"
}
// IsReady 是否已准备被调用
func (this *NFTablesFirewall) IsReady() bool {
return false
}
// IsMock 是否为模拟
func (this *NFTablesFirewall) IsMock() bool {
return true
}
// AllowPort 允许端口
func (this *NFTablesFirewall) AllowPort(port int, protocol string) error {
return nil
}
// RemovePort 删除端口
func (this *NFTablesFirewall) RemovePort(port int, protocol string) error {
return nil
}
// AllowSourceIP Allow把IP加入白名单
func (this *NFTablesFirewall) AllowSourceIP(ip string) error {
return nil
}
// RejectSourceIP 拒绝某个源IP连接
func (this *NFTablesFirewall) RejectSourceIP(ip string, timeoutSeconds int) error {
return nil
}
// DropSourceIP 丢弃某个源IP数据
func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int) error {
return nil
}
// RemoveSourceIP 删除某个源IP
func (this *NFTablesFirewall) RemoveSourceIP(ip string) error {
return nil
}

View File

@@ -0,0 +1 @@
build_remote.sh

View File

@@ -0,0 +1,369 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
package nftables
import (
"bytes"
"errors"
nft "github.com/google/nftables"
"github.com/google/nftables/expr"
)
const MaxChainNameLength = 31
type RuleOptions struct {
Exprs []expr.Any
UserData []byte
}
// Chain chain object in table
type Chain struct {
conn *Conn
rawTable *nft.Table
rawChain *nft.Chain
}
func NewChain(conn *Conn, rawTable *nft.Table, rawChain *nft.Chain) *Chain {
return &Chain{
conn: conn,
rawTable: rawTable,
rawChain: rawChain,
}
}
func (this *Chain) Raw() *nft.Chain {
return this.rawChain
}
func (this *Chain) Name() string {
return this.rawChain.Name
}
func (this *Chain) AddRule(options *RuleOptions) (*Rule, error) {
var rawRule = this.conn.Raw().AddRule(&nft.Rule{
Table: this.rawTable,
Chain: this.rawChain,
Exprs: options.Exprs,
UserData: options.UserData,
})
err := this.conn.Commit()
if err != nil {
return nil, err
}
return NewRule(rawRule), nil
}
func (this *Chain) AddAcceptIPv4Rule(ip []byte, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ip,
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
UserData: userData,
})
}
func (this *Chain) AddAcceptIPv6Rule(ip []byte, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 8,
Len: 16,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ip,
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
UserData: userData,
})
}
func (this *Chain) AddDropIPv4Rule(ip []byte, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ip,
},
&expr.Verdict{
Kind: expr.VerdictDrop,
},
},
UserData: userData,
})
}
func (this *Chain) AddDropIPv6Rule(ip []byte, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 8,
Len: 16,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ip,
},
&expr.Verdict{
Kind: expr.VerdictDrop,
},
},
UserData: userData,
})
}
func (this *Chain) AddRejectIPv4Rule(ip []byte, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ip,
},
&expr.Reject{},
},
UserData: userData,
})
}
func (this *Chain) AddRejectIPv6Rule(ip []byte, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 8,
Len: 16,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ip,
},
&expr.Reject{},
},
UserData: userData,
})
}
func (this *Chain) AddAcceptIPv4SetRule(setName string, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Lookup{
SourceRegister: 1,
SetName: setName,
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
UserData: userData,
})
}
func (this *Chain) AddAcceptIPv6SetRule(setName string, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 8,
Len: 16,
},
&expr.Lookup{
SourceRegister: 1,
SetName: setName,
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
UserData: userData,
})
}
func (this *Chain) AddDropIPv4SetRule(setName string, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Lookup{
SourceRegister: 1,
SetName: setName,
},
&expr.Verdict{
Kind: expr.VerdictDrop,
},
},
UserData: userData,
})
}
func (this *Chain) AddDropIPv6SetRule(setName string, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 8,
Len: 16,
},
&expr.Lookup{
SourceRegister: 1,
SetName: setName,
},
&expr.Verdict{
Kind: expr.VerdictDrop,
},
},
UserData: userData,
})
}
func (this *Chain) AddRejectIPv4SetRule(setName string, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Lookup{
SourceRegister: 1,
SetName: setName,
},
&expr.Reject{},
},
UserData: userData,
})
}
func (this *Chain) AddRejectIPv6SetRule(setName string, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 8,
Len: 16,
},
&expr.Lookup{
SourceRegister: 1,
SetName: setName,
},
&expr.Reject{},
},
UserData: userData,
})
}
func (this *Chain) AddAcceptInterfaceRule(interfaceName string, userData []byte) (*Rule, error) {
if len(interfaceName) >= 16 {
return nil, errors.New("invalid interface name '" + interfaceName + "'")
}
var ifname = make([]byte, 16)
copy(ifname, interfaceName+"\x00")
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname,
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
UserData: userData,
})
}
func (this *Chain) GetRuleWithUserData(userData []byte) (*Rule, error) {
rawRules, err := this.conn.Raw().GetRule(this.rawTable, this.rawChain)
if err != nil {
return nil, err
}
for _, rawRule := range rawRules {
if bytes.Compare(rawRule.UserData, userData) == 0 {
return NewRule(rawRule), nil
}
}
return nil, ErrRuleNotFound
}
func (this *Chain) GetRules() ([]*Rule, error) {
rawRules, err := this.conn.Raw().GetRule(this.rawTable, this.rawChain)
if err != nil {
return nil, err
}
var result = []*Rule{}
for _, rawRule := range rawRules {
result = append(result, NewRule(rawRule))
}
return result, nil
}
func (this *Chain) DeleteRule(rule *Rule) error {
err := this.conn.Raw().DelRule(rule.Raw())
if err != nil {
return err
}
return this.conn.Commit()
}
func (this *Chain) Flush() error {
this.conn.Raw().FlushChain(this.rawChain)
return this.conn.Commit()
}

View File

@@ -0,0 +1,14 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
package nftables
import nft "github.com/google/nftables"
type ChainPolicy = nft.ChainPolicy
// Possible ChainPolicy values.
const (
ChainPolicyDrop = nft.ChainPolicyDrop
ChainPolicyAccept = nft.ChainPolicyAccept
)

View File

@@ -0,0 +1,129 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
package nftables_test
import (
"github.com/TeaOSLab/EdgeDNS/internal/firewalls/nftables"
"net"
"testing"
)
func getIPv4Chain(t *testing.T) *nftables.Chain {
var conn = nftables.NewConn()
table, err := conn.GetTable("test_ipv4", nftables.TableFamilyIPv4)
if err != nil {
if err == nftables.ErrTableNotFound {
table, err = conn.AddIPv4Table("test_ipv4")
if err != nil {
t.Fatal(err)
}
} else {
t.Fatal(err)
}
}
chain, err := table.GetChain("test_chain")
if err != nil {
if err == nftables.ErrChainNotFound {
chain, err = table.AddAcceptChain("test_chain")
}
}
if err != nil {
t.Fatal(err)
}
return chain
}
func TestChain_AddAcceptIPRule(t *testing.T) {
var chain = getIPv4Chain(t)
_, err := chain.AddAcceptIPv4Rule(net.ParseIP("192.168.2.40").To4(), nil)
if err != nil {
t.Fatal(err)
}
}
func TestChain_AddDropIPRule(t *testing.T) {
var chain = getIPv4Chain(t)
_, err := chain.AddDropIPv4Rule(net.ParseIP("192.168.2.31").To4(), nil)
if err != nil {
t.Fatal(err)
}
}
func TestChain_AddAcceptSetRule(t *testing.T) {
var chain = getIPv4Chain(t)
_, err := chain.AddAcceptIPv4SetRule("ipv4_black_set", nil)
if err != nil {
t.Fatal(err)
}
}
func TestChain_AddDropSetRule(t *testing.T) {
var chain = getIPv4Chain(t)
_, err := chain.AddDropIPv4SetRule("ipv4_black_set", nil)
if err != nil {
t.Fatal(err)
}
}
func TestChain_AddRejectSetRule(t *testing.T) {
var chain = getIPv4Chain(t)
_, err := chain.AddRejectIPv4SetRule("ipv4_black_set", nil)
if err != nil {
t.Fatal(err)
}
}
func TestChain_GetRuleWithUserData(t *testing.T) {
var chain = getIPv4Chain(t)
rule, err := chain.GetRuleWithUserData([]byte("test"))
if err != nil {
if err == nftables.ErrRuleNotFound {
t.Log("rule not found")
return
} else {
t.Fatal(err)
}
}
t.Log("rule:", rule)
}
func TestChain_GetRules(t *testing.T) {
var chain = getIPv4Chain(t)
rules, err := chain.GetRules()
if err != nil {
t.Fatal(err)
}
for _, rule := range rules {
t.Log("handle:", rule.Handle(), "set name:", rule.LookupSetName(),
"verdict:", rule.VerDict(), "user data:", string(rule.UserData()))
}
}
func TestChain_DeleteRule(t *testing.T) {
var chain = getIPv4Chain(t)
rule, err := chain.GetRuleWithUserData([]byte("test"))
if err != nil {
if err == nftables.ErrRuleNotFound {
t.Log("rule not found")
return
}
t.Fatal(err)
}
err = chain.DeleteRule(rule)
if err != nil {
t.Fatal(err)
}
}
func TestChain_Flush(t *testing.T) {
var chain = getIPv4Chain(t)
err := chain.Flush()
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}

View File

@@ -0,0 +1,83 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
package nftables
import (
"errors"
nft "github.com/google/nftables"
"github.com/iwind/TeaGo/types"
)
const MaxTableNameLength = 27
type Conn struct {
rawConn *nft.Conn
}
func NewConn() *Conn {
return &Conn{
rawConn: &nft.Conn{},
}
}
func (this *Conn) Raw() *nft.Conn {
return this.rawConn
}
func (this *Conn) GetTable(name string, family TableFamily) (*Table, error) {
rawTables, err := this.rawConn.ListTables()
if err != nil {
return nil, err
}
for _, rawTable := range rawTables {
if rawTable.Name == name && rawTable.Family == family {
return NewTable(this, rawTable), nil
}
}
return nil, ErrTableNotFound
}
func (this *Conn) AddTable(name string, family TableFamily) (*Table, error) {
if len(name) > MaxTableNameLength {
return nil, errors.New("table name too long (max " + types.String(MaxTableNameLength) + ")")
}
var rawTable = this.rawConn.AddTable(&nft.Table{
Family: family,
Name: name,
})
err := this.Commit()
if err != nil {
return nil, err
}
return NewTable(this, rawTable), nil
}
func (this *Conn) AddIPv4Table(name string) (*Table, error) {
return this.AddTable(name, TableFamilyIPv4)
}
func (this *Conn) AddIPv6Table(name string) (*Table, error) {
return this.AddTable(name, TableFamilyIPv6)
}
func (this *Conn) DeleteTable(name string, family TableFamily) error {
table, err := this.GetTable(name, family)
if err != nil {
if err == ErrTableNotFound {
return nil
}
return err
}
this.rawConn.DelTable(table.Raw())
return this.Commit()
}
func (this *Conn) Commit() error {
return this.rawConn.Flush()
}

View File

@@ -0,0 +1,77 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
package nftables_test
import (
"github.com/TeaOSLab/EdgeDNS/internal/firewalls/nftables"
executils "github.com/TeaOSLab/EdgeDNS/internal/utils/exec"
"testing"
)
func TestConn_Test(t *testing.T) {
_, err := executils.LookPath("nft")
if err == nil {
t.Log("ok")
return
}
t.Log(err)
}
func TestConn_GetTable_NotFound(t *testing.T) {
var conn = nftables.NewConn()
table, err := conn.GetTable("a", nftables.TableFamilyIPv4)
if err != nil {
if err == nftables.ErrTableNotFound {
t.Log("table not found")
} else {
t.Fatal(err)
}
} else {
t.Log("table:", table)
}
}
func TestConn_GetTable(t *testing.T) {
var conn = nftables.NewConn()
table, err := conn.GetTable("myFilter", nftables.TableFamilyIPv4)
if err != nil {
if err == nftables.ErrTableNotFound {
t.Log("table not found")
} else {
t.Fatal(err)
}
} else {
t.Log("table:", table)
}
}
func TestConn_AddTable(t *testing.T) {
var conn = nftables.NewConn()
{
table, err := conn.AddIPv4Table("test_ipv4")
if err != nil {
t.Fatal(err)
}
t.Log(table.Name())
}
{
table, err := conn.AddIPv6Table("test_ipv6")
if err != nil {
t.Fatal(err)
}
t.Log(table.Name())
}
}
func TestConn_DeleteTable(t *testing.T) {
var conn = nftables.NewConn()
err := conn.DeleteTable("test_ipv4", nftables.TableFamilyIPv4)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}

View File

@@ -0,0 +1,7 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
package nftables
type Element struct {
}

View File

@@ -0,0 +1,18 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
package nftables
import "errors"
var ErrTableNotFound = errors.New("table not found")
var ErrChainNotFound = errors.New("chain not found")
var ErrSetNotFound = errors.New("set not found")
var ErrRuleNotFound = errors.New("rule not found")
func IsNotFound(err error) bool {
if err == nil {
return false
}
return err == ErrTableNotFound || err == ErrChainNotFound || err == ErrSetNotFound || err == ErrRuleNotFound
}

View File

@@ -0,0 +1,19 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
package nftables
import (
nft "github.com/google/nftables"
)
type TableFamily = nft.TableFamily
const (
TableFamilyINet TableFamily = nft.TableFamilyINet
TableFamilyIPv4 TableFamily = nft.TableFamilyIPv4
TableFamilyIPv6 TableFamily = nft.TableFamilyIPv6
TableFamilyARP TableFamily = nft.TableFamilyARP
TableFamilyNetdev TableFamily = nft.TableFamilyNetdev
TableFamilyBridge TableFamily = nft.TableFamilyBridge
)

View File

@@ -0,0 +1,52 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
package nftables
import (
nft "github.com/google/nftables"
"github.com/google/nftables/expr"
)
type Rule struct {
rawRule *nft.Rule
}
func NewRule(rawRule *nft.Rule) *Rule {
return &Rule{
rawRule: rawRule,
}
}
func (this *Rule) Raw() *nft.Rule {
return this.rawRule
}
func (this *Rule) LookupSetName() string {
for _, e := range this.rawRule.Exprs {
exp, ok := e.(*expr.Lookup)
if ok {
return exp.SetName
}
}
return ""
}
func (this *Rule) VerDict() expr.VerdictKind {
for _, e := range this.rawRule.Exprs {
exp, ok := e.(*expr.Verdict)
if ok {
return exp.Kind
}
}
return -100
}
func (this *Rule) Handle() uint64 {
return this.rawRule.Handle
}
func (this *Rule) UserData() []byte {
return this.rawRule.UserData
}

View File

@@ -0,0 +1,160 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
package nftables
import (
"errors"
"github.com/TeaOSLab/EdgeDNS/internal/utils"
nft "github.com/google/nftables"
"net"
"strings"
"time"
)
const MaxSetNameLength = 15
type SetOptions struct {
Id uint32
HasTimeout bool
Timeout time.Duration
KeyType SetDataType
DataType SetDataType
Constant bool
Interval bool
Anonymous bool
IsMap bool
}
type ElementOptions struct {
Timeout time.Duration
}
type Set struct {
conn *Conn
rawSet *nft.Set
batch *SetBatch
}
func NewSet(conn *Conn, rawSet *nft.Set) *Set {
return &Set{
conn: conn,
rawSet: rawSet,
batch: &SetBatch{
conn: conn,
rawSet: rawSet,
},
}
}
func (this *Set) Raw() *nft.Set {
return this.rawSet
}
func (this *Set) Name() string {
return this.rawSet.Name
}
func (this *Set) AddElement(key []byte, options *ElementOptions) error {
var rawElement = nft.SetElement{
Key: key,
}
if options != nil {
rawElement.Timeout = options.Timeout
}
err := this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{
rawElement,
})
if err != nil {
return err
}
err = this.conn.Commit()
if err != nil {
// retry if exists
if strings.Contains(err.Error(), "file exists") {
deleteErr := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
{
Key: key,
},
})
if deleteErr == nil {
err = this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{
rawElement,
})
if err == nil {
err = this.conn.Commit()
}
}
}
}
return err
}
func (this *Set) AddIPElement(ip string, options *ElementOptions) error {
var ipObj = net.ParseIP(ip)
if ipObj == nil {
return errors.New("invalid ip '" + ip + "'")
}
if utils.IsIPv4(ip) {
return this.AddElement(ipObj.To4(), options)
} else {
return this.AddElement(ipObj.To16(), options)
}
}
func (this *Set) DeleteElement(key []byte) error {
err := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
{
Key: key,
},
})
if err != nil {
return err
}
err = this.conn.Commit()
if err != nil {
if strings.Contains(err.Error(), "no such file or directory") {
err = nil
}
}
return err
}
func (this *Set) DeleteIPElement(ip string) error {
var ipObj = net.ParseIP(ip)
if ipObj == nil {
return errors.New("invalid ip '" + ip + "'")
}
if utils.IsIPv4(ip) {
return this.DeleteElement(ipObj.To4())
} else {
return this.DeleteElement(ipObj.To16())
}
}
func (this *Set) Batch() *SetBatch {
return this.batch
}
func (this *Set) GetIPElements() ([]string, error) {
elements, err := this.conn.Raw().GetSetElements(this.rawSet)
if err != nil {
return nil, err
}
var result = []string{}
for _, element := range elements {
result = append(result, net.IP(element.Key).String())
}
return result, nil
}
// not work current time
/**func (this *Set) Flush() error {
this.conn.Raw().FlushSet(this.rawSet)
return this.conn.Commit()
}**/

View File

@@ -0,0 +1,37 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
package nftables
import (
nft "github.com/google/nftables"
)
type SetBatch struct {
conn *Conn
rawSet *nft.Set
}
func (this *SetBatch) AddElement(key []byte, options *ElementOptions) error {
var rawElement = nft.SetElement{
Key: key,
}
if options != nil {
rawElement.Timeout = options.Timeout
}
return this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{
rawElement,
})
}
func (this *SetBatch) DeleteElement(key []byte) error {
return this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
{
Key: key,
},
})
}
func (this *SetBatch) Commit() error {
return this.conn.Commit()
}

View File

@@ -0,0 +1,58 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
package nftables
import nft "github.com/google/nftables"
type SetDataType = nft.SetDatatype
var (
TypeInvalid = nft.TypeInvalid
TypeVerdict = nft.TypeVerdict
TypeNFProto = nft.TypeNFProto
TypeBitmask = nft.TypeBitmask
TypeInteger = nft.TypeInteger
TypeString = nft.TypeString
TypeLLAddr = nft.TypeLLAddr
TypeIPAddr = nft.TypeIPAddr
TypeIP6Addr = nft.TypeIP6Addr
TypeEtherAddr = nft.TypeEtherAddr
TypeEtherType = nft.TypeEtherType
TypeARPOp = nft.TypeARPOp
TypeInetProto = nft.TypeInetProto
TypeInetService = nft.TypeInetService
TypeICMPType = nft.TypeICMPType
TypeTCPFlag = nft.TypeTCPFlag
TypeDCCPPktType = nft.TypeDCCPPktType
TypeMHType = nft.TypeMHType
TypeTime = nft.TypeTime
TypeMark = nft.TypeMark
TypeIFIndex = nft.TypeIFIndex
TypeARPHRD = nft.TypeARPHRD
TypeRealm = nft.TypeRealm
TypeClassID = nft.TypeClassID
TypeUID = nft.TypeUID
TypeGID = nft.TypeGID
TypeCTState = nft.TypeCTState
TypeCTDir = nft.TypeCTDir
TypeCTStatus = nft.TypeCTStatus
TypeICMP6Type = nft.TypeICMP6Type
TypeCTLabel = nft.TypeCTLabel
TypePktType = nft.TypePktType
TypeICMPCode = nft.TypeICMPCode
TypeICMPV6Code = nft.TypeICMPV6Code
TypeICMPXCode = nft.TypeICMPXCode
TypeDevGroup = nft.TypeDevGroup
TypeDSCP = nft.TypeDSCP
TypeECN = nft.TypeECN
TypeFIBAddr = nft.TypeFIBAddr
TypeBoolean = nft.TypeBoolean
TypeCTEventBit = nft.TypeCTEventBit
TypeIFName = nft.TypeIFName
TypeIGMPType = nft.TypeIGMPType
TypeTimeDate = nft.TypeTimeDate
TypeTimeHour = nft.TypeTimeHour
TypeTimeDay = nft.TypeTimeDay
TypeCGroupV2 = nft.TypeCGroupV2
)

View File

@@ -0,0 +1,111 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
package nftables_test
import (
"errors"
"github.com/TeaOSLab/EdgeDNS/internal/firewalls/nftables"
"github.com/iwind/TeaGo/types"
"github.com/mdlayher/netlink"
"net"
"testing"
"time"
)
func getIPv4Set(t *testing.T) *nftables.Set {
var table = getIPv4Table(t)
set, err := table.GetSet("test_ipv4_set")
if err != nil {
if err == nftables.ErrSetNotFound {
set, err = table.AddSet("test_ipv4_set", &nftables.SetOptions{
KeyType: nftables.TypeIPAddr,
HasTimeout: true,
})
if err != nil {
t.Fatal(err)
}
} else {
t.Fatal(err)
}
}
return set
}
func TestSet_AddElement(t *testing.T) {
var set = getIPv4Set(t)
err := set.AddElement(net.ParseIP("192.168.2.31").To4(), &nftables.ElementOptions{Timeout: 86400 * time.Second})
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestSet_DeleteElement(t *testing.T) {
var set = getIPv4Set(t)
err := set.DeleteElement(net.ParseIP("192.168.2.31").To4())
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestSet_Batch(t *testing.T) {
var batch = getIPv4Set(t).Batch()
for _, ip := range []string{"192.168.2.30", "192.168.2.31", "192.168.2.32", "192.168.2.33", "192.168.2.34"} {
var ipData = net.ParseIP(ip).To4()
//err := batch.DeleteElement(ipData)
//if err != nil {
// t.Fatal(err)
//}
err := batch.AddElement(ipData, &nftables.ElementOptions{Timeout: 10 * time.Second})
if err != nil {
t.Fatal(err)
}
}
err := batch.Commit()
if err != nil {
t.Logf("%#v", errors.Unwrap(err).(*netlink.OpError))
t.Fatal(err)
}
t.Log("ok")
}
func TestSet_Add_Many(t *testing.T) {
var set = getIPv4Set(t)
for i := 0; i < 255; i++ {
t.Log(i)
for j := 0; j < 255; j++ {
var ip = "192.167." + types.String(i) + "." + types.String(j)
var ipData = net.ParseIP(ip).To4()
err := set.Batch().AddElement(ipData, &nftables.ElementOptions{Timeout: 3600 * time.Second})
if err != nil {
t.Fatal(err)
}
if j%10 == 0 {
err = set.Batch().Commit()
if err != nil {
t.Fatal(err)
}
}
}
err := set.Batch().Commit()
if err != nil {
t.Fatal(err)
}
}
t.Log("ok")
}
/**func TestSet_Flush(t *testing.T) {
var set = getIPv4Set(t)
err := set.Flush()
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}**/

View File

@@ -0,0 +1,156 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
package nftables
import (
"errors"
nft "github.com/google/nftables"
"github.com/iwind/TeaGo/types"
"strings"
)
type Table struct {
conn *Conn
rawTable *nft.Table
}
func NewTable(conn *Conn, rawTable *nft.Table) *Table {
return &Table{
conn: conn,
rawTable: rawTable,
}
}
func (this *Table) Raw() *nft.Table {
return this.rawTable
}
func (this *Table) Name() string {
return this.rawTable.Name
}
func (this *Table) Family() TableFamily {
return this.rawTable.Family
}
func (this *Table) GetChain(name string) (*Chain, error) {
rawChains, err := this.conn.Raw().ListChains()
if err != nil {
return nil, err
}
for _, rawChain := range rawChains {
// must compare table name
if rawChain.Name == name && rawChain.Table.Name == this.rawTable.Name {
return NewChain(this.conn, this.rawTable, rawChain), nil
}
}
return nil, ErrChainNotFound
}
func (this *Table) AddChain(name string, chainPolicy *ChainPolicy) (*Chain, error) {
if len(name) > MaxChainNameLength {
return nil, errors.New("chain name too long (max " + types.String(MaxChainNameLength) + ")")
}
var rawChain = this.conn.Raw().AddChain(&nft.Chain{
Name: name,
Table: this.rawTable,
Hooknum: nft.ChainHookInput,
Priority: nft.ChainPriorityFilter,
Type: nft.ChainTypeFilter,
Policy: chainPolicy,
})
err := this.conn.Commit()
if err != nil {
return nil, err
}
return NewChain(this.conn, this.rawTable, rawChain), nil
}
func (this *Table) AddAcceptChain(name string) (*Chain, error) {
var policy = ChainPolicyAccept
return this.AddChain(name, &policy)
}
func (this *Table) AddDropChain(name string) (*Chain, error) {
var policy = ChainPolicyDrop
return this.AddChain(name, &policy)
}
func (this *Table) DeleteChain(name string) error {
chain, err := this.GetChain(name)
if err != nil {
if err == ErrChainNotFound {
return nil
}
return err
}
this.conn.Raw().DelChain(chain.Raw())
return this.conn.Commit()
}
func (this *Table) GetSet(name string) (*Set, error) {
rawSet, err := this.conn.Raw().GetSetByName(this.rawTable, name)
if err != nil {
if strings.Contains(err.Error(), "no such file or directory") {
return nil, ErrSetNotFound
}
return nil, err
}
return NewSet(this.conn, rawSet), nil
}
func (this *Table) AddSet(name string, options *SetOptions) (*Set, error) {
if len(name) > MaxSetNameLength {
return nil, errors.New("set name too long (max " + types.String(MaxSetNameLength) + ")")
}
if options == nil {
options = &SetOptions{}
}
var rawSet = &nft.Set{
Table: this.rawTable,
ID: options.Id,
Name: name,
Anonymous: options.Anonymous,
Constant: options.Constant,
Interval: options.Interval,
IsMap: options.IsMap,
HasTimeout: options.HasTimeout,
Timeout: options.Timeout,
KeyType: options.KeyType,
DataType: options.DataType,
}
err := this.conn.Raw().AddSet(rawSet, nil)
if err != nil {
return nil, err
}
err = this.conn.Commit()
if err != nil {
return nil, err
}
return NewSet(this.conn, rawSet), nil
}
func (this *Table) DeleteSet(name string) error {
set, err := this.GetSet(name)
if err != nil {
if err == ErrSetNotFound {
return nil
}
return err
}
this.conn.Raw().DelSet(set.Raw())
return this.conn.Commit()
}
func (this *Table) Flush() error {
this.conn.Raw().FlushTable(this.rawTable)
return this.conn.Commit()
}

View File

@@ -0,0 +1,139 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
package nftables_test
import (
"github.com/TeaOSLab/EdgeDNS/internal/firewalls/nftables"
"testing"
)
func getIPv4Table(t *testing.T) *nftables.Table {
var conn = nftables.NewConn()
table, err := conn.GetTable("test_ipv4", nftables.TableFamilyIPv4)
if err != nil {
if err == nftables.ErrTableNotFound {
table, err = conn.AddIPv4Table("test_ipv4")
if err != nil {
t.Fatal(err)
}
} else {
t.Fatal(err)
}
}
return table
}
func TestTable_AddChain(t *testing.T) {
var table = getIPv4Table(t)
{
chain, err := table.AddChain("test_default_chain", nil)
if err != nil {
t.Fatal(err)
}
t.Log("created:", chain.Name())
}
{
chain, err := table.AddAcceptChain("test_accept_chain")
if err != nil {
t.Fatal(err)
}
t.Log("created:", chain.Name())
}
// Do not test drop chain before adding accept rule, you will drop yourself!!!!!!!
/**{
chain, err := table.AddDropChain("test_drop_chain")
if err != nil {
t.Fatal(err)
}
t.Log("created:", chain.Name())
}**/
}
func TestTable_GetChain(t *testing.T) {
var table = getIPv4Table(t)
for _, chainName := range []string{"not_found_chain", "test_default_chain"} {
chain, err := table.GetChain(chainName)
if err != nil {
if err == nftables.ErrChainNotFound {
t.Log(chainName, ":", "not found")
} else {
t.Fatal(err)
}
} else {
t.Log(chainName, ":", chain)
}
}
}
func TestTable_DeleteChain(t *testing.T) {
var table = getIPv4Table(t)
err := table.DeleteChain("test_default_chain")
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestTable_AddSet(t *testing.T) {
var table = getIPv4Table(t)
{
set, err := table.AddSet("ipv4_black_set", &nftables.SetOptions{
HasTimeout: false,
KeyType: nftables.TypeIPAddr,
})
if err != nil {
t.Fatal(err)
}
t.Log(set.Name())
}
{
set, err := table.AddSet("ipv6_black_set", &nftables.SetOptions{
HasTimeout: true,
//Timeout: 3600 * time.Second,
KeyType: nftables.TypeIP6Addr,
})
if err != nil {
t.Fatal(err)
}
t.Log(set.Name())
}
}
func TestTable_GetSet(t *testing.T) {
var table = getIPv4Table(t)
for _, setName := range []string{"not_found_set", "ipv4_black_set"} {
set, err := table.GetSet(setName)
if err != nil {
if err == nftables.ErrSetNotFound {
t.Log(setName, ": not found")
} else {
t.Fatal(err)
}
} else {
t.Log(setName, ":", set)
}
}
}
func TestTable_DeleteSet(t *testing.T) {
var table = getIPv4Table(t)
err := table.DeleteSet("ipv4_black_set")
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestTable_Flush(t *testing.T) {
var table = getIPv4Table(t)
err := table.Flush()
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}

View File

@@ -0,0 +1,12 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package goman
import "time"
type Instance struct {
Id uint64
CreatedTime time.Time
File string
Line int
}

View File

@@ -0,0 +1,81 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package goman
import (
"runtime"
"sync"
"time"
)
var locker = &sync.Mutex{}
var instanceMap = map[uint64]*Instance{} // id => *Instance
var instanceId = uint64(0)
// New 新创建goroutine
func New(f func()) {
_, file, line, _ := runtime.Caller(1)
go func() {
locker.Lock()
instanceId++
var instance = &Instance{
Id: instanceId,
CreatedTime: time.Now(),
}
instance.File = file
instance.Line = line
instanceMap[instanceId] = instance
locker.Unlock()
// run function
f()
locker.Lock()
delete(instanceMap, instanceId)
locker.Unlock()
}()
}
// NewWithArgs 创建带有参数的goroutine
func NewWithArgs(f func(args ...interface{}), args ...interface{}) {
_, file, line, _ := runtime.Caller(1)
go func() {
locker.Lock()
instanceId++
var instance = &Instance{
Id: instanceId,
CreatedTime: time.Now(),
}
instance.File = file
instance.Line = line
instanceMap[instanceId] = instance
locker.Unlock()
// run function
f(args...)
locker.Lock()
delete(instanceMap, instanceId)
locker.Unlock()
}()
}
// List 列出所有正在运行goroutine
func List() []*Instance {
locker.Lock()
defer locker.Unlock()
var result = []*Instance{}
for _, instance := range instanceMap {
result = append(result, instance)
}
return result
}

View File

@@ -0,0 +1,28 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package goman
import (
"testing"
"time"
)
func TestNew(t *testing.T) {
New(func() {
t.Log("Hello")
t.Log(List())
})
time.Sleep(1 * time.Second)
t.Log(List())
time.Sleep(1 * time.Second)
}
func TestNewWithArgs(t *testing.T) {
NewWithArgs(func(args ...interface{}) {
t.Log(args[0], args[1])
}, 1, 2)
time.Sleep(1 * time.Second)
}

View File

@@ -0,0 +1,9 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package models
type AgentIP struct {
Id int64
IP string
AgentCode string
}

View File

@@ -0,0 +1,15 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus
package models
import "github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
type NSDomain struct {
Id int64
ClusterId int64
UserId int64
Name string
TSIG *dnsconfigs.NSTSIGConfig
Version int64
}

View File

@@ -0,0 +1,13 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package models
type NSKey struct {
Id int64
DomainId int64
ZoneId int64
Algo string
Secret string
SecretType string
Version int64
}

View File

@@ -0,0 +1,27 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package models
type NSKeys struct {
m map[int64]*NSKey // keyId => *NSKey
}
func NewNSKeys() *NSKeys {
return &NSKeys{m: map[int64]*NSKey{}}
}
func (this *NSKeys) Add(key *NSKey) {
this.m[key.Id] = key
}
func (this *NSKeys) Remove(keyId int64) {
delete(this.m, keyId)
}
func (this *NSKeys) All() []*NSKey {
var result = []*NSKey{}
for _, k := range this.m {
result = append(result, k)
}
return result
}

View File

@@ -0,0 +1,166 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package models
import (
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/iwind/TeaGo/types"
"github.com/miekg/dns"
"net"
"strings"
)
type NSRecord struct {
Id int64
Name string
Type dnsconfigs.RecordType
Value string
MXPriority int32
SRVPriority int32
SRVWeight int32
SRVPort int32
CAAFlag int32
CAATag string
Ttl int32
Weight int32
Version int64
RouteIds []string
DomainId int64
}
func (this *NSRecord) ToRRAnswer(questionName string, rrClass uint16) dns.RR {
if this.Ttl <= 0 {
this.Ttl = 60
}
switch this.Type {
case dnsconfigs.RecordTypeA:
return &dns.A{
Hdr: this.ToRRHeader(questionName, dns.TypeA, rrClass),
A: net.ParseIP(this.Value),
}
case dnsconfigs.RecordTypeCNAME:
var value = this.Value
if !strings.HasSuffix(value, ".") {
value += "."
}
return &dns.CNAME{
Hdr: this.ToRRHeader(questionName, dns.TypeCNAME, rrClass),
Target: value,
}
case dnsconfigs.RecordTypeAAAA:
return &dns.AAAA{
Hdr: this.ToRRHeader(questionName, dns.TypeAAAA, rrClass),
AAAA: net.ParseIP(this.Value),
}
case dnsconfigs.RecordTypeNS:
var value = this.Value
if !strings.HasSuffix(value, ".") {
value += "."
}
return &dns.NS{
Hdr: this.ToRRHeader(questionName, dns.TypeNS, rrClass),
Ns: value,
}
case dnsconfigs.RecordTypeMX:
var value = this.Value
if !strings.HasSuffix(value, ".") {
value += "."
}
var preference uint16 = 0
var priority = this.MXPriority
if priority >= 0 {
if priority > 65535 {
priority = 65535
}
preference = types.Uint16(priority)
}
return &dns.MX{
Hdr: this.ToRRHeader(questionName, dns.TypeMX, rrClass),
Preference: preference,
Mx: value,
}
case dnsconfigs.RecordTypeSRV:
var priority uint16 = 10
if this.SRVPriority > 0 {
priority = uint16(this.SRVPriority)
}
var weight uint16 = 10
if this.SRVWeight > 0 {
weight = uint16(this.SRVWeight)
}
var port uint16 = 0
if this.SRVPort > 0 {
port = uint16(this.SRVPort)
}
var value = this.Value
if !strings.HasSuffix(value, ".") {
value += "."
}
return &dns.SRV{
Hdr: this.ToRRHeader(questionName, dns.TypeSRV, rrClass),
Priority: priority,
Weight: weight,
Port: port,
Target: value,
}
case dnsconfigs.RecordTypeTXT:
var values []string
var runes = []rune(this.Value)
const maxChars = 255
for {
if len(runes) <= maxChars {
values = append(values, string(runes))
break
}
values = append(values, string(runes[:maxChars]))
runes = runes[maxChars:]
if len(runes) == 0 {
break
}
}
return &dns.TXT{
Hdr: this.ToRRHeader(questionName, dns.TypeTXT, rrClass),
Txt: values, // TODO 可以添加多个
}
case dnsconfigs.RecordTypeCAA:
var flag uint8 = 0
if this.CAAFlag >= 0 && this.CAAFlag <= 128 {
flag = uint8(this.CAAFlag)
}
var tag = this.CAATag
if tag != "issue" && tag != "issuewild" && tag != "iodef" {
tag = "issue"
}
return &dns.CAA{
Hdr: this.ToRRHeader(questionName, dns.TypeCAA, rrClass),
Flag: flag, // 0-128
Tag: tag, // issue|issuewild|iodef
Value: this.Value,
}
}
return nil
}
func (this *NSRecord) ToRRHeader(questionName string, rrType uint16, rrClass uint16) dns.RR_Header {
return dns.RR_Header{
Name: questionName,
Rrtype: rrType,
Class: rrClass,
Ttl: uint32(this.Ttl),
}
}

View File

@@ -0,0 +1,45 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus
package models
import (
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"net"
)
type NSRoute struct {
Id int64
Ranges []dnsconfigs.NSRouteRangeInterface
Priority int32
Order int32
UserId int64
Version int64
}
func (this *NSRoute) Contains(ip net.IP) bool {
if len(ip) == 0 {
return false
}
// 先执行IsReverse
for _, r := range this.Ranges {
if r.IsExcluding() && r.Contains(ip) {
return false
}
}
// 再执行正常的
for _, r := range this.Ranges {
if !r.IsExcluding() && r.Contains(ip) {
return true
}
}
return false
}
// RealCode 代号
// TODO 支持自定义代号
func (this *NSRoute) RealCode() string {
return RouteIdString(this.Id)
}

View File

@@ -0,0 +1,58 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus
package models
import (
"encoding/json"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/iwind/TeaGo/maps"
)
// InitRangesFromJSON 从JSON中初始化线路范围
func InitRangesFromJSON(rangesJSON []byte) (ranges []dnsconfigs.NSRouteRangeInterface, err error) {
if len(rangesJSON) == 0 {
return
}
var rangeMaps = []maps.Map{}
err = json.Unmarshal(rangesJSON, &rangeMaps)
if err != nil {
return nil, err
}
for _, rangeMap := range rangeMaps {
var rangeType = rangeMap.GetString("type")
paramsJSON, err := json.Marshal(rangeMap.Get("params"))
if err != nil {
return nil, err
}
var r dnsconfigs.NSRouteRangeInterface
switch rangeType {
case dnsconfigs.NSRouteRangeTypeIP:
r = &dnsconfigs.NSRouteRangeIPRange{}
case dnsconfigs.NSRouteRangeTypeCIDR:
r = &dnsconfigs.NSRouteRangeCIDR{}
case dnsconfigs.NSRouteRangeTypeRegion:
r = &dnsconfigs.NSRouteRangeRegion{
Connector: rangeMap.GetString("connector"),
}
r.SetRegionResolver(DefaultRegionResolver)
default:
return nil, errors.New("invalid route line type '" + rangeType + "'")
}
err = json.Unmarshal(paramsJSON, r)
if err != nil {
return nil, err
}
err = r.Init()
if err != nil {
return nil, err
}
ranges = append(ranges, r)
}
return
}

View File

@@ -0,0 +1,172 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package models
import (
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/iwind/TeaGo/rands"
"math/rand"
"time"
)
func init() {
rand.Seed(time.Now().UnixNano())
}
type recordIdInfo struct {
Id int64
Weight int32
}
type RecordIds struct {
IdList []*recordIdInfo
IdBucket []int64
RoundIndex int
totalWeight int64
}
func NewRecordIds() *RecordIds {
return &RecordIds{}
}
func (this *RecordIds) IsEmpty() bool {
return len(this.IdList) == 0
}
func (this *RecordIds) Add(newId int64, weight int32) {
if weight <= 0 {
weight = 10
}
const maxWeight = 999999
if weight > maxWeight {
weight = maxWeight
}
// 检查是否存在
for _, idInfo := range this.IdList {
if idInfo.Id == newId {
return
}
}
// 添加
this.IdList = append(this.IdList, &recordIdInfo{
Id: newId,
Weight: weight,
})
// 重置数据
this.resetData()
}
func (this *RecordIds) Remove(oldId int64) {
defer this.resetData()
var newIdList = []*recordIdInfo{}
for _, idInfo := range this.IdList {
if idInfo.Id == oldId {
continue
}
newIdList = append(newIdList, idInfo)
}
this.IdList = newIdList
}
// NextId for round-robin
func (this *RecordIds) NextId() int64 {
var l = len(this.IdList)
if l == 0 {
return 0
}
if l == 1 {
return this.IdList[0].Id
}
if this.RoundIndex > l-1 {
this.RoundIndex = 0
}
var id = this.IdList[this.RoundIndex].Id
this.RoundIndex++
return id
}
func (this *RecordIds) RandomIds(count int) []int64 {
if count <= 0 {
count = dnsconfigs.NSAnswerDefaultSize
}
var totalRecords = len(this.IdList)
if totalRecords == 0 {
return nil
}
if totalRecords == 1 {
return []int64{this.IdList[0].Id} // duplicate
}
if totalRecords < count {
count = totalRecords
}
var totalIds = len(this.IdBucket)
var startIndex = rands.Int(0, totalIds-1)
var endIndex = startIndex + count - 1
if endIndex <= totalIds-1 {
return this.IdBucket[startIndex : endIndex+1]
}
return append(this.IdBucket[startIndex:totalIds], this.IdBucket[0:endIndex-totalIds+1]...)
}
func (this *RecordIds) resetData() {
this.resetWeight()
}
func (this *RecordIds) resetWeight() {
var totalWeight int64
var weightMap = map[int32]bool{} // weight => bool
var hasUniqueWeights = false
var ids []int64
for _, idInfo := range this.IdList {
totalWeight += int64(idInfo.Weight)
// 检查是否有不同的权重
if len(weightMap) > 0 && !weightMap[idInfo.Weight] {
hasUniqueWeights = true
}
weightMap[idInfo.Weight] = true
ids = append(ids, idInfo.Id)
}
// 根据权重重新组织IDs
if hasUniqueWeights {
var newIds = []int64{}
for _, idInfo := range this.IdList {
for i := int32(0); i < idInfo.Weight; i++ {
newIds = append(newIds, idInfo.Id)
}
}
ids = newIds
}
var countIds = len(ids)
if countIds > 0 {
rand.Shuffle(countIds, func(i, j int) {
ids[i], ids[j] = ids[j], ids[i]
})
}
this.totalWeight = totalWeight
this.IdBucket = ids
}

View File

@@ -0,0 +1,131 @@
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package models
import (
"github.com/iwind/TeaGo/logs"
"testing"
)
func TestRecordIds_RandomIds_Once(t *testing.T) {
var recordIds = NewRecordIds()
for id := 1; id <= 10; id++ {
recordIds.Add(int64(id), 1)
}
t.Log("totalWeight:", recordIds.totalWeight)
t.Log(recordIds.RandomIds(5))
}
func TestRecordIds_RandomIds_Once2(t *testing.T) {
var recordIds = NewRecordIds()
for id := 1; id <= 10; id++ {
var weight int32 = 1
if id%3 == 0 {
weight = 3
}
recordIds.Add(int64(id), weight)
}
t.Log("totalWeight:", recordIds.totalWeight)
t.Log(recordIds.RandomIds(5))
}
func TestRecordIds_RandomIds(t *testing.T) {
var recordIds = NewRecordIds()
for id := 1; id <= 10; id++ {
recordIds.Add(int64(id), 1)
}
t.Log("totalWeight:", recordIds.totalWeight)
var statMap = map[int64]int{}
for i := 0; i < 2_000_000; i++ {
var resultIds = recordIds.RandomIds(5)
for _, resultId := range resultIds {
statMap[resultId]++
}
}
logs.PrintAsJSON(statMap, t)
}
func TestRecordIds_RandomIds_Weight1(t *testing.T) {
var recordIds = NewRecordIds()
for id := 1; id <= 10; id++ {
var weight int32 = 10
if id%3 == 0 {
weight = 20
}
recordIds.Add(int64(id), weight)
}
t.Log("totalWeight:", recordIds.totalWeight)
for i := 0; i < 10; i++ {
t.Log(recordIds.RandomIds(5))
}
}
func TestRecordIds_RandomIds_Weight2(t *testing.T) {
var recordIds = NewRecordIds()
for id := 1; id <= 10; id++ {
var weight int32 = 10
if id%3 == 0 {
weight = 20
}
recordIds.Add(int64(id), weight)
}
t.Log("totalWeight:", recordIds.totalWeight)
var statMap = map[int64]int{}
for i := 0; i < 2_000_000; i++ {
var resultIds = recordIds.RandomIds(5)
for _, resultId := range resultIds {
statMap[resultId]++
break
}
}
logs.PrintAsJSON(statMap, t)
}
func TestRecordIds_RandomIds_Weight3(t *testing.T) {
var recordIds = NewRecordIds()
for id := 1; id <= 5; id++ {
var weight int32 = 10
if id%3 == 0 {
weight = 20
}
recordIds.Add(int64(id), weight)
}
t.Log("totalWeight:", recordIds.totalWeight)
for i := 0; i < 10; i++ {
t.Log(recordIds.RandomIds(5))
}
}
func BenchmarkRecordIds_RandomIds_SAME_Weight(b *testing.B) {
var recordIds = NewRecordIds()
for id := 1; id <= 100; id++ {
var weight int32 = 10
recordIds.Add(int64(id), weight)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
recordIds.RandomIds(5)
}
}
func BenchmarkRecordIds_RandomIds(b *testing.B) {
var recordIds = NewRecordIds()
for id := 1; id <= 100; id++ {
var weight int32 = 10
if id%3 == 0 {
weight = 20
}
recordIds.Add(int64(id), weight)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
recordIds.RandomIds(5)
}
}

View File

@@ -0,0 +1,12 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package models
import "strings"
type RecordKey string
func NewRecordKey(recordName string, recordType string) RecordKey {
// 记录名全部使用小写
return RecordKey(strings.ToLower(recordName) + "|" + recordType)
}

View File

@@ -0,0 +1,80 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus
package models
import (
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"strings"
)
type DomainRecords struct {
RecordsMap map[RecordKey]*RouteRecords // key => records
Keys map[int64]RecordKey // recordId => key
}
func NewDomainRecords() *DomainRecords {
return &DomainRecords{
RecordsMap: map[RecordKey]*RouteRecords{},
Keys: map[int64]RecordKey{},
}
}
func (this *DomainRecords) Add(record *NSRecord) {
var key = NewRecordKey(record.Name, record.Type)
records, ok := this.RecordsMap[key]
if !ok {
records = NewRouteRecords()
this.RecordsMap[key] = records
}
records.Add(record)
this.Keys[record.Id] = key
}
func (this *DomainRecords) Find(routeCodes []string, recordName string, recordType string, config *dnsconfigs.NSAnswerConfig, strictMode bool) (record []*NSRecord, routeCode string) {
// NAME.example.com
var key = NewRecordKey(recordName, recordType)
records, ok := this.RecordsMap[key]
if ok {
return records.Find(routeCodes, config, strictMode)
}
// @.example.com
if len(recordName) == 0 {
records, ok = this.RecordsMap[NewRecordKey("@", recordType)]
if ok {
return records.Find(routeCodes, config, strictMode)
}
return nil, ""
}
// *.NAME.example.com
var dotIndex = strings.Index(recordName, ".")
var wildcardNames = []string{}
if dotIndex > 0 {
wildcardNames = append(wildcardNames, "*."+recordName[dotIndex+1:])
}
wildcardNames = append(wildcardNames, "*")
for _, wildcardName := range wildcardNames {
records, ok = this.RecordsMap[NewRecordKey(wildcardName, recordType)]
if ok {
return records.Find(routeCodes, config, strictMode)
}
}
return nil, ""
}
func (this *DomainRecords) Remove(recordId int64) {
key, ok := this.Keys[recordId]
if ok {
var recordsMap = this.RecordsMap[key]
recordsMap.Remove(recordId)
if recordsMap.IsEmpty() {
delete(this.RecordsMap, key)
}
delete(this.Keys, recordId)
}
}

View File

@@ -0,0 +1,57 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus
package models
import (
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/iwind/TeaGo/logs"
"testing"
)
func TestDomainRecords_Find(t *testing.T) {
var records = NewDomainRecords()
records.Add(&NSRecord{Id: 1, Name: "", Value: "1", Type: dnsconfigs.RecordTypeA})
records.Add(&NSRecord{Id: 2, Name: "@", Value: "@", Type: dnsconfigs.RecordTypeA})
records.Add(&NSRecord{Id: 3, Name: "*", Value: "*", Type: dnsconfigs.RecordTypeA})
records.Add(&NSRecord{Id: 4, Name: "hello", Value: "HELLO", Type: dnsconfigs.RecordTypeA})
records.Add(&NSRecord{Id: 5, Name: "*.world", Value: "*.world", Type: dnsconfigs.RecordTypeA})
for _, name := range []string{"", "hello", "world", "hello.world", "hello.world2"} {
record, routeCode := records.Find([]string{}, name, dnsconfigs.RecordTypeA, nil, false)
t.Log(name, record, routeCode)
}
}
func TestDomainRecords_Find_RouteIds(t *testing.T) {
var records = NewDomainRecords()
records.Add(&NSRecord{Id: 1, Name: "", Value: "1", Type: dnsconfigs.RecordTypeA, RouteIds: []string{RouteIdString(11), RouteIdString(22)}})
records.Add(&NSRecord{Id: 2, Name: "@", Value: "@", Type: dnsconfigs.RecordTypeA})
records.Add(&NSRecord{Id: 3, Name: "*", Value: "*", Type: dnsconfigs.RecordTypeA})
records.Add(&NSRecord{Id: 4, Name: "hello", Value: "HELLO", Type: dnsconfigs.RecordTypeA})
records.Add(&NSRecord{Id: 41, Name: "hello", Value: "HELLO1", Type: dnsconfigs.RecordTypeA})
records.Add(&NSRecord{Id: 42, Name: "hello", Value: "HELLO2", Type: dnsconfigs.RecordTypeA})
records.Add(&NSRecord{Id: 43, Name: "hello", Value: "HELLO3", Type: dnsconfigs.RecordTypeA})
records.Add(&NSRecord{Id: 5, Name: "*.world", Value: "*.world", Type: dnsconfigs.RecordTypeA})
for _, name := range []string{"", "hello", "world", "hello.world", "hello.world2"} {
record, _ := records.Find([]string{RouteIdString(11), RouteIdString(22)}, name, dnsconfigs.RecordTypeA, nil, false)
if record == nil {
t.Fatal("'" + name + "' record should not be nil")
}
t.Log(name, record)
}
}
func TestDomainRecords_Remove(t *testing.T) {
var records = NewDomainRecords()
records.Add(&NSRecord{Id: 1, Name: "", Value: "1", Type: dnsconfigs.RecordTypeA})
records.Add(&NSRecord{Id: 2, Name: "@", Value: "@", Type: dnsconfigs.RecordTypeA})
records.Add(&NSRecord{Id: 3, Name: "*", Value: "*", Type: dnsconfigs.RecordTypeA})
records.Add(&NSRecord{Id: 4, Name: "hello", Value: "HELLO", Type: dnsconfigs.RecordTypeA})
records.Add(&NSRecord{Id: 5, Name: "*.world", Value: "*.world", Type: dnsconfigs.RecordTypeA})
records.Remove(1)
records.Remove(2)
records.Remove(3)
records.Remove(4)
//records.Remove(5)
logs.PrintAsJSON(records, t)
}

View File

@@ -0,0 +1,134 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus
package models
import (
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/iwind/TeaGo/types"
)
func RouteIdString(routeId int64) string {
return "id:" + types.String(routeId)
}
type RouteRecords struct {
routeRecordsMap map[string]*RecordIds // routeCode => { recordId1, recordId2, ... }
recordsMap map[int64]*NSRecord // recordId => *NSRecord
}
func NewRouteRecords() *RouteRecords {
return &RouteRecords{
routeRecordsMap: map[string]*RecordIds{},
recordsMap: map[int64]*NSRecord{},
}
}
func (this *RouteRecords) Add(record *NSRecord) {
// 先删除
this.remove(record.Id)
// 添加记录
this.recordsMap[record.Id] = record
// 添加线路
var routeIds = record.RouteIds
if len(routeIds) == 0 || (len(routeIds) == 1 && routeIds[0] == "") {
routeIds = []string{"default"}
}
for _, routeId := range routeIds {
recordIds, ok := this.routeRecordsMap[routeId]
if !ok {
recordIds = NewRecordIds()
this.routeRecordsMap[routeId] = recordIds
}
recordIds.Add(record.Id, record.Weight)
}
}
// Find 查找与线路匹配的记录
// strictMode 表示是否严格匹配线路
func (this *RouteRecords) Find(routeCodes []string, config *dnsconfigs.NSAnswerConfig, strictMode bool) (records []*NSRecord, routeCode string) {
if config == nil {
config = dnsconfigs.DefaultNSAnswerConfig()
}
var maxSize = int(config.MaxSize)
if maxSize <= 0 {
maxSize = dnsconfigs.NSAnswerDefaultSize
}
// 查找匹配的线路
for _, routeId := range routeCodes {
recordIds, ok := this.routeRecordsMap[routeId]
if ok && !recordIds.IsEmpty() {
return this.recordsWithIds(recordIds, config.Mode, maxSize), routeId
}
}
// 查找默认线路
recordIds, ok := this.routeRecordsMap["default"]
if ok && !recordIds.IsEmpty() {
return this.recordsWithIds(recordIds, config.Mode, maxSize), "default"
}
// 随机一个
if !strictMode {
for _, record := range this.recordsMap {
return []*NSRecord{record}, "default"
}
}
return nil, ""
}
func (this *RouteRecords) Remove(recordId int64) {
this.remove(recordId)
}
func (this *RouteRecords) IsEmpty() bool {
return len(this.recordsMap) == 0
}
func (this *RouteRecords) remove(recordId int64) {
oldRecord, ok := this.recordsMap[recordId]
if !ok {
return
}
delete(this.recordsMap, recordId)
var oldRouteIds = oldRecord.RouteIds
if len(oldRouteIds) == 0 || (len(oldRouteIds) == 1 && oldRouteIds[0] == "") {
oldRouteIds = []string{"default"}
}
for _, routeId := range oldRouteIds {
recordIds, ok := this.routeRecordsMap[routeId]
if ok {
recordIds.Remove(recordId)
if recordIds.IsEmpty() {
delete(this.routeRecordsMap, routeId)
}
}
}
}
func (this *RouteRecords) recordsWithIds(recordIds *RecordIds, mode dnsconfigs.NSAnswerMode, maxSize int) (records []*NSRecord) {
// round-robin
if mode == dnsconfigs.NSAnswerModeRoundRobin {
var recordId = recordIds.NextId()
if recordId > 0 {
return []*NSRecord{this.recordsMap[recordId]}
}
}
// random
var randomIds = recordIds.RandomIds(maxSize)
for _, randomId := range randomIds {
records = append(records, this.recordsMap[randomId])
}
return
}

View File

@@ -0,0 +1,118 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus
package models
import (
"github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/rands"
"testing"
)
func TestRouteRecords_Add(t *testing.T) {
var records = NewRouteRecords()
{
records.Add(&NSRecord{Id: 1, RouteIds: []string{"CN"}})
records.Add(&NSRecord{Id: 2})
logs.PrintAsJSON(records.routeRecordsMap, t)
logs.PrintAsJSON(records.recordsMap, t)
}
}
func TestRouteRecords_Add2(t *testing.T) {
var records = NewRouteRecords()
{
records.Add(&NSRecord{Id: 1, RouteIds: []string{RouteIdString(11), RouteIdString(22)}})
records.Add(&NSRecord{Id: 2, RouteIds: []string{RouteIdString(11), RouteIdString(22)}})
logs.PrintAsJSON(records.routeRecordsMap, t)
logs.PrintAsJSON(records.recordsMap, t)
}
}
func TestRouteRecords_Add3(t *testing.T) {
var records = NewRouteRecords()
{
records.Add(&NSRecord{Id: 1, RouteIds: []string{RouteIdString(11), RouteIdString(22)}})
records.Add(&NSRecord{Id: 2, RouteIds: []string{RouteIdString(11), RouteIdString(22)}})
records.Add(&NSRecord{Id: 2, RouteIds: []string{RouteIdString(33), RouteIdString(44)}}) // duplicated
logs.PrintAsJSON(records.routeRecordsMap, t)
logs.PrintAsJSON(records.recordsMap, t)
}
}
func TestRouteRecords_Remove(t *testing.T) {
var records = NewRouteRecords()
records.Add(&NSRecord{Id: 1})
records.Add(&NSRecord{Id: 2, RouteIds: []string{RouteIdString(11), RouteIdString(22)}})
records.Add(&NSRecord{Id: 3, RouteIds: []string{RouteIdString(11), RouteIdString(22)}})
records.Add(&NSRecord{Id: 4, RouteIds: []string{RouteIdString(11)}})
t.Log("===before===")
logs.PrintAsJSON(records.routeRecordsMap, t)
logs.PrintAsJSON(records.recordsMap, t)
t.Log("===after===")
//records.Remove(1)
records.Remove(2)
logs.PrintAsJSON(records.routeRecordsMap, t)
logs.PrintAsJSON(records.recordsMap, t)
}
func TestRouteRecords_Find(t *testing.T) {
var records = NewRouteRecords()
records.Add(&NSRecord{Id: 1})
records.Add(&NSRecord{Id: 2, RouteIds: []string{RouteIdString(11), RouteIdString(22)}})
records.Add(&NSRecord{Id: 3, RouteIds: []string{RouteIdString(11), RouteIdString(22)}})
records.Add(&NSRecord{Id: 4, RouteIds: []string{RouteIdString(11)}})
for _, routeIds := range [][]string{
{},
{RouteIdString(11)},
{RouteIdString(22)},
{RouteIdString(100)},
} {
record, routeId := records.Find(routeIds, nil, false)
t.Logf("routeIds: %v, record: %#v, route: %s", routeIds, record, routeId)
}
}
func TestRouteRecords_Find_Balance(t *testing.T) {
var records = NewRouteRecords()
records.Add(&NSRecord{Id: 1, RouteIds: []string{"aa"}})
records.Add(&NSRecord{Id: 2, RouteIds: []string{"aa", "bb", "default"}})
records.Add(&NSRecord{Id: 3})
records.Add(&NSRecord{Id: 4})
for _, route := range []string{"", "default", "aa", "bb", "cc"} {
var m = map[int64]int{} // id => count
for i := 0; i < 1_000_000; i++ {
var records, _ = records.Find([]string{route}, nil, false)
for _, record := range records {
m[record.Id]++
}
}
t.Logf("%s: %+v", route, m)
}
}
func BenchmarkRouteRecords_Remove(b *testing.B) {
var records = NewRouteRecords()
for i := 0; i < 1_000_000; i++ {
records.Add(&NSRecord{Id: int64(i), RouteIds: []string{RouteIdString(int64(i % 100))}})
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
records.Remove(int64(rands.Int(0, 100)))
}
}
func BenchmarkRouteRecords_Find(b *testing.B) {
var records = NewRouteRecords()
for i := 0; i < 1_000_000; i++ {
records.Add(&NSRecord{Id: int64(i), RouteIds: []string{RouteIdString(int64(i % 100))}})
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = records.Find([]string{RouteIdString(int64(i % 200))}, nil, false)
}
}

View File

@@ -0,0 +1,24 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package models
import (
"github.com/TeaOSLab/EdgeCommon/pkg/iplibrary"
"net"
)
var DefaultRegionResolver = &RegionResolver{}
type RegionResolver struct {
}
func (this *RegionResolver) Resolve(ip net.IP) (countryId int64, provinceId int64, cityId int64, providerId int64) {
var result = iplibrary.Lookup(ip)
if result != nil && result.IsOk() {
countryId = result.CountryId()
provinceId = result.ProvinceId()
cityId = result.CityId()
providerId = result.ProviderId()
}
return
}

View File

@@ -0,0 +1,10 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package monitor
// ItemValue 数据值定义
type ItemValue struct {
Item string
ValueJSON []byte
CreatedAt int64
}

View File

@@ -0,0 +1,89 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package monitor
import (
"encoding/json"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
"github.com/TeaOSLab/EdgeDNS/internal/events"
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
"github.com/iwind/TeaGo/maps"
"time"
)
var SharedValueQueue = NewValueQueue()
func init() {
if !teaconst.IsMain {
return
}
events.On(events.EventStart, func() {
go SharedValueQueue.Start()
})
}
// ValueQueue 数据记录队列
type ValueQueue struct {
valuesChan chan *ItemValue
}
func NewValueQueue() *ValueQueue {
return &ValueQueue{
valuesChan: make(chan *ItemValue, 1024),
}
}
// Start 启动队列
func (this *ValueQueue) Start() {
// 这里单次循环就行因为Loop里已经使用了Range通道
err := this.Loop()
if err != nil {
remotelogs.Error("MONITOR_QUEUE", err.Error())
}
}
// Add 添加数据
func (this *ValueQueue) Add(item string, value maps.Map) {
valueJSON, err := json.Marshal(value)
if err != nil {
remotelogs.Error("MONITOR_QUEUE", "marshal value error: "+err.Error())
return
}
select {
case this.valuesChan <- &ItemValue{
Item: item,
ValueJSON: valueJSON,
CreatedAt: time.Now().Unix(),
}:
default:
}
}
// Loop 单次循环
func (this *ValueQueue) Loop() error {
rpcClient, err := rpc.SharedRPC()
if err != nil {
return err
}
for value := range this.valuesChan {
_, err = rpcClient.NodeValueRPC.CreateNodeValue(rpcClient.Context(), &pb.CreateNodeValueRequest{
Item: value.Item,
ValueJSON: value.ValueJSON,
CreatedAt: value.CreatedAt,
})
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("MONITOR", err.Error())
} else {
remotelogs.Error("MONITOR", err.Error())
}
continue
}
}
return nil
}

View File

@@ -0,0 +1,256 @@
package nodes
import (
"bytes"
"context"
"encoding/json"
"github.com/TeaOSLab/EdgeCommon/pkg/errors"
"github.com/TeaOSLab/EdgeCommon/pkg/messageconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
"github.com/TeaOSLab/EdgeDNS/internal/events"
"github.com/TeaOSLab/EdgeDNS/internal/firewalls"
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
"github.com/TeaOSLab/EdgeDNS/internal/utils"
executils "github.com/TeaOSLab/EdgeDNS/internal/utils/exec"
"github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/maps"
"os/exec"
"regexp"
"runtime"
"strconv"
"time"
)
type APIStream struct {
stream pb.NSNodeService_NsNodeStreamClient
isQuiting bool
cancelFunc context.CancelFunc
}
func NewAPIStream() *APIStream {
return &APIStream{}
}
func (this *APIStream) Start() {
events.On(events.EventQuit, func() {
this.isQuiting = true
if this.cancelFunc != nil {
this.cancelFunc()
}
})
for {
if this.isQuiting {
return
}
err := this.loop()
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("API_STREAM", err.Error())
} else {
remotelogs.Error("API_STREAM", err.Error())
}
time.Sleep(10 * time.Second)
continue
}
time.Sleep(1 * time.Second)
}
}
func (this *APIStream) loop() error {
rpcClient, err := rpc.SharedRPC()
if err != nil {
return errors.Wrap(err)
}
ctx, cancelFunc := context.WithCancel(rpcClient.Context())
this.cancelFunc = cancelFunc
defer func() {
cancelFunc()
}()
nodeStream, err := rpcClient.NSNodeRPC.NsNodeStream(ctx)
if err != nil {
if this.isQuiting {
return nil
}
return errors.Wrap(err)
}
this.stream = nodeStream
for {
if this.isQuiting {
logs.Println("API_STREAM", "quit")
break
}
message, err := nodeStream.Recv()
if err != nil {
if this.isQuiting {
remotelogs.Println("API_STREAM", "quit")
return nil
}
return errors.Wrap(err)
}
// 处理消息
switch message.Code {
case messageconfigs.NSMessageCodeConnectedAPINode: // 连接API节点成功
err = this.handleConnectedAPINode(message)
case messageconfigs.NSMessageCodeNewNodeTask: // 有新的任务
err = this.handleNewNodeTask(message)
case messageconfigs.NSMessageCodeCheckSystemdService: // 检查Systemd服务
err = this.handleCheckSystemdService(message)
case messageconfigs.MessageCodeCheckLocalFirewall: // 检查本地防火墙
err = this.handleCheckLocalFirewall(message)
default:
err = this.handleUnknownMessage(message)
}
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("API_STREAM", "handle message failed: "+err.Error())
} else {
remotelogs.Error("API_STREAM", "handle message failed: "+err.Error())
}
}
}
return nil
}
// 连接API节点成功
func (this *APIStream) handleConnectedAPINode(message *pb.NSNodeStreamMessage) error {
// 更改连接的APINode信息
if len(message.DataJSON) == 0 {
return nil
}
msg := &messageconfigs.ConnectedAPINodeMessage{}
err := json.Unmarshal(message.DataJSON, msg)
if err != nil {
return errors.Wrap(err)
}
_, err = rpc.SharedRPC()
if err != nil {
return errors.Wrap(err)
}
remotelogs.Println("API_STREAM", "connected to api node '"+strconv.FormatInt(msg.APINodeId, 10)+"'")
return nil
}
// 处理配置变化
func (this *APIStream) handleNewNodeTask(message *pb.NSNodeStreamMessage) error {
select {
case nodeTaskNotify <- true:
default:
}
this.replyOk(message.RequestId, "ok")
return nil
}
// 检查Systemd服务
func (this *APIStream) handleCheckSystemdService(message *pb.NSNodeStreamMessage) error {
systemctl, err := executils.LookPath("systemctl")
if err != nil {
this.replyFail(message.RequestId, "'systemctl' not found")
return nil
}
if len(systemctl) == 0 {
this.replyFail(message.RequestId, "'systemctl' not found")
return nil
}
cmd := utils.NewCommandExecutor()
shortName := teaconst.SystemdServiceName
cmd.Add(systemctl, "is-enabled", shortName)
output, err := cmd.Run()
if err != nil {
this.replyFail(message.RequestId, "'systemctl' command error: "+err.Error())
return nil
}
if output == "enabled" {
this.replyOk(message.RequestId, "ok")
} else {
this.replyFail(message.RequestId, "not installed")
}
return nil
}
// 检查本地防火墙
func (this *APIStream) handleCheckLocalFirewall(message *pb.NSNodeStreamMessage) error {
var dataMessage = &messageconfigs.CheckLocalFirewallMessage{}
err := json.Unmarshal(message.DataJSON, dataMessage)
if err != nil {
this.replyFail(message.RequestId, "decode message data failed: "+err.Error())
return nil
}
// nft
if dataMessage.Name == "nftables" {
if runtime.GOOS != "linux" {
this.replyFail(message.RequestId, "not Linux system")
return nil
}
nft, err := executils.LookPath("nft")
if err != nil {
this.replyFail(message.RequestId, "'nft' not found: "+err.Error())
return nil
}
var cmd = exec.Command(nft, "--version")
var output = &bytes.Buffer{}
cmd.Stdout = output
err = cmd.Run()
if err != nil {
this.replyFail(message.RequestId, "get version failed: "+err.Error())
return nil
}
var outputString = output.String()
var versionMatches = regexp.MustCompile(`nftables v([\d.]+)`).FindStringSubmatch(outputString)
if len(versionMatches) <= 1 {
this.replyFail(message.RequestId, "can not get nft version")
return nil
}
var version = versionMatches[1]
var result = maps.Map{
"version": version,
}
var protectionConfig = sharedNodeConfig.DDoSProtection
err = firewalls.SharedDDoSProtectionManager.Apply(protectionConfig)
if err != nil {
this.replyFail(message.RequestId, dataMessage.Name+"was installed, but apply DDoS protection config failed: "+err.Error())
} else {
this.replyOk(message.RequestId, string(result.AsJSON()))
}
} else {
this.replyFail(message.RequestId, "invalid firewall name '"+dataMessage.Name+"'")
}
return nil
}
// 处理未知消息
func (this *APIStream) handleUnknownMessage(message *pb.NSNodeStreamMessage) error {
this.replyFail(message.RequestId, "unknown message code '"+message.Code+"'")
return nil
}
// 回复失败
func (this *APIStream) replyFail(requestId int64, message string) {
_ = this.stream.Send(&pb.NSNodeStreamMessage{RequestId: requestId, IsOk: false, Message: message})
}
// 回复成功
func (this *APIStream) replyOk(requestId int64, message string) {
_ = this.stream.Send(&pb.NSNodeStreamMessage{RequestId: requestId, IsOk: true, Message: message})
}

View File

@@ -0,0 +1,455 @@
//go:build plus
package nodes
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/iplibrary"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
"github.com/TeaOSLab/EdgeDNS/internal/agents"
"github.com/TeaOSLab/EdgeDNS/internal/configs"
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
"github.com/TeaOSLab/EdgeDNS/internal/dbs"
"github.com/TeaOSLab/EdgeDNS/internal/events"
"github.com/TeaOSLab/EdgeDNS/internal/firewalls"
"github.com/TeaOSLab/EdgeDNS/internal/goman"
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
"github.com/TeaOSLab/EdgeDNS/internal/utils"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
"github.com/iwind/gosock/pkg/gosock"
"log"
"os"
"os/exec"
"os/signal"
"runtime"
"runtime/debug"
"syscall"
"time"
)
var DaemonIsOn = false
var DaemonPid = 0
var nodeTaskNotify = make(chan bool, 8)
var sharedDomainManager *DomainManager
var sharedRecordManager *RecordManager
var sharedRouteManager *RouteManager
var sharedKeyManager *KeyManager
var sharedNodeConfig = &dnsconfigs.NSNodeConfig{}
func NewDNSNode() *DNSNode {
return &DNSNode{
sock: gosock.NewTmpSock(teaconst.ProcessName),
}
}
type DNSNode struct {
sock *gosock.Sock
RPC *rpc.RPCClient
}
func (this *DNSNode) Start() {
// 设置netdns
// 这个需要放在所有网络访问的最前面
_ = os.Setenv("GODEBUG", "netdns=go")
// 判断是否在守护进程下
_, ok := os.LookupEnv("EdgeDaemon")
if ok {
remotelogs.Println("DNS_NODE", "start from daemon")
DaemonIsOn = true
DaemonPid = os.Getppid()
}
// 设置DNS解析库
err := os.Setenv("GODEBUG", "netdns=go")
if err != nil {
remotelogs.Error("DNS_NODE", "[DNS_RESOLVER]set env failed: "+err.Error())
}
// 处理异常
this.handlePanic()
// 监听signal
this.listenSignals()
// 本地Sock
err = this.listenSock()
if err != nil {
logs.Println("[DNS_NODE]" + err.Error())
return
}
// 启动IP库
remotelogs.Println("DNS_NODE", "initializing ip library ...")
err = iplibrary.InitPlus()
if err != nil {
remotelogs.Error("DNS_NODE", "initialize ip library failed: "+err.Error())
}
runtime.GC()
debug.FreeOSMemory()
// 触发事件
events.Notify(events.EventStart)
// 监控状态
go NewNodeStatusExecutor().Listen()
// 连接API
go NewAPIStream().Start()
// 启动
go this.start()
// Hold住进程
logs.Println("[DNS_NODE]started")
select {}
}
// Daemon 实现守护进程
func (this *DNSNode) Daemon() {
var isDebug = lists.ContainsString(os.Args, "debug")
for {
conn, err := this.sock.Dial()
if err != nil {
if isDebug {
log.Println("[DAEMON]starting ...")
}
// 尝试启动
err = func() error {
exe, err := os.Executable()
if err != nil {
return err
}
// 可以标记当前是从守护进程启动的
_ = os.Setenv("EdgeDaemon", "on")
_ = os.Setenv("EdgeBackground", "on")
var cmd = exec.Command(exe)
err = cmd.Start()
if err != nil {
return err
}
err = cmd.Wait()
if err != nil {
return err
}
return nil
}()
if err != nil {
if isDebug {
log.Println("[DAEMON]", err)
}
time.Sleep(1 * time.Second)
} else {
time.Sleep(5 * time.Second)
}
} else {
_ = conn.Close()
time.Sleep(5 * time.Second)
}
}
}
// Test 测试配置
func (this *DNSNode) Test() error {
// 检查是否能连接API
rpcClient, err := rpc.SharedRPC()
if err != nil {
return fmt.Errorf("test rpc failed: %w", err)
}
_, err = rpcClient.APINodeRPC.FindCurrentAPINodeVersion(rpcClient.Context(), &pb.FindCurrentAPINodeVersionRequest{})
if err != nil {
return fmt.Errorf("test rpc failed: %w", err)
}
return nil
}
// InstallSystemService 安装系统服务
func (this *DNSNode) InstallSystemService() error {
shortName := teaconst.SystemdServiceName
exe, err := os.Executable()
if err != nil {
return err
}
manager := utils.NewServiceManager(shortName, teaconst.ProductName)
err = manager.Install(exe, []string{})
if err != nil {
return err
}
return nil
}
// 监听一些信号
func (this *DNSNode) listenSignals() {
var queue = make(chan os.Signal, 8)
signal.Notify(queue, syscall.SIGTERM, syscall.SIGINT, syscall.SIGKILL, syscall.SIGQUIT)
goman.New(func() {
for range queue {
time.Sleep(100 * time.Millisecond)
utils.Exit()
return
}
})
}
// 监听本地sock
func (this *DNSNode) listenSock() error {
// 检查是否在运行
if this.sock.IsListening() {
reply, err := this.sock.Send(&gosock.Command{Code: "pid"})
if err == nil {
return errors.New("error: the process is already running, pid: " + maps.NewMap(reply.Params).GetString("pid"))
} else {
return errors.New("error: the process is already running")
}
}
// 启动监听
go func() {
this.sock.OnCommand(func(cmd *gosock.Command) {
switch cmd.Code {
case "pid":
_ = cmd.Reply(&gosock.Command{
Code: "pid",
Params: map[string]interface{}{
"pid": os.Getpid(),
},
})
case "info":
exePath, _ := os.Executable()
_ = cmd.Reply(&gosock.Command{
Code: "info",
Params: map[string]interface{}{
"pid": os.Getpid(),
"version": teaconst.Version,
"path": exePath,
},
})
case "stop":
_ = cmd.ReplyOk()
// 退出主进程
events.Notify(events.EventQuit)
time.Sleep(100 * time.Millisecond)
os.Exit(0)
case "gc":
runtime.GC()
debug.FreeOSMemory()
_ = cmd.ReplyOk()
}
})
err := this.sock.Listen()
if err != nil {
logs.Println("NODE", err.Error())
}
}()
events.On(events.EventQuit, func() {
logs.Println("[DNS_NODE]", "quit unix sock")
_ = this.sock.Close()
})
return nil
}
// 启动
func (this *DNSNode) start() {
client, err := rpc.SharedRPC()
if err != nil {
remotelogs.Error("DNS_NODE", err.Error())
return
}
this.RPC = client
tryTimes := 0
var configJSON []byte
for {
resp, err := client.NSNodeRPC.FindCurrentNSNodeConfig(client.Context(), &pb.FindCurrentNSNodeConfigRequest{})
if err != nil {
tryTimes++
if tryTimes%10 == 0 {
remotelogs.Error("NODE", "read config from API failed: "+err.Error())
}
time.Sleep(1 * time.Second)
// 不做长时间的无意义的重试
if tryTimes > 1000 {
remotelogs.Error("NODE", "load failed: unable to read config from API")
return
}
} else {
configJSON = resp.NsNodeJSON
break
}
}
if len(configJSON) == 0 {
remotelogs.Error("NODE", "can not find node config")
return
}
var config = &dnsconfigs.NSNodeConfig{}
err = json.Unmarshal(configJSON, config)
if err != nil {
remotelogs.Error("NODE", "decode config failed: "+err.Error())
return
}
err = config.Init(context.TODO())
if err != nil {
remotelogs.Error("NODE", "init config failed: "+err.Error())
return
}
sharedNodeConfig = config
configs.SharedNodeConfig = config
events.Notify(events.EventReload)
sharedNodeConfigManager.reload(config)
apiConfig, _ := configs.SharedAPIConfig()
if apiConfig != nil {
apiConfig.NumberId = config.Id
}
var db = dbs.NewDB(Tea.Root + "/data/data-" + types.String(config.Id) + "-" + config.NodeId + "-v0.1.0.db")
err = db.Init()
if err != nil {
remotelogs.Error("NODE", "init database failed: "+err.Error())
return
}
go sharedNodeConfigManager.Start()
sharedDomainManager = NewDomainManager(db, config.ClusterId)
go sharedDomainManager.Start()
sharedRecordManager = NewRecordManager(db)
go sharedRecordManager.Start()
sharedRouteManager = NewRouteManager(db)
go sharedRouteManager.Start()
sharedKeyManager = NewKeyManager(db)
go sharedKeyManager.Start()
agents.SharedManager = agents.NewManager(db)
go agents.SharedManager.Start()
// 发送通知这里发送通知需要在DomainManager、RecordeManager等加载完成之后
time.Sleep(1 * time.Second)
events.Notify(events.EventLoaded)
// 启动循环
go this.loop()
}
// 更新配置Loop
func (this *DNSNode) loop() {
var ticker = time.NewTicker(60 * time.Second)
for {
select {
case <-ticker.C:
case <-nodeTaskNotify:
}
err := this.processTasks()
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("DNS_NODE", "process tasks: "+err.Error())
} else {
remotelogs.Error("DNS_NODE", "process tasks: "+err.Error())
}
}
}
}
// 处理任务
func (this *DNSNode) processTasks() error {
rpcClient, err := rpc.SharedRPC()
if err != nil {
return err
}
// 所有的任务
tasksResp, err := rpcClient.NodeTaskRPC.FindNodeTasks(rpcClient.Context(), &pb.FindNodeTasksRequest{})
if err != nil {
return err
}
for _, task := range tasksResp.NodeTasks {
switch task.Type {
case "nsConfigChanged":
sharedNodeConfigManager.NotifyChange()
case "nsDomainChanged":
sharedDomainManager.NotifyUpdate()
case "nsRecordChanged":
sharedRecordManager.NotifyUpdate()
case "nsRouteChanged":
sharedRouteManager.NotifyUpdate()
case "nsKeyChanged":
sharedKeyManager.NotifyUpdate()
case "nsDDoSProtectionChanged":
err := this.updateDDoS(rpcClient)
if err != nil {
remotelogs.Error("DNS_NODE", "apply DDoS config failed: "+err.Error())
}
}
_, err = rpcClient.NodeTaskRPC.ReportNodeTaskDone(rpcClient.Context(), &pb.ReportNodeTaskDoneRequest{
NodeTaskId: task.Id,
IsOk: true,
Error: "",
})
if err != nil {
return err
}
}
return nil
}
func (this *DNSNode) updateDDoS(rpcClient *rpc.RPCClient) error {
resp, err := rpcClient.NSNodeRPC.FindNSNodeDDoSProtection(rpcClient.Context(), &pb.FindNSNodeDDoSProtectionRequest{})
if err != nil {
return err
}
if len(resp.DdosProtectionJSON) == 0 {
if sharedNodeConfig != nil {
sharedNodeConfig.DDoSProtection = nil
}
} else {
var ddosProtectionConfig = &ddosconfigs.ProtectionConfig{}
err = json.Unmarshal(resp.DdosProtectionJSON, ddosProtectionConfig)
if err != nil {
return fmt.Errorf("decode DDoS protection config failed: %w", err)
}
if sharedNodeConfig != nil {
sharedNodeConfig.DDoSProtection = ddosProtectionConfig
}
err = firewalls.SharedDDoSProtectionManager.Apply(ddosProtectionConfig)
if err != nil {
// 不阻塞
remotelogs.Error("NODE", "apply DDoS protection failed: "+err.Error())
}
}
return nil
}

View File

@@ -0,0 +1,306 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build plus
package nodes_test
import (
"crypto/tls"
"github.com/miekg/dns"
"testing"
"time"
)
func TestDNS_Query_A_UDP(t *testing.T) {
type query struct {
Domain string
Type uint16
}
c := new(dns.Client)
for _, query := range []query{
{"hello.world.teaos.cn", dns.TypeA},
{"cdn.teaos.cn", dns.TypeA},
{"hello.teaos.cn", dns.TypeCNAME},
{"hello.teaos.cn", dns.TypeA},
{"edgecdn.teaos.cn", dns.TypeA},
} {
var m = new(dns.Msg)
m.SetQuestion(query.Domain+".", query.Type)
r, _, err := c.Exchange(m, "127.0.0.1:53")
if err != nil {
t.Fatalf("failed to exchange: %v", err)
}
for _, a := range r.Answer {
t.Log(a)
}
}
}
func TestDNS_Query_A_Many(t *testing.T) {
type queryDef struct {
Domain string
Type uint16
}
var c = new(dns.Client)
for i := 0; i < 10000; i++ {
for _, query := range []queryDef{
{"hello.goedge.cn", dns.TypeA},
} {
var m = new(dns.Msg)
m.SetQuestion(query.Domain+".", query.Type)
r, _, err := c.Exchange(m, "127.0.0.1:53")
if err != nil {
t.Fatalf("failed to exchange: %v", err)
}
_ = r
}
}
}
func TestDNS_Query_CNAME(t *testing.T) {
type query struct {
Domain string
Type uint16
}
c := new(dns.Client)
for _, query := range []query{
{"hello.teaos.cn", dns.TypeA},
} {
m := new(dns.Msg)
m.SetQuestion(query.Domain+".", query.Type)
r, _, err := c.Exchange(m, "127.0.0.1:53")
if err != nil {
t.Fatalf("failed to exchange: %v", err)
}
for _, a := range r.Answer {
t.Log(a)
}
}
}
func TestDNS_Query_A_TCP(t *testing.T) {
type query struct {
Domain string
Type uint16
}
c := new(dns.Client)
for _, query := range []query{
{"hello.world.teaos.cn", dns.TypeA},
{"cdn.teaos.cn", dns.TypeA},
{"hello.teaos.cn", dns.TypeCNAME},
{"goedge.cn", dns.TypeA},
{"www.goedge.cn", dns.TypeA},
} {
m := new(dns.Msg)
m.SetQuestion(query.Domain+".", query.Type)
//r, _, err := c.Exchange(m, "127.0.0.1:54")
conn, err := dns.Dial("tcp", "127.0.0.1:53")
if err != nil {
t.Fatal(err)
}
r, _, err := c.ExchangeWithConn(m, conn)
if err != nil {
t.Fatalf("failed to exchange: %v", err)
}
for _, a := range r.Answer {
t.Log(a)
}
}
}
func TestDNS_Query_A_TLS(t *testing.T) {
type query struct {
Domain string
Type uint16
}
c := new(dns.Client)
for _, query := range []query{
{"hello.world.teaos.cn", dns.TypeA},
{"cdn.teaos.cn", dns.TypeA},
{"hello.teaos.cn", dns.TypeCNAME},
} {
m := new(dns.Msg)
m.SetQuestion(query.Domain+".", query.Type)
//r, _, err := c.Exchange(m, "127.0.0.1:54")
conn, err := dns.DialWithTLS("tcp", "127.0.0.1:853", &tls.Config{
InsecureSkipVerify: true,
})
if err != nil {
t.Fatal(err)
}
r, _, err := c.ExchangeWithConn(m, conn)
if err != nil {
t.Fatalf("failed to exchange: %v", err)
}
for _, a := range r.Answer {
t.Log(a)
}
}
}
func TestDNS_Query_Internet(t *testing.T) {
type query struct {
Domain string
Type uint16
}
c := new(dns.Client)
for _, query := range []query{
{"1.goedge.cn", dns.TypeA},
} {
m := new(dns.Msg)
//m.RecursionDesired = true
m.SetQuestion(query.Domain+".", query.Type)
conn, err := dns.Dial("udp", "ns1.teaos.cn:53")
if err != nil {
t.Fatal(err)
}
r, _, err := c.ExchangeWithConn(m, conn)
if err != nil {
t.Fatalf("failed to exchange: %v", err)
}
for _, a := range r.Answer {
t.Log(a)
}
}
}
func TestDNS_Query_TSIG(t *testing.T) {
type query struct {
Domain string
Type uint16
}
c := new(dns.Client)
c.TsigSecret = map[string]string{"teaos.": "NzhhZDExMzM5NWMwN2Q5OWM5YTFhMzgxZWNkZGMwMDA2ODUzODdiYTM2ODA5N2I2YjYwZWRlNmNlNjlhMzdmM2JmNjcxZmQ4NzVjMjI1Y2QwOTQ2Njk5OWY0MzRkMTJkNTczNjFlZDgwYmQxZWZjZDM4ZjAxNDNmM2Y2NTU1YjE="}
for _, query := range []query{
{"hello.cdn.teaos.cn", dns.TypeA},
{"cdn.teaos.cn", dns.TypeA},
{"hello.teaos.cn", dns.TypeCNAME},
{"hello.teaos.cn", dns.TypeA},
{"edgecdn.teaos.cn", dns.TypeA},
} {
m := new(dns.Msg)
m.SetQuestion(query.Domain+".", query.Type)
m.SetTsig("teaos.", dns.HmacSHA512, 300, time.Now().Unix())
r, _, err := c.Exchange(m, "127.0.0.1:53")
if err != nil {
t.Fatalf("failed to exchange: %v", err)
}
for _, a := range r.Answer {
t.Log(a)
}
}
}
func TestDNS_Query_A_Route(t *testing.T) {
type query struct {
Domain string
Type uint16
}
c := new(dns.Client)
for _, query := range []query{
{"route.com", dns.TypeA},
} {
m := new(dns.Msg)
m.SetQuestion(query.Domain+".", query.Type)
r, _, err := c.Exchange(m, "127.0.0.1:53")
if err != nil {
t.Fatalf("failed to exchange: %v", err)
}
for _, a := range r.Answer {
t.Log(a)
}
}
}
func TestDNS_Query_Recursion(t *testing.T) {
type query struct {
Domain string
Type uint16
}
c := new(dns.Client)
for _, query := range []query{
{"example.org", dns.TypeA},
} {
m := new(dns.Msg)
m.SetQuestion(query.Domain+".", query.Type)
r, _, err := c.Exchange(m, "127.0.0.1:53")
if err != nil {
t.Fatalf("failed to exchange: %v", err)
}
for _, a := range r.Answer {
t.Log(a)
}
}
}
func TestDNS_Flood(t *testing.T) {
type query struct {
Domain string
Type uint16
}
var c = new(dns.Client)
for i := 0; i < 1_000_000; i++ {
for _, query := range []query{
{"hello.world.teaos.cn", dns.TypeA},
{"cdn.teaos.cn", dns.TypeA},
{"hello.teaos.cn", dns.TypeCNAME},
{"hello.teaos.cn", dns.TypeA},
{"edgecdn.teaos.cn", dns.TypeA},
} {
var m = new(dns.Msg)
m.SetQuestion(query.Domain+".", query.Type)
r, _, err := c.Exchange(m, "192.168.2.60:58")
if err != nil {
t.Fatalf("failed to exchange: %v", err)
}
_ = r
}
}
}
func BenchmarkDNSNode(b *testing.B) {
var c = new(dns.Client)
conn, err := c.Dial("192.168.2.60:58")
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
type query struct {
Domain string
Type uint16
}
for i := 0; i < b.N; i++ {
for _, query := range []query{
{"cdn.teaos.cn", dns.TypeA},
} {
var m = new(dns.Msg)
m.SetQuestion(query.Domain+".", query.Type)
_, _, err := c.ExchangeWithConn(m, conn)
if err != nil {
b.Fatal(err)
}
}
}
}

View File

@@ -0,0 +1,179 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes
import (
"encoding/json"
"errors"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/types"
"github.com/miekg/dns"
"net"
"net/http"
"reflect"
)
type HTTPWriter struct {
rawConn net.Conn
rawWriter http.ResponseWriter
contentType string
}
func NewHTTPWriter(rawWriter http.ResponseWriter, rawConn net.Conn, contentType string) *HTTPWriter {
return &HTTPWriter{
rawWriter: rawWriter,
rawConn: rawConn,
contentType: contentType,
}
}
func (this *HTTPWriter) LocalAddr() net.Addr {
return this.rawConn.LocalAddr()
}
func (this *HTTPWriter) RemoteAddr() net.Addr {
return this.rawConn.RemoteAddr()
}
func (this *HTTPWriter) WriteMsg(msg *dns.Msg) error {
if msg == nil {
return errors.New("'msg' should not be nil")
}
msgData, err := this.encodeMsg(msg)
if err != nil {
return err
}
this.rawWriter.Header().Set("Content-Length", types.String(len(msgData)))
this.rawWriter.Header().Set("Content-Type", this.contentType)
// cache-control
if len(msg.Answer) > 0 {
var minTtl uint32
for _, answer := range msg.Answer {
var header = answer.Header()
if header != nil && header.Ttl > 0 && (minTtl == 0 || header.Ttl < minTtl) {
minTtl = header.Ttl
}
}
if minTtl > 0 {
this.rawWriter.Header().Set("Cache-Control", "max-age="+types.String(minTtl))
}
}
this.rawWriter.WriteHeader(http.StatusOK)
_, err = this.rawWriter.Write(msgData)
return err
}
func (this *HTTPWriter) Write(p []byte) (int, error) {
this.rawWriter.Header().Set("Content-Length", types.String(len(p)))
this.rawWriter.WriteHeader(http.StatusOK)
return this.rawWriter.Write(p)
}
func (this *HTTPWriter) Close() error {
return nil
}
func (this *HTTPWriter) TsigStatus() error {
return nil
}
func (this *HTTPWriter) TsigTimersOnly(timersOnly bool) {
}
func (this *HTTPWriter) Hijack() {
hijacker, ok := this.rawWriter.(http.Hijacker)
if ok {
_, _, _ = hijacker.Hijack()
}
}
func (this *HTTPWriter) encodeMsg(msg *dns.Msg) ([]byte, error) {
if this.contentType == "application/x-javascript" || this.contentType == "application/json" {
var result = map[string]any{
"Status": 0,
"TC": msg.Truncated,
"RD": msg.RecursionDesired,
"RA": msg.RecursionAvailable,
"AD": msg.AuthenticatedData,
"CD": msg.CheckingDisabled,
}
// questions
var questionMaps = []map[string]any{}
for _, question := range msg.Question {
questionMaps = append(questionMaps, map[string]any{
"name": question.Name,
"type": question.Qtype,
})
}
result["Question"] = questionMaps
// answers
var answerMaps = []map[string]any{}
for _, answer := range msg.Answer {
var answerMap = map[string]any{
"name": answer.Header().Name,
"type": answer.Header().Rrtype,
"TTL": answer.Header().Ttl,
}
switch x := answer.(type) {
case *dns.A:
answerMap["data"] = x.A.String()
case *dns.AAAA:
answerMap["data"] = x.AAAA.String()
case *dns.CNAME:
answerMap["data"] = x.Target
case *dns.TXT:
answerMap["data"] = x.Txt
case *dns.NS:
answerMap["data"] = x.Ns
case *dns.MX:
answerMap["data"] = x.Mx
answerMap["preference"] = x.Preference
default:
var answerValue = reflect.ValueOf(answer).Elem()
var answerType = answerValue.Type()
var countFields = answerType.NumField()
for i := 0; i < countFields; i++ {
var fieldName = answerType.Field(i).Name
var fieldValue = answerValue.FieldByName(fieldName)
if !fieldValue.IsValid() {
continue
}
var fieldInterface = fieldValue.Interface()
if fieldInterface == nil {
continue
}
_, ok := fieldInterface.(dns.RR_Header)
if ok {
continue
}
if countFields == 2 {
answerMap["data"] = fieldValue.Interface()
} else {
answerMap[fieldName] = fieldValue.Interface()
}
}
}
answerMaps = append(answerMaps, answerMap)
}
result["Answer"] = answerMaps
if Tea.IsTesting() {
return json.MarshalIndent(result, "", " ")
} else {
return json.Marshal(result)
}
} else {
return msg.Pack()
}
}

View File

@@ -0,0 +1,306 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
"github.com/TeaOSLab/EdgeDNS/internal/events"
"github.com/TeaOSLab/EdgeDNS/internal/firewalls"
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
"github.com/TeaOSLab/EdgeDNS/internal/utils"
"github.com/iwind/TeaGo/types"
"runtime"
"sort"
"strings"
"sync"
)
var sharedListenManager *ListenManager = nil
func init() {
if !teaconst.IsMain {
return
}
sharedListenManager = NewListenManager()
events.On(events.EventReload, func() {
sharedListenManager.Update(sharedNodeConfig)
})
events.On(events.EventQuit, func() {
_ = sharedListenManager.ShutdownAll()
})
}
// ListenManager 端口监听管理器
type ListenManager struct {
locker sync.Mutex
serverMap map[string]*Server // addr => *Server
firewalld *firewalls.Firewalld
lastPortStrings string
lastTCPPortRanges [][2]int
lastUDPPortRanges [][2]int
}
// NewListenManager 获取新对象
func NewListenManager() *ListenManager {
return &ListenManager{
serverMap: map[string]*Server{},
firewalld: firewalls.NewFirewalld(),
}
}
// Update 修改配置
func (this *ListenManager) Update(config *dnsconfigs.NSNodeConfig) {
// 构造服务配置
var serverConfigs = []*ServerConfig{}
var serverAddrs = []string{}
// 如果没有配置,则配置一些默认的端口
if config.TCP == nil && config.TLS == nil && config.UDP == nil {
config.TCP = &serverconfigs.TCPProtocolConfig{}
config.TCP.IsOn = true
config.TCP.Listen = []*serverconfigs.NetworkAddressConfig{
{
Protocol: serverconfigs.ProtocolTCP,
MinPort: 53,
MaxPort: 53,
},
}
config.UDP = &serverconfigs.UDPProtocolConfig{}
config.UDP.IsOn = true
config.UDP.Listen = []*serverconfigs.NetworkAddressConfig{
{
Protocol: serverconfigs.ProtocolUDP,
MinPort: 53,
MaxPort: 53,
},
}
}
// 读取配置
if config.TCP != nil && config.TCP.IsOn {
for _, listen := range config.TCP.Listen {
for port := listen.MinPort; port <= listen.MaxPort; port++ {
serverConfigs = append(serverConfigs, &ServerConfig{
Protocol: listen.Protocol,
Host: listen.Host,
Port: port,
SSLPolicy: nil,
})
}
}
}
if config.TLS != nil && config.TLS.IsOn {
for _, listen := range config.TLS.Listen {
if config.TLS.SSLPolicy == nil {
continue
}
for port := listen.MinPort; port <= listen.MaxPort; port++ {
serverConfigs = append(serverConfigs, &ServerConfig{
Protocol: listen.Protocol,
Host: listen.Host,
Port: port,
SSLPolicy: config.TLS.SSLPolicy,
})
}
}
}
if config.DoH != nil && config.DoH.IsOn {
for _, listen := range config.DoH.Listen {
if config.DoH.SSLPolicy == nil {
continue
}
for port := listen.MinPort; port <= listen.MaxPort; port++ {
serverConfigs = append(serverConfigs, &ServerConfig{
Protocol: listen.Protocol,
Host: listen.Host,
Port: port,
SSLPolicy: config.DoH.SSLPolicy,
})
}
}
}
if config.UDP != nil && config.UDP.IsOn {
for _, listen := range config.UDP.Listen {
for port := listen.MinPort; port <= listen.MaxPort; port++ {
serverConfigs = append(serverConfigs, &ServerConfig{
Protocol: listen.Protocol,
Host: listen.Host,
Port: port,
SSLPolicy: nil,
})
}
}
}
// 启动新的
var addrMap = map[string]bool{} // addr => bool
for _, serverConfig := range serverConfigs {
var fullAddr = serverConfig.FullAddr()
serverAddrs = append(serverAddrs, fullAddr)
addrMap[fullAddr] = true
this.locker.Lock()
server, ok := this.serverMap[fullAddr]
this.locker.Unlock()
if !ok {
// 启动新的
var err error
server, err = NewServer(serverConfig)
if err != nil {
remotelogs.Error("LISTEN_MANAGER", "create listener '"+fullAddr+"' failed: "+err.Error())
continue
}
this.locker.Lock()
this.serverMap[fullAddr] = server
this.locker.Unlock()
go func() {
remotelogs.Println("LISTEN_MANAGER", "listen '"+fullAddr+"'")
err = server.ListenAndServe()
if err != nil {
this.locker.Lock()
delete(this.serverMap, fullAddr)
this.locker.Unlock()
remotelogs.Error("LISTEN_MANAGER", "listen '"+fullAddr+"' failed: "+err.Error())
}
}()
} else {
// 更新配置
server.Reload(serverConfig)
}
}
// 停止老的
this.locker.Lock()
for fullAddr, server := range this.serverMap {
_, ok := addrMap[fullAddr]
if !ok {
delete(this.serverMap, fullAddr)
remotelogs.Println("LISTEN_MANAGER", "shutdown "+fullAddr)
err := server.Shutdown()
if err != nil {
remotelogs.Error("LISTEN_MANAGER", "shutdown listener '"+fullAddr+"' failed: "+err.Error())
}
}
}
this.locker.Unlock()
// 添加端口到firewalld
go func() {
this.addToFirewalld(serverAddrs)
}()
}
// ShutdownAll 关闭所有的监听端口
func (this *ListenManager) ShutdownAll() error {
this.locker.Lock()
defer this.locker.Unlock()
var lastErr error
for _, server := range this.serverMap {
err := server.Shutdown()
if err != nil {
lastErr = err
}
}
return lastErr
}
func (this *ListenManager) addToFirewalld(serverAddrs []string) {
if runtime.GOOS != "linux" {
return
}
if this.firewalld == nil || !this.firewalld.IsReady() {
return
}
// 组合端口号
var portStrings = []string{}
var udpPorts = []int{}
var tcpPorts = []int{}
for _, addr := range serverAddrs {
var protocol = "tcp"
if strings.HasPrefix(addr, "udp") {
protocol = "udp"
}
var lastIndex = strings.LastIndex(addr, ":")
if lastIndex > 0 {
var portString = addr[lastIndex+1:]
portStrings = append(portStrings, portString+"/"+protocol)
switch protocol {
case "tcp":
tcpPorts = append(tcpPorts, types.Int(portString))
case "udp":
udpPorts = append(udpPorts, types.Int(portString))
}
}
}
if len(portStrings) == 0 {
return
}
// 检查是否有变化
sort.Strings(portStrings)
var newPortStrings = strings.Join(portStrings, ",")
if newPortStrings == this.lastPortStrings {
return
}
this.lastPortStrings = newPortStrings
remotelogs.Println("FIREWALLD", "opening ports automatically ...")
defer func() {
remotelogs.Println("FIREWALLD", "open ports successfully")
}()
// 合并端口
var tcpPortRanges = utils.MergePorts(tcpPorts)
var udpPortRanges = utils.MergePorts(udpPorts)
defer func() {
this.lastTCPPortRanges = tcpPortRanges
this.lastUDPPortRanges = udpPortRanges
}()
// 删除老的不存在的端口
var tcpPortRangesMap = map[string]bool{}
var udpPortRangesMap = map[string]bool{}
for _, portRange := range tcpPortRanges {
tcpPortRangesMap[this.firewalld.PortRangeString(portRange, "tcp")] = true
}
for _, portRange := range udpPortRanges {
udpPortRangesMap[this.firewalld.PortRangeString(portRange, "udp")] = true
}
for _, portRange := range this.lastTCPPortRanges {
var s = this.firewalld.PortRangeString(portRange, "tcp")
_, ok := tcpPortRangesMap[s]
if ok {
continue
}
remotelogs.Println("FIREWALLD", "remove port '"+s+"'")
_ = this.firewalld.RemovePortRangePermanently(portRange, "tcp")
}
for _, portRange := range this.lastUDPPortRanges {
var s = this.firewalld.PortRangeString(portRange, "udp")
_, ok := udpPortRangesMap[s]
if ok {
continue
}
remotelogs.Println("FIREWALLD", "remove port '"+s+"'")
_ = this.firewalld.RemovePortRangePermanently(portRange, "udp")
}
// 添加新的
_ = this.firewalld.AllowPortRangesPermanently(tcpPortRanges, "tcp")
_ = this.firewalld.AllowPortRangesPermanently(udpPortRanges, "udp")
}

View File

@@ -0,0 +1,319 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"encoding/json"
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeDNS/internal/dbs"
"github.com/TeaOSLab/EdgeDNS/internal/models"
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
"github.com/iwind/TeaGo/types"
"strings"
"sync"
"time"
)
// DomainManager 域名管理器
type DomainManager struct {
domainMap map[int64]*models.NSDomain // domainId => domain
namesMap map[string]map[int64]*models.NSDomain // domain name => { domainId => domain }
clusterId int64
db *dbs.DB
version int64
locker *sync.RWMutex
notifier chan bool
}
// NewDomainManager 获取域名管理器对象
func NewDomainManager(db *dbs.DB, clusterId int64) *DomainManager {
return &DomainManager{
db: db,
domainMap: map[int64]*models.NSDomain{},
namesMap: map[string]map[int64]*models.NSDomain{},
clusterId: clusterId,
notifier: make(chan bool, 8),
locker: &sync.RWMutex{},
}
}
// Start 启动自动任务
func (this *DomainManager) Start() {
remotelogs.Println("DOMAIN_MANAGER", "starting ...")
// 从本地数据库中加载数据
err := this.Load()
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("DOMAIN_MANAGER", "load failed: "+err.Error())
} else {
remotelogs.Error("DOMAIN_MANAGER", "load failed: "+err.Error())
}
}
// 初始化运行
err = this.LoopAll()
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("DOMAIN_MANAGER", "loop failed: "+err.Error())
} else {
remotelogs.Error("DOMAIN_MANAGER", "loop failed: "+err.Error())
}
}
// 更新
var ticker = time.NewTicker(20 * time.Second)
for {
select {
case <-ticker.C:
case <-this.notifier:
}
err = this.LoopAll()
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("DOMAIN_MANAGER", "loop failed: "+err.Error())
} else {
remotelogs.Error("DOMAIN_MANAGER", "loop failed: "+err.Error())
}
}
}
}
func (this *DomainManager) LoopAll() error {
for {
hasNext, err := this.Loop()
if err != nil {
return err
}
if !hasNext {
break
}
}
return nil
}
// Load 从数据库中加载数据
func (this *DomainManager) Load() error {
var offset = 0
var size = 10000
for {
domains, err := this.db.ListDomains(this.clusterId, offset, size)
if err != nil {
return err
}
if len(domains) == 0 {
break
}
this.locker.Lock()
for _, domain := range domains {
this.domainMap[domain.Id] = domain
nameMap, ok := this.namesMap[domain.Name]
if ok {
nameMap[domain.Id] = domain
} else {
this.namesMap[domain.Name] = map[int64]*models.NSDomain{
domain.Id: domain,
}
}
if domain.Version > this.version {
this.version = domain.Version
}
}
this.locker.Unlock()
offset += size
}
if this.version > 0 {
this.version++
}
return nil
}
// Loop 单次循环任务
func (this *DomainManager) Loop() (hasNext bool, err error) {
client, err := rpc.SharedRPC()
if err != nil {
return false, err
}
resp, err := client.NSDomainRPC.ListNSDomainsAfterVersion(client.Context(), &pb.ListNSDomainsAfterVersionRequest{
Version: this.version,
Size: 20000,
})
if err != nil {
return false, err
}
var domains = resp.NsDomains
if len(domains) == 0 {
return false, nil
}
for _, domain := range domains {
this.processDomain(domain)
if domain.Version > this.version {
this.version = domain.Version
}
}
this.version++
return true, nil
}
// FindDomain 根据名称查找域名
func (this *DomainManager) FindDomain(name string) (domain *models.NSDomain, ok bool) {
this.locker.RLock()
defer this.locker.RUnlock()
nameMap, ok := this.namesMap[name]
if !ok {
return nil, false
}
for _, domain2 := range nameMap {
return domain2, true
}
return
}
// FindDomainWithId 根据域名ID查询域名
func (this *DomainManager) FindDomainWithId(domainId int64) (domain *models.NSDomain) {
this.locker.RLock()
defer this.locker.RUnlock()
return this.domainMap[domainId]
}
// NotifyUpdate 通知更新
func (this *DomainManager) NotifyUpdate() {
select {
case this.notifier <- true:
default:
}
}
// SplitDomain 分解域名
func (this *DomainManager) SplitDomain(fullDomainName string) (rootDomain *models.NSDomain, recordName string) {
if len(fullDomainName) == 0 {
return
}
fullDomainName = strings.TrimSuffix(fullDomainName, ".") // 去除尾部的点(.
fullDomainName = strings.ToLower(fullDomainName) // 转换为小写
var domainName = fullDomainName
var domain, ok = this.FindDomain(domainName)
if !ok {
for {
var index = strings.Index(domainName, ".")
if index < 0 {
break
}
domainName = domainName[index+1:]
domain, ok = this.FindDomain(domainName)
if ok {
recordName = fullDomainName[:len(fullDomainName)-len(domainName)-1]
break
}
}
}
return domain, recordName
}
// 处理域名
func (this *DomainManager) processDomain(domain *pb.NSDomain) {
if !domain.IsOn || domain.IsDeleted || domain.Status != dnsconfigs.NSDomainStatusVerified {
this.locker.Lock()
delete(this.domainMap, domain.Id)
nameMap, ok := this.namesMap[domain.Name]
if ok {
delete(nameMap, domain.Id)
if len(nameMap) == 0 {
delete(this.namesMap, domain.Name)
}
}
this.locker.Unlock()
// 从数据库中删除
if this.db != nil {
err := this.db.DeleteDomain(domain.Id)
if err != nil {
remotelogs.Error("DOMAIN_MANAGER", "delete domain from db failed: "+err.Error())
}
}
return
}
// 存入数据库
if this.db != nil {
exists, err := this.db.ExistsDomain(domain.Id)
if err != nil {
remotelogs.Error("DOMAIN_MANAGER", "query failed: "+err.Error())
} else {
if exists {
err = this.db.UpdateDomain(domain.Id, domain.NsCluster.Id, domain.UserId, domain.Name, domain.TsigJSON, domain.Version)
if err != nil {
remotelogs.Error("DOMAIN_MANAGER", "update failed: "+err.Error())
}
} else {
err = this.db.InsertDomain(domain.Id, domain.NsCluster.Id, domain.UserId, domain.Name, domain.TsigJSON, domain.Version)
if err != nil {
remotelogs.Error("DOMAIN_MANAGER", "insert failed: "+err.Error())
}
}
}
}
// 同集群的才需要加载
if this.clusterId == domain.NsCluster.Id {
this.locker.Lock()
var tsigConfig = &dnsconfigs.NSTSIGConfig{}
if len(domain.TsigJSON) > 0 {
err := json.Unmarshal(domain.TsigJSON, tsigConfig)
if err != nil {
remotelogs.Error("DOMAIN_MANAGER", "decode TSIG json failed: "+err.Error()+", domain: "+domain.Name+", domainId: "+types.String(domain.Id)+", JSON: "+string(domain.TsigJSON))
}
}
var nsDomain = &models.NSDomain{
Id: domain.Id,
ClusterId: domain.NsCluster.Id,
UserId: domain.UserId,
Name: domain.Name,
TSIG: tsigConfig,
Version: domain.Version,
}
this.domainMap[domain.Id] = nsDomain
nameMap, ok := this.namesMap[domain.Name]
if ok {
nameMap[nsDomain.Id] = nsDomain
} else {
this.namesMap[domain.Name] = map[int64]*models.NSDomain{
nsDomain.Id: nsDomain,
}
}
this.locker.Unlock()
} else {
// 不同集群的删除域名
this.locker.Lock()
delete(this.domainMap, domain.Id)
nameMap, ok := this.namesMap[domain.Name]
if ok {
delete(nameMap, domain.Id)
if len(nameMap) == 0 {
delete(this.namesMap, domain.Name)
}
}
this.locker.Unlock()
}
}

View File

@@ -0,0 +1,63 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"github.com/TeaOSLab/EdgeDNS/internal/dbs"
"github.com/iwind/TeaGo/Tea"
_ "github.com/iwind/TeaGo/bootstrap"
"github.com/iwind/TeaGo/logs"
"testing"
)
func TestDomainManager_Loop(t *testing.T) {
var db = dbs.NewDB(Tea.Root + "/data/data.db")
err := db.Init()
if err != nil {
t.Fatal(err)
}
var manager = NewDomainManager(db, 1)
for i := 0; i < 10; i++ {
_, err := manager.Loop()
if err != nil {
t.Fatal(err)
}
}
logs.PrintAsJSON(manager.domainMap, t)
}
func TestDomainManager_Load(t *testing.T) {
var db = dbs.NewDB(Tea.Root + "/data/data.db")
err := db.Init()
if err != nil {
t.Fatal(err)
}
manager := NewDomainManager(db, 2)
err = manager.Load()
if err != nil {
t.Fatal(err)
}
logs.PrintAsJSON(manager.domainMap, t)
t.Log("version:", manager.version)
}
func TestDomainManager_FindDomain(t *testing.T) {
var db = dbs.NewDB(Tea.Root + "/data/data.db")
err := db.Init()
if err != nil {
t.Fatal(err)
}
var manager = NewDomainManager(db, 2)
err = manager.Load()
if err != nil {
t.Fatal(err)
}
for _, name := range []string{"hello.com", "teaos.cn"} {
domain, ok := manager.FindDomain(name)
t.Log(name, ok, domain)
}
}

View File

@@ -0,0 +1,284 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeDNS/internal/dbs"
"github.com/TeaOSLab/EdgeDNS/internal/models"
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
"sync"
"time"
)
// KeyManager 密钥管理器
type KeyManager struct {
domainKeyMap map[int64]*models.NSKeys // domainId => *NSKeys
zoneKeyMap map[int64]*models.NSKeys // zoneId => *NSKeys
db *dbs.DB
locker sync.RWMutex
version int64
notifier chan bool
}
// NewKeyManager 获取密钥管理器
func NewKeyManager(db *dbs.DB) *KeyManager {
return &KeyManager{
domainKeyMap: map[int64]*models.NSKeys{},
zoneKeyMap: map[int64]*models.NSKeys{},
db: db,
notifier: make(chan bool, 8),
}
}
// Start 启动自动任务
func (this *KeyManager) Start() {
remotelogs.Println("KEY_MANAGER", "starting ...")
// 从本地数据库中加载数据
err := this.Load()
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("KEY_MANAGER", "load failed: "+err.Error())
} else {
remotelogs.Error("KEY_MANAGER", "load failed: "+err.Error())
}
}
// 初始化运行
err = this.LoopAll()
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("KEY_MANAGER", "loop failed: "+err.Error())
} else {
remotelogs.Error("KEY_MANAGER", "loop failed: "+err.Error())
}
}
// 更新
var ticker = time.NewTicker(1 * time.Minute)
for {
select {
case <-ticker.C:
case <-this.notifier:
}
err := this.LoopAll()
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("KEY_MANAGER", "loop failed: "+err.Error())
} else {
remotelogs.Error("KEY_MANAGER", "loop failed: "+err.Error())
}
}
}
}
// Load 从数据库中加载数据
func (this *KeyManager) Load() error {
var offset = 0
var size = 10000
for {
keys, err := this.db.ListKeys(offset, size)
if err != nil {
return err
}
if len(keys) == 0 {
break
}
this.locker.Lock()
for _, key := range keys {
if key.ZoneId > 0 {
keyList, ok := this.zoneKeyMap[key.ZoneId]
if ok {
keyList.Add(key)
} else {
keyList = models.NewNSKeys()
keyList.Add(key)
this.zoneKeyMap[key.ZoneId] = keyList
}
} else if key.DomainId > 0 {
keyList, ok := this.domainKeyMap[key.DomainId]
if ok {
keyList.Add(key)
} else {
keyList = models.NewNSKeys()
keyList.Add(key)
this.domainKeyMap[key.DomainId] = keyList
}
}
if key.Version > this.version {
this.version = key.Version
}
}
this.locker.Unlock()
offset += size
}
if this.version > 0 {
this.version++
}
return nil
}
func (this *KeyManager) LoopAll() error {
for {
hasNext, err := this.Loop()
if err != nil {
return err
}
if !hasNext {
break
}
}
return nil
}
// Loop 单次循环任务
func (this *KeyManager) Loop() (hasNext bool, err error) {
client, err := rpc.SharedRPC()
if err != nil {
return false, err
}
resp, err := client.NSKeyRPC.ListNSKeysAfterVersion(client.Context(), &pb.ListNSKeysAfterVersionRequest{
Version: this.version,
Size: 20000,
})
if err != nil {
return false, err
}
var keys = resp.NsKeys
if len(keys) == 0 {
return false, nil
}
for _, key := range keys {
this.processKey(key)
if key.Version > this.version {
this.version = key.Version
}
}
this.version++
return true, nil
}
func (this *KeyManager) FindKeysWithDomain(domainId int64) []*models.NSKey {
this.locker.RLock()
defer this.locker.RUnlock()
keys, ok := this.domainKeyMap[domainId]
if ok {
return keys.All()
}
return nil
}
// NotifyUpdate 通知更新
func (this *KeyManager) NotifyUpdate() {
select {
case this.notifier <- true:
default:
}
}
// 处理Key
func (this *KeyManager) processKey(key *pb.NSKey) {
if key.NsDomain == nil && key.NsZone == nil {
return
}
if !key.IsOn || key.IsDeleted {
this.locker.Lock()
if key.NsDomain != nil {
list, ok := this.domainKeyMap[key.NsDomain.Id]
if ok {
list.Remove(key.Id)
}
}
if key.NsZone != nil {
list, ok := this.zoneKeyMap[key.NsZone.Id]
if ok {
list.Remove(key.Id)
}
}
this.locker.Unlock()
// 从数据库中删除
if this.db != nil {
err := this.db.DeleteKey(key.Id)
if err != nil {
remotelogs.Error("KEY_MANAGER", "delete key from db failed: "+err.Error())
}
}
return
}
var domainId int64
var zoneId int64
if key.NsDomain != nil {
domainId = key.NsDomain.Id
}
if key.NsZone != nil {
zoneId = key.NsZone.Id
}
// 存入数据库
if this.db != nil {
exists, err := this.db.ExistsKey(key.Id)
if err != nil {
remotelogs.Error("KEY_MANAGER", "query failed: "+err.Error())
} else {
if exists {
err = this.db.UpdateKey(key.Id, domainId, zoneId, key.Algo, key.Secret, key.SecretType, key.Version)
if err != nil {
remotelogs.Error("KEY_MANAGER", "update failed: "+err.Error())
}
} else {
err = this.db.InsertKey(key.Id, domainId, zoneId, key.Algo, key.Secret, key.SecretType, key.Version)
if err != nil {
remotelogs.Error("KEY_MANAGER", "insert failed: "+err.Error())
}
}
}
}
// 加入缓存Map
this.locker.Lock()
var nsKey = &models.NSKey{
Id: key.Id,
DomainId: domainId,
ZoneId: zoneId,
Algo: key.Algo,
Secret: key.Secret,
SecretType: key.SecretType,
Version: key.Version,
}
if zoneId > 0 {
keyList, ok := this.zoneKeyMap[zoneId]
if ok {
keyList.Add(nsKey)
} else {
keyList = models.NewNSKeys()
keyList.Add(nsKey)
this.zoneKeyMap[zoneId] = keyList
}
} else if domainId > 0 {
keyList, ok := this.domainKeyMap[domainId]
if ok {
keyList.Add(nsKey)
} else {
keyList = models.NewNSKeys()
keyList.Add(nsKey)
this.domainKeyMap[domainId] = keyList
}
}
this.locker.Unlock()
}

Some files were not shown because too many files have changed in this diff Show More