1.4.5.2
This commit is contained in:
BIN
EdgeDNS/.DS_Store
vendored
Normal file
BIN
EdgeDNS/.DS_Store
vendored
Normal file
Binary file not shown.
73
EdgeDNS/.golangci.yaml
Normal file
73
EdgeDNS/.golangci.yaml
Normal file
@@ -0,0 +1,73 @@
|
||||
# https://golangci-lint.run/usage/configuration/
|
||||
|
||||
linters:
|
||||
enable-all: true
|
||||
disable:
|
||||
- ifshort
|
||||
- golint
|
||||
- nosnakecase
|
||||
- scopelint
|
||||
- varcheck
|
||||
- structcheck
|
||||
- interfacer
|
||||
- exhaustivestruct
|
||||
- maligned
|
||||
- deadcode
|
||||
- dogsled
|
||||
- wrapcheck
|
||||
- wastedassign
|
||||
- varnamelen
|
||||
- testpackage
|
||||
- thelper
|
||||
- nilerr
|
||||
- sqlclosecheck
|
||||
- paralleltest
|
||||
- nonamedreturns
|
||||
- nlreturn
|
||||
- nakedret
|
||||
- ireturn
|
||||
- interfacebloat
|
||||
- gosmopolitan
|
||||
- gomnd
|
||||
- goerr113
|
||||
- gochecknoglobals
|
||||
- exhaustruct
|
||||
- errorlint
|
||||
- depguard
|
||||
- exhaustive
|
||||
- containedctx
|
||||
- wsl
|
||||
- cyclop
|
||||
- dupword
|
||||
- errchkjson
|
||||
- contextcheck
|
||||
- tagalign
|
||||
- dupl
|
||||
- forbidigo
|
||||
- funlen
|
||||
- goconst
|
||||
- godox
|
||||
- gosec
|
||||
- lll
|
||||
- nestif
|
||||
- revive
|
||||
- unparam
|
||||
- stylecheck
|
||||
- gocritic
|
||||
- gofumpt
|
||||
- gomoddirectives
|
||||
- godot
|
||||
- gofmt
|
||||
- gocognit
|
||||
- mirror
|
||||
- gocyclo
|
||||
- gochecknoinits
|
||||
- gci
|
||||
- maintidx
|
||||
- prealloc
|
||||
- goimports
|
||||
- errname
|
||||
- musttag
|
||||
- forcetypeassert
|
||||
- whitespace
|
||||
- noctx
|
||||
9
EdgeDNS/build/build-all.sh
Normal file
9
EdgeDNS/build/build-all.sh
Normal file
@@ -0,0 +1,9 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
./build.sh linux amd64
|
||||
#./build.sh linux 386
|
||||
./build.sh linux arm64
|
||||
#./build.sh linux mips64
|
||||
#./build.sh linux mips64le
|
||||
#./build.sh darwin amd64
|
||||
#./build.sh darwin arm64
|
||||
110
EdgeDNS/build/build.sh
Normal file
110
EdgeDNS/build/build.sh
Normal file
@@ -0,0 +1,110 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
function build() {
|
||||
ROOT=$(dirname "$0")
|
||||
NAME="edge-dns"
|
||||
VERSION=$(lookup-version "$ROOT"/../internal/const/const.go)
|
||||
DIST=$ROOT/"../dist/${NAME}"
|
||||
OS=${1}
|
||||
ARCH=${2}
|
||||
|
||||
if [ -z "$OS" ]; then
|
||||
echo "usage: build.sh OS ARCH"
|
||||
exit
|
||||
fi
|
||||
if [ -z "$ARCH" ]; then
|
||||
echo "usage: build.sh OS ARCH"
|
||||
exit
|
||||
fi
|
||||
|
||||
echo "checking ..."
|
||||
ZIP_PATH=$(which zip)
|
||||
if [ -z "$ZIP_PATH" ]; then
|
||||
echo "we need 'zip' command to compress files"
|
||||
exit
|
||||
fi
|
||||
|
||||
echo "building v${VERSION}/${OS}/${ARCH} ..."
|
||||
ZIP="${NAME}-${OS}-${ARCH}-v${VERSION}.zip"
|
||||
|
||||
echo "copying ..."
|
||||
if [ ! -d "$DIST" ]; then
|
||||
mkdir "$DIST"
|
||||
mkdir "$DIST"/bin
|
||||
mkdir "$DIST"/configs
|
||||
mkdir "$DIST"/logs
|
||||
mkdir "$DIST"/data
|
||||
fi
|
||||
|
||||
cp "$ROOT"/configs/api_dns.template.yaml "$DIST"/configs
|
||||
|
||||
echo "building ..."
|
||||
|
||||
MUSL_DIR="/usr/local/opt/musl-cross/bin"
|
||||
CC_PATH=""
|
||||
CXX_PATH=""
|
||||
if [[ $(uname -a) == *"Darwin"* && "${OS}" == "linux" ]]; then
|
||||
# /usr/local/opt/musl-cross/bin/
|
||||
if [ "${ARCH}" == "amd64" ]; then
|
||||
CC_PATH="x86_64-linux-musl-gcc"
|
||||
CXX_PATH="x86_64-linux-musl-g++"
|
||||
fi
|
||||
if [ "${ARCH}" == "386" ]; then
|
||||
CC_PATH="i486-linux-musl-gcc"
|
||||
CXX_PATH="i486-linux-musl-g++"
|
||||
fi
|
||||
if [ "${ARCH}" == "arm64" ]; then
|
||||
CC_PATH="aarch64-linux-musl-gcc"
|
||||
CXX_PATH="aarch64-linux-musl-g++"
|
||||
fi
|
||||
if [ "${ARCH}" == "mips64" ]; then
|
||||
CC_PATH="mips64-linux-musl-gcc"
|
||||
CXX_PATH="mips64-linux-musl-g++"
|
||||
fi
|
||||
if [ "${ARCH}" == "mips64le" ]; then
|
||||
CC_PATH="mips64el-linux-musl-gcc"
|
||||
CXX_PATH="mips64el-linux-musl-g++"
|
||||
fi
|
||||
fi
|
||||
if [ ! -z $CC_PATH ]; then
|
||||
env CC=$MUSL_DIR/$CC_PATH CXX=$MUSL_DIR/$CXX_PATH GOOS="${OS}" GOARCH="${ARCH}" CGO_ENABLED=1 go build -trimpath -tags="plus" -o "$DIST"/bin/${NAME} -ldflags "-linkmode external -extldflags -static -s -w" "$ROOT"/../cmd/edge-dns/main.go
|
||||
else
|
||||
env GOOS="${OS}" GOARCH="${ARCH}" CGO_ENABLED=1 go build -trimpath -tags="plus" -o "$DIST"/bin/${NAME} -ldflags="-s -w" "$ROOT"/../cmd/edge-dns/main.go
|
||||
fi
|
||||
|
||||
# check build result
|
||||
RESULT=$?
|
||||
if [ "${RESULT}" != "0" ]; then
|
||||
exit
|
||||
fi
|
||||
|
||||
# delete hidden files
|
||||
find "$DIST" -name ".DS_Store" -delete
|
||||
find "$DIST" -name ".gitignore" -delete
|
||||
|
||||
echo "zip files"
|
||||
cd "${DIST}/../" || exit
|
||||
if [ -f "${ZIP}" ]; then
|
||||
rm -f "${ZIP}"
|
||||
fi
|
||||
zip -r -X -q "${ZIP}" ${NAME}/
|
||||
rm -rf ${NAME}
|
||||
cd - || exit
|
||||
|
||||
echo "OK"
|
||||
}
|
||||
|
||||
function lookup-version() {
|
||||
FILE=$1
|
||||
VERSION_DATA=$(cat "$FILE")
|
||||
re="Version[ ]+=[ ]+\"([0-9.]+)\""
|
||||
if [[ $VERSION_DATA =~ $re ]]; then
|
||||
VERSION=${BASH_REMATCH[1]}
|
||||
echo "$VERSION"
|
||||
else
|
||||
echo "could not match version"
|
||||
exit
|
||||
fi
|
||||
}
|
||||
|
||||
build "$1" "$2"
|
||||
3
EdgeDNS/build/configs/.gitignore
vendored
Normal file
3
EdgeDNS/build/configs/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
api.yaml
|
||||
api_dns.yaml
|
||||
*.cache
|
||||
3
EdgeDNS/build/configs/api_dns.template.yaml
Normal file
3
EdgeDNS/build/configs/api_dns.template.yaml
Normal file
@@ -0,0 +1,3 @@
|
||||
rpc.endpoints: [ "http://127.0.0.1:8003" ]
|
||||
nodeId: ""
|
||||
secret: ""
|
||||
4
EdgeDNS/build/data/.gitignore
vendored
Normal file
4
EdgeDNS/build/data/.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
*.db
|
||||
*.db-shm
|
||||
*.db-wal
|
||||
*.lock
|
||||
1
EdgeDNS/build/licenses/README.md
Normal file
1
EdgeDNS/build/licenses/README.md
Normal file
@@ -0,0 +1 @@
|
||||
这个目录下我们列举了所有需要公开声明的第三方License,如果有遗漏,烦请告知 iwind.liu@gmail.com。再次感谢这些开源软件项目和贡献人员!
|
||||
30
EdgeDNS/build/licenses/miekg-dns.md
Normal file
30
EdgeDNS/build/licenses/miekg-dns.md
Normal file
@@ -0,0 +1,30 @@
|
||||
Copyright (c) 2009 The Go Authors. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of Google Inc. nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
As this is fork of the official Go code the same license applies.
|
||||
Extensions of the original work are copyright (c) 2011 Miek Gieben
|
||||
1
EdgeDNS/build/logs/.gitignore
vendored
Normal file
1
EdgeDNS/build/logs/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
*.log
|
||||
142
EdgeDNS/cmd/edge-dns/main.go
Normal file
142
EdgeDNS/cmd/edge-dns/main.go
Normal file
@@ -0,0 +1,142 @@
|
||||
//go:build plus
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/apps"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/configs"
|
||||
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/nodes"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/utils"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"github.com/iwind/gosock/pkg/gosock"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
func main() {
|
||||
var app = apps.NewAppCmd().
|
||||
Version(teaconst.Version).
|
||||
Product(teaconst.ProductName).
|
||||
Usage(teaconst.ProcessName + " [-v|start|stop|restart|service|daemon|pprof|gc|uninstall]")
|
||||
app.On("start:before", func() {
|
||||
// validate config
|
||||
_, err := configs.LoadAPIConfig()
|
||||
if err != nil {
|
||||
fmt.Println("[ERROR]start failed: load api config from '" + Tea.ConfigFile(configs.ConfigFileName) + "' failed: " + err.Error())
|
||||
os.Exit(0)
|
||||
}
|
||||
})
|
||||
app.On("test", func() {
|
||||
err := nodes.NewDNSNode().Test()
|
||||
if err != nil {
|
||||
_, _ = os.Stderr.WriteString(err.Error())
|
||||
}
|
||||
})
|
||||
app.On("daemon", func() {
|
||||
nodes.NewDNSNode().Daemon()
|
||||
})
|
||||
app.On("service", func() {
|
||||
err := nodes.NewDNSNode().InstallSystemService()
|
||||
if err != nil {
|
||||
fmt.Println("[ERROR]install failed: " + err.Error())
|
||||
return
|
||||
}
|
||||
fmt.Println("done")
|
||||
})
|
||||
app.On("pprof", func() {
|
||||
var flagSet = flag.NewFlagSet("pprof", flag.ExitOnError)
|
||||
var addr string
|
||||
flagSet.StringVar(&addr, "addr", "", "")
|
||||
_ = flagSet.Parse(os.Args[2:])
|
||||
|
||||
if len(addr) == 0 {
|
||||
addr = "127.0.0.1:6060"
|
||||
}
|
||||
logs.Println("starting with pprof '" + addr + "'...")
|
||||
|
||||
go func() {
|
||||
err := http.ListenAndServe(addr, nil)
|
||||
if err != nil {
|
||||
logs.Println("[ERROR]" + err.Error())
|
||||
}
|
||||
}()
|
||||
|
||||
var node = nodes.NewDNSNode()
|
||||
node.Start()
|
||||
})
|
||||
app.On("gc", func() {
|
||||
var sock = gosock.NewTmpSock(teaconst.ProcessName)
|
||||
_, err := sock.Send(&gosock.Command{Code: "gc"})
|
||||
if err != nil {
|
||||
fmt.Println("[ERROR]" + err.Error())
|
||||
} else {
|
||||
fmt.Println("ok")
|
||||
}
|
||||
})
|
||||
app.On("uninstall", func() {
|
||||
// service
|
||||
fmt.Println("Uninstall service ...")
|
||||
var manager = utils.NewServiceManager(teaconst.ProcessName, teaconst.ProductName)
|
||||
go func() {
|
||||
_ = manager.Uninstall()
|
||||
}()
|
||||
|
||||
// stop
|
||||
fmt.Println("Stopping ...")
|
||||
_, _ = gosock.NewTmpSock(teaconst.ProcessName).SendTimeout(&gosock.Command{Code: "stop"}, 1*time.Second)
|
||||
|
||||
// delete files
|
||||
var exe, _ = os.Executable()
|
||||
if len(exe) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var dir = filepath.Dir(filepath.Dir(exe)) // ROOT / bin / exe
|
||||
|
||||
// verify dir
|
||||
{
|
||||
fmt.Println("Checking '" + dir + "' ...")
|
||||
for _, subDir := range []string{"bin/" + filepath.Base(exe), "configs", "logs"} {
|
||||
_, err := os.Stat(dir + "/" + subDir)
|
||||
if err != nil {
|
||||
fmt.Println("[ERROR]program directory structure has been broken, please remove it manually.")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("Removing '" + dir + "' ...")
|
||||
err := os.RemoveAll(dir)
|
||||
if err != nil {
|
||||
fmt.Println("[ERROR]remove failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// delete symbolic links
|
||||
fmt.Println("Removing symbolic links ...")
|
||||
_ = os.Remove("/usr/bin/" + teaconst.ProcessName)
|
||||
_ = os.Remove("/var/log/" + teaconst.ProcessName)
|
||||
|
||||
// delete configs
|
||||
// nothing to delete for EdgeDNS
|
||||
|
||||
// delete sock
|
||||
fmt.Println("Removing temporary files ...")
|
||||
var tempDir = os.TempDir()
|
||||
_ = os.Remove(tempDir + "/" + teaconst.ProcessName + ".sock")
|
||||
|
||||
// done
|
||||
fmt.Println("[DONE]")
|
||||
})
|
||||
app.Run(func() {
|
||||
var node = nodes.NewDNSNode()
|
||||
node.Start()
|
||||
})
|
||||
}
|
||||
2
EdgeDNS/dist/.gitignore
vendored
Normal file
2
EdgeDNS/dist/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
*.zip
|
||||
edge-dns
|
||||
39
EdgeDNS/go.mod
Normal file
39
EdgeDNS/go.mod
Normal file
@@ -0,0 +1,39 @@
|
||||
module github.com/TeaOSLab/EdgeDNS
|
||||
|
||||
go 1.25
|
||||
|
||||
replace github.com/TeaOSLab/EdgeCommon => ../EdgeCommon
|
||||
|
||||
require (
|
||||
github.com/TeaOSLab/EdgeCommon v0.0.0-00010101000000-000000000000
|
||||
github.com/google/nftables v0.2.0
|
||||
github.com/iwind/TeaGo v0.0.0-20240128112714-6bcd0529d0ea
|
||||
github.com/iwind/gosock v0.0.0-20220505115348-f88412125a62
|
||||
github.com/mattn/go-sqlite3 v1.14.22
|
||||
github.com/mdlayher/netlink v1.7.2
|
||||
github.com/miekg/dns v1.1.58
|
||||
github.com/shirou/gopsutil/v3 v3.24.2
|
||||
golang.org/x/sys v0.38.0
|
||||
google.golang.org/grpc v1.78.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/josharian/native v1.1.0 // indirect
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||
github.com/mdlayher/socket v0.5.0 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||
github.com/shoenig/go-m1cpu v0.1.6 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.13 // indirect
|
||||
github.com/tklauser/numcpus v0.7.0 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||
golang.org/x/mod v0.29.0 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
golang.org/x/sync v0.18.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
golang.org/x/tools v0.38.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251029180050-ab9386a59fda // indirect
|
||||
google.golang.org/protobuf v1.36.10 // indirect
|
||||
)
|
||||
99
EdgeDNS/go.sum
Normal file
99
EdgeDNS/go.sum
Normal file
@@ -0,0 +1,99 @@
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/nftables v0.2.0 h1:PbJwaBmbVLzpeldoeUKGkE2RjstrjPKMl6oLrfEJ6/8=
|
||||
github.com/google/nftables v0.2.0/go.mod h1:Beg6V6zZ3oEn0JuiUQ4wqwuyqqzasOltcoXPtgLbFp4=
|
||||
github.com/iwind/TeaGo v0.0.0-20240128112714-6bcd0529d0ea h1:o0QCF6tMJ9E6OgU1c0L+rYDshKsTu7mEk+7KCGLbnpI=
|
||||
github.com/iwind/TeaGo v0.0.0-20240128112714-6bcd0529d0ea/go.mod h1:Ng3xWekHSVy0E/6/jYqJ7Htydm/H+mWIl0AS+Eg3H2M=
|
||||
github.com/iwind/gosock v0.0.0-20220505115348-f88412125a62 h1:HJH6RDheAY156DnIfJSD/bEvqyXzsZuE2gzs8PuUjoo=
|
||||
github.com/iwind/gosock v0.0.0-20220505115348-f88412125a62/go.mod h1:H5Q7SXwbx3a97ecJkaS2sD77gspzE7HFUafBO0peEyA=
|
||||
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
|
||||
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
|
||||
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
|
||||
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
|
||||
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
|
||||
github.com/mdlayher/socket v0.5.0 h1:ilICZmJcQz70vrWVes1MFera4jGiWNocSkykwwoy3XI=
|
||||
github.com/mdlayher/socket v0.5.0/go.mod h1:WkcBFfvyG8QENs5+hfQPl1X6Jpd2yeLIYgrGFmJiJxI=
|
||||
github.com/miekg/dns v1.1.58 h1:ca2Hdkz+cDg/7eNF6V56jjzuZ4aCAE+DbVkILdQWG/4=
|
||||
github.com/miekg/dns v1.1.58/go.mod h1:Ypv+3b/KadlvW9vJfXOTf300O4UqaHFzFCuHz+rPkBY=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||
github.com/shirou/gopsutil/v3 v3.24.2 h1:kcR0erMbLg5/3LcInpw0X/rrPSqq4CDPyI6A6ZRC18Y=
|
||||
github.com/shirou/gopsutil/v3 v3.24.2/go.mod h1:tSg/594BcA+8UdQU2XcW803GWYgdtauFFPgJCJKZlVk=
|
||||
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
|
||||
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
|
||||
github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU=
|
||||
github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
|
||||
github.com/tklauser/go-sysconf v0.3.13 h1:GBUpcahXSpR2xN01jhkNAbTLRk2Yzgggk8IM08lq3r4=
|
||||
github.com/tklauser/go-sysconf v0.3.13/go.mod h1:zwleP4Q4OehZHGn4CYZDipCgg9usW5IJePewFCGVEa0=
|
||||
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
|
||||
github.com/tklauser/numcpus v0.7.0 h1:yjuerZP127QG9m5Zh/mSO4wqurYil27tHrqwRoRjpr4=
|
||||
github.com/tklauser/numcpus v0.7.0/go.mod h1:bb6dMVcj8A42tSE7i32fsIUCbQNllK5iDguyOZRUzAY=
|
||||
github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc h1:R83G5ikgLMxrBvLh22JhdfI8K6YXEPHx5P03Uu3DRs4=
|
||||
github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc/go.mod h1:ZjcWmFBXmLKZu9Nxj3WKYEafiSqer2rnvPr0en9UNpI=
|
||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic=
|
||||
golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
|
||||
golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w=
|
||||
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
|
||||
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o=
|
||||
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw=
|
||||
golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc=
|
||||
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240311173647-c811ad7063a7 h1:8EeVk1VKMD+GD/neyEHGmz7pFblqPjHoi+PGQIlLx2s=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240311173647-c811ad7063a7/go.mod h1:WtryC6hu0hhx87FDGxWCDptyssuo68sk10vYjF+T9fY=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251029180050-ab9386a59fda/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk=
|
||||
google.golang.org/grpc v1.63.2 h1:MUeiw1B2maTVZthpU5xvASfTh3LDbxHd6IJ6QQVU+xM=
|
||||
google.golang.org/grpc v1.63.2/go.mod h1:WAX/8DgncnokcFUldAxq7GeB5DXHDbMF+lLvDomNkRA=
|
||||
google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U=
|
||||
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
|
||||
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
36
EdgeDNS/internal/agents/agent.go
Normal file
36
EdgeDNS/internal/agents/agent.go
Normal file
@@ -0,0 +1,36 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package agents
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Agent struct {
|
||||
Code string
|
||||
suffixes []string
|
||||
reg *regexp.Regexp
|
||||
}
|
||||
|
||||
func NewAgent(code string, suffixes []string, reg *regexp.Regexp) *Agent {
|
||||
return &Agent{
|
||||
Code: code,
|
||||
suffixes: suffixes,
|
||||
reg: reg,
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Agent) Match(ptr string) bool {
|
||||
if len(this.suffixes) > 0 {
|
||||
for _, suffix := range this.suffixes {
|
||||
if strings.HasSuffix(ptr, suffix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
if this.reg != nil {
|
||||
return this.reg.MatchString(ptr)
|
||||
}
|
||||
return false
|
||||
}
|
||||
17
EdgeDNS/internal/agents/agents.go
Normal file
17
EdgeDNS/internal/agents/agents.go
Normal file
@@ -0,0 +1,17 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package agents
|
||||
|
||||
var AllAgents = []*Agent{
|
||||
NewAgent("baidu", []string{".baidu.com."}, nil),
|
||||
NewAgent("google", []string{".googlebot.com."}, nil),
|
||||
NewAgent("bing", []string{".search.msn.com."}, nil),
|
||||
NewAgent("sogou", []string{".sogou.com."}, nil),
|
||||
NewAgent("youdao", []string{".163.com."}, nil),
|
||||
NewAgent("yahoo", []string{".yahoo.com."}, nil),
|
||||
NewAgent("bytedance", []string{".bytedance.com."}, nil),
|
||||
NewAgent("sm", []string{".sm.cn."}, nil),
|
||||
NewAgent("yandex", []string{".yandex.com.", ".yndx.net."}, nil),
|
||||
NewAgent("semrush", []string{".semrush.com."}, nil),
|
||||
NewAgent("facebook", []string{"facebook-waw.1-ix.net.", "facebook.b-ix.net."}, nil),
|
||||
}
|
||||
54
EdgeDNS/internal/agents/ip_cache_map.go
Normal file
54
EdgeDNS/internal/agents/ip_cache_map.go
Normal file
@@ -0,0 +1,54 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package agents
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/utils/zero"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type IPCacheMap struct {
|
||||
m map[string]zero.Zero
|
||||
list []string
|
||||
|
||||
locker sync.RWMutex
|
||||
maxLen int
|
||||
}
|
||||
|
||||
func NewIPCacheMap(maxLen int) *IPCacheMap {
|
||||
if maxLen <= 0 {
|
||||
maxLen = 65535
|
||||
}
|
||||
return &IPCacheMap{
|
||||
m: map[string]zero.Zero{},
|
||||
maxLen: maxLen,
|
||||
}
|
||||
}
|
||||
|
||||
func (this *IPCacheMap) Add(ip string) {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
// 是否已经存在
|
||||
_, ok := this.m[ip]
|
||||
if ok {
|
||||
return
|
||||
}
|
||||
|
||||
// 超出长度删除第一个
|
||||
if len(this.list) >= this.maxLen {
|
||||
delete(this.m, this.list[0])
|
||||
this.list = this.list[1:]
|
||||
}
|
||||
|
||||
// 加入新数据
|
||||
this.m[ip] = zero.Zero{}
|
||||
this.list = append(this.list, ip)
|
||||
}
|
||||
|
||||
func (this *IPCacheMap) Contains(ip string) bool {
|
||||
this.locker.RLock()
|
||||
defer this.locker.RUnlock()
|
||||
_, ok := this.m[ip]
|
||||
return ok
|
||||
}
|
||||
34
EdgeDNS/internal/agents/ip_cache_map_test.go
Normal file
34
EdgeDNS/internal/agents/ip_cache_map_test.go
Normal file
@@ -0,0 +1,34 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
//go:build plus
|
||||
|
||||
package agents
|
||||
|
||||
import (
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewIPCacheMap(t *testing.T) {
|
||||
var cacheMap = NewIPCacheMap(3)
|
||||
|
||||
t.Log("====")
|
||||
cacheMap.Add("1")
|
||||
cacheMap.Add("2")
|
||||
logs.PrintAsJSON(cacheMap.m, t)
|
||||
logs.PrintAsJSON(cacheMap.list, t)
|
||||
|
||||
t.Log("====")
|
||||
cacheMap.Add("3")
|
||||
logs.PrintAsJSON(cacheMap.m, t)
|
||||
logs.PrintAsJSON(cacheMap.list, t)
|
||||
|
||||
t.Log("====")
|
||||
cacheMap.Add("4")
|
||||
logs.PrintAsJSON(cacheMap.m, t)
|
||||
logs.PrintAsJSON(cacheMap.list, t)
|
||||
|
||||
t.Log("====")
|
||||
cacheMap.Add("3")
|
||||
logs.PrintAsJSON(cacheMap.m, t)
|
||||
logs.PrintAsJSON(cacheMap.list, t)
|
||||
}
|
||||
173
EdgeDNS/internal/agents/manager.go
Normal file
173
EdgeDNS/internal/agents/manager.go
Normal file
@@ -0,0 +1,173 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package agents
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/dbs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SharedManager 此值在外面调用时指定
|
||||
var SharedManager *Manager
|
||||
|
||||
// Manager Agent管理器
|
||||
type Manager struct {
|
||||
ipMap map[string]string // ip => agentCode
|
||||
locker sync.RWMutex
|
||||
|
||||
db *dbs.DB
|
||||
|
||||
lastId int64
|
||||
}
|
||||
|
||||
func NewManager(db *dbs.DB) *Manager {
|
||||
return &Manager{
|
||||
ipMap: map[string]string{},
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Manager) Start() {
|
||||
remotelogs.Println("AGENT_MANAGER", "starting ...")
|
||||
|
||||
// 从本地数据库中加载
|
||||
err := this.Load()
|
||||
if err != nil {
|
||||
remotelogs.Error("AGENT_MANAGER", "load failed: "+err.Error())
|
||||
}
|
||||
|
||||
// 先从API获取
|
||||
err = this.LoopAll()
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("AGENT_MANAGER", "retrieve latest agent ip failed: "+err.Error())
|
||||
} else {
|
||||
remotelogs.Error("AGENT_MANAGER", "retrieve latest agent ip failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// 定时获取
|
||||
var duration = 30 * time.Minute
|
||||
if Tea.IsTesting() {
|
||||
duration = 30 * time.Second
|
||||
}
|
||||
var ticker = time.NewTicker(duration)
|
||||
for range ticker.C {
|
||||
err = this.LoopAll()
|
||||
if err != nil {
|
||||
remotelogs.Error("AGENT_MANAGER", "retrieve latest agent ip failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Manager) Load() error {
|
||||
var offset int64 = 0
|
||||
var size int64 = 10000
|
||||
for {
|
||||
agentIPs, err := this.db.ListAgentIPs(offset, size)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(agentIPs) == 0 {
|
||||
break
|
||||
}
|
||||
for _, agentIP := range agentIPs {
|
||||
this.locker.Lock()
|
||||
this.ipMap[agentIP.IP] = agentIP.AgentCode
|
||||
this.locker.Unlock()
|
||||
|
||||
if agentIP.Id > this.lastId {
|
||||
this.lastId = agentIP.Id
|
||||
}
|
||||
}
|
||||
offset += size
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Manager) LoopAll() error {
|
||||
for {
|
||||
hasNext, err := this.Loop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !hasNext {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Loop 单次循环获取数据
|
||||
func (this *Manager) Loop() (hasNext bool, err error) {
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
ipsResp, err := rpcClient.ClientAgentIPRPC.ListClientAgentIPsAfterId(rpcClient.Context(), &pb.ListClientAgentIPsAfterIdRequest{
|
||||
Id: this.lastId,
|
||||
Size: 10000,
|
||||
})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if len(ipsResp.ClientAgentIPs) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
for _, agentIP := range ipsResp.ClientAgentIPs {
|
||||
if agentIP.ClientAgent == nil {
|
||||
// 设置ID
|
||||
if agentIP.Id > this.lastId {
|
||||
this.lastId = agentIP.Id
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// 写入到数据库
|
||||
err = this.db.InsertAgentIP(agentIP.Id, agentIP.Ip, agentIP.ClientAgent.Code)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// 写入Map
|
||||
this.locker.Lock()
|
||||
this.ipMap[agentIP.Ip] = agentIP.ClientAgent.Code
|
||||
this.locker.Unlock()
|
||||
|
||||
// 设置ID
|
||||
if agentIP.Id > this.lastId {
|
||||
this.lastId = agentIP.Id
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// AddIP 添加记录
|
||||
func (this *Manager) AddIP(ip string, agentCode string) {
|
||||
this.locker.Lock()
|
||||
this.ipMap[ip] = agentCode
|
||||
this.locker.Unlock()
|
||||
}
|
||||
|
||||
// LookupIP 查询IP所属Agent
|
||||
func (this *Manager) LookupIP(ip string) (agentCode string) {
|
||||
this.locker.RLock()
|
||||
defer this.locker.RUnlock()
|
||||
return this.ipMap[ip]
|
||||
}
|
||||
|
||||
// ContainsIP 检查是否有IP相关数据
|
||||
func (this *Manager) ContainsIP(ip string) bool {
|
||||
this.locker.RLock()
|
||||
defer this.locker.RUnlock()
|
||||
_, ok := this.ipMap[ip]
|
||||
return ok
|
||||
}
|
||||
33
EdgeDNS/internal/agents/manager_test.go
Normal file
33
EdgeDNS/internal/agents/manager_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
//go:build plus
|
||||
|
||||
package agents_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/agents"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/dbs"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewManager(t *testing.T) {
|
||||
var db = dbs.NewDB(Tea.Root + "/data/data.db")
|
||||
err := db.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var manager = agents.NewManager(db)
|
||||
err = manager.Load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = manager.Loop()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Log(manager.LookupIP("192.168.3.100"))
|
||||
}
|
||||
138
EdgeDNS/internal/agents/queue.go
Normal file
138
EdgeDNS/internal/agents/queue.go
Normal file
@@ -0,0 +1,138 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package agents
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/events"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"net"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if !teaconst.IsMain {
|
||||
return
|
||||
}
|
||||
|
||||
events.On(events.EventLoaded, func() {
|
||||
goman.New(func() {
|
||||
SharedQueue.Start()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
var SharedQueue = NewQueue()
|
||||
|
||||
type Queue struct {
|
||||
c chan string // chan ip
|
||||
cacheMap *IPCacheMap
|
||||
}
|
||||
|
||||
func NewQueue() *Queue {
|
||||
return &Queue{
|
||||
c: make(chan string, 128),
|
||||
cacheMap: NewIPCacheMap(65535),
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Queue) Start() {
|
||||
for ip := range this.c {
|
||||
err := this.Process(ip)
|
||||
if err != nil {
|
||||
// 不需要上报错误
|
||||
if Tea.IsTesting() {
|
||||
remotelogs.Debug("SharedParseQueue", err.Error())
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Push 将IP加入到处理队列
|
||||
func (this *Queue) Push(ip string) {
|
||||
// 是否在处理中
|
||||
if this.cacheMap.Contains(ip) {
|
||||
return
|
||||
}
|
||||
this.cacheMap.Add(ip)
|
||||
|
||||
// 加入到队列
|
||||
select {
|
||||
case this.c <- ip:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// Process 处理IP
|
||||
func (this *Queue) Process(ip string) error {
|
||||
// 是否已经在库中
|
||||
if SharedManager.ContainsIP(ip) {
|
||||
return nil
|
||||
}
|
||||
|
||||
ptr, err := this.ParseIP(ip)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(ptr) == 0 || ptr == "." {
|
||||
return nil
|
||||
}
|
||||
|
||||
//remotelogs.Debug("AGENT", ip+" => "+ptr)
|
||||
|
||||
var agentCode = this.ParsePtr(ptr)
|
||||
if len(agentCode) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 加入到本地
|
||||
SharedManager.AddIP(ip, agentCode)
|
||||
|
||||
var pbAgentIP = &pb.CreateClientAgentIPsRequest_AgentIPInfo{
|
||||
AgentCode: agentCode,
|
||||
Ip: ip,
|
||||
Ptr: ptr,
|
||||
}
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = rpcClient.ClientAgentIPRPC.CreateClientAgentIPs(rpcClient.Context(), &pb.CreateClientAgentIPsRequest{AgentIPs: []*pb.CreateClientAgentIPsRequest_AgentIPInfo{pbAgentIP}})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseIP 分析IP的PTR值
|
||||
func (this *Queue) ParseIP(ip string) (ptr string, err error) {
|
||||
if len(ip) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
names, err := net.LookupAddr(ip)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(names) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return names[0], nil
|
||||
}
|
||||
|
||||
// ParsePtr 分析PTR对应的Agent
|
||||
func (this *Queue) ParsePtr(ptr string) (agentCode string) {
|
||||
for _, agent := range AllAgents {
|
||||
if agent.Match(ptr) {
|
||||
return agent.Code
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
77
EdgeDNS/internal/agents/queue_test.go
Normal file
77
EdgeDNS/internal/agents/queue_test.go
Normal file
@@ -0,0 +1,77 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
//go:build plus
|
||||
|
||||
package agents_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/agents"
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestParseQueue_Process(t *testing.T) {
|
||||
var queue = agents.NewQueue()
|
||||
go queue.Start()
|
||||
time.Sleep(1 * time.Second)
|
||||
queue.Push("220.181.13.100")
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
func TestParseQueue_ParseIP(t *testing.T) {
|
||||
var queue = agents.NewQueue()
|
||||
for _, ip := range []string{
|
||||
"192.168.1.100",
|
||||
"42.120.160.1",
|
||||
"42.236.10.98",
|
||||
"124.115.0.100",
|
||||
} {
|
||||
ptr, err := queue.ParseIP(ip)
|
||||
if err != nil {
|
||||
t.Log(ip, "=>", err)
|
||||
continue
|
||||
}
|
||||
t.Log(ip, "=>", ptr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseQueue_ParsePtr(t *testing.T) {
|
||||
var a = assert.NewAssertion(t)
|
||||
|
||||
var queue = agents.NewQueue()
|
||||
for _, s := range [][]string{
|
||||
{"baiduspider-220-181-108-101.crawl.baidu.com.", "baidu"},
|
||||
{"crawl-66-249-71-219.googlebot.com.", "google"},
|
||||
{"msnbot-40-77-167-31.search.msn.com.", "bing"},
|
||||
{"sogouspider-49-7-20-129.crawl.sogou.com.", "sogou"},
|
||||
{"m13102.mail.163.com.", "youdao"},
|
||||
{"yeurosport.pat1.tc2.yahoo.com.", "yahoo"},
|
||||
{"shenmaspider-42-120-160-1.crawl.sm.cn.", "sm"},
|
||||
{"93-158-161-39.spider.yandex.com.", "yandex"},
|
||||
{"25.bl.bot.semrush.com.", "semrush"},
|
||||
} {
|
||||
a.IsTrue(queue.ParsePtr(s[0]) == s[1])
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkQueue_ParsePtr(b *testing.B) {
|
||||
var queue = agents.NewQueue()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, s := range [][]string{
|
||||
{"baiduspider-220-181-108-101.crawl.baidu.com.", "baidu"},
|
||||
{"crawl-66-249-71-219.googlebot.com.", "google"},
|
||||
{"msnbot-40-77-167-31.search.msn.com.", "bing"},
|
||||
{"sogouspider-49-7-20-129.crawl.sogou.com.", "sogou"},
|
||||
{"m13102.mail.163.com.", "youdao"},
|
||||
{"yeurosport.pat1.tc2.yahoo.com.", "yahoo"},
|
||||
{"shenmaspider-42-120-160-1.crawl.sm.cn.", "sm"},
|
||||
{"93-158-161-39.spider.yandex.com.", "yandex"},
|
||||
{"93.158.164.218-red.dhcp.yndx.net.", "yandex"},
|
||||
{"25.bl.bot.semrush.com.", "semrush"},
|
||||
} {
|
||||
queue.ParsePtr(s[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
319
EdgeDNS/internal/apps/app_cmd.go
Normal file
319
EdgeDNS/internal/apps/app_cmd.go
Normal file
@@ -0,0 +1,319 @@
|
||||
package apps
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"github.com/iwind/gosock/pkg/gosock"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AppCmd App命令帮助
|
||||
type AppCmd struct {
|
||||
product string
|
||||
version string
|
||||
usage string
|
||||
options []*CommandHelpOption
|
||||
appendStrings []string
|
||||
|
||||
directives []*Directive
|
||||
|
||||
sock *gosock.Sock
|
||||
}
|
||||
|
||||
func NewAppCmd() *AppCmd {
|
||||
return &AppCmd{
|
||||
sock: gosock.NewTmpSock(teaconst.ProcessName),
|
||||
}
|
||||
}
|
||||
|
||||
type CommandHelpOption struct {
|
||||
Code string
|
||||
Description string
|
||||
}
|
||||
|
||||
// Product 产品
|
||||
func (this *AppCmd) Product(product string) *AppCmd {
|
||||
this.product = product
|
||||
return this
|
||||
}
|
||||
|
||||
// Version 版本
|
||||
func (this *AppCmd) Version(version string) *AppCmd {
|
||||
this.version = version
|
||||
return this
|
||||
}
|
||||
|
||||
// Usage 使用方法
|
||||
func (this *AppCmd) Usage(usage string) *AppCmd {
|
||||
this.usage = usage
|
||||
return this
|
||||
}
|
||||
|
||||
// Option 选项
|
||||
func (this *AppCmd) Option(code string, description string) *AppCmd {
|
||||
this.options = append(this.options, &CommandHelpOption{
|
||||
Code: code,
|
||||
Description: description,
|
||||
})
|
||||
return this
|
||||
}
|
||||
|
||||
// Append 附加内容
|
||||
func (this *AppCmd) Append(appendString string) *AppCmd {
|
||||
this.appendStrings = append(this.appendStrings, appendString)
|
||||
return this
|
||||
}
|
||||
|
||||
// Print 打印
|
||||
func (this *AppCmd) Print() {
|
||||
fmt.Println(this.product + " v" + this.version)
|
||||
|
||||
usage := this.usage
|
||||
fmt.Println("Usage:", "\n "+usage)
|
||||
|
||||
if len(this.options) > 0 {
|
||||
fmt.Println("")
|
||||
fmt.Println("Options:")
|
||||
|
||||
spaces := 20
|
||||
max := 40
|
||||
for _, option := range this.options {
|
||||
l := len(option.Code)
|
||||
if l < max && l > spaces {
|
||||
spaces = l + 4
|
||||
}
|
||||
}
|
||||
|
||||
for _, option := range this.options {
|
||||
if len(option.Code) > max {
|
||||
fmt.Println("")
|
||||
fmt.Println(" " + option.Code)
|
||||
option.Code = ""
|
||||
}
|
||||
|
||||
fmt.Printf(" %-"+strconv.Itoa(spaces)+"s%s\n", option.Code, ": "+option.Description)
|
||||
}
|
||||
}
|
||||
|
||||
if len(this.appendStrings) > 0 {
|
||||
fmt.Println("")
|
||||
for _, s := range this.appendStrings {
|
||||
fmt.Println(s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// On 添加指令
|
||||
func (this *AppCmd) On(arg string, callback func()) {
|
||||
this.directives = append(this.directives, &Directive{
|
||||
Arg: arg,
|
||||
Callback: callback,
|
||||
})
|
||||
}
|
||||
|
||||
// Run 运行
|
||||
func (this *AppCmd) Run(main func()) {
|
||||
// 获取参数
|
||||
var args = os.Args[1:]
|
||||
if len(args) > 0 {
|
||||
var mainArg = args[0]
|
||||
this.callDirective(mainArg + ":before")
|
||||
|
||||
switch mainArg {
|
||||
case "-v", "version", "-version", "--version":
|
||||
this.runVersion()
|
||||
return
|
||||
case "?", "help", "-help", "h", "-h":
|
||||
this.runHelp()
|
||||
return
|
||||
case "start":
|
||||
this.runStart()
|
||||
return
|
||||
case "stop":
|
||||
this.runStop()
|
||||
return
|
||||
case "restart":
|
||||
this.runRestart()
|
||||
return
|
||||
case "status":
|
||||
this.runStatus()
|
||||
return
|
||||
}
|
||||
|
||||
// 查找指令
|
||||
for _, directive := range this.directives {
|
||||
if directive.Arg == mainArg {
|
||||
directive.Callback()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("unknown command '" + mainArg + "'")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// 日志
|
||||
writer := new(LogWriter)
|
||||
writer.Init()
|
||||
logs.SetWriter(writer)
|
||||
|
||||
// 运行主函数
|
||||
main()
|
||||
}
|
||||
|
||||
// 版本号
|
||||
func (this *AppCmd) runVersion() {
|
||||
fmt.Println(this.product+" v"+this.version, "(build: "+runtime.Version(), runtime.GOOS, runtime.GOARCH+")")
|
||||
}
|
||||
|
||||
// 帮助
|
||||
func (this *AppCmd) runHelp() {
|
||||
this.Print()
|
||||
}
|
||||
|
||||
// 启动
|
||||
func (this *AppCmd) runStart() {
|
||||
var pid = this.getPID()
|
||||
if pid > 0 {
|
||||
fmt.Println(this.product+" already started, pid:", pid)
|
||||
return
|
||||
}
|
||||
|
||||
var cmd = exec.Command(this.exe())
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
Foreground: false,
|
||||
Setsid: true,
|
||||
}
|
||||
err := cmd.Start()
|
||||
if err != nil {
|
||||
fmt.Println(this.product+" start failed:", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// create symbolic links
|
||||
_ = this.createSymLinks()
|
||||
|
||||
fmt.Println(this.product+" started ok, pid:", cmd.Process.Pid)
|
||||
}
|
||||
|
||||
// 停止
|
||||
func (this *AppCmd) runStop() {
|
||||
var pid = this.getPID()
|
||||
if pid == 0 {
|
||||
fmt.Println(this.product + " not started yet")
|
||||
return
|
||||
}
|
||||
|
||||
_, _ = this.sock.Send(&gosock.Command{Code: "stop"})
|
||||
|
||||
fmt.Println(this.product+" stopped ok, pid:", types.String(pid))
|
||||
}
|
||||
|
||||
// 重启
|
||||
func (this *AppCmd) runRestart() {
|
||||
this.runStop()
|
||||
time.Sleep(1 * time.Second)
|
||||
this.runStart()
|
||||
}
|
||||
|
||||
// 状态
|
||||
func (this *AppCmd) runStatus() {
|
||||
var pid = this.getPID()
|
||||
if pid == 0 {
|
||||
fmt.Println(this.product + " not started yet")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println(this.product + " is running, pid: " + types.String(pid))
|
||||
}
|
||||
|
||||
// 获取当前的PID
|
||||
func (this *AppCmd) getPID() int {
|
||||
if !this.sock.IsListening() {
|
||||
return 0
|
||||
}
|
||||
|
||||
reply, err := this.sock.Send(&gosock.Command{Code: "pid"})
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return maps.NewMap(reply.Params).GetInt("pid")
|
||||
}
|
||||
|
||||
func (this *AppCmd) exe() string {
|
||||
var exe, _ = os.Executable()
|
||||
if len(exe) == 0 {
|
||||
exe = os.Args[0]
|
||||
}
|
||||
return exe
|
||||
}
|
||||
|
||||
// 创建软链接
|
||||
func (this *AppCmd) createSymLinks() error {
|
||||
if runtime.GOOS != "linux" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var exe, _ = os.Executable()
|
||||
if len(exe) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var errorList = []string{}
|
||||
|
||||
// bin
|
||||
{
|
||||
var target = "/usr/bin/" + teaconst.ProcessName
|
||||
old, _ := filepath.EvalSymlinks(target)
|
||||
if old != exe {
|
||||
_ = os.Remove(target)
|
||||
err := os.Symlink(exe, target)
|
||||
if err != nil {
|
||||
errorList = append(errorList, err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// log
|
||||
{
|
||||
var realPath = filepath.Dir(filepath.Dir(exe)) + "/logs/run.log"
|
||||
var target = "/var/log/" + teaconst.ProcessName + ".log"
|
||||
old, _ := filepath.EvalSymlinks(target)
|
||||
if old != realPath {
|
||||
_ = os.Remove(target)
|
||||
err := os.Symlink(realPath, target)
|
||||
if err != nil {
|
||||
errorList = append(errorList, err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(errorList) > 0 {
|
||||
return errors.New(strings.Join(errorList, "\n"))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *AppCmd) callDirective(code string) {
|
||||
for _, directive := range this.directives {
|
||||
if directive.Arg == code {
|
||||
if directive.Callback != nil {
|
||||
directive.Callback()
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
6
EdgeDNS/internal/apps/directive.go
Normal file
6
EdgeDNS/internal/apps/directive.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package apps
|
||||
|
||||
type Directive struct {
|
||||
Arg string
|
||||
Callback func()
|
||||
}
|
||||
111
EdgeDNS/internal/apps/log_writer.go
Normal file
111
EdgeDNS/internal/apps/log_writer.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package apps
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/utils/sizes"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/files"
|
||||
timeutil "github.com/iwind/TeaGo/utils/time"
|
||||
"log"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type LogWriter struct {
|
||||
fp *os.File
|
||||
c chan string
|
||||
}
|
||||
|
||||
func (this *LogWriter) Init() {
|
||||
// 创建目录
|
||||
var dir = files.NewFile(Tea.LogDir())
|
||||
if !dir.Exists() {
|
||||
err := dir.Mkdir()
|
||||
if err != nil {
|
||||
log.Println("[LOG]create log dir failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// 打开要写入的日志文件
|
||||
var logPath = Tea.LogFile("run.log")
|
||||
fp, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
||||
if err != nil {
|
||||
log.Println("[LOG]open log file failed: " + err.Error())
|
||||
} else {
|
||||
this.fp = fp
|
||||
}
|
||||
|
||||
this.c = make(chan string, 1024)
|
||||
|
||||
// 异步写入文件
|
||||
var maxFileSize = 128 * sizes.M // 文件最大尺寸,超出此尺寸则清空
|
||||
if fp != nil {
|
||||
goman.New(func() {
|
||||
var totalSize int64 = 0
|
||||
stat, err := fp.Stat()
|
||||
if err == nil {
|
||||
totalSize = stat.Size()
|
||||
}
|
||||
|
||||
for message := range this.c {
|
||||
totalSize += int64(len(message))
|
||||
_, err := fp.WriteString(timeutil.Format("Y/m/d H:i:s ") + message + "\n")
|
||||
if err != nil {
|
||||
log.Println("[LOG]write log failed: " + err.Error())
|
||||
} else {
|
||||
// 如果太大则Truncate
|
||||
if totalSize > maxFileSize {
|
||||
_ = fp.Truncate(0)
|
||||
totalSize = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (this *LogWriter) Write(message string) {
|
||||
backgroundEnv, _ := os.LookupEnv("EdgeBackground")
|
||||
if backgroundEnv != "on" {
|
||||
// 文件和行号
|
||||
var file string
|
||||
var line int
|
||||
if Tea.IsTesting() {
|
||||
var callDepth = 3
|
||||
var ok bool
|
||||
_, file, line, ok = runtime.Caller(callDepth)
|
||||
if ok {
|
||||
file = this.packagePath(file)
|
||||
}
|
||||
}
|
||||
|
||||
if len(file) > 0 {
|
||||
log.Println(message + " (" + file + ":" + strconv.Itoa(line) + ")")
|
||||
} else {
|
||||
log.Println(message)
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case this.c <- message:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (this *LogWriter) Close() {
|
||||
if this.fp != nil {
|
||||
_ = this.fp.Close()
|
||||
}
|
||||
|
||||
close(this.c)
|
||||
}
|
||||
|
||||
func (this *LogWriter) packagePath(path string) string {
|
||||
var pieces = strings.Split(path, "/")
|
||||
if len(pieces) >= 2 {
|
||||
return strings.Join(pieces[len(pieces)-2:], "/")
|
||||
}
|
||||
return path
|
||||
}
|
||||
107
EdgeDNS/internal/configs/api_config.go
Normal file
107
EdgeDNS/internal/configs/api_config.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package configs
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"gopkg.in/yaml.v3"
|
||||
"os"
|
||||
)
|
||||
|
||||
const ConfigFileName = "api_dns.yaml"
|
||||
const oldConfigFileName = "api.yaml"
|
||||
|
||||
var sharedAPIConfig *APIConfig
|
||||
|
||||
// APIConfig API配置
|
||||
type APIConfig struct {
|
||||
OldRPC struct {
|
||||
Endpoints []string `yaml:"endpoints"`
|
||||
DisableUpdate bool `yaml:"disableUpdate"`
|
||||
} `yaml:"rpc,omitempty"`
|
||||
|
||||
RPCEndpoints []string `yaml:"rpc.endpoints,flow" json:"rpc.endpoints"`
|
||||
RPCDisableUpdate bool `yaml:"rpc.disableUpdate" json:"rpc.disableUpdate"`
|
||||
|
||||
NodeId string `yaml:"nodeId"`
|
||||
Secret string `yaml:"secret"`
|
||||
NumberId int64 `yaml:"numberId"`
|
||||
}
|
||||
|
||||
// SharedAPIConfig 加载API配置
|
||||
func SharedAPIConfig() (*APIConfig, error) {
|
||||
if sharedAPIConfig != nil {
|
||||
return sharedAPIConfig, nil
|
||||
}
|
||||
|
||||
config, err := LoadAPIConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sharedAPIConfig = config
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// LoadAPIConfig 加载API配置
|
||||
func LoadAPIConfig() (*APIConfig, error) {
|
||||
for _, filename := range []string{ConfigFileName, oldConfigFileName} {
|
||||
data, err := os.ReadFile(Tea.ConfigFile(filename))
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var config = &APIConfig{}
|
||||
err = yaml.Unmarshal(data, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = config.Init()
|
||||
if err != nil {
|
||||
return nil, errors.New("init error: " + err.Error())
|
||||
}
|
||||
|
||||
// 自动生成新的配置文件
|
||||
if filename == oldConfigFileName {
|
||||
config.OldRPC.Endpoints = nil
|
||||
_ = config.WriteFile(Tea.ConfigFile(ConfigFileName))
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("no config file '" + ConfigFileName + "' found")
|
||||
}
|
||||
|
||||
func (this *APIConfig) Init() error {
|
||||
// compatible with old
|
||||
if len(this.RPCEndpoints) == 0 && len(this.OldRPC.Endpoints) > 0 {
|
||||
this.RPCEndpoints = this.OldRPC.Endpoints
|
||||
this.RPCDisableUpdate = this.OldRPC.DisableUpdate
|
||||
}
|
||||
|
||||
if len(this.RPCEndpoints) == 0 {
|
||||
return errors.New("no valid 'rpc.endpoints'")
|
||||
}
|
||||
|
||||
if len(this.NodeId) == 0 {
|
||||
return errors.New("'nodeId' required")
|
||||
}
|
||||
if len(this.Secret) == 0 {
|
||||
return errors.New("'secret' required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteFile 写入API配置
|
||||
func (this *APIConfig) WriteFile(path string) error {
|
||||
data, err := yaml.Marshal(this)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(path, data, 0666)
|
||||
}
|
||||
21
EdgeDNS/internal/configs/api_config_test.go
Normal file
21
EdgeDNS/internal/configs/api_config_test.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package configs
|
||||
|
||||
import (
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"gopkg.in/yaml.v3"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadAPIConfig(t *testing.T) {
|
||||
config, err := LoadAPIConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("%+v", config)
|
||||
|
||||
configData, err := yaml.Marshal(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(string(configData))
|
||||
}
|
||||
8
EdgeDNS/internal/configs/node_config.go
Normal file
8
EdgeDNS/internal/configs/node_config.go
Normal file
@@ -0,0 +1,8 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
//go:build plus
|
||||
|
||||
package configs
|
||||
|
||||
import "github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
|
||||
|
||||
var SharedNodeConfig *dnsconfigs.NSNodeConfig
|
||||
15
EdgeDNS/internal/const/const.go
Normal file
15
EdgeDNS/internal/const/const.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package teaconst
|
||||
|
||||
const (
|
||||
Version = "1.4.5.1" //1.3.8.2
|
||||
|
||||
ProductName = "Edge DNS"
|
||||
ProcessName = "edge-dns"
|
||||
ProductNameZH = "Edge"
|
||||
|
||||
Role = "dns"
|
||||
|
||||
EncryptMethod = "aes-256-cfb"
|
||||
|
||||
SystemdServiceName = "edge-dns"
|
||||
)
|
||||
29
EdgeDNS/internal/const/vars.go
Normal file
29
EdgeDNS/internal/const/vars.go
Normal file
@@ -0,0 +1,29 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package teaconst
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
IsDaemon = false
|
||||
IsPlus = true
|
||||
IsMain = checkMain()
|
||||
|
||||
IsQuiting = false // 是否正在退出
|
||||
EnableDBStat = false // 是否开启本地数据库统计
|
||||
)
|
||||
|
||||
// 检查是否为主程序
|
||||
func checkMain() bool {
|
||||
if len(os.Args) == 1 ||
|
||||
(len(os.Args) >= 2 && os.Args[1] == "pprof") {
|
||||
return true
|
||||
}
|
||||
exe, _ := os.Executable()
|
||||
return strings.HasSuffix(exe, ".test") ||
|
||||
strings.HasSuffix(exe, ".test.exe") ||
|
||||
strings.Contains(exe, "___")
|
||||
}
|
||||
783
EdgeDNS/internal/dbs/db.go
Normal file
783
EdgeDNS/internal/dbs/db.go
Normal file
@@ -0,0 +1,783 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build plus
|
||||
|
||||
package dbs
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/models"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||||
dbutils "github.com/TeaOSLab/EdgeDNS/internal/utils/dbs"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
tableDomains = "domains_v2"
|
||||
tableRecords = "records_v2"
|
||||
tableRoutes = "routes_v2"
|
||||
tableKeys = "keys"
|
||||
tableAgentIPs = "agentIPs"
|
||||
)
|
||||
|
||||
type DB struct {
|
||||
db *dbutils.DB
|
||||
path string
|
||||
|
||||
insertDomainStmt *dbutils.Stmt
|
||||
updateDomainStmt *dbutils.Stmt
|
||||
deleteDomainStmt *dbutils.Stmt
|
||||
existsDomainStmt *dbutils.Stmt
|
||||
listDomainsStmt *dbutils.Stmt
|
||||
|
||||
insertRecordStmt *dbutils.Stmt
|
||||
updateRecordStmt *dbutils.Stmt
|
||||
existsRecordStmt *dbutils.Stmt
|
||||
deleteRecordStmt *dbutils.Stmt
|
||||
listRecordsStmt *dbutils.Stmt
|
||||
|
||||
insertRouteStmt *dbutils.Stmt
|
||||
updateRouteStmt *dbutils.Stmt
|
||||
deleteRouteStmt *dbutils.Stmt
|
||||
listRoutesStmt *dbutils.Stmt
|
||||
existsRouteStmt *dbutils.Stmt
|
||||
|
||||
insertKeyStmt *dbutils.Stmt
|
||||
updateKeyStmt *dbutils.Stmt
|
||||
deleteKeyStmt *dbutils.Stmt
|
||||
listKeysStmt *dbutils.Stmt
|
||||
existsKeyStmt *dbutils.Stmt
|
||||
|
||||
insertAgentIPStmt *dbutils.Stmt
|
||||
listAgentIPsStmt *dbutils.Stmt
|
||||
}
|
||||
|
||||
func NewDB(path string) *DB {
|
||||
return &DB{path: path}
|
||||
}
|
||||
|
||||
func (this *DB) Init() error {
|
||||
// 检查目录是否存在
|
||||
var dir = filepath.Dir(this.path)
|
||||
|
||||
_, err := os.Stat(dir)
|
||||
if err != nil {
|
||||
err = os.MkdirAll(dir, 0777)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
remotelogs.Println("DB", "create database dir '"+dir+"'")
|
||||
}
|
||||
|
||||
// TODO 思考 data.db 的数据安全性
|
||||
db, err := dbutils.OpenWriter("file:" + this.path + "?cache=shared&mode=rwc&_journal_mode=WAL&_locking_mode=EXCLUSIVE")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.SetMaxOpenConns(1)
|
||||
|
||||
/**_, err = db.Exec("VACUUM")
|
||||
if err != nil {
|
||||
return err
|
||||
}**/
|
||||
|
||||
// 创建数据表
|
||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + tableDomains + `" (
|
||||
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
"clusterId" integer DEFAULT 0,
|
||||
"userId" integer DEFAULT 0,
|
||||
"name" varchar(255),
|
||||
"version" integer DEFAULT 0,
|
||||
"tsig" text
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS "clusterId"
|
||||
ON "` + tableDomains + `" (
|
||||
"clusterId"
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + tableRecords + `" (
|
||||
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
"domainId" integer DEFAULT 0,
|
||||
"name" varchar(255),
|
||||
"type" varchar(32),
|
||||
"value" varchar(4096),
|
||||
"mxPriority" integer DEFAULT 10,
|
||||
"srvPriority" integer DEFAULT 10,
|
||||
"srvWeight" integer DEFAULT 10,
|
||||
"srvPort" integer DEFAULT 0,
|
||||
"caaFlag" integer DEFAULT 0,
|
||||
"caaTag" varchar(16),
|
||||
"ttl" integer DEFAULT 0,
|
||||
"weight" integer DEFAULT 0,
|
||||
"routeIds" varchar(512),
|
||||
"version" integer DEFAULT 0
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
// 忽略可以预期的错误
|
||||
if strings.Contains(err.Error(), "duplicate column name") {
|
||||
err = nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + tableRoutes + `" (
|
||||
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
"ranges" text,
|
||||
"order" integer DEFAULT 0,
|
||||
"priority" integer DEFAULT 0,
|
||||
"userId" integer DEFAULT 0,
|
||||
"version" integer DEFAULT 0
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
// 忽略可以预期的错误
|
||||
if strings.Contains(err.Error(), "duplicate column name") {
|
||||
err = nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + tableKeys + `" (
|
||||
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
"domainId" integer DEFAULT 0,
|
||||
"zoneId" integer DEFAULT 0,
|
||||
"algo" varchar(128),
|
||||
"secret" varchar(4096),
|
||||
"secretType" varchar(32),
|
||||
"version" integer DEFAULT 0
|
||||
);`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + tableAgentIPs + `" (
|
||||
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
"ip" varchar(64),
|
||||
"agentCode" varchar(128)
|
||||
);`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 预编译语句
|
||||
// domain statements
|
||||
this.insertDomainStmt, err = db.Prepare(`INSERT INTO "` + tableDomains + `" ("id", "clusterId", "userId", "name", "tsig", "version") VALUES (?, ?, ?, ?, ?, ?)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.updateDomainStmt, err = db.Prepare(`UPDATE "` + tableDomains + `" SET "clusterId"=?, "userId"=?, "name"=?, "tsig"=?, "version"=? WHERE "id"=?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.deleteDomainStmt, err = db.Prepare(`DELETE FROM "` + tableDomains + `" WHERE id=?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.existsDomainStmt, err = db.Prepare(`SELECT "id" FROM "` + tableDomains + `" WHERE "id"=? LIMIT 1`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.listDomainsStmt, err = db.Prepare(`SELECT "id", "clusterId", "userId", "name", "tsig", "version" FROM "` + tableDomains + `" WHERE "clusterId"=? ORDER BY "id" ASC LIMIT ? OFFSET ?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// record statements
|
||||
this.insertRecordStmt, err = db.Prepare(`INSERT INTO "` + tableRecords + `" ("id", "domainId", "name", "type", "value", "mxPriority", "srvPriority", "srvWeight", "srvPort", "caaFlag", "caaTag", "ttl", "weight", "routeIds", "version") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.updateRecordStmt, err = db.Prepare(`UPDATE "` + tableRecords + `" SET "domainId"=?, "name"=?, "type"=?, "value"=?, "mxPriority"=?, "srvPriority"=?, "srvWeight"=?, "srvPort"=?, "caaFlag"=?, "caaTag"=?, "ttl"=?, "weight"=?, "routeIds"=?, "version"=? WHERE "id"=?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.existsRecordStmt, err = db.Prepare(`SELECT "id" FROM "` + tableRecords + `" WHERE "id"=? LIMIT 1`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.deleteRecordStmt, err = db.Prepare(`DELETE FROM "` + tableRecords + `" WHERE id=?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.listRecordsStmt, err = db.Prepare(`SELECT "id", "domainId", "name", "type", "value", "mxPriority", "srvPriority", "srvWeight", "srvPort", "caaFlag", "caaTag", "ttl", "weight", "routeIds", "version" FROM "` + tableRecords + `" ORDER BY "id" ASC LIMIT ? OFFSET ?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// route statements
|
||||
this.insertRouteStmt, err = db.Prepare(`INSERT INTO "` + tableRoutes + `" ("id", "userId", "ranges", "order", "priority", "version") VALUES (?, ?, ?, ?, ?, ?)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.updateRouteStmt, err = db.Prepare(`UPDATE "` + tableRoutes + `" SET "userId"=?, "ranges"=?, "order"=?, "priority"=?, "version"=? WHERE "id"=?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.deleteRouteStmt, err = db.Prepare(`DELETE FROM "` + tableRoutes + `" WHERE "id"=?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.listRoutesStmt, err = db.Prepare(`SELECT "id", "userId", "ranges", "priority", "order", "version" FROM "` + tableRoutes + `" ORDER BY "id" ASC LIMIT ? OFFSET ?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.existsRouteStmt, err = db.Prepare(`SELECT "id" FROM "` + tableRoutes + `" WHERE "id"=? LIMIT 1`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// key statements
|
||||
this.insertKeyStmt, err = db.Prepare(`INSERT INTO "` + tableKeys + `" ("id", "domainId", "zoneId", "algo", "secret", "secretType", "version") VALUES (?, ?, ?, ?, ?, ?, ?)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.updateKeyStmt, err = db.Prepare(`UPDATE "` + tableKeys + `" SET "domainId"=?, "zoneId"=?, "algo"=?, "secret"=?, "secretType"=?, "version"=? WHERE "id"=?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.deleteKeyStmt, err = db.Prepare(`DELETE FROM "` + tableKeys + `" WHERE "id"=?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.listKeysStmt, err = db.Prepare(`SELECT "id", "domainId", "zoneId", "algo", "secret", "secretType", "version" FROM "` + tableKeys + `" ORDER BY "id" ASC LIMIT ? OFFSET ?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.existsKeyStmt, err = db.Prepare(`SELECT "id" FROM "` + tableKeys + `" WHERE "id"=? LIMIT 1`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// agent ip record statements
|
||||
this.insertAgentIPStmt, err = db.Prepare(`INSERT INTO "` + tableAgentIPs + `" ("id", "ip", "agentCode") VALUES (?, ?, ?)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.listAgentIPsStmt, err = db.Prepare(`SELECT "id", "ip", "agentCode" FROM "` + tableAgentIPs + `" ORDER BY "id" ASC LIMIT ? OFFSET ?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.db = db
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *DB) InsertDomain(domainId int64, clusterId int64, userId int64, name string, tsigJSON []byte, version int64) error {
|
||||
if this.db == nil {
|
||||
return errors.New("db should not be nil")
|
||||
}
|
||||
|
||||
this.log("InsertDomain", "domain:", domainId, "user:", userId, "name:", name)
|
||||
|
||||
_, err := this.insertDomainStmt.Exec(domainId, clusterId, userId, name, string(tsigJSON), version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *DB) UpdateDomain(domainId int64, clusterId int64, userId int64, name string, tsigJSON []byte, version int64) error {
|
||||
if this.db == nil {
|
||||
return errors.New("db should not be nil")
|
||||
}
|
||||
|
||||
this.log("UpdateDomain", "domain:", domainId, "user:", userId, "name:", name)
|
||||
|
||||
_, err := this.updateDomainStmt.Exec(clusterId, userId, name, string(tsigJSON), version, domainId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *DB) DeleteDomain(domainId int64) error {
|
||||
if this.db == nil {
|
||||
return errors.New("db should not be nil")
|
||||
}
|
||||
|
||||
this.log("DeleteDomain", "domain:", domainId)
|
||||
|
||||
_, err := this.deleteDomainStmt.Exec(domainId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *DB) ExistsDomain(domainId int64) (bool, error) {
|
||||
if this.db == nil {
|
||||
return false, errors.New("db should not be nil")
|
||||
}
|
||||
rows, err := this.existsDomainStmt.Query(domainId)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if rows.Err() != nil {
|
||||
return false, rows.Err()
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
if rows.Next() {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (this *DB) ListDomains(clusterId int64, offset int, size int) (domains []*models.NSDomain, err error) {
|
||||
if this.db == nil {
|
||||
return nil, errors.New("db should not be nil")
|
||||
}
|
||||
|
||||
rows, err := this.listDomainsStmt.Query(clusterId, size, offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rows.Err() != nil {
|
||||
return nil, rows.Err()
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
|
||||
for rows.Next() {
|
||||
var domain = &models.NSDomain{}
|
||||
var tsigString string
|
||||
err = rows.Scan(&domain.Id, &domain.ClusterId, &domain.UserId, &domain.Name, &tsigString, &domain.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(tsigString) > 0 {
|
||||
var tsigConfig = &dnsconfigs.NSTSIGConfig{}
|
||||
err = json.Unmarshal([]byte(tsigString), tsigConfig)
|
||||
if err != nil {
|
||||
remotelogs.Error("decode tsig string failed: "+err.Error()+", domain:"+domain.Name, ", domainId: "+types.String(domain.Id))
|
||||
} else {
|
||||
domain.TSIG = tsigConfig
|
||||
}
|
||||
}
|
||||
|
||||
domains = append(domains, domain)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (this *DB) InsertRecord(recordId int64, domainId int64, name string, recordType dnsconfigs.RecordType, value string, mxPriority int32, srvPriority int32, srvWeight int32, srvPort int32, caaFlag int32, caaTag string, ttl int32, weight int32, routeIds []string, version int64) error {
|
||||
if this.db == nil {
|
||||
return errors.New("db should not be nil")
|
||||
}
|
||||
|
||||
this.log("InsertRecord", "domain:", domainId, "name:", name)
|
||||
|
||||
_, err := this.insertRecordStmt.Exec(recordId, domainId, name, recordType, value, mxPriority, srvPriority, srvWeight, srvPort, caaFlag, caaTag, ttl, weight, strings.Join(routeIds, ","), version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *DB) UpdateRecord(recordId int64, domainId int64, name string, recordType dnsconfigs.RecordType, value string, mxPriority int32, srvPriority int32, srvWeight int32, srvPort int32, caaFlag int32, caaTag string, ttl int32, weight int32, routeIds []string, version int64) error {
|
||||
if this.db == nil {
|
||||
return errors.New("db should not be nil")
|
||||
}
|
||||
|
||||
this.log("UpdateRecord", "domain:", domainId, "name:", name)
|
||||
|
||||
_, err := this.updateRecordStmt.Exec(domainId, name, recordType, value, mxPriority, srvPriority, srvWeight, srvPort, caaFlag, caaTag, ttl, weight, strings.Join(routeIds, ","), version, recordId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *DB) DeleteRecord(recordId int64) error {
|
||||
if this.db == nil {
|
||||
return errors.New("db should not be nil")
|
||||
}
|
||||
|
||||
this.log("DeleteRecord", "record:", recordId)
|
||||
|
||||
_, err := this.deleteRecordStmt.Exec(recordId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *DB) ExistsRecord(recordId int64) (bool, error) {
|
||||
if this.db == nil {
|
||||
return false, errors.New("db should not be nil")
|
||||
}
|
||||
rows, err := this.existsRecordStmt.Query(recordId)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if rows.Err() != nil {
|
||||
return false, rows.Err()
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
if rows.Next() {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// ListRecords 列出一组记录
|
||||
// TODO 将来只加载本集群上的记录
|
||||
func (this *DB) ListRecords(offset int, size int) (records []*models.NSRecord, err error) {
|
||||
if this.db == nil {
|
||||
return nil, errors.New("db should not be nil")
|
||||
}
|
||||
|
||||
rows, err := this.listRecordsStmt.Query(size, offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rows.Err() != nil {
|
||||
return nil, rows.Err()
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
|
||||
for rows.Next() {
|
||||
var record = &models.NSRecord{}
|
||||
var routeIds = ""
|
||||
err = rows.Scan(&record.Id, &record.DomainId, &record.Name, &record.Type, &record.Value, &record.MXPriority, &record.SRVPriority, &record.SRVWeight, &record.SRVPort, &record.CAAFlag, &record.CAATag, &record.Ttl, &record.Weight, &routeIds, &record.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(routeIds) > 0 {
|
||||
record.RouteIds = strings.Split(routeIds, ",")
|
||||
}
|
||||
records = append(records, record)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// InsertRoute 创建线路
|
||||
func (this *DB) InsertRoute(routeId int64, userId int64, rangesJSON []byte, order int32, priority int32, version int64) error {
|
||||
if this.db == nil {
|
||||
return errors.New("db should not be nil")
|
||||
}
|
||||
|
||||
this.log("InsertRoute", "route:", routeId, "user:", userId)
|
||||
|
||||
_, err := this.insertRouteStmt.Exec(routeId, userId, string(rangesJSON), order, priority, version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateRoute 修改线路
|
||||
func (this *DB) UpdateRoute(routeId int64, userId int64, rangesJSON []byte, order int32, priority int32, version int64) error {
|
||||
if this.db == nil {
|
||||
return errors.New("db should not be nil")
|
||||
}
|
||||
|
||||
this.log("UpdateRoute", "route:", routeId, "user:", userId)
|
||||
|
||||
_, err := this.updateRouteStmt.Exec(userId, string(rangesJSON), order, priority, version, routeId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteRoute 删除线路
|
||||
func (this *DB) DeleteRoute(routeId int64) error {
|
||||
if this.db == nil {
|
||||
return errors.New("db should not be nil")
|
||||
}
|
||||
|
||||
this.log("DeleteRoute", "route:", routeId)
|
||||
|
||||
_, err := this.deleteRouteStmt.Exec(routeId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExistsRoute 检查是否存在线路
|
||||
func (this *DB) ExistsRoute(routeId int64) (bool, error) {
|
||||
if this.db == nil {
|
||||
return false, errors.New("db should not be nil")
|
||||
}
|
||||
rows, err := this.existsRouteStmt.Query(routeId)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if rows.Err() != nil {
|
||||
return false, rows.Err()
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
if rows.Next() {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// ListRoutes 查找所有线路
|
||||
func (this *DB) ListRoutes(offset int64, size int64) (routes []*models.NSRoute, err error) {
|
||||
if this.db == nil {
|
||||
return nil, errors.New("db should not be nil")
|
||||
}
|
||||
|
||||
rows, err := this.listRoutesStmt.Query(size, offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rows.Err() != nil {
|
||||
return nil, rows.Err()
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
|
||||
for rows.Next() {
|
||||
var route = &models.NSRoute{}
|
||||
var rangesString = ""
|
||||
err = rows.Scan(&route.Id, &route.UserId, &rangesString, &route.Priority, &route.Order, &route.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
route.Ranges, err = models.InitRangesFromJSON([]byte(rangesString))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
routes = append(routes, route)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (this *DB) InsertKey(keyId int64, domainId int64, zoneId int64, algo string, secret string, secretType string, version int64) error {
|
||||
if this.db == nil {
|
||||
return errors.New("db should not be nil")
|
||||
}
|
||||
|
||||
this.log("InsertKey", "key:", keyId, "domain:", domainId, "zone:", zoneId)
|
||||
|
||||
_, err := this.insertKeyStmt.Exec(keyId, domainId, zoneId, algo, secret, secretType, version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *DB) UpdateKey(keyId int64, domainId int64, zoneId int64, algo string, secret string, secretType string, version int64) error {
|
||||
if this.db == nil {
|
||||
return errors.New("db should not be nil")
|
||||
}
|
||||
|
||||
this.log("UpdateKey", "key:", keyId, "domain:", domainId, "zone:", zoneId)
|
||||
|
||||
_, err := this.updateKeyStmt.Exec(domainId, zoneId, algo, secret, secretType, version, keyId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *DB) ListKeys(offset int, size int) (keys []*models.NSKey, err error) {
|
||||
if this.db == nil {
|
||||
return nil, errors.New("db should not be nil")
|
||||
}
|
||||
rows, err := this.listKeysStmt.Query(size, offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rows.Err() != nil {
|
||||
return nil, rows.Err()
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
for rows.Next() {
|
||||
var key = &models.NSKey{}
|
||||
err = rows.Scan(&key.Id, &key.DomainId, &key.ZoneId, &key.Algo, &key.Secret, &key.SecretType, &key.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keys = append(keys, key)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (this *DB) DeleteKey(keyId int64) error {
|
||||
if this.db == nil {
|
||||
return errors.New("db should not be nil")
|
||||
}
|
||||
|
||||
this.log("DeleteKey", "key:", keyId)
|
||||
|
||||
_, err := this.deleteKeyStmt.Exec(keyId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *DB) ExistsKey(keyId int64) (bool, error) {
|
||||
if this.db == nil {
|
||||
return false, errors.New("db should not be nil")
|
||||
}
|
||||
rows, err := this.existsKeyStmt.Query(keyId)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if rows.Err() != nil {
|
||||
return false, rows.Err()
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
if rows.Next() {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (this *DB) InsertAgentIP(ipId int64, ip string, agentCode string) error {
|
||||
if this.db == nil {
|
||||
return errors.New("db should not be nil")
|
||||
}
|
||||
|
||||
this.log("InsertAgentIP", "id:", ipId, "ip:", ip, "agent:", agentCode)
|
||||
_, err := this.insertAgentIPStmt.Exec(ipId, ip, agentCode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *DB) ListAgentIPs(offset int64, size int64) (agentIPs []*models.AgentIP, err error) {
|
||||
if this.db == nil {
|
||||
return nil, errors.New("db should not be nil")
|
||||
}
|
||||
rows, err := this.listAgentIPsStmt.Query(size, offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rows.Err() != nil {
|
||||
return nil, rows.Err()
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
for rows.Next() {
|
||||
var agentIP = &models.AgentIP{}
|
||||
err = rows.Scan(&agentIP.Id, &agentIP.IP, &agentIP.AgentCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
agentIPs = append(agentIPs, agentIP)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (this *DB) Close() error {
|
||||
if this.db == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, stmt := range []*dbutils.Stmt{
|
||||
this.insertDomainStmt,
|
||||
this.updateDomainStmt,
|
||||
this.deleteDomainStmt,
|
||||
this.existsDomainStmt,
|
||||
this.listDomainsStmt,
|
||||
|
||||
this.insertRecordStmt,
|
||||
this.updateRecordStmt,
|
||||
this.existsRecordStmt,
|
||||
this.deleteRecordStmt,
|
||||
this.listRecordsStmt,
|
||||
|
||||
this.insertRouteStmt,
|
||||
this.updateRouteStmt,
|
||||
this.deleteRouteStmt,
|
||||
this.listRoutesStmt,
|
||||
this.existsRouteStmt,
|
||||
|
||||
this.insertKeyStmt,
|
||||
this.updateKeyStmt,
|
||||
this.deleteKeyStmt,
|
||||
this.listKeysStmt,
|
||||
this.existsKeyStmt,
|
||||
|
||||
this.insertAgentIPStmt,
|
||||
this.listAgentIPsStmt,
|
||||
} {
|
||||
if stmt != nil {
|
||||
_ = stmt.Close()
|
||||
}
|
||||
}
|
||||
|
||||
err := this.db.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 打印日志
|
||||
func (this *DB) log(args ...any) {
|
||||
if !Tea.IsTesting() {
|
||||
return
|
||||
}
|
||||
if len(args) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
args[0] = "[" + types.String(args[0]) + "]"
|
||||
log.Println(args...)
|
||||
}
|
||||
228
EdgeDNS/internal/dbs/db_test.go
Normal file
228
EdgeDNS/internal/dbs/db_test.go
Normal file
@@ -0,0 +1,228 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build plus
|
||||
|
||||
package dbs
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDB_Init(t *testing.T) {
|
||||
db := NewDB(Tea.Root + "/data/data.db")
|
||||
err := db.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestDB_InsertDomain(t *testing.T) {
|
||||
db := NewDB(Tea.Root + "/data/data.db")
|
||||
err := db.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.InsertDomain(1, 1, 1, "examples.com", nil, 1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.InsertDomain(2, 2, 1, "examples2.com", nil, 1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestDB_UpdateDomain(t *testing.T) {
|
||||
db := NewDB(Tea.Root + "/data/data.db")
|
||||
err := db.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.UpdateDomain(1, 1, 1, "examples2.com", nil, 1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestDB_DeleteDomain(t *testing.T) {
|
||||
db := NewDB(Tea.Root + "/data/data.db")
|
||||
err := db.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.DeleteDomain(1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestDB_ExistsDomain(t *testing.T) {
|
||||
db := NewDB(Tea.Root + "/data/data.db")
|
||||
err := db.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
b, err := db.ExistsDomain(1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("exists:", b)
|
||||
}
|
||||
|
||||
func TestDB_FindAllDomains(t *testing.T) {
|
||||
db := NewDB(Tea.Root + "/data/data.db")
|
||||
err := db.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
limit := 2
|
||||
offset := 0
|
||||
for {
|
||||
t.Log("===", offset, "===")
|
||||
domains, err := db.ListDomains(2, offset, limit)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(domains) == 0 {
|
||||
break
|
||||
}
|
||||
for _, domain := range domains {
|
||||
t.Log(domain.Name)
|
||||
}
|
||||
offset += limit
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_InsertRecord(t *testing.T) {
|
||||
db := NewDB(Tea.Root + "/data/data.db")
|
||||
err := db.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = db.InsertRecord(1, 1, "a", dnsconfigs.RecordTypeA, "192.168.1.100", 0, 3600, 10, 8080, 1, "", 3600, 10, []string{"id:100", "id:1"}, 1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestDB_UpdateRecord(t *testing.T) {
|
||||
db := NewDB(Tea.Root + "/data/data.db")
|
||||
err := db.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = db.UpdateRecord(1, 1, "a1", dnsconfigs.RecordTypeA, "192.168.1.101", 0, 3600, 10, 8080, 1, "", 3600, 10, []string{"id:100", "id:1"}, 1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestDB_InsertRoute(t *testing.T) {
|
||||
db := NewDB(Tea.Root + "/data/data.db")
|
||||
err := db.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = db.InsertRoute(1, 1, []byte("[]"), 1, 0, 1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestDB_UpdateRoute(t *testing.T) {
|
||||
db := NewDB(Tea.Root + "/data/data.db")
|
||||
err := db.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = db.UpdateRoute(1, 1, []byte("[{}, {}]"), 2, 0, 1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestDB_InsertKey(t *testing.T) {
|
||||
db := NewDB(Tea.Root + "/data/data.db")
|
||||
err := db.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = db.InsertKey(1, 2, 3, "md5", "secret123", dnsconfigs.NSKeySecretTypeClear, 4)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestDB_UpdateKey(t *testing.T) {
|
||||
db := NewDB(Tea.Root + "/data/data.db")
|
||||
err := db.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = db.UpdateKey(1, 22, 33, "sha1", "secret456", dnsconfigs.NSKeySecretTypeBase64, 5)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestDB_DeleteKey(t *testing.T) {
|
||||
db := NewDB(Tea.Root + "/data/data.db")
|
||||
err := db.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = db.DeleteKey(1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestDB_ExistsKey(t *testing.T) {
|
||||
db := NewDB(Tea.Root + "/data/data.db")
|
||||
err := db.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for _, id := range []int64{1, 2, 3} {
|
||||
b, err := db.ExistsKey(id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(id, b)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_ListKeys(t *testing.T) {
|
||||
db := NewDB(Tea.Root + "/data/data.db")
|
||||
err := db.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
keys, err := db.ListKeys(0, 10)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for _, key := range keys {
|
||||
t.Log(key)
|
||||
}
|
||||
}
|
||||
41
EdgeDNS/internal/encrypt/magic_key.go
Normal file
41
EdgeDNS/internal/encrypt/magic_key.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package encrypt
|
||||
|
||||
import (
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
)
|
||||
|
||||
const (
|
||||
MagicKey = "f1c8eafb543f03023e97b7be864a4e9b"
|
||||
)
|
||||
|
||||
// 加密特殊信息
|
||||
func MagicKeyEncode(data []byte) []byte {
|
||||
method, err := NewMethodInstance("aes-256-cfb", MagicKey, MagicKey[:16])
|
||||
if err != nil {
|
||||
logs.Println("[MagicKeyEncode]" + err.Error())
|
||||
return data
|
||||
}
|
||||
|
||||
dst, err := method.Encrypt(data)
|
||||
if err != nil {
|
||||
logs.Println("[MagicKeyEncode]" + err.Error())
|
||||
return data
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// 解密特殊信息
|
||||
func MagicKeyDecode(data []byte) []byte {
|
||||
method, err := NewMethodInstance("aes-256-cfb", MagicKey, MagicKey[:16])
|
||||
if err != nil {
|
||||
logs.Println("[MagicKeyEncode]" + err.Error())
|
||||
return data
|
||||
}
|
||||
|
||||
src, err := method.Decrypt(data)
|
||||
if err != nil {
|
||||
logs.Println("[MagicKeyEncode]" + err.Error())
|
||||
return data
|
||||
}
|
||||
return src
|
||||
}
|
||||
11
EdgeDNS/internal/encrypt/magic_key_test.go
Normal file
11
EdgeDNS/internal/encrypt/magic_key_test.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package encrypt
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestMagicKeyEncode(t *testing.T) {
|
||||
dst := MagicKeyEncode([]byte("Hello,World"))
|
||||
t.Log("dst:", string(dst))
|
||||
|
||||
src := MagicKeyDecode(dst)
|
||||
t.Log("src:", string(src))
|
||||
}
|
||||
12
EdgeDNS/internal/encrypt/method.go
Normal file
12
EdgeDNS/internal/encrypt/method.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package encrypt
|
||||
|
||||
type MethodInterface interface {
|
||||
// 初始化
|
||||
Init(key []byte, iv []byte) error
|
||||
|
||||
// 加密
|
||||
Encrypt(src []byte) (dst []byte, err error)
|
||||
|
||||
// 解密
|
||||
Decrypt(dst []byte) (src []byte, err error)
|
||||
}
|
||||
73
EdgeDNS/internal/encrypt/method_aes_128_cfb.go
Normal file
73
EdgeDNS/internal/encrypt/method_aes_128_cfb.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package encrypt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
)
|
||||
|
||||
type AES128CFBMethod struct {
|
||||
iv []byte
|
||||
block cipher.Block
|
||||
}
|
||||
|
||||
func (this *AES128CFBMethod) Init(key, iv []byte) error {
|
||||
// 判断key是否为32长度
|
||||
l := len(key)
|
||||
if l > 16 {
|
||||
key = key[:16]
|
||||
} else if l < 16 {
|
||||
key = append(key, bytes.Repeat([]byte{' '}, 16-l)...)
|
||||
}
|
||||
|
||||
// 判断iv长度
|
||||
l2 := len(iv)
|
||||
if l2 > aes.BlockSize {
|
||||
iv = iv[:aes.BlockSize]
|
||||
} else if l2 < aes.BlockSize {
|
||||
iv = append(iv, bytes.Repeat([]byte{' '}, aes.BlockSize-l2)...)
|
||||
}
|
||||
|
||||
this.iv = iv
|
||||
|
||||
// block
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
this.block = block
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *AES128CFBMethod) Encrypt(src []byte) (dst []byte, err error) {
|
||||
if len(src) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = RecoverMethodPanic(recover())
|
||||
}()
|
||||
|
||||
dst = make([]byte, len(src))
|
||||
encrypter := cipher.NewCFBEncrypter(this.block, this.iv)
|
||||
encrypter.XORKeyStream(dst, src)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (this *AES128CFBMethod) Decrypt(dst []byte) (src []byte, err error) {
|
||||
if len(dst) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = RecoverMethodPanic(recover())
|
||||
}()
|
||||
|
||||
src = make([]byte, len(dst))
|
||||
encrypter := cipher.NewCFBDecrypter(this.block, this.iv)
|
||||
encrypter.XORKeyStream(src, dst)
|
||||
|
||||
return
|
||||
}
|
||||
90
EdgeDNS/internal/encrypt/method_aes_128_cfb_test.go
Normal file
90
EdgeDNS/internal/encrypt/method_aes_128_cfb_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package encrypt
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAES128CFBMethod_Encrypt(t *testing.T) {
|
||||
method, err := NewMethodInstance("aes-128-cfb", "abc", "123")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
src := []byte("Hello, World")
|
||||
dst, err := method.Encrypt(src)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dst = dst[:len(src)]
|
||||
t.Log("dst:", string(dst))
|
||||
|
||||
src, err = method.Decrypt(dst)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("src:", string(src))
|
||||
}
|
||||
|
||||
func TestAES128CFBMethod_Encrypt2(t *testing.T) {
|
||||
method, err := NewMethodInstance("aes-128-cfb", "abc", "123")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
sources := [][]byte{}
|
||||
|
||||
{
|
||||
a := []byte{1}
|
||||
_, err = method.Encrypt(a)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
src := []byte(strings.Repeat("Hello", 1))
|
||||
dst, err := method.Encrypt(src)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
sources = append(sources, dst)
|
||||
}
|
||||
|
||||
{
|
||||
|
||||
a := []byte{1}
|
||||
_, err = method.Decrypt(a)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, dst := range sources {
|
||||
dst2 := append([]byte{}, dst...)
|
||||
src2, err := method.Decrypt(dst2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(string(src2))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAES128CFBMethod_Encrypt(b *testing.B) {
|
||||
runtime.GOMAXPROCS(1)
|
||||
|
||||
method, err := NewMethodInstance("aes-128-cfb", "abc", "123")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
src := []byte(strings.Repeat("Hello", 1024))
|
||||
for i := 0; i < b.N; i++ {
|
||||
dst, err := method.Encrypt(src)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
_ = dst
|
||||
}
|
||||
}
|
||||
74
EdgeDNS/internal/encrypt/method_aes_192_cfb.go
Normal file
74
EdgeDNS/internal/encrypt/method_aes_192_cfb.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package encrypt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
)
|
||||
|
||||
type AES192CFBMethod struct {
|
||||
block cipher.Block
|
||||
iv []byte
|
||||
}
|
||||
|
||||
func (this *AES192CFBMethod) Init(key, iv []byte) error {
|
||||
// 判断key是否为24长度
|
||||
l := len(key)
|
||||
if l > 24 {
|
||||
key = key[:24]
|
||||
} else if l < 24 {
|
||||
key = append(key, bytes.Repeat([]byte{' '}, 24-l)...)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
this.block = block
|
||||
|
||||
// 判断iv长度
|
||||
l2 := len(iv)
|
||||
if l2 > aes.BlockSize {
|
||||
iv = iv[:aes.BlockSize]
|
||||
} else if l2 < aes.BlockSize {
|
||||
iv = append(iv, bytes.Repeat([]byte{' '}, aes.BlockSize-l2)...)
|
||||
}
|
||||
|
||||
this.iv = iv
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *AES192CFBMethod) Encrypt(src []byte) (dst []byte, err error) {
|
||||
if len(src) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = RecoverMethodPanic(recover())
|
||||
}()
|
||||
|
||||
dst = make([]byte, len(src))
|
||||
|
||||
encrypter := cipher.NewCFBEncrypter(this.block, this.iv)
|
||||
encrypter.XORKeyStream(dst, src)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (this *AES192CFBMethod) Decrypt(dst []byte) (src []byte, err error) {
|
||||
if len(dst) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = RecoverMethodPanic(recover())
|
||||
}()
|
||||
|
||||
src = make([]byte, len(dst))
|
||||
|
||||
decrypter := cipher.NewCFBDecrypter(this.block, this.iv)
|
||||
decrypter.XORKeyStream(src, dst)
|
||||
|
||||
return
|
||||
}
|
||||
45
EdgeDNS/internal/encrypt/method_aes_192_cfb_test.go
Normal file
45
EdgeDNS/internal/encrypt/method_aes_192_cfb_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package encrypt
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAES192CFBMethod_Encrypt(t *testing.T) {
|
||||
method, err := NewMethodInstance("aes-192-cfb", "abc", "123")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
src := []byte("Hello, World")
|
||||
dst, err := method.Encrypt(src)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dst = dst[:len(src)]
|
||||
t.Log("dst:", string(dst))
|
||||
|
||||
src, err = method.Decrypt(dst)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("src:", string(src))
|
||||
}
|
||||
|
||||
func BenchmarkAES192CFBMethod_Encrypt(b *testing.B) {
|
||||
runtime.GOMAXPROCS(1)
|
||||
|
||||
method, err := NewMethodInstance("aes-192-cfb", "abc", "123")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
src := []byte(strings.Repeat("Hello", 1024))
|
||||
for i := 0; i < b.N; i++ {
|
||||
dst, err := method.Encrypt(src)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
_ = dst
|
||||
}
|
||||
}
|
||||
72
EdgeDNS/internal/encrypt/method_aes_256_cfb.go
Normal file
72
EdgeDNS/internal/encrypt/method_aes_256_cfb.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package encrypt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
)
|
||||
|
||||
type AES256CFBMethod struct {
|
||||
block cipher.Block
|
||||
iv []byte
|
||||
}
|
||||
|
||||
func (this *AES256CFBMethod) Init(key, iv []byte) error {
|
||||
// 判断key是否为32长度
|
||||
l := len(key)
|
||||
if l > 32 {
|
||||
key = key[:32]
|
||||
} else if l < 32 {
|
||||
key = append(key, bytes.Repeat([]byte{' '}, 32-l)...)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
this.block = block
|
||||
|
||||
// 判断iv长度
|
||||
l2 := len(iv)
|
||||
if l2 > aes.BlockSize {
|
||||
iv = iv[:aes.BlockSize]
|
||||
} else if l2 < aes.BlockSize {
|
||||
iv = append(iv, bytes.Repeat([]byte{' '}, aes.BlockSize-l2)...)
|
||||
}
|
||||
this.iv = iv
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *AES256CFBMethod) Encrypt(src []byte) (dst []byte, err error) {
|
||||
if len(src) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = RecoverMethodPanic(recover())
|
||||
}()
|
||||
|
||||
dst = make([]byte, len(src))
|
||||
|
||||
encrypter := cipher.NewCFBEncrypter(this.block, this.iv)
|
||||
encrypter.XORKeyStream(dst, src)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (this *AES256CFBMethod) Decrypt(dst []byte) (src []byte, err error) {
|
||||
if len(dst) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = RecoverMethodPanic(recover())
|
||||
}()
|
||||
|
||||
src = make([]byte, len(dst))
|
||||
decrypter := cipher.NewCFBDecrypter(this.block, this.iv)
|
||||
decrypter.XORKeyStream(src, dst)
|
||||
|
||||
return
|
||||
}
|
||||
42
EdgeDNS/internal/encrypt/method_aes_256_cfb_test.go
Normal file
42
EdgeDNS/internal/encrypt/method_aes_256_cfb_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package encrypt
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestAES256CFBMethod_Encrypt(t *testing.T) {
|
||||
method, err := NewMethodInstance("aes-256-cfb", "abc", "123")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
src := []byte("Hello, World")
|
||||
dst, err := method.Encrypt(src)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dst = dst[:len(src)]
|
||||
t.Log("dst:", string(dst))
|
||||
|
||||
src, err = method.Decrypt(dst)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("src:", string(src))
|
||||
}
|
||||
|
||||
func TestAES256CFBMethod_Encrypt2(t *testing.T) {
|
||||
method, err := NewMethodInstance("aes-256-cfb", "abc", "123")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
src := []byte("Hello, World")
|
||||
dst, err := method.Encrypt(src)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("dst:", string(dst))
|
||||
|
||||
src, err = method.Decrypt(dst)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("src:", string(src))
|
||||
}
|
||||
26
EdgeDNS/internal/encrypt/method_raw.go
Normal file
26
EdgeDNS/internal/encrypt/method_raw.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package encrypt
|
||||
|
||||
type RawMethod struct {
|
||||
}
|
||||
|
||||
func (this *RawMethod) Init(key, iv []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *RawMethod) Encrypt(src []byte) (dst []byte, err error) {
|
||||
if len(src) == 0 {
|
||||
return
|
||||
}
|
||||
dst = make([]byte, len(src))
|
||||
copy(dst, src)
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RawMethod) Decrypt(dst []byte) (src []byte, err error) {
|
||||
if len(dst) == 0 {
|
||||
return
|
||||
}
|
||||
src = make([]byte, len(dst))
|
||||
copy(src, dst)
|
||||
return
|
||||
}
|
||||
23
EdgeDNS/internal/encrypt/method_raw_test.go
Normal file
23
EdgeDNS/internal/encrypt/method_raw_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package encrypt
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestRawMethod_Encrypt(t *testing.T) {
|
||||
method, err := NewMethodInstance("raw", "abc", "123")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
src := []byte("Hello, World")
|
||||
dst, err := method.Encrypt(src)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dst = dst[:len(src)]
|
||||
t.Log("dst:", string(dst))
|
||||
|
||||
src, err = method.Decrypt(dst)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("src:", string(src))
|
||||
}
|
||||
43
EdgeDNS/internal/encrypt/method_utils.go
Normal file
43
EdgeDNS/internal/encrypt/method_utils.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package encrypt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
var methods = map[string]reflect.Type{
|
||||
"raw": reflect.TypeOf(new(RawMethod)).Elem(),
|
||||
"aes-128-cfb": reflect.TypeOf(new(AES128CFBMethod)).Elem(),
|
||||
"aes-192-cfb": reflect.TypeOf(new(AES192CFBMethod)).Elem(),
|
||||
"aes-256-cfb": reflect.TypeOf(new(AES256CFBMethod)).Elem(),
|
||||
}
|
||||
|
||||
func NewMethodInstance(method string, key string, iv string) (MethodInterface, error) {
|
||||
valueType, ok := methods[method]
|
||||
if !ok {
|
||||
return nil, errors.New("method '" + method + "' not found")
|
||||
}
|
||||
instance, ok := reflect.New(valueType).Interface().(MethodInterface)
|
||||
if !ok {
|
||||
return nil, errors.New("method '" + method + "' must implement MethodInterface")
|
||||
}
|
||||
err := instance.Init([]byte(key), []byte(iv))
|
||||
return instance, err
|
||||
}
|
||||
|
||||
func RecoverMethodPanic(err interface{}) error {
|
||||
if err != nil {
|
||||
s, ok := err.(string)
|
||||
if ok {
|
||||
return errors.New(s)
|
||||
}
|
||||
|
||||
e, ok := err.(error)
|
||||
if ok {
|
||||
return e
|
||||
}
|
||||
|
||||
return errors.New("unknown error")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
8
EdgeDNS/internal/encrypt/method_utils_test.go
Normal file
8
EdgeDNS/internal/encrypt/method_utils_test.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package encrypt
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestFindMethodInstance(t *testing.T) {
|
||||
t.Log(NewMethodInstance("a", "b", ""))
|
||||
t.Log(NewMethodInstance("aes-256-cfb", "123456", ""))
|
||||
}
|
||||
12
EdgeDNS/internal/events/events.go
Normal file
12
EdgeDNS/internal/events/events.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package events
|
||||
|
||||
type Event = string
|
||||
|
||||
const (
|
||||
EventStart Event = "start" // start loading
|
||||
EventLoaded Event = "load" // loaded
|
||||
EventQuit Event = "quit" // quit node gracefully
|
||||
EventTerminated Event = "terminated" // process terminated
|
||||
EventReload Event = "reload" // reload config
|
||||
EventNFTablesReady Event = "nftablesReady" // nftables ready
|
||||
)
|
||||
27
EdgeDNS/internal/events/utils.go
Normal file
27
EdgeDNS/internal/events/utils.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package events
|
||||
|
||||
import "sync"
|
||||
|
||||
var eventsMap = map[string][]func(){} // event => []callbacks
|
||||
var locker = sync.Mutex{}
|
||||
|
||||
// On 增加事件回调
|
||||
func On(event string, callback func()) {
|
||||
locker.Lock()
|
||||
defer locker.Unlock()
|
||||
|
||||
var callbacks = eventsMap[event]
|
||||
callbacks = append(callbacks, callback)
|
||||
eventsMap[event] = callbacks
|
||||
}
|
||||
|
||||
// Notify 通知事件
|
||||
func Notify(event string) {
|
||||
locker.Lock()
|
||||
var callbacks = eventsMap[event]
|
||||
locker.Unlock()
|
||||
|
||||
for _, callback := range callbacks {
|
||||
callback()
|
||||
}
|
||||
}
|
||||
16
EdgeDNS/internal/events/utils_test.go
Normal file
16
EdgeDNS/internal/events/utils_test.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package events
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestOn(t *testing.T) {
|
||||
On("hello", func() {
|
||||
t.Log("world")
|
||||
})
|
||||
On("hello", func() {
|
||||
t.Log("world2")
|
||||
})
|
||||
On("hello2", func() {
|
||||
t.Log("world2")
|
||||
})
|
||||
Notify("hello")
|
||||
}
|
||||
570
EdgeDNS/internal/firewalls/ddos_protection.go
Normal file
570
EdgeDNS/internal/firewalls/ddos_protection.go
Normal file
@@ -0,0 +1,570 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
|
||||
package firewalls
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/configs"
|
||||
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/events"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/firewalls/nftables"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/utils"
|
||||
executils "github.com/TeaOSLab/EdgeDNS/internal/utils/exec"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/utils/zero"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
stringutil "github.com/iwind/TeaGo/utils/string"
|
||||
"net"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var SharedDDoSProtectionManager = NewDDoSProtectionManager()
|
||||
|
||||
func init() {
|
||||
if !teaconst.IsMain {
|
||||
return
|
||||
}
|
||||
|
||||
events.On(events.EventReload, func() {
|
||||
if nftablesInstance == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var nodeConfig = configs.SharedNodeConfig
|
||||
if nodeConfig != nil {
|
||||
err := SharedDDoSProtectionManager.Apply(nodeConfig.DDoSProtection)
|
||||
if err != nil {
|
||||
remotelogs.Error("FIREWALL", "apply DDoS protection failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
events.On(events.EventNFTablesReady, func() {
|
||||
var nodeConfig = configs.SharedNodeConfig
|
||||
if nodeConfig != nil {
|
||||
err := SharedDDoSProtectionManager.Apply(nodeConfig.DDoSProtection)
|
||||
if err != nil {
|
||||
remotelogs.Error("FIREWALL", "apply DDoS protection failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// DDoSProtectionManager DDoS防护
|
||||
type DDoSProtectionManager struct {
|
||||
nftPath string
|
||||
|
||||
lastAllowIPList []string
|
||||
lastConfig []byte
|
||||
}
|
||||
|
||||
// NewDDoSProtectionManager 获取新对象
|
||||
func NewDDoSProtectionManager() *DDoSProtectionManager {
|
||||
nftPath, _ := executils.LookPath("nft")
|
||||
|
||||
return &DDoSProtectionManager{
|
||||
nftPath: nftPath,
|
||||
}
|
||||
}
|
||||
|
||||
// Apply 应用配置
|
||||
func (this *DDoSProtectionManager) Apply(config *ddosconfigs.ProtectionConfig) error {
|
||||
// 同集群节点IP白名单
|
||||
var allowIPListChanged = false
|
||||
var nodeConfig = configs.SharedNodeConfig
|
||||
if nodeConfig != nil {
|
||||
var allowIPList = nodeConfig.AllowedIPs
|
||||
if !utils.EqualStrings(allowIPList, this.lastAllowIPList) {
|
||||
allowIPListChanged = true
|
||||
this.lastAllowIPList = allowIPList
|
||||
}
|
||||
}
|
||||
|
||||
// 对比配置
|
||||
configJSON, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encode config to json failed: %w", err)
|
||||
}
|
||||
if config == nil {
|
||||
configJSON = nil
|
||||
}
|
||||
|
||||
if !allowIPListChanged && bytes.Equal(this.lastConfig, configJSON) {
|
||||
return nil
|
||||
}
|
||||
remotelogs.Println("FIREWALL", "change DDoS protection config")
|
||||
|
||||
if len(this.nftPath) == 0 {
|
||||
return errors.New("can not find nft command")
|
||||
}
|
||||
|
||||
if nftablesInstance == nil {
|
||||
return errors.New("nftables instance should not be nil")
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
// TCP
|
||||
err := this.removeTCPRules()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO other protocols
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TCP
|
||||
if config.TCP == nil {
|
||||
err := this.removeTCPRules()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// allow ip list
|
||||
var allowIPList = []string{}
|
||||
for _, ipConfig := range config.TCP.AllowIPList {
|
||||
allowIPList = append(allowIPList, ipConfig.IP)
|
||||
}
|
||||
for _, ip := range this.lastAllowIPList {
|
||||
if !lists.ContainsString(allowIPList, ip) {
|
||||
allowIPList = append(allowIPList, ip)
|
||||
}
|
||||
}
|
||||
err = this.updateAllowIPList(allowIPList)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// tcp
|
||||
if config.TCP.IsOn {
|
||||
err := this.addTCPRules(config.TCP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
err := this.removeTCPRules()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this.lastConfig = configJSON
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 添加TCP规则
|
||||
func (this *DDoSProtectionManager) addTCPRules(tcpConfig *ddosconfigs.TCPConfig) error {
|
||||
// 检查nft版本不能小于0.9
|
||||
if len(nftablesInstance.version) > 0 && stringutil.VersionCompare("0.9", nftablesInstance.version) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var ports = []int32{}
|
||||
for _, portConfig := range tcpConfig.Ports {
|
||||
if !lists.ContainsInt32(ports, portConfig.Port) {
|
||||
ports = append(ports, portConfig.Port)
|
||||
}
|
||||
}
|
||||
if len(ports) == 0 {
|
||||
ports = []int32{53}
|
||||
}
|
||||
|
||||
for _, filter := range nftablesFilters {
|
||||
chain, oldRules, err := this.getRules(filter)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get old rules failed: %w", err)
|
||||
}
|
||||
|
||||
var protocol = filter.protocol()
|
||||
|
||||
// max connections
|
||||
var maxConnections = tcpConfig.MaxConnections
|
||||
if maxConnections <= 0 {
|
||||
maxConnections = dnsconfigs.DefaultTCPMaxConnections
|
||||
if maxConnections <= 0 {
|
||||
maxConnections = 100000
|
||||
}
|
||||
}
|
||||
|
||||
// max connections per ip
|
||||
var maxConnectionsPerIP = tcpConfig.MaxConnectionsPerIP
|
||||
if maxConnectionsPerIP <= 0 {
|
||||
maxConnectionsPerIP = dnsconfigs.DefaultTCPMaxConnectionsPerIP
|
||||
if maxConnectionsPerIP <= 0 {
|
||||
maxConnectionsPerIP = 100000
|
||||
}
|
||||
}
|
||||
|
||||
// new connections rate (minutely)
|
||||
var newConnectionsMinutelyRate = tcpConfig.NewConnectionsMinutelyRate
|
||||
if newConnectionsMinutelyRate <= 0 {
|
||||
newConnectionsMinutelyRate = nodeconfigs.DefaultTCPNewConnectionsMinutelyRate
|
||||
if newConnectionsMinutelyRate <= 0 {
|
||||
newConnectionsMinutelyRate = 100000
|
||||
}
|
||||
}
|
||||
var newConnectionsMinutelyRateBlockTimeout = tcpConfig.NewConnectionsMinutelyRateBlockTimeout
|
||||
if newConnectionsMinutelyRateBlockTimeout < 0 {
|
||||
newConnectionsMinutelyRateBlockTimeout = 0
|
||||
}
|
||||
|
||||
// new connections rate (secondly)
|
||||
var newConnectionsSecondlyRate = tcpConfig.NewConnectionsSecondlyRate
|
||||
if newConnectionsSecondlyRate <= 0 {
|
||||
newConnectionsSecondlyRate = nodeconfigs.DefaultTCPNewConnectionsSecondlyRate
|
||||
if newConnectionsSecondlyRate <= 0 {
|
||||
newConnectionsSecondlyRate = 10000
|
||||
}
|
||||
}
|
||||
var newConnectionsSecondlyRateBlockTimeout = tcpConfig.NewConnectionsSecondlyRateBlockTimeout
|
||||
if newConnectionsSecondlyRateBlockTimeout < 0 {
|
||||
newConnectionsSecondlyRateBlockTimeout = 0
|
||||
}
|
||||
|
||||
// 检查是否有变化
|
||||
var hasChanges = false
|
||||
for _, port := range ports {
|
||||
if !this.existsRule(oldRules, []string{"tcp", types.String(port), "maxConnections", types.String(maxConnections)}) {
|
||||
hasChanges = true
|
||||
break
|
||||
}
|
||||
if !this.existsRule(oldRules, []string{"tcp", types.String(port), "maxConnectionsPerIP", types.String(maxConnectionsPerIP)}) {
|
||||
hasChanges = true
|
||||
break
|
||||
}
|
||||
if !this.existsRule(oldRules, []string{"tcp", types.String(port), "newConnectionsRate", types.String(newConnectionsMinutelyRate), types.String(newConnectionsMinutelyRateBlockTimeout)}) {
|
||||
hasChanges = true
|
||||
break
|
||||
}
|
||||
if !this.existsRule(oldRules, []string{"tcp", types.String(port), "newConnectionsSecondlyRate", types.String(newConnectionsSecondlyRate), types.String(newConnectionsSecondlyRateBlockTimeout)}) {
|
||||
hasChanges = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasChanges {
|
||||
// 检查是否有多余的端口
|
||||
var oldPorts = this.getTCPPorts(oldRules)
|
||||
if !this.eqPorts(ports, oldPorts) {
|
||||
hasChanges = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasChanges {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 先清空所有相关规则
|
||||
err = this.removeOldTCPRules(chain, oldRules)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete old rules failed: %w", err)
|
||||
}
|
||||
|
||||
// 添加新规则
|
||||
for _, port := range ports {
|
||||
// TODO 让用户选择是drop还是reject
|
||||
if maxConnections > 0 {
|
||||
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "count", "over", types.String(maxConnections), "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "maxConnections", types.String(maxConnections)}))
|
||||
var stderr = &bytes.Buffer{}
|
||||
cmd.Stderr = stderr
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return fmt.Errorf("add nftables rule '%s' failed: %w (%s)", cmd.String(), err, stderr.String())
|
||||
}
|
||||
}
|
||||
|
||||
// TODO 让用户选择是drop还是reject
|
||||
if maxConnectionsPerIP > 0 {
|
||||
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "meter", "meter-"+protocol+"-"+types.String(port)+"-max-connections", "{ "+protocol+" saddr ct count over "+types.String(maxConnectionsPerIP)+" }", "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "maxConnectionsPerIP", types.String(maxConnectionsPerIP)}))
|
||||
var stderr = &bytes.Buffer{}
|
||||
cmd.Stderr = stderr
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return fmt.Errorf("add nftables rule '%s' failed: %w (%s)", cmd.String(), err, stderr.String())
|
||||
}
|
||||
}
|
||||
|
||||
// 超过一定速率就drop或者加入黑名单(分钟)
|
||||
// TODO 让用户选择是drop还是reject
|
||||
if newConnectionsMinutelyRate > 0 {
|
||||
if newConnectionsMinutelyRateBlockTimeout > 0 {
|
||||
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsMinutelyRate)+"/minute burst "+types.String(newConnectionsMinutelyRate+3)+" packets }", "add", "@deny_set", "{"+protocol+" saddr timeout "+types.String(newConnectionsMinutelyRateBlockTimeout)+"s}", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsRate", types.String(newConnectionsMinutelyRate), types.String(newConnectionsMinutelyRateBlockTimeout)}))
|
||||
var stderr = &bytes.Buffer{}
|
||||
cmd.Stderr = stderr
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return fmt.Errorf("add nftables rule '%s' failed: %w (%s)", cmd.String(), err, stderr.String())
|
||||
}
|
||||
} else {
|
||||
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsMinutelyRate)+"/minute burst "+types.String(newConnectionsMinutelyRate+3)+" packets }" /**"add", "@deny_set", "{"+protocol+" saddr}",**/, "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsRate", "0"}))
|
||||
var stderr = &bytes.Buffer{}
|
||||
cmd.Stderr = stderr
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return fmt.Errorf("add nftables rule '%s' failed: %w (%s)", cmd.String(), err, stderr.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 超过一定速率就drop或者加入黑名单(秒)
|
||||
// TODO 让用户选择是drop还是reject
|
||||
if newConnectionsSecondlyRate > 0 {
|
||||
if newConnectionsSecondlyRateBlockTimeout > 0 {
|
||||
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-secondly-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsSecondlyRate)+"/second burst "+types.String(newConnectionsSecondlyRate+3)+" packets }", "add", "@deny_set", "{"+protocol+" saddr timeout "+types.String(newConnectionsSecondlyRateBlockTimeout)+"s}", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsSecondlyRate", types.String(newConnectionsSecondlyRate), types.String(newConnectionsSecondlyRateBlockTimeout)}))
|
||||
var stderr = &bytes.Buffer{}
|
||||
cmd.Stderr = stderr
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return fmt.Errorf("add nftables rule '%s' failed: %w (%s)", cmd.String(), err, stderr.String())
|
||||
}
|
||||
} else {
|
||||
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-secondly-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsSecondlyRate)+"/second burst "+types.String(newConnectionsSecondlyRate+3)+" packets }" /**"add", "@deny_set", "{"+protocol+" saddr}",**/, "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsSecondlyRate", "0"}))
|
||||
var stderr = &bytes.Buffer{}
|
||||
cmd.Stderr = stderr
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return fmt.Errorf("add nftables rule '%s' failed: %w (%s)", cmd.String(), err, stderr.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 删除TCP规则
|
||||
func (this *DDoSProtectionManager) removeTCPRules() error {
|
||||
for _, filter := range nftablesFilters {
|
||||
chain, rules, err := this.getRules(filter)
|
||||
|
||||
// TCP
|
||||
err = this.removeOldTCPRules(chain, rules)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 组合user data
|
||||
// 数据中不能包含字母、数字、下划线以外的数据
|
||||
func (this *DDoSProtectionManager) encodeUserData(attrs []string) string {
|
||||
if attrs == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return "ZZ" + strings.Join(attrs, "_") + "ZZ"
|
||||
}
|
||||
|
||||
// 解码user data
|
||||
func (this *DDoSProtectionManager) decodeUserData(data []byte) []string {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var dataCopy = make([]byte, len(data))
|
||||
copy(dataCopy, data)
|
||||
|
||||
var separatorLen = 2
|
||||
var index1 = bytes.Index(dataCopy, []byte{'Z', 'Z'})
|
||||
if index1 < 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
dataCopy = dataCopy[index1+separatorLen:]
|
||||
var index2 = bytes.LastIndex(dataCopy, []byte{'Z', 'Z'})
|
||||
if index2 < 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var s = string(dataCopy[:index2])
|
||||
var pieces = strings.Split(s, "_")
|
||||
for index, piece := range pieces {
|
||||
pieces[index] = strings.TrimSpace(piece)
|
||||
}
|
||||
return pieces
|
||||
}
|
||||
|
||||
// 清除规则
|
||||
func (this *DDoSProtectionManager) removeOldTCPRules(chain *nftables.Chain, rules []*nftables.Rule) error {
|
||||
for _, rule := range rules {
|
||||
var pieces = this.decodeUserData(rule.UserData())
|
||||
if len(pieces) < 4 {
|
||||
continue
|
||||
}
|
||||
if pieces[0] != "tcp" {
|
||||
continue
|
||||
}
|
||||
switch pieces[2] {
|
||||
case "maxConnections", "maxConnectionsPerIP", "newConnectionsRate", "newConnectionsSecondlyRate":
|
||||
err := chain.DeleteRule(rule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 根据参数检查规则是否存在
|
||||
func (this *DDoSProtectionManager) existsRule(rules []*nftables.Rule, attrs []string) (exists bool) {
|
||||
if len(attrs) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, oldRule := range rules {
|
||||
var pieces = this.decodeUserData(oldRule.UserData())
|
||||
if len(attrs) != len(pieces) {
|
||||
continue
|
||||
}
|
||||
var isSame = true
|
||||
for index, piece := range pieces {
|
||||
if strings.TrimSpace(piece) != attrs[index] {
|
||||
isSame = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if isSame {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// 获取规则中的端口号
|
||||
func (this *DDoSProtectionManager) getTCPPorts(rules []*nftables.Rule) []int32 {
|
||||
var ports = []int32{}
|
||||
for _, rule := range rules {
|
||||
var pieces = this.decodeUserData(rule.UserData())
|
||||
if len(pieces) != 4 {
|
||||
continue
|
||||
}
|
||||
if pieces[0] != "tcp" {
|
||||
continue
|
||||
}
|
||||
var port = types.Int32(pieces[1])
|
||||
if port > 0 && !lists.ContainsInt32(ports, port) {
|
||||
ports = append(ports, port)
|
||||
}
|
||||
}
|
||||
return ports
|
||||
}
|
||||
|
||||
// 检查端口是否一样
|
||||
func (this *DDoSProtectionManager) eqPorts(ports1 []int32, ports2 []int32) bool {
|
||||
if len(ports1) != len(ports2) {
|
||||
return false
|
||||
}
|
||||
|
||||
var portMap = map[int32]bool{}
|
||||
for _, port := range ports2 {
|
||||
portMap[port] = true
|
||||
}
|
||||
|
||||
for _, port := range ports1 {
|
||||
_, ok := portMap[port]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// 查找Table
|
||||
func (this *DDoSProtectionManager) getTable(filter *nftablesTableDefinition) (*nftables.Table, error) {
|
||||
var family nftables.TableFamily
|
||||
if filter.IsIPv4 {
|
||||
family = nftables.TableFamilyIPv4
|
||||
} else if filter.IsIPv6 {
|
||||
family = nftables.TableFamilyIPv6
|
||||
} else {
|
||||
return nil, errors.New("table '" + filter.Name + "' should be IPv4 or IPv6")
|
||||
}
|
||||
return nftablesInstance.conn.GetTable(filter.Name, family)
|
||||
}
|
||||
|
||||
// 查找所有规则
|
||||
func (this *DDoSProtectionManager) getRules(filter *nftablesTableDefinition) (*nftables.Chain, []*nftables.Rule, error) {
|
||||
table, err := this.getTable(filter)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("get table failed: %w", err)
|
||||
}
|
||||
chain, err := table.GetChain(nftablesChainName)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("get chain failed: %w", err)
|
||||
}
|
||||
rules, err := chain.GetRules()
|
||||
return chain, rules, err
|
||||
}
|
||||
|
||||
// 更新白名单
|
||||
func (this *DDoSProtectionManager) updateAllowIPList(allIPList []string) error {
|
||||
if nftablesInstance == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var allMap = map[string]zero.Zero{}
|
||||
for _, ip := range allIPList {
|
||||
allMap[ip] = zero.New()
|
||||
}
|
||||
|
||||
for _, set := range []*nftables.Set{nftablesInstance.allowIPv4Set, nftablesInstance.allowIPv6Set} {
|
||||
var isIPv4 = set == nftablesInstance.allowIPv4Set
|
||||
var isIPv6 = !isIPv4
|
||||
|
||||
// 现有的
|
||||
oldList, err := set.GetIPElements()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var oldMap = map[string]zero.Zero{} // ip=> zero
|
||||
for _, ip := range oldList {
|
||||
oldMap[ip] = zero.New()
|
||||
|
||||
if (utils.IsIPv4(ip) && isIPv4) || (utils.IsIPv6(ip) && isIPv6) {
|
||||
_, ok := allMap[ip]
|
||||
if !ok {
|
||||
// 不存在则删除
|
||||
err = set.DeleteIPElement(ip)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete ip element '%s' failed: %w", ip, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 新增的
|
||||
for _, ip := range allIPList {
|
||||
var ipObj = net.ParseIP(ip)
|
||||
if ipObj == nil {
|
||||
continue
|
||||
}
|
||||
if (utils.IsIPv4(ip) && isIPv4) || (utils.IsIPv6(ip) && isIPv6) {
|
||||
_, ok := oldMap[ip]
|
||||
if !ok {
|
||||
// 不存在则添加
|
||||
err = set.AddIPElement(ip, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add ip '%s' failed: %w", ip, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
22
EdgeDNS/internal/firewalls/ddos_protection_others.go
Normal file
22
EdgeDNS/internal/firewalls/ddos_protection_others.go
Normal file
@@ -0,0 +1,22 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build !linux
|
||||
// +build !linux
|
||||
|
||||
package firewalls
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
|
||||
)
|
||||
|
||||
var SharedDDoSProtectionManager = NewDDoSProtectionManager()
|
||||
|
||||
type DDoSProtectionManager struct {
|
||||
}
|
||||
|
||||
func NewDDoSProtectionManager() *DDoSProtectionManager {
|
||||
return &DDoSProtectionManager{}
|
||||
}
|
||||
|
||||
func (this *DDoSProtectionManager) Apply(config *ddosconfigs.ProtectionConfig) error {
|
||||
return nil
|
||||
}
|
||||
66
EdgeDNS/internal/firewalls/firewall.go
Normal file
66
EdgeDNS/internal/firewalls/firewall.go
Normal file
@@ -0,0 +1,66 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package firewalls
|
||||
|
||||
import (
|
||||
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/events"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||||
"runtime"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var currentFirewall FirewallInterface
|
||||
var firewallLocker = &sync.Mutex{}
|
||||
|
||||
// 初始化
|
||||
func init() {
|
||||
if !teaconst.IsMain {
|
||||
return
|
||||
}
|
||||
|
||||
events.On(events.EventLoaded, func() {
|
||||
var firewall = Firewall()
|
||||
if firewall.Name() != "mock" {
|
||||
remotelogs.Println("FIREWALL", "found local firewall '"+firewall.Name()+"'")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Firewall 查找当前系统中最适合的防火墙
|
||||
func Firewall() FirewallInterface {
|
||||
firewallLocker.Lock()
|
||||
defer firewallLocker.Unlock()
|
||||
if currentFirewall != nil {
|
||||
return currentFirewall
|
||||
}
|
||||
|
||||
// nftables
|
||||
if runtime.GOOS == "linux" {
|
||||
nftables, err := NewNFTablesFirewall()
|
||||
if err != nil {
|
||||
remotelogs.Warn("FIREWALL", "'nftables' should be installed on the system to enhance security (init failed: "+err.Error()+")")
|
||||
} else {
|
||||
if nftables.IsReady() {
|
||||
currentFirewall = nftables
|
||||
events.Notify(events.EventNFTablesReady)
|
||||
return nftables
|
||||
} else {
|
||||
remotelogs.Warn("FIREWALL", "'nftables' should be enabled on the system to enhance security")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// firewalld
|
||||
if runtime.GOOS == "linux" {
|
||||
var firewalld = NewFirewalld()
|
||||
if firewalld.IsReady() {
|
||||
currentFirewall = firewalld
|
||||
return currentFirewall
|
||||
}
|
||||
}
|
||||
|
||||
// 至少返回一个
|
||||
currentFirewall = NewMockFirewall()
|
||||
return currentFirewall
|
||||
}
|
||||
186
EdgeDNS/internal/firewalls/firewall_firewalld.go
Normal file
186
EdgeDNS/internal/firewalls/firewall_firewalld.go
Normal file
@@ -0,0 +1,186 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package firewalls
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||||
executils "github.com/TeaOSLab/EdgeDNS/internal/utils/exec"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Firewalld struct {
|
||||
isReady bool
|
||||
exe string
|
||||
cmdQueue chan *exec.Cmd
|
||||
}
|
||||
|
||||
func NewFirewalld() *Firewalld {
|
||||
var firewalld = &Firewalld{
|
||||
cmdQueue: make(chan *exec.Cmd, 4096),
|
||||
}
|
||||
|
||||
path, err := executils.LookPath("firewall-cmd")
|
||||
if err == nil && len(path) > 0 {
|
||||
var cmd = exec.Command(path, "--state")
|
||||
err := cmd.Run()
|
||||
if err == nil {
|
||||
firewalld.exe = path
|
||||
// TODO check firewalld status with 'firewall-cmd --state' (running or not running),
|
||||
// but we should recover the state when firewalld state changes, maybe check it every minutes
|
||||
|
||||
firewalld.isReady = true
|
||||
firewalld.init()
|
||||
}
|
||||
}
|
||||
|
||||
return firewalld
|
||||
}
|
||||
|
||||
func (this *Firewalld) init() {
|
||||
goman.New(func() {
|
||||
for cmd := range this.cmdQueue {
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
if strings.HasPrefix(err.Error(), "Warning:") {
|
||||
continue
|
||||
}
|
||||
remotelogs.Warn("FIREWALL", "run command failed '"+cmd.String()+"': "+err.Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Name 名称
|
||||
func (this *Firewalld) Name() string {
|
||||
return "firewalld"
|
||||
}
|
||||
|
||||
func (this *Firewalld) IsReady() bool {
|
||||
return this.isReady
|
||||
}
|
||||
|
||||
// IsMock 是否为模拟
|
||||
func (this *Firewalld) IsMock() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *Firewalld) AllowPort(port int, protocol string) error {
|
||||
if !this.isReady {
|
||||
return nil
|
||||
}
|
||||
var cmd = exec.Command(this.exe, "--add-port="+types.String(port)+"/"+protocol)
|
||||
this.pushCmd(cmd)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Firewalld) AllowPortRangesPermanently(portRanges [][2]int, protocol string) error {
|
||||
for _, portRange := range portRanges {
|
||||
var port = this.PortRangeString(portRange, protocol)
|
||||
|
||||
{
|
||||
var cmd = exec.Command(this.exe, "--add-port="+port, "--permanent")
|
||||
this.pushCmd(cmd)
|
||||
}
|
||||
|
||||
{
|
||||
var cmd = exec.Command(this.exe, "--add-port="+port)
|
||||
this.pushCmd(cmd)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Firewalld) RemovePort(port int, protocol string) error {
|
||||
if !this.isReady {
|
||||
return nil
|
||||
}
|
||||
var cmd = exec.Command(this.exe, "--remove-port="+types.String(port)+"/"+protocol)
|
||||
this.pushCmd(cmd)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Firewalld) RemovePortRangePermanently(portRange [2]int, protocol string) error {
|
||||
var port = this.PortRangeString(portRange, protocol)
|
||||
|
||||
{
|
||||
var cmd = exec.Command(this.exe, "--remove-port="+port, "--permanent")
|
||||
this.pushCmd(cmd)
|
||||
}
|
||||
|
||||
{
|
||||
var cmd = exec.Command(this.exe, "--remove-port="+port)
|
||||
this.pushCmd(cmd)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Firewalld) PortRangeString(portRange [2]int, protocol string) string {
|
||||
if portRange[0] == portRange[1] {
|
||||
return types.String(portRange[0]) + "/" + protocol
|
||||
} else {
|
||||
return types.String(portRange[0]) + "-" + types.String(portRange[1]) + "/" + protocol
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Firewalld) RejectSourceIP(ip string, timeoutSeconds int) error {
|
||||
if !this.isReady {
|
||||
return nil
|
||||
}
|
||||
var family = "ipv4"
|
||||
if strings.Contains(ip, ":") {
|
||||
family = "ipv6"
|
||||
}
|
||||
var args = []string{"--add-rich-rule=rule family='" + family + "' source address='" + ip + "' reject"}
|
||||
if timeoutSeconds > 0 {
|
||||
args = append(args, "--timeout="+types.String(timeoutSeconds)+"s")
|
||||
}
|
||||
var cmd = exec.Command(this.exe, args...)
|
||||
this.pushCmd(cmd)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Firewalld) DropSourceIP(ip string, timeoutSeconds int) error {
|
||||
if !this.isReady {
|
||||
return nil
|
||||
}
|
||||
var family = "ipv4"
|
||||
if strings.Contains(ip, ":") {
|
||||
family = "ipv6"
|
||||
}
|
||||
var args = []string{"--add-rich-rule=rule family='" + family + "' source address='" + ip + "' drop"}
|
||||
if timeoutSeconds > 0 {
|
||||
args = append(args, "--timeout="+types.String(timeoutSeconds)+"s")
|
||||
}
|
||||
var cmd = exec.Command(this.exe, args...)
|
||||
this.pushCmd(cmd)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Firewalld) RemoveSourceIP(ip string) error {
|
||||
if !this.isReady {
|
||||
return nil
|
||||
}
|
||||
var family = "ipv4"
|
||||
if strings.Contains(ip, ":") {
|
||||
family = "ipv6"
|
||||
}
|
||||
for _, action := range []string{"reject", "drop"} {
|
||||
var args = []string{"--remove-rich-rule=rule family='" + family + "' source address='" + ip + "' " + action}
|
||||
var cmd = exec.Command(this.exe, args...)
|
||||
this.pushCmd(cmd)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Firewalld) pushCmd(cmd *exec.Cmd) {
|
||||
select {
|
||||
case this.cmdQueue <- cmd:
|
||||
default:
|
||||
// we discard the command
|
||||
}
|
||||
}
|
||||
30
EdgeDNS/internal/firewalls/firewall_interface.go
Normal file
30
EdgeDNS/internal/firewalls/firewall_interface.go
Normal file
@@ -0,0 +1,30 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package firewalls
|
||||
|
||||
// FirewallInterface 防火墙接口
|
||||
type FirewallInterface interface {
|
||||
// Name 名称
|
||||
Name() string
|
||||
|
||||
// IsReady 是否已准备被调用
|
||||
IsReady() bool
|
||||
|
||||
// IsMock 是否为模拟
|
||||
IsMock() bool
|
||||
|
||||
// AllowPort 允许端口
|
||||
AllowPort(port int, protocol string) error
|
||||
|
||||
// RemovePort 删除端口
|
||||
RemovePort(port int, protocol string) error
|
||||
|
||||
// RejectSourceIP 拒绝某个源IP连接
|
||||
RejectSourceIP(ip string, timeoutSeconds int) error
|
||||
|
||||
// DropSourceIP 丢弃某个源IP数据
|
||||
DropSourceIP(ip string, timeoutSeconds int) error
|
||||
|
||||
// RemoveSourceIP 删除某个源IP
|
||||
RemoveSourceIP(ip string) error
|
||||
}
|
||||
60
EdgeDNS/internal/firewalls/firewall_mock.go
Normal file
60
EdgeDNS/internal/firewalls/firewall_mock.go
Normal file
@@ -0,0 +1,60 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package firewalls
|
||||
|
||||
// MockFirewall 模拟防火墙
|
||||
type MockFirewall struct {
|
||||
}
|
||||
|
||||
func NewMockFirewall() *MockFirewall {
|
||||
return &MockFirewall{}
|
||||
}
|
||||
|
||||
// Name 名称
|
||||
func (this *MockFirewall) Name() string {
|
||||
return "mock"
|
||||
}
|
||||
|
||||
// IsReady 是否已准备被调用
|
||||
func (this *MockFirewall) IsReady() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// IsMock 是否为模拟
|
||||
func (this *MockFirewall) IsMock() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// AllowPort 允许端口
|
||||
func (this *MockFirewall) AllowPort(port int, protocol string) error {
|
||||
_ = port
|
||||
_ = protocol
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemovePort 删除端口
|
||||
func (this *MockFirewall) RemovePort(port int, protocol string) error {
|
||||
_ = port
|
||||
_ = protocol
|
||||
return nil
|
||||
}
|
||||
|
||||
// RejectSourceIP 拒绝某个源IP连接
|
||||
func (this *MockFirewall) RejectSourceIP(ip string, timeoutSeconds int) error {
|
||||
_ = ip
|
||||
_ = timeoutSeconds
|
||||
return nil
|
||||
}
|
||||
|
||||
// DropSourceIP 丢弃某个源IP数据
|
||||
func (this *MockFirewall) DropSourceIP(ip string, timeoutSeconds int) error {
|
||||
_ = ip
|
||||
_ = timeoutSeconds
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveSourceIP 删除某个源IP
|
||||
func (this *MockFirewall) RemoveSourceIP(ip string) error {
|
||||
_ = ip
|
||||
return nil
|
||||
}
|
||||
417
EdgeDNS/internal/firewalls/firewall_nftables.go
Normal file
417
EdgeDNS/internal/firewalls/firewall_nftables.go
Normal file
@@ -0,0 +1,417 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
|
||||
package firewalls
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/events"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/firewalls/nftables"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||||
executils "github.com/TeaOSLab/EdgeDNS/internal/utils/exec"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// check nft status, if being enabled we load it automatically
|
||||
func init() {
|
||||
if !teaconst.IsMain {
|
||||
return
|
||||
}
|
||||
|
||||
if teaconst.IsDaemon {
|
||||
return
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" {
|
||||
var ticker = time.NewTicker(3 * time.Minute)
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
// if already ready, we break
|
||||
if nftablesIsReady {
|
||||
ticker.Stop()
|
||||
break
|
||||
}
|
||||
_, err := executils.LookPath("nft")
|
||||
if err == nil {
|
||||
nftablesFirewall, err := NewNFTablesFirewall()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
currentFirewall = nftablesFirewall
|
||||
remotelogs.Println("FIREWALL", "nftables is ready")
|
||||
|
||||
// fire event
|
||||
if nftablesFirewall.IsReady() {
|
||||
events.Notify(events.EventNFTablesReady)
|
||||
}
|
||||
|
||||
ticker.Stop()
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
var nftablesInstance *NFTablesFirewall
|
||||
var nftablesIsReady = false
|
||||
var nftablesFilters = []*nftablesTableDefinition{
|
||||
// we shorten the name for table name length restriction
|
||||
{Name: "edge_dns_v4", IsIPv4: true},
|
||||
{Name: "edge_dns_v6", IsIPv6: true},
|
||||
}
|
||||
var nftablesChainName = "input"
|
||||
|
||||
type nftablesTableDefinition struct {
|
||||
Name string
|
||||
IsIPv4 bool
|
||||
IsIPv6 bool
|
||||
}
|
||||
|
||||
func (this *nftablesTableDefinition) protocol() string {
|
||||
if this.IsIPv6 {
|
||||
return "ip6"
|
||||
}
|
||||
return "ip"
|
||||
}
|
||||
|
||||
func NewNFTablesFirewall() (*NFTablesFirewall, error) {
|
||||
var firewall = &NFTablesFirewall{
|
||||
conn: nftables.NewConn(),
|
||||
}
|
||||
err := firewall.init()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return firewall, nil
|
||||
}
|
||||
|
||||
type NFTablesFirewall struct {
|
||||
conn *nftables.Conn
|
||||
isReady bool
|
||||
version string
|
||||
|
||||
allowIPv4Set *nftables.Set
|
||||
allowIPv6Set *nftables.Set
|
||||
|
||||
denyIPv4Set *nftables.Set
|
||||
denyIPv6Set *nftables.Set
|
||||
|
||||
firewalld *Firewalld
|
||||
}
|
||||
|
||||
func (this *NFTablesFirewall) init() error {
|
||||
// check nft
|
||||
nftPath, err := executils.LookPath("nft")
|
||||
if err != nil {
|
||||
return errors.New("nft not found")
|
||||
}
|
||||
this.version = this.readVersion(nftPath)
|
||||
|
||||
// table
|
||||
for _, tableDef := range nftablesFilters {
|
||||
var family nftables.TableFamily
|
||||
if tableDef.IsIPv4 {
|
||||
family = nftables.TableFamilyIPv4
|
||||
} else if tableDef.IsIPv6 {
|
||||
family = nftables.TableFamilyIPv6
|
||||
} else {
|
||||
return errors.New("invalid table family: " + types.String(tableDef))
|
||||
}
|
||||
table, err := this.conn.GetTable(tableDef.Name, family)
|
||||
if err != nil {
|
||||
if nftables.IsNotFound(err) {
|
||||
if tableDef.IsIPv4 {
|
||||
table, err = this.conn.AddIPv4Table(tableDef.Name)
|
||||
} else if tableDef.IsIPv6 {
|
||||
table, err = this.conn.AddIPv6Table(tableDef.Name)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("create table '%s' failed: %w", tableDef.Name, err)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("get table '%s' failed: %w", tableDef.Name, err)
|
||||
}
|
||||
}
|
||||
if table == nil {
|
||||
return errors.New("can not create table '" + tableDef.Name + "'")
|
||||
}
|
||||
|
||||
// chain
|
||||
var chainName = nftablesChainName
|
||||
chain, err := table.GetChain(chainName)
|
||||
if err != nil {
|
||||
if nftables.IsNotFound(err) {
|
||||
chain, err = table.AddAcceptChain(chainName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create chain '%s' failed: %w", chainName, err)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("get chain '%s' failed: %w", chainName, err)
|
||||
}
|
||||
}
|
||||
if chain == nil {
|
||||
return errors.New("can not create chain '" + chainName + "'")
|
||||
}
|
||||
|
||||
// allow lo
|
||||
var loRuleName = []byte("lo")
|
||||
_, err = chain.GetRuleWithUserData(loRuleName)
|
||||
if err != nil {
|
||||
if nftables.IsNotFound(err) {
|
||||
_, err = chain.AddAcceptInterfaceRule("lo", loRuleName)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("add 'lo' rule failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// allow set
|
||||
// "allow" should be always first
|
||||
for _, setAction := range []string{"allow", "deny"} {
|
||||
var setName = setAction + "_set"
|
||||
|
||||
set, err := table.GetSet(setName)
|
||||
if err != nil {
|
||||
if nftables.IsNotFound(err) {
|
||||
var keyType nftables.SetDataType
|
||||
if tableDef.IsIPv4 {
|
||||
keyType = nftables.TypeIPAddr
|
||||
} else if tableDef.IsIPv6 {
|
||||
keyType = nftables.TypeIP6Addr
|
||||
}
|
||||
set, err = table.AddSet(setName, &nftables.SetOptions{
|
||||
KeyType: keyType,
|
||||
HasTimeout: true,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("create set '%s' failed: %w", setName, err)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("get set '%s' failed: %w", setName, err)
|
||||
}
|
||||
}
|
||||
if set == nil {
|
||||
return errors.New("can not create set '" + setName + "'")
|
||||
}
|
||||
if tableDef.IsIPv4 {
|
||||
if setAction == "allow" {
|
||||
this.allowIPv4Set = set
|
||||
} else {
|
||||
this.denyIPv4Set = set
|
||||
}
|
||||
} else if tableDef.IsIPv6 {
|
||||
if setAction == "allow" {
|
||||
this.allowIPv6Set = set
|
||||
} else {
|
||||
this.denyIPv6Set = set
|
||||
}
|
||||
}
|
||||
|
||||
// rule
|
||||
var ruleName = []byte(setAction)
|
||||
rule, err := chain.GetRuleWithUserData(ruleName)
|
||||
|
||||
// 将以前的drop规则删掉,替换成后面的reject
|
||||
if err == nil && setAction != "allow" && rule != nil && rule.VerDict() == expr.VerdictDrop {
|
||||
deleteErr := chain.DeleteRule(rule)
|
||||
if deleteErr == nil {
|
||||
err = nftables.ErrRuleNotFound
|
||||
rule = nil
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if nftables.IsNotFound(err) {
|
||||
if tableDef.IsIPv4 {
|
||||
if setAction == "allow" {
|
||||
rule, err = chain.AddAcceptIPv4SetRule(setName, ruleName)
|
||||
} else {
|
||||
rule, err = chain.AddRejectIPv4SetRule(setName, ruleName)
|
||||
}
|
||||
} else if tableDef.IsIPv6 {
|
||||
if setAction == "allow" {
|
||||
rule, err = chain.AddAcceptIPv6SetRule(setName, ruleName)
|
||||
} else {
|
||||
rule, err = chain.AddRejectIPv6SetRule(setName, ruleName)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("add rule failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("get rule failed: %w", err)
|
||||
}
|
||||
}
|
||||
if rule == nil {
|
||||
return errors.New("can not create rule '" + string(ruleName) + "'")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this.isReady = true
|
||||
nftablesIsReady = true
|
||||
nftablesInstance = this
|
||||
|
||||
// load firewalld
|
||||
var firewalld = NewFirewalld()
|
||||
if firewalld.IsReady() {
|
||||
this.firewalld = firewalld
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Name 名称
|
||||
func (this *NFTablesFirewall) Name() string {
|
||||
return "nftables"
|
||||
}
|
||||
|
||||
// IsReady 是否已准备被调用
|
||||
func (this *NFTablesFirewall) IsReady() bool {
|
||||
return this.isReady
|
||||
}
|
||||
|
||||
// IsMock 是否为模拟
|
||||
func (this *NFTablesFirewall) IsMock() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// AllowPort 允许端口
|
||||
func (this *NFTablesFirewall) AllowPort(port int, protocol string) error {
|
||||
if this.firewalld != nil {
|
||||
return this.firewalld.AllowPort(port, protocol)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemovePort 删除端口
|
||||
func (this *NFTablesFirewall) RemovePort(port int, protocol string) error {
|
||||
if this.firewalld != nil {
|
||||
return this.firewalld.RemovePort(port, protocol)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AllowSourceIP Allow把IP加入白名单
|
||||
func (this *NFTablesFirewall) AllowSourceIP(ip string) error {
|
||||
var data = net.ParseIP(ip)
|
||||
if data == nil {
|
||||
return errors.New("invalid ip '" + ip + "'")
|
||||
}
|
||||
|
||||
if strings.Contains(ip, ":") { // ipv6
|
||||
if this.allowIPv6Set == nil {
|
||||
return errors.New("ipv6 ip set is nil")
|
||||
}
|
||||
return this.allowIPv6Set.AddElement(data.To16(), nil)
|
||||
}
|
||||
|
||||
// ipv4
|
||||
if this.allowIPv4Set == nil {
|
||||
return errors.New("ipv4 ip set is nil")
|
||||
}
|
||||
return this.allowIPv4Set.AddElement(data.To4(), nil)
|
||||
}
|
||||
|
||||
// RejectSourceIP 拒绝某个源IP连接
|
||||
// we did not create set for drop ip, so we reuse DropSourceIP() method here
|
||||
func (this *NFTablesFirewall) RejectSourceIP(ip string, timeoutSeconds int) error {
|
||||
return this.DropSourceIP(ip, timeoutSeconds)
|
||||
}
|
||||
|
||||
// DropSourceIP 丢弃某个源IP数据
|
||||
func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int) error {
|
||||
var data = net.ParseIP(ip)
|
||||
if data == nil {
|
||||
return errors.New("invalid ip '" + ip + "'")
|
||||
}
|
||||
|
||||
if strings.Contains(ip, ":") { // ipv6
|
||||
if this.denyIPv6Set == nil {
|
||||
return errors.New("ipv6 ip set is nil")
|
||||
}
|
||||
return this.denyIPv6Set.AddElement(data.To16(), &nftables.ElementOptions{
|
||||
Timeout: time.Duration(timeoutSeconds) * time.Second,
|
||||
})
|
||||
}
|
||||
|
||||
// ipv4
|
||||
if this.denyIPv4Set == nil {
|
||||
return errors.New("ipv4 ip set is nil")
|
||||
}
|
||||
return this.denyIPv4Set.AddElement(data.To4(), &nftables.ElementOptions{
|
||||
Timeout: time.Duration(timeoutSeconds) * time.Second,
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveSourceIP 删除某个源IP
|
||||
func (this *NFTablesFirewall) RemoveSourceIP(ip string) error {
|
||||
var data = net.ParseIP(ip)
|
||||
if data == nil {
|
||||
return errors.New("invalid ip '" + ip + "'")
|
||||
}
|
||||
|
||||
if strings.Contains(ip, ":") { // ipv6
|
||||
if this.denyIPv6Set != nil {
|
||||
err := this.denyIPv6Set.DeleteElement(data.To16())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if this.allowIPv6Set != nil {
|
||||
err := this.allowIPv6Set.DeleteElement(data.To16())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ipv4
|
||||
if this.denyIPv4Set != nil {
|
||||
err := this.denyIPv4Set.DeleteElement(data.To4())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if this.allowIPv4Set != nil {
|
||||
err := this.allowIPv4Set.DeleteElement(data.To4())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 读取版本号
|
||||
func (this *NFTablesFirewall) readVersion(nftPath string) string {
|
||||
var cmd = exec.Command(nftPath, "--version")
|
||||
var output = &bytes.Buffer{}
|
||||
cmd.Stdout = output
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var outputString = output.String()
|
||||
var versionMatches = regexp.MustCompile(`nftables v([\d.]+)`).FindStringSubmatch(outputString)
|
||||
if len(versionMatches) <= 1 {
|
||||
return ""
|
||||
}
|
||||
return versionMatches[1]
|
||||
}
|
||||
60
EdgeDNS/internal/firewalls/firewall_nftables_others.go
Normal file
60
EdgeDNS/internal/firewalls/firewall_nftables_others.go
Normal file
@@ -0,0 +1,60 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build !linux
|
||||
|
||||
package firewalls
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
func NewNFTablesFirewall() (*NFTablesFirewall, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
type NFTablesFirewall struct {
|
||||
}
|
||||
|
||||
// Name 名称
|
||||
func (this *NFTablesFirewall) Name() string {
|
||||
return "nftables"
|
||||
}
|
||||
|
||||
// IsReady 是否已准备被调用
|
||||
func (this *NFTablesFirewall) IsReady() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// IsMock 是否为模拟
|
||||
func (this *NFTablesFirewall) IsMock() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// AllowPort 允许端口
|
||||
func (this *NFTablesFirewall) AllowPort(port int, protocol string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemovePort 删除端口
|
||||
func (this *NFTablesFirewall) RemovePort(port int, protocol string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AllowSourceIP Allow把IP加入白名单
|
||||
func (this *NFTablesFirewall) AllowSourceIP(ip string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RejectSourceIP 拒绝某个源IP连接
|
||||
func (this *NFTablesFirewall) RejectSourceIP(ip string, timeoutSeconds int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// DropSourceIP 丢弃某个源IP数据
|
||||
func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveSourceIP 删除某个源IP
|
||||
func (this *NFTablesFirewall) RemoveSourceIP(ip string) error {
|
||||
return nil
|
||||
}
|
||||
1
EdgeDNS/internal/firewalls/nftables/.gitignore
vendored
Normal file
1
EdgeDNS/internal/firewalls/nftables/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
build_remote.sh
|
||||
369
EdgeDNS/internal/firewalls/nftables/chain.go
Normal file
369
EdgeDNS/internal/firewalls/nftables/chain.go
Normal file
@@ -0,0 +1,369 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
nft "github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
)
|
||||
|
||||
const MaxChainNameLength = 31
|
||||
|
||||
type RuleOptions struct {
|
||||
Exprs []expr.Any
|
||||
UserData []byte
|
||||
}
|
||||
|
||||
// Chain chain object in table
|
||||
type Chain struct {
|
||||
conn *Conn
|
||||
rawTable *nft.Table
|
||||
rawChain *nft.Chain
|
||||
}
|
||||
|
||||
func NewChain(conn *Conn, rawTable *nft.Table, rawChain *nft.Chain) *Chain {
|
||||
return &Chain{
|
||||
conn: conn,
|
||||
rawTable: rawTable,
|
||||
rawChain: rawChain,
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Chain) Raw() *nft.Chain {
|
||||
return this.rawChain
|
||||
}
|
||||
|
||||
func (this *Chain) Name() string {
|
||||
return this.rawChain.Name
|
||||
}
|
||||
|
||||
func (this *Chain) AddRule(options *RuleOptions) (*Rule, error) {
|
||||
var rawRule = this.conn.Raw().AddRule(&nft.Rule{
|
||||
Table: this.rawTable,
|
||||
Chain: this.rawChain,
|
||||
Exprs: options.Exprs,
|
||||
UserData: options.UserData,
|
||||
})
|
||||
err := this.conn.Commit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewRule(rawRule), nil
|
||||
}
|
||||
|
||||
func (this *Chain) AddAcceptIPv4Rule(ip []byte, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ip,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddAcceptIPv6Rule(ip []byte, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 8,
|
||||
Len: 16,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ip,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddDropIPv4Rule(ip []byte, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ip,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictDrop,
|
||||
},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddDropIPv6Rule(ip []byte, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 8,
|
||||
Len: 16,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ip,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictDrop,
|
||||
},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddRejectIPv4Rule(ip []byte, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ip,
|
||||
},
|
||||
&expr.Reject{},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddRejectIPv6Rule(ip []byte, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 8,
|
||||
Len: 16,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ip,
|
||||
},
|
||||
&expr.Reject{},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddAcceptIPv4SetRule(setName string, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: setName,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddAcceptIPv6SetRule(setName string, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 8,
|
||||
Len: 16,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: setName,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddDropIPv4SetRule(setName string, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: setName,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictDrop,
|
||||
},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddDropIPv6SetRule(setName string, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 8,
|
||||
Len: 16,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: setName,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictDrop,
|
||||
},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddRejectIPv4SetRule(setName string, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: setName,
|
||||
},
|
||||
&expr.Reject{},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddRejectIPv6SetRule(setName string, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 8,
|
||||
Len: 16,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: setName,
|
||||
},
|
||||
&expr.Reject{},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddAcceptInterfaceRule(interfaceName string, userData []byte) (*Rule, error) {
|
||||
if len(interfaceName) >= 16 {
|
||||
return nil, errors.New("invalid interface name '" + interfaceName + "'")
|
||||
}
|
||||
var ifname = make([]byte, 16)
|
||||
copy(ifname, interfaceName+"\x00")
|
||||
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) GetRuleWithUserData(userData []byte) (*Rule, error) {
|
||||
rawRules, err := this.conn.Raw().GetRule(this.rawTable, this.rawChain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, rawRule := range rawRules {
|
||||
if bytes.Compare(rawRule.UserData, userData) == 0 {
|
||||
return NewRule(rawRule), nil
|
||||
}
|
||||
}
|
||||
return nil, ErrRuleNotFound
|
||||
}
|
||||
|
||||
func (this *Chain) GetRules() ([]*Rule, error) {
|
||||
rawRules, err := this.conn.Raw().GetRule(this.rawTable, this.rawChain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var result = []*Rule{}
|
||||
for _, rawRule := range rawRules {
|
||||
result = append(result, NewRule(rawRule))
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (this *Chain) DeleteRule(rule *Rule) error {
|
||||
err := this.conn.Raw().DelRule(rule.Raw())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return this.conn.Commit()
|
||||
}
|
||||
|
||||
func (this *Chain) Flush() error {
|
||||
this.conn.Raw().FlushChain(this.rawChain)
|
||||
return this.conn.Commit()
|
||||
}
|
||||
14
EdgeDNS/internal/firewalls/nftables/chain_policy.go
Normal file
14
EdgeDNS/internal/firewalls/nftables/chain_policy.go
Normal file
@@ -0,0 +1,14 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
|
||||
package nftables
|
||||
|
||||
import nft "github.com/google/nftables"
|
||||
|
||||
type ChainPolicy = nft.ChainPolicy
|
||||
|
||||
// Possible ChainPolicy values.
|
||||
const (
|
||||
ChainPolicyDrop = nft.ChainPolicyDrop
|
||||
ChainPolicyAccept = nft.ChainPolicyAccept
|
||||
)
|
||||
129
EdgeDNS/internal/firewalls/nftables/chain_test.go
Normal file
129
EdgeDNS/internal/firewalls/nftables/chain_test.go
Normal file
@@ -0,0 +1,129 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
|
||||
package nftables_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/firewalls/nftables"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func getIPv4Chain(t *testing.T) *nftables.Chain {
|
||||
var conn = nftables.NewConn()
|
||||
table, err := conn.GetTable("test_ipv4", nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
if err == nftables.ErrTableNotFound {
|
||||
table, err = conn.AddIPv4Table("test_ipv4")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
chain, err := table.GetChain("test_chain")
|
||||
if err != nil {
|
||||
if err == nftables.ErrChainNotFound {
|
||||
chain, err = table.AddAcceptChain("test_chain")
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return chain
|
||||
}
|
||||
|
||||
func TestChain_AddAcceptIPRule(t *testing.T) {
|
||||
var chain = getIPv4Chain(t)
|
||||
_, err := chain.AddAcceptIPv4Rule(net.ParseIP("192.168.2.40").To4(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_AddDropIPRule(t *testing.T) {
|
||||
var chain = getIPv4Chain(t)
|
||||
_, err := chain.AddDropIPv4Rule(net.ParseIP("192.168.2.31").To4(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_AddAcceptSetRule(t *testing.T) {
|
||||
var chain = getIPv4Chain(t)
|
||||
_, err := chain.AddAcceptIPv4SetRule("ipv4_black_set", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_AddDropSetRule(t *testing.T) {
|
||||
var chain = getIPv4Chain(t)
|
||||
_, err := chain.AddDropIPv4SetRule("ipv4_black_set", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_AddRejectSetRule(t *testing.T) {
|
||||
var chain = getIPv4Chain(t)
|
||||
_, err := chain.AddRejectIPv4SetRule("ipv4_black_set", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_GetRuleWithUserData(t *testing.T) {
|
||||
var chain = getIPv4Chain(t)
|
||||
rule, err := chain.GetRuleWithUserData([]byte("test"))
|
||||
if err != nil {
|
||||
if err == nftables.ErrRuleNotFound {
|
||||
t.Log("rule not found")
|
||||
return
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
t.Log("rule:", rule)
|
||||
}
|
||||
|
||||
func TestChain_GetRules(t *testing.T) {
|
||||
var chain = getIPv4Chain(t)
|
||||
rules, err := chain.GetRules()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for _, rule := range rules {
|
||||
t.Log("handle:", rule.Handle(), "set name:", rule.LookupSetName(),
|
||||
"verdict:", rule.VerDict(), "user data:", string(rule.UserData()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_DeleteRule(t *testing.T) {
|
||||
var chain = getIPv4Chain(t)
|
||||
rule, err := chain.GetRuleWithUserData([]byte("test"))
|
||||
if err != nil {
|
||||
if err == nftables.ErrRuleNotFound {
|
||||
t.Log("rule not found")
|
||||
return
|
||||
}
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = chain.DeleteRule(rule)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_Flush(t *testing.T) {
|
||||
var chain = getIPv4Chain(t)
|
||||
err := chain.Flush()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
83
EdgeDNS/internal/firewalls/nftables/conn.go
Normal file
83
EdgeDNS/internal/firewalls/nftables/conn.go
Normal file
@@ -0,0 +1,83 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"errors"
|
||||
nft "github.com/google/nftables"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
)
|
||||
|
||||
const MaxTableNameLength = 27
|
||||
|
||||
type Conn struct {
|
||||
rawConn *nft.Conn
|
||||
}
|
||||
|
||||
func NewConn() *Conn {
|
||||
return &Conn{
|
||||
rawConn: &nft.Conn{},
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Conn) Raw() *nft.Conn {
|
||||
return this.rawConn
|
||||
}
|
||||
|
||||
func (this *Conn) GetTable(name string, family TableFamily) (*Table, error) {
|
||||
rawTables, err := this.rawConn.ListTables()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, rawTable := range rawTables {
|
||||
if rawTable.Name == name && rawTable.Family == family {
|
||||
return NewTable(this, rawTable), nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, ErrTableNotFound
|
||||
}
|
||||
|
||||
func (this *Conn) AddTable(name string, family TableFamily) (*Table, error) {
|
||||
if len(name) > MaxTableNameLength {
|
||||
return nil, errors.New("table name too long (max " + types.String(MaxTableNameLength) + ")")
|
||||
}
|
||||
|
||||
var rawTable = this.rawConn.AddTable(&nft.Table{
|
||||
Family: family,
|
||||
Name: name,
|
||||
})
|
||||
|
||||
err := this.Commit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewTable(this, rawTable), nil
|
||||
}
|
||||
|
||||
func (this *Conn) AddIPv4Table(name string) (*Table, error) {
|
||||
return this.AddTable(name, TableFamilyIPv4)
|
||||
}
|
||||
|
||||
func (this *Conn) AddIPv6Table(name string) (*Table, error) {
|
||||
return this.AddTable(name, TableFamilyIPv6)
|
||||
}
|
||||
|
||||
func (this *Conn) DeleteTable(name string, family TableFamily) error {
|
||||
table, err := this.GetTable(name, family)
|
||||
if err != nil {
|
||||
if err == ErrTableNotFound {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
this.rawConn.DelTable(table.Raw())
|
||||
return this.Commit()
|
||||
}
|
||||
|
||||
func (this *Conn) Commit() error {
|
||||
return this.rawConn.Flush()
|
||||
}
|
||||
77
EdgeDNS/internal/firewalls/nftables/conn_test.go
Normal file
77
EdgeDNS/internal/firewalls/nftables/conn_test.go
Normal file
@@ -0,0 +1,77 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
|
||||
package nftables_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/firewalls/nftables"
|
||||
executils "github.com/TeaOSLab/EdgeDNS/internal/utils/exec"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConn_Test(t *testing.T) {
|
||||
_, err := executils.LookPath("nft")
|
||||
if err == nil {
|
||||
t.Log("ok")
|
||||
return
|
||||
}
|
||||
t.Log(err)
|
||||
}
|
||||
|
||||
func TestConn_GetTable_NotFound(t *testing.T) {
|
||||
var conn = nftables.NewConn()
|
||||
|
||||
table, err := conn.GetTable("a", nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
if err == nftables.ErrTableNotFound {
|
||||
t.Log("table not found")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
t.Log("table:", table)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_GetTable(t *testing.T) {
|
||||
var conn = nftables.NewConn()
|
||||
|
||||
table, err := conn.GetTable("myFilter", nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
if err == nftables.ErrTableNotFound {
|
||||
t.Log("table not found")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
t.Log("table:", table)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_AddTable(t *testing.T) {
|
||||
var conn = nftables.NewConn()
|
||||
|
||||
{
|
||||
table, err := conn.AddIPv4Table("test_ipv4")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(table.Name())
|
||||
}
|
||||
{
|
||||
table, err := conn.AddIPv6Table("test_ipv6")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(table.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_DeleteTable(t *testing.T) {
|
||||
var conn = nftables.NewConn()
|
||||
err := conn.DeleteTable("test_ipv4", nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
7
EdgeDNS/internal/firewalls/nftables/element.go
Normal file
7
EdgeDNS/internal/firewalls/nftables/element.go
Normal file
@@ -0,0 +1,7 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
|
||||
package nftables
|
||||
|
||||
type Element struct {
|
||||
}
|
||||
18
EdgeDNS/internal/firewalls/nftables/errors.go
Normal file
18
EdgeDNS/internal/firewalls/nftables/errors.go
Normal file
@@ -0,0 +1,18 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
|
||||
package nftables
|
||||
|
||||
import "errors"
|
||||
|
||||
var ErrTableNotFound = errors.New("table not found")
|
||||
var ErrChainNotFound = errors.New("chain not found")
|
||||
var ErrSetNotFound = errors.New("set not found")
|
||||
var ErrRuleNotFound = errors.New("rule not found")
|
||||
|
||||
func IsNotFound(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return err == ErrTableNotFound || err == ErrChainNotFound || err == ErrSetNotFound || err == ErrRuleNotFound
|
||||
}
|
||||
19
EdgeDNS/internal/firewalls/nftables/family.go
Normal file
19
EdgeDNS/internal/firewalls/nftables/family.go
Normal file
@@ -0,0 +1,19 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
nft "github.com/google/nftables"
|
||||
)
|
||||
|
||||
type TableFamily = nft.TableFamily
|
||||
|
||||
const (
|
||||
TableFamilyINet TableFamily = nft.TableFamilyINet
|
||||
TableFamilyIPv4 TableFamily = nft.TableFamilyIPv4
|
||||
TableFamilyIPv6 TableFamily = nft.TableFamilyIPv6
|
||||
TableFamilyARP TableFamily = nft.TableFamilyARP
|
||||
TableFamilyNetdev TableFamily = nft.TableFamilyNetdev
|
||||
TableFamilyBridge TableFamily = nft.TableFamilyBridge
|
||||
)
|
||||
52
EdgeDNS/internal/firewalls/nftables/rule.go
Normal file
52
EdgeDNS/internal/firewalls/nftables/rule.go
Normal file
@@ -0,0 +1,52 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
nft "github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
)
|
||||
|
||||
type Rule struct {
|
||||
rawRule *nft.Rule
|
||||
}
|
||||
|
||||
func NewRule(rawRule *nft.Rule) *Rule {
|
||||
return &Rule{
|
||||
rawRule: rawRule,
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Rule) Raw() *nft.Rule {
|
||||
return this.rawRule
|
||||
}
|
||||
|
||||
func (this *Rule) LookupSetName() string {
|
||||
for _, e := range this.rawRule.Exprs {
|
||||
exp, ok := e.(*expr.Lookup)
|
||||
if ok {
|
||||
return exp.SetName
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (this *Rule) VerDict() expr.VerdictKind {
|
||||
for _, e := range this.rawRule.Exprs {
|
||||
exp, ok := e.(*expr.Verdict)
|
||||
if ok {
|
||||
return exp.Kind
|
||||
}
|
||||
}
|
||||
|
||||
return -100
|
||||
}
|
||||
|
||||
func (this *Rule) Handle() uint64 {
|
||||
return this.rawRule.Handle
|
||||
}
|
||||
|
||||
func (this *Rule) UserData() []byte {
|
||||
return this.rawRule.UserData
|
||||
}
|
||||
160
EdgeDNS/internal/firewalls/nftables/set.go
Normal file
160
EdgeDNS/internal/firewalls/nftables/set.go
Normal file
@@ -0,0 +1,160 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/utils"
|
||||
nft "github.com/google/nftables"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const MaxSetNameLength = 15
|
||||
|
||||
type SetOptions struct {
|
||||
Id uint32
|
||||
HasTimeout bool
|
||||
Timeout time.Duration
|
||||
KeyType SetDataType
|
||||
DataType SetDataType
|
||||
Constant bool
|
||||
Interval bool
|
||||
Anonymous bool
|
||||
IsMap bool
|
||||
}
|
||||
|
||||
type ElementOptions struct {
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
type Set struct {
|
||||
conn *Conn
|
||||
rawSet *nft.Set
|
||||
batch *SetBatch
|
||||
}
|
||||
|
||||
func NewSet(conn *Conn, rawSet *nft.Set) *Set {
|
||||
return &Set{
|
||||
conn: conn,
|
||||
rawSet: rawSet,
|
||||
batch: &SetBatch{
|
||||
conn: conn,
|
||||
rawSet: rawSet,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Set) Raw() *nft.Set {
|
||||
return this.rawSet
|
||||
}
|
||||
|
||||
func (this *Set) Name() string {
|
||||
return this.rawSet.Name
|
||||
}
|
||||
|
||||
func (this *Set) AddElement(key []byte, options *ElementOptions) error {
|
||||
var rawElement = nft.SetElement{
|
||||
Key: key,
|
||||
}
|
||||
if options != nil {
|
||||
rawElement.Timeout = options.Timeout
|
||||
}
|
||||
err := this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{
|
||||
rawElement,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = this.conn.Commit()
|
||||
if err != nil {
|
||||
// retry if exists
|
||||
if strings.Contains(err.Error(), "file exists") {
|
||||
deleteErr := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
|
||||
{
|
||||
Key: key,
|
||||
},
|
||||
})
|
||||
if deleteErr == nil {
|
||||
err = this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{
|
||||
rawElement,
|
||||
})
|
||||
if err == nil {
|
||||
err = this.conn.Commit()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (this *Set) AddIPElement(ip string, options *ElementOptions) error {
|
||||
var ipObj = net.ParseIP(ip)
|
||||
if ipObj == nil {
|
||||
return errors.New("invalid ip '" + ip + "'")
|
||||
}
|
||||
|
||||
if utils.IsIPv4(ip) {
|
||||
return this.AddElement(ipObj.To4(), options)
|
||||
} else {
|
||||
return this.AddElement(ipObj.To16(), options)
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Set) DeleteElement(key []byte) error {
|
||||
err := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
|
||||
{
|
||||
Key: key,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = this.conn.Commit()
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "no such file or directory") {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (this *Set) DeleteIPElement(ip string) error {
|
||||
var ipObj = net.ParseIP(ip)
|
||||
if ipObj == nil {
|
||||
return errors.New("invalid ip '" + ip + "'")
|
||||
}
|
||||
|
||||
if utils.IsIPv4(ip) {
|
||||
return this.DeleteElement(ipObj.To4())
|
||||
} else {
|
||||
return this.DeleteElement(ipObj.To16())
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Set) Batch() *SetBatch {
|
||||
return this.batch
|
||||
}
|
||||
|
||||
func (this *Set) GetIPElements() ([]string, error) {
|
||||
elements, err := this.conn.Raw().GetSetElements(this.rawSet)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result = []string{}
|
||||
for _, element := range elements {
|
||||
result = append(result, net.IP(element.Key).String())
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// not work current time
|
||||
/**func (this *Set) Flush() error {
|
||||
this.conn.Raw().FlushSet(this.rawSet)
|
||||
return this.conn.Commit()
|
||||
}**/
|
||||
37
EdgeDNS/internal/firewalls/nftables/set_batch.go
Normal file
37
EdgeDNS/internal/firewalls/nftables/set_batch.go
Normal file
@@ -0,0 +1,37 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
nft "github.com/google/nftables"
|
||||
)
|
||||
|
||||
type SetBatch struct {
|
||||
conn *Conn
|
||||
rawSet *nft.Set
|
||||
}
|
||||
|
||||
func (this *SetBatch) AddElement(key []byte, options *ElementOptions) error {
|
||||
var rawElement = nft.SetElement{
|
||||
Key: key,
|
||||
}
|
||||
if options != nil {
|
||||
rawElement.Timeout = options.Timeout
|
||||
}
|
||||
return this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{
|
||||
rawElement,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *SetBatch) DeleteElement(key []byte) error {
|
||||
return this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
|
||||
{
|
||||
Key: key,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (this *SetBatch) Commit() error {
|
||||
return this.conn.Commit()
|
||||
}
|
||||
58
EdgeDNS/internal/firewalls/nftables/set_data_type.go
Normal file
58
EdgeDNS/internal/firewalls/nftables/set_data_type.go
Normal file
@@ -0,0 +1,58 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
|
||||
package nftables
|
||||
|
||||
import nft "github.com/google/nftables"
|
||||
|
||||
type SetDataType = nft.SetDatatype
|
||||
|
||||
var (
|
||||
TypeInvalid = nft.TypeInvalid
|
||||
TypeVerdict = nft.TypeVerdict
|
||||
TypeNFProto = nft.TypeNFProto
|
||||
TypeBitmask = nft.TypeBitmask
|
||||
TypeInteger = nft.TypeInteger
|
||||
TypeString = nft.TypeString
|
||||
TypeLLAddr = nft.TypeLLAddr
|
||||
TypeIPAddr = nft.TypeIPAddr
|
||||
TypeIP6Addr = nft.TypeIP6Addr
|
||||
TypeEtherAddr = nft.TypeEtherAddr
|
||||
TypeEtherType = nft.TypeEtherType
|
||||
TypeARPOp = nft.TypeARPOp
|
||||
TypeInetProto = nft.TypeInetProto
|
||||
TypeInetService = nft.TypeInetService
|
||||
TypeICMPType = nft.TypeICMPType
|
||||
TypeTCPFlag = nft.TypeTCPFlag
|
||||
TypeDCCPPktType = nft.TypeDCCPPktType
|
||||
TypeMHType = nft.TypeMHType
|
||||
TypeTime = nft.TypeTime
|
||||
TypeMark = nft.TypeMark
|
||||
TypeIFIndex = nft.TypeIFIndex
|
||||
TypeARPHRD = nft.TypeARPHRD
|
||||
TypeRealm = nft.TypeRealm
|
||||
TypeClassID = nft.TypeClassID
|
||||
TypeUID = nft.TypeUID
|
||||
TypeGID = nft.TypeGID
|
||||
TypeCTState = nft.TypeCTState
|
||||
TypeCTDir = nft.TypeCTDir
|
||||
TypeCTStatus = nft.TypeCTStatus
|
||||
TypeICMP6Type = nft.TypeICMP6Type
|
||||
TypeCTLabel = nft.TypeCTLabel
|
||||
TypePktType = nft.TypePktType
|
||||
TypeICMPCode = nft.TypeICMPCode
|
||||
TypeICMPV6Code = nft.TypeICMPV6Code
|
||||
TypeICMPXCode = nft.TypeICMPXCode
|
||||
TypeDevGroup = nft.TypeDevGroup
|
||||
TypeDSCP = nft.TypeDSCP
|
||||
TypeECN = nft.TypeECN
|
||||
TypeFIBAddr = nft.TypeFIBAddr
|
||||
TypeBoolean = nft.TypeBoolean
|
||||
TypeCTEventBit = nft.TypeCTEventBit
|
||||
TypeIFName = nft.TypeIFName
|
||||
TypeIGMPType = nft.TypeIGMPType
|
||||
TypeTimeDate = nft.TypeTimeDate
|
||||
TypeTimeHour = nft.TypeTimeHour
|
||||
TypeTimeDay = nft.TypeTimeDay
|
||||
TypeCGroupV2 = nft.TypeCGroupV2
|
||||
)
|
||||
111
EdgeDNS/internal/firewalls/nftables/set_test.go
Normal file
111
EdgeDNS/internal/firewalls/nftables/set_test.go
Normal file
@@ -0,0 +1,111 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
|
||||
package nftables_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/firewalls/nftables"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"github.com/mdlayher/netlink"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func getIPv4Set(t *testing.T) *nftables.Set {
|
||||
var table = getIPv4Table(t)
|
||||
set, err := table.GetSet("test_ipv4_set")
|
||||
if err != nil {
|
||||
if err == nftables.ErrSetNotFound {
|
||||
set, err = table.AddSet("test_ipv4_set", &nftables.SetOptions{
|
||||
KeyType: nftables.TypeIPAddr,
|
||||
HasTimeout: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
return set
|
||||
}
|
||||
|
||||
func TestSet_AddElement(t *testing.T) {
|
||||
var set = getIPv4Set(t)
|
||||
err := set.AddElement(net.ParseIP("192.168.2.31").To4(), &nftables.ElementOptions{Timeout: 86400 * time.Second})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestSet_DeleteElement(t *testing.T) {
|
||||
var set = getIPv4Set(t)
|
||||
err := set.DeleteElement(net.ParseIP("192.168.2.31").To4())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestSet_Batch(t *testing.T) {
|
||||
var batch = getIPv4Set(t).Batch()
|
||||
|
||||
for _, ip := range []string{"192.168.2.30", "192.168.2.31", "192.168.2.32", "192.168.2.33", "192.168.2.34"} {
|
||||
var ipData = net.ParseIP(ip).To4()
|
||||
//err := batch.DeleteElement(ipData)
|
||||
//if err != nil {
|
||||
// t.Fatal(err)
|
||||
//}
|
||||
err := batch.AddElement(ipData, &nftables.ElementOptions{Timeout: 10 * time.Second})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
err := batch.Commit()
|
||||
if err != nil {
|
||||
t.Logf("%#v", errors.Unwrap(err).(*netlink.OpError))
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestSet_Add_Many(t *testing.T) {
|
||||
var set = getIPv4Set(t)
|
||||
|
||||
for i := 0; i < 255; i++ {
|
||||
t.Log(i)
|
||||
for j := 0; j < 255; j++ {
|
||||
var ip = "192.167." + types.String(i) + "." + types.String(j)
|
||||
var ipData = net.ParseIP(ip).To4()
|
||||
err := set.Batch().AddElement(ipData, &nftables.ElementOptions{Timeout: 3600 * time.Second})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if j%10 == 0 {
|
||||
err = set.Batch().Commit()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
err := set.Batch().Commit()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
/**func TestSet_Flush(t *testing.T) {
|
||||
var set = getIPv4Set(t)
|
||||
err := set.Flush()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}**/
|
||||
156
EdgeDNS/internal/firewalls/nftables/table.go
Normal file
156
EdgeDNS/internal/firewalls/nftables/table.go
Normal file
@@ -0,0 +1,156 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"errors"
|
||||
nft "github.com/google/nftables"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Table struct {
|
||||
conn *Conn
|
||||
rawTable *nft.Table
|
||||
}
|
||||
|
||||
func NewTable(conn *Conn, rawTable *nft.Table) *Table {
|
||||
return &Table{
|
||||
conn: conn,
|
||||
rawTable: rawTable,
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Table) Raw() *nft.Table {
|
||||
return this.rawTable
|
||||
}
|
||||
|
||||
func (this *Table) Name() string {
|
||||
return this.rawTable.Name
|
||||
}
|
||||
|
||||
func (this *Table) Family() TableFamily {
|
||||
return this.rawTable.Family
|
||||
}
|
||||
|
||||
func (this *Table) GetChain(name string) (*Chain, error) {
|
||||
rawChains, err := this.conn.Raw().ListChains()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, rawChain := range rawChains {
|
||||
// must compare table name
|
||||
if rawChain.Name == name && rawChain.Table.Name == this.rawTable.Name {
|
||||
return NewChain(this.conn, this.rawTable, rawChain), nil
|
||||
}
|
||||
}
|
||||
return nil, ErrChainNotFound
|
||||
}
|
||||
|
||||
func (this *Table) AddChain(name string, chainPolicy *ChainPolicy) (*Chain, error) {
|
||||
if len(name) > MaxChainNameLength {
|
||||
return nil, errors.New("chain name too long (max " + types.String(MaxChainNameLength) + ")")
|
||||
}
|
||||
|
||||
var rawChain = this.conn.Raw().AddChain(&nft.Chain{
|
||||
Name: name,
|
||||
Table: this.rawTable,
|
||||
Hooknum: nft.ChainHookInput,
|
||||
Priority: nft.ChainPriorityFilter,
|
||||
Type: nft.ChainTypeFilter,
|
||||
Policy: chainPolicy,
|
||||
})
|
||||
|
||||
err := this.conn.Commit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewChain(this.conn, this.rawTable, rawChain), nil
|
||||
}
|
||||
|
||||
func (this *Table) AddAcceptChain(name string) (*Chain, error) {
|
||||
var policy = ChainPolicyAccept
|
||||
return this.AddChain(name, &policy)
|
||||
}
|
||||
|
||||
func (this *Table) AddDropChain(name string) (*Chain, error) {
|
||||
var policy = ChainPolicyDrop
|
||||
return this.AddChain(name, &policy)
|
||||
}
|
||||
|
||||
func (this *Table) DeleteChain(name string) error {
|
||||
chain, err := this.GetChain(name)
|
||||
if err != nil {
|
||||
if err == ErrChainNotFound {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
this.conn.Raw().DelChain(chain.Raw())
|
||||
return this.conn.Commit()
|
||||
}
|
||||
|
||||
func (this *Table) GetSet(name string) (*Set, error) {
|
||||
rawSet, err := this.conn.Raw().GetSetByName(this.rawTable, name)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "no such file or directory") {
|
||||
return nil, ErrSetNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewSet(this.conn, rawSet), nil
|
||||
}
|
||||
|
||||
func (this *Table) AddSet(name string, options *SetOptions) (*Set, error) {
|
||||
if len(name) > MaxSetNameLength {
|
||||
return nil, errors.New("set name too long (max " + types.String(MaxSetNameLength) + ")")
|
||||
}
|
||||
|
||||
if options == nil {
|
||||
options = &SetOptions{}
|
||||
}
|
||||
var rawSet = &nft.Set{
|
||||
Table: this.rawTable,
|
||||
ID: options.Id,
|
||||
Name: name,
|
||||
Anonymous: options.Anonymous,
|
||||
Constant: options.Constant,
|
||||
Interval: options.Interval,
|
||||
IsMap: options.IsMap,
|
||||
HasTimeout: options.HasTimeout,
|
||||
Timeout: options.Timeout,
|
||||
KeyType: options.KeyType,
|
||||
DataType: options.DataType,
|
||||
}
|
||||
err := this.conn.Raw().AddSet(rawSet, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = this.conn.Commit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewSet(this.conn, rawSet), nil
|
||||
}
|
||||
|
||||
func (this *Table) DeleteSet(name string) error {
|
||||
set, err := this.GetSet(name)
|
||||
if err != nil {
|
||||
if err == ErrSetNotFound {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
this.conn.Raw().DelSet(set.Raw())
|
||||
return this.conn.Commit()
|
||||
}
|
||||
|
||||
func (this *Table) Flush() error {
|
||||
this.conn.Raw().FlushTable(this.rawTable)
|
||||
return this.conn.Commit()
|
||||
}
|
||||
139
EdgeDNS/internal/firewalls/nftables/table_test.go
Normal file
139
EdgeDNS/internal/firewalls/nftables/table_test.go
Normal file
@@ -0,0 +1,139 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
|
||||
package nftables_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/firewalls/nftables"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func getIPv4Table(t *testing.T) *nftables.Table {
|
||||
var conn = nftables.NewConn()
|
||||
table, err := conn.GetTable("test_ipv4", nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
if err == nftables.ErrTableNotFound {
|
||||
table, err = conn.AddIPv4Table("test_ipv4")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
return table
|
||||
}
|
||||
|
||||
func TestTable_AddChain(t *testing.T) {
|
||||
var table = getIPv4Table(t)
|
||||
|
||||
{
|
||||
chain, err := table.AddChain("test_default_chain", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("created:", chain.Name())
|
||||
}
|
||||
|
||||
{
|
||||
chain, err := table.AddAcceptChain("test_accept_chain")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("created:", chain.Name())
|
||||
}
|
||||
|
||||
// Do not test drop chain before adding accept rule, you will drop yourself!!!!!!!
|
||||
/**{
|
||||
chain, err := table.AddDropChain("test_drop_chain")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("created:", chain.Name())
|
||||
}**/
|
||||
}
|
||||
|
||||
func TestTable_GetChain(t *testing.T) {
|
||||
var table = getIPv4Table(t)
|
||||
for _, chainName := range []string{"not_found_chain", "test_default_chain"} {
|
||||
chain, err := table.GetChain(chainName)
|
||||
if err != nil {
|
||||
if err == nftables.ErrChainNotFound {
|
||||
t.Log(chainName, ":", "not found")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
t.Log(chainName, ":", chain)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTable_DeleteChain(t *testing.T) {
|
||||
var table = getIPv4Table(t)
|
||||
err := table.DeleteChain("test_default_chain")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestTable_AddSet(t *testing.T) {
|
||||
var table = getIPv4Table(t)
|
||||
{
|
||||
set, err := table.AddSet("ipv4_black_set", &nftables.SetOptions{
|
||||
HasTimeout: false,
|
||||
KeyType: nftables.TypeIPAddr,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(set.Name())
|
||||
}
|
||||
|
||||
{
|
||||
set, err := table.AddSet("ipv6_black_set", &nftables.SetOptions{
|
||||
HasTimeout: true,
|
||||
//Timeout: 3600 * time.Second,
|
||||
KeyType: nftables.TypeIP6Addr,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(set.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTable_GetSet(t *testing.T) {
|
||||
var table = getIPv4Table(t)
|
||||
for _, setName := range []string{"not_found_set", "ipv4_black_set"} {
|
||||
set, err := table.GetSet(setName)
|
||||
if err != nil {
|
||||
if err == nftables.ErrSetNotFound {
|
||||
t.Log(setName, ": not found")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
t.Log(setName, ":", set)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTable_DeleteSet(t *testing.T) {
|
||||
var table = getIPv4Table(t)
|
||||
err := table.DeleteSet("ipv4_black_set")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestTable_Flush(t *testing.T) {
|
||||
var table = getIPv4Table(t)
|
||||
err := table.Flush()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
12
EdgeDNS/internal/goman/instance.go
Normal file
12
EdgeDNS/internal/goman/instance.go
Normal file
@@ -0,0 +1,12 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package goman
|
||||
|
||||
import "time"
|
||||
|
||||
type Instance struct {
|
||||
Id uint64
|
||||
CreatedTime time.Time
|
||||
File string
|
||||
Line int
|
||||
}
|
||||
81
EdgeDNS/internal/goman/lib.go
Normal file
81
EdgeDNS/internal/goman/lib.go
Normal file
@@ -0,0 +1,81 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package goman
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var locker = &sync.Mutex{}
|
||||
var instanceMap = map[uint64]*Instance{} // id => *Instance
|
||||
var instanceId = uint64(0)
|
||||
|
||||
// New 新创建goroutine
|
||||
func New(f func()) {
|
||||
_, file, line, _ := runtime.Caller(1)
|
||||
|
||||
go func() {
|
||||
locker.Lock()
|
||||
instanceId++
|
||||
|
||||
var instance = &Instance{
|
||||
Id: instanceId,
|
||||
CreatedTime: time.Now(),
|
||||
}
|
||||
|
||||
instance.File = file
|
||||
instance.Line = line
|
||||
|
||||
instanceMap[instanceId] = instance
|
||||
locker.Unlock()
|
||||
|
||||
// run function
|
||||
f()
|
||||
|
||||
locker.Lock()
|
||||
delete(instanceMap, instanceId)
|
||||
locker.Unlock()
|
||||
}()
|
||||
}
|
||||
|
||||
// NewWithArgs 创建带有参数的goroutine
|
||||
func NewWithArgs(f func(args ...interface{}), args ...interface{}) {
|
||||
_, file, line, _ := runtime.Caller(1)
|
||||
|
||||
go func() {
|
||||
locker.Lock()
|
||||
instanceId++
|
||||
|
||||
var instance = &Instance{
|
||||
Id: instanceId,
|
||||
CreatedTime: time.Now(),
|
||||
}
|
||||
|
||||
instance.File = file
|
||||
instance.Line = line
|
||||
|
||||
instanceMap[instanceId] = instance
|
||||
locker.Unlock()
|
||||
|
||||
// run function
|
||||
f(args...)
|
||||
|
||||
locker.Lock()
|
||||
delete(instanceMap, instanceId)
|
||||
locker.Unlock()
|
||||
}()
|
||||
}
|
||||
|
||||
// List 列出所有正在运行goroutine
|
||||
func List() []*Instance {
|
||||
locker.Lock()
|
||||
defer locker.Unlock()
|
||||
|
||||
var result = []*Instance{}
|
||||
for _, instance := range instanceMap {
|
||||
result = append(result, instance)
|
||||
}
|
||||
return result
|
||||
}
|
||||
28
EdgeDNS/internal/goman/lib_test.go
Normal file
28
EdgeDNS/internal/goman/lib_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package goman
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
New(func() {
|
||||
t.Log("Hello")
|
||||
|
||||
t.Log(List())
|
||||
})
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
t.Log(List())
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
func TestNewWithArgs(t *testing.T) {
|
||||
NewWithArgs(func(args ...interface{}) {
|
||||
t.Log(args[0], args[1])
|
||||
}, 1, 2)
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
9
EdgeDNS/internal/models/agent_ip.go
Normal file
9
EdgeDNS/internal/models/agent_ip.go
Normal file
@@ -0,0 +1,9 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package models
|
||||
|
||||
type AgentIP struct {
|
||||
Id int64
|
||||
IP string
|
||||
AgentCode string
|
||||
}
|
||||
15
EdgeDNS/internal/models/ns_domain.go
Normal file
15
EdgeDNS/internal/models/ns_domain.go
Normal file
@@ -0,0 +1,15 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build plus
|
||||
|
||||
package models
|
||||
|
||||
import "github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
|
||||
|
||||
type NSDomain struct {
|
||||
Id int64
|
||||
ClusterId int64
|
||||
UserId int64
|
||||
Name string
|
||||
TSIG *dnsconfigs.NSTSIGConfig
|
||||
Version int64
|
||||
}
|
||||
13
EdgeDNS/internal/models/ns_key.go
Normal file
13
EdgeDNS/internal/models/ns_key.go
Normal file
@@ -0,0 +1,13 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package models
|
||||
|
||||
type NSKey struct {
|
||||
Id int64
|
||||
DomainId int64
|
||||
ZoneId int64
|
||||
Algo string
|
||||
Secret string
|
||||
SecretType string
|
||||
Version int64
|
||||
}
|
||||
27
EdgeDNS/internal/models/ns_keys.go
Normal file
27
EdgeDNS/internal/models/ns_keys.go
Normal file
@@ -0,0 +1,27 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package models
|
||||
|
||||
type NSKeys struct {
|
||||
m map[int64]*NSKey // keyId => *NSKey
|
||||
}
|
||||
|
||||
func NewNSKeys() *NSKeys {
|
||||
return &NSKeys{m: map[int64]*NSKey{}}
|
||||
}
|
||||
|
||||
func (this *NSKeys) Add(key *NSKey) {
|
||||
this.m[key.Id] = key
|
||||
}
|
||||
|
||||
func (this *NSKeys) Remove(keyId int64) {
|
||||
delete(this.m, keyId)
|
||||
}
|
||||
|
||||
func (this *NSKeys) All() []*NSKey {
|
||||
var result = []*NSKey{}
|
||||
for _, k := range this.m {
|
||||
result = append(result, k)
|
||||
}
|
||||
return result
|
||||
}
|
||||
166
EdgeDNS/internal/models/ns_record.go
Normal file
166
EdgeDNS/internal/models/ns_record.go
Normal file
@@ -0,0 +1,166 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package models
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"github.com/miekg/dns"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type NSRecord struct {
|
||||
Id int64
|
||||
Name string
|
||||
Type dnsconfigs.RecordType
|
||||
Value string
|
||||
|
||||
MXPriority int32
|
||||
|
||||
SRVPriority int32
|
||||
SRVWeight int32
|
||||
SRVPort int32
|
||||
|
||||
CAAFlag int32
|
||||
CAATag string
|
||||
|
||||
Ttl int32
|
||||
Weight int32
|
||||
Version int64
|
||||
RouteIds []string
|
||||
DomainId int64
|
||||
}
|
||||
|
||||
func (this *NSRecord) ToRRAnswer(questionName string, rrClass uint16) dns.RR {
|
||||
if this.Ttl <= 0 {
|
||||
this.Ttl = 60
|
||||
}
|
||||
|
||||
switch this.Type {
|
||||
case dnsconfigs.RecordTypeA:
|
||||
return &dns.A{
|
||||
Hdr: this.ToRRHeader(questionName, dns.TypeA, rrClass),
|
||||
A: net.ParseIP(this.Value),
|
||||
}
|
||||
case dnsconfigs.RecordTypeCNAME:
|
||||
var value = this.Value
|
||||
if !strings.HasSuffix(value, ".") {
|
||||
value += "."
|
||||
}
|
||||
return &dns.CNAME{
|
||||
Hdr: this.ToRRHeader(questionName, dns.TypeCNAME, rrClass),
|
||||
Target: value,
|
||||
}
|
||||
case dnsconfigs.RecordTypeAAAA:
|
||||
return &dns.AAAA{
|
||||
Hdr: this.ToRRHeader(questionName, dns.TypeAAAA, rrClass),
|
||||
AAAA: net.ParseIP(this.Value),
|
||||
}
|
||||
case dnsconfigs.RecordTypeNS:
|
||||
var value = this.Value
|
||||
if !strings.HasSuffix(value, ".") {
|
||||
value += "."
|
||||
}
|
||||
return &dns.NS{
|
||||
Hdr: this.ToRRHeader(questionName, dns.TypeNS, rrClass),
|
||||
Ns: value,
|
||||
}
|
||||
case dnsconfigs.RecordTypeMX:
|
||||
var value = this.Value
|
||||
if !strings.HasSuffix(value, ".") {
|
||||
value += "."
|
||||
}
|
||||
|
||||
var preference uint16 = 0
|
||||
var priority = this.MXPriority
|
||||
if priority >= 0 {
|
||||
if priority > 65535 {
|
||||
priority = 65535
|
||||
}
|
||||
preference = types.Uint16(priority)
|
||||
}
|
||||
|
||||
return &dns.MX{
|
||||
Hdr: this.ToRRHeader(questionName, dns.TypeMX, rrClass),
|
||||
Preference: preference,
|
||||
Mx: value,
|
||||
}
|
||||
case dnsconfigs.RecordTypeSRV:
|
||||
var priority uint16 = 10
|
||||
if this.SRVPriority > 0 {
|
||||
priority = uint16(this.SRVPriority)
|
||||
}
|
||||
|
||||
var weight uint16 = 10
|
||||
if this.SRVWeight > 0 {
|
||||
weight = uint16(this.SRVWeight)
|
||||
}
|
||||
|
||||
var port uint16 = 0
|
||||
if this.SRVPort > 0 {
|
||||
port = uint16(this.SRVPort)
|
||||
}
|
||||
|
||||
var value = this.Value
|
||||
if !strings.HasSuffix(value, ".") {
|
||||
value += "."
|
||||
}
|
||||
|
||||
return &dns.SRV{
|
||||
Hdr: this.ToRRHeader(questionName, dns.TypeSRV, rrClass),
|
||||
Priority: priority,
|
||||
Weight: weight,
|
||||
Port: port,
|
||||
Target: value,
|
||||
}
|
||||
case dnsconfigs.RecordTypeTXT:
|
||||
var values []string
|
||||
var runes = []rune(this.Value)
|
||||
const maxChars = 255
|
||||
for {
|
||||
if len(runes) <= maxChars {
|
||||
values = append(values, string(runes))
|
||||
break
|
||||
}
|
||||
values = append(values, string(runes[:maxChars]))
|
||||
runes = runes[maxChars:]
|
||||
if len(runes) == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return &dns.TXT{
|
||||
Hdr: this.ToRRHeader(questionName, dns.TypeTXT, rrClass),
|
||||
Txt: values, // TODO 可以添加多个
|
||||
}
|
||||
case dnsconfigs.RecordTypeCAA:
|
||||
var flag uint8 = 0
|
||||
if this.CAAFlag >= 0 && this.CAAFlag <= 128 {
|
||||
flag = uint8(this.CAAFlag)
|
||||
}
|
||||
|
||||
var tag = this.CAATag
|
||||
if tag != "issue" && tag != "issuewild" && tag != "iodef" {
|
||||
tag = "issue"
|
||||
}
|
||||
|
||||
return &dns.CAA{
|
||||
Hdr: this.ToRRHeader(questionName, dns.TypeCAA, rrClass),
|
||||
Flag: flag, // 0-128
|
||||
Tag: tag, // issue|issuewild|iodef
|
||||
Value: this.Value,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *NSRecord) ToRRHeader(questionName string, rrType uint16, rrClass uint16) dns.RR_Header {
|
||||
return dns.RR_Header{
|
||||
Name: questionName,
|
||||
Rrtype: rrType,
|
||||
Class: rrClass,
|
||||
Ttl: uint32(this.Ttl),
|
||||
}
|
||||
}
|
||||
45
EdgeDNS/internal/models/ns_route.go
Normal file
45
EdgeDNS/internal/models/ns_route.go
Normal file
@@ -0,0 +1,45 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build plus
|
||||
|
||||
package models
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
|
||||
"net"
|
||||
)
|
||||
|
||||
type NSRoute struct {
|
||||
Id int64
|
||||
Ranges []dnsconfigs.NSRouteRangeInterface
|
||||
Priority int32
|
||||
Order int32
|
||||
UserId int64
|
||||
Version int64
|
||||
}
|
||||
|
||||
func (this *NSRoute) Contains(ip net.IP) bool {
|
||||
if len(ip) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 先执行IsReverse
|
||||
for _, r := range this.Ranges {
|
||||
if r.IsExcluding() && r.Contains(ip) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 再执行正常的
|
||||
for _, r := range this.Ranges {
|
||||
if !r.IsExcluding() && r.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// RealCode 代号
|
||||
// TODO 支持自定义代号
|
||||
func (this *NSRoute) RealCode() string {
|
||||
return RouteIdString(this.Id)
|
||||
}
|
||||
58
EdgeDNS/internal/models/ranges.go
Normal file
58
EdgeDNS/internal/models/ranges.go
Normal file
@@ -0,0 +1,58 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build plus
|
||||
|
||||
package models
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
)
|
||||
|
||||
// InitRangesFromJSON 从JSON中初始化线路范围
|
||||
func InitRangesFromJSON(rangesJSON []byte) (ranges []dnsconfigs.NSRouteRangeInterface, err error) {
|
||||
if len(rangesJSON) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var rangeMaps = []maps.Map{}
|
||||
err = json.Unmarshal(rangesJSON, &rangeMaps)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, rangeMap := range rangeMaps {
|
||||
var rangeType = rangeMap.GetString("type")
|
||||
paramsJSON, err := json.Marshal(rangeMap.Get("params"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var r dnsconfigs.NSRouteRangeInterface
|
||||
|
||||
switch rangeType {
|
||||
case dnsconfigs.NSRouteRangeTypeIP:
|
||||
r = &dnsconfigs.NSRouteRangeIPRange{}
|
||||
case dnsconfigs.NSRouteRangeTypeCIDR:
|
||||
r = &dnsconfigs.NSRouteRangeCIDR{}
|
||||
case dnsconfigs.NSRouteRangeTypeRegion:
|
||||
r = &dnsconfigs.NSRouteRangeRegion{
|
||||
Connector: rangeMap.GetString("connector"),
|
||||
}
|
||||
r.SetRegionResolver(DefaultRegionResolver)
|
||||
default:
|
||||
return nil, errors.New("invalid route line type '" + rangeType + "'")
|
||||
}
|
||||
|
||||
err = json.Unmarshal(paramsJSON, r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = r.Init()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ranges = append(ranges, r)
|
||||
}
|
||||
return
|
||||
}
|
||||
172
EdgeDNS/internal/models/record_ids.go
Normal file
172
EdgeDNS/internal/models/record_ids.go
Normal file
@@ -0,0 +1,172 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
//go:build plus
|
||||
|
||||
package models
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
|
||||
"github.com/iwind/TeaGo/rands"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
func init() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
type recordIdInfo struct {
|
||||
Id int64
|
||||
Weight int32
|
||||
}
|
||||
|
||||
type RecordIds struct {
|
||||
IdList []*recordIdInfo
|
||||
IdBucket []int64
|
||||
|
||||
RoundIndex int
|
||||
|
||||
totalWeight int64
|
||||
}
|
||||
|
||||
func NewRecordIds() *RecordIds {
|
||||
return &RecordIds{}
|
||||
}
|
||||
|
||||
func (this *RecordIds) IsEmpty() bool {
|
||||
return len(this.IdList) == 0
|
||||
}
|
||||
|
||||
func (this *RecordIds) Add(newId int64, weight int32) {
|
||||
if weight <= 0 {
|
||||
weight = 10
|
||||
}
|
||||
|
||||
const maxWeight = 999999
|
||||
if weight > maxWeight {
|
||||
weight = maxWeight
|
||||
}
|
||||
|
||||
// 检查是否存在
|
||||
for _, idInfo := range this.IdList {
|
||||
if idInfo.Id == newId {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 添加
|
||||
this.IdList = append(this.IdList, &recordIdInfo{
|
||||
Id: newId,
|
||||
Weight: weight,
|
||||
})
|
||||
|
||||
// 重置数据
|
||||
this.resetData()
|
||||
}
|
||||
|
||||
func (this *RecordIds) Remove(oldId int64) {
|
||||
defer this.resetData()
|
||||
|
||||
var newIdList = []*recordIdInfo{}
|
||||
for _, idInfo := range this.IdList {
|
||||
if idInfo.Id == oldId {
|
||||
continue
|
||||
}
|
||||
newIdList = append(newIdList, idInfo)
|
||||
}
|
||||
this.IdList = newIdList
|
||||
}
|
||||
|
||||
// NextId for round-robin
|
||||
func (this *RecordIds) NextId() int64 {
|
||||
var l = len(this.IdList)
|
||||
|
||||
if l == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
if l == 1 {
|
||||
return this.IdList[0].Id
|
||||
}
|
||||
|
||||
if this.RoundIndex > l-1 {
|
||||
this.RoundIndex = 0
|
||||
}
|
||||
|
||||
var id = this.IdList[this.RoundIndex].Id
|
||||
|
||||
this.RoundIndex++
|
||||
return id
|
||||
}
|
||||
|
||||
func (this *RecordIds) RandomIds(count int) []int64 {
|
||||
if count <= 0 {
|
||||
count = dnsconfigs.NSAnswerDefaultSize
|
||||
}
|
||||
|
||||
var totalRecords = len(this.IdList)
|
||||
|
||||
if totalRecords == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if totalRecords == 1 {
|
||||
return []int64{this.IdList[0].Id} // duplicate
|
||||
}
|
||||
|
||||
if totalRecords < count {
|
||||
count = totalRecords
|
||||
}
|
||||
|
||||
var totalIds = len(this.IdBucket)
|
||||
var startIndex = rands.Int(0, totalIds-1)
|
||||
var endIndex = startIndex + count - 1
|
||||
if endIndex <= totalIds-1 {
|
||||
return this.IdBucket[startIndex : endIndex+1]
|
||||
}
|
||||
return append(this.IdBucket[startIndex:totalIds], this.IdBucket[0:endIndex-totalIds+1]...)
|
||||
}
|
||||
|
||||
func (this *RecordIds) resetData() {
|
||||
this.resetWeight()
|
||||
}
|
||||
|
||||
func (this *RecordIds) resetWeight() {
|
||||
var totalWeight int64
|
||||
|
||||
var weightMap = map[int32]bool{} // weight => bool
|
||||
var hasUniqueWeights = false
|
||||
var ids []int64
|
||||
|
||||
for _, idInfo := range this.IdList {
|
||||
totalWeight += int64(idInfo.Weight)
|
||||
|
||||
// 检查是否有不同的权重
|
||||
if len(weightMap) > 0 && !weightMap[idInfo.Weight] {
|
||||
hasUniqueWeights = true
|
||||
}
|
||||
weightMap[idInfo.Weight] = true
|
||||
|
||||
ids = append(ids, idInfo.Id)
|
||||
}
|
||||
|
||||
// 根据权重,重新组织IDs
|
||||
if hasUniqueWeights {
|
||||
var newIds = []int64{}
|
||||
for _, idInfo := range this.IdList {
|
||||
for i := int32(0); i < idInfo.Weight; i++ {
|
||||
newIds = append(newIds, idInfo.Id)
|
||||
}
|
||||
}
|
||||
ids = newIds
|
||||
}
|
||||
|
||||
var countIds = len(ids)
|
||||
if countIds > 0 {
|
||||
rand.Shuffle(countIds, func(i, j int) {
|
||||
ids[i], ids[j] = ids[j], ids[i]
|
||||
})
|
||||
}
|
||||
|
||||
this.totalWeight = totalWeight
|
||||
this.IdBucket = ids
|
||||
}
|
||||
131
EdgeDNS/internal/models/record_ids_test.go
Normal file
131
EdgeDNS/internal/models/record_ids_test.go
Normal file
@@ -0,0 +1,131 @@
|
||||
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
//go:build plus
|
||||
|
||||
package models
|
||||
|
||||
import (
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRecordIds_RandomIds_Once(t *testing.T) {
|
||||
var recordIds = NewRecordIds()
|
||||
for id := 1; id <= 10; id++ {
|
||||
recordIds.Add(int64(id), 1)
|
||||
}
|
||||
t.Log("totalWeight:", recordIds.totalWeight)
|
||||
t.Log(recordIds.RandomIds(5))
|
||||
}
|
||||
|
||||
func TestRecordIds_RandomIds_Once2(t *testing.T) {
|
||||
var recordIds = NewRecordIds()
|
||||
for id := 1; id <= 10; id++ {
|
||||
var weight int32 = 1
|
||||
if id%3 == 0 {
|
||||
weight = 3
|
||||
}
|
||||
recordIds.Add(int64(id), weight)
|
||||
}
|
||||
t.Log("totalWeight:", recordIds.totalWeight)
|
||||
t.Log(recordIds.RandomIds(5))
|
||||
}
|
||||
|
||||
func TestRecordIds_RandomIds(t *testing.T) {
|
||||
var recordIds = NewRecordIds()
|
||||
for id := 1; id <= 10; id++ {
|
||||
recordIds.Add(int64(id), 1)
|
||||
}
|
||||
t.Log("totalWeight:", recordIds.totalWeight)
|
||||
|
||||
var statMap = map[int64]int{}
|
||||
for i := 0; i < 2_000_000; i++ {
|
||||
var resultIds = recordIds.RandomIds(5)
|
||||
for _, resultId := range resultIds {
|
||||
statMap[resultId]++
|
||||
}
|
||||
}
|
||||
logs.PrintAsJSON(statMap, t)
|
||||
}
|
||||
|
||||
func TestRecordIds_RandomIds_Weight1(t *testing.T) {
|
||||
var recordIds = NewRecordIds()
|
||||
for id := 1; id <= 10; id++ {
|
||||
var weight int32 = 10
|
||||
if id%3 == 0 {
|
||||
weight = 20
|
||||
}
|
||||
recordIds.Add(int64(id), weight)
|
||||
}
|
||||
t.Log("totalWeight:", recordIds.totalWeight)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
t.Log(recordIds.RandomIds(5))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordIds_RandomIds_Weight2(t *testing.T) {
|
||||
var recordIds = NewRecordIds()
|
||||
for id := 1; id <= 10; id++ {
|
||||
var weight int32 = 10
|
||||
if id%3 == 0 {
|
||||
weight = 20
|
||||
}
|
||||
recordIds.Add(int64(id), weight)
|
||||
}
|
||||
t.Log("totalWeight:", recordIds.totalWeight)
|
||||
|
||||
var statMap = map[int64]int{}
|
||||
for i := 0; i < 2_000_000; i++ {
|
||||
var resultIds = recordIds.RandomIds(5)
|
||||
for _, resultId := range resultIds {
|
||||
statMap[resultId]++
|
||||
break
|
||||
}
|
||||
}
|
||||
logs.PrintAsJSON(statMap, t)
|
||||
}
|
||||
|
||||
func TestRecordIds_RandomIds_Weight3(t *testing.T) {
|
||||
var recordIds = NewRecordIds()
|
||||
for id := 1; id <= 5; id++ {
|
||||
var weight int32 = 10
|
||||
if id%3 == 0 {
|
||||
weight = 20
|
||||
}
|
||||
recordIds.Add(int64(id), weight)
|
||||
}
|
||||
t.Log("totalWeight:", recordIds.totalWeight)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
t.Log(recordIds.RandomIds(5))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRecordIds_RandomIds_SAME_Weight(b *testing.B) {
|
||||
var recordIds = NewRecordIds()
|
||||
for id := 1; id <= 100; id++ {
|
||||
var weight int32 = 10
|
||||
recordIds.Add(int64(id), weight)
|
||||
}
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
recordIds.RandomIds(5)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRecordIds_RandomIds(b *testing.B) {
|
||||
var recordIds = NewRecordIds()
|
||||
for id := 1; id <= 100; id++ {
|
||||
var weight int32 = 10
|
||||
if id%3 == 0 {
|
||||
weight = 20
|
||||
}
|
||||
recordIds.Add(int64(id), weight)
|
||||
}
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
recordIds.RandomIds(5)
|
||||
}
|
||||
}
|
||||
12
EdgeDNS/internal/models/record_key.go
Normal file
12
EdgeDNS/internal/models/record_key.go
Normal file
@@ -0,0 +1,12 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package models
|
||||
|
||||
import "strings"
|
||||
|
||||
type RecordKey string
|
||||
|
||||
func NewRecordKey(recordName string, recordType string) RecordKey {
|
||||
// 记录名全部使用小写
|
||||
return RecordKey(strings.ToLower(recordName) + "|" + recordType)
|
||||
}
|
||||
80
EdgeDNS/internal/models/records_domain.go
Normal file
80
EdgeDNS/internal/models/records_domain.go
Normal file
@@ -0,0 +1,80 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build plus
|
||||
|
||||
package models
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type DomainRecords struct {
|
||||
RecordsMap map[RecordKey]*RouteRecords // key => records
|
||||
Keys map[int64]RecordKey // recordId => key
|
||||
}
|
||||
|
||||
func NewDomainRecords() *DomainRecords {
|
||||
return &DomainRecords{
|
||||
RecordsMap: map[RecordKey]*RouteRecords{},
|
||||
Keys: map[int64]RecordKey{},
|
||||
}
|
||||
}
|
||||
|
||||
func (this *DomainRecords) Add(record *NSRecord) {
|
||||
var key = NewRecordKey(record.Name, record.Type)
|
||||
records, ok := this.RecordsMap[key]
|
||||
if !ok {
|
||||
records = NewRouteRecords()
|
||||
this.RecordsMap[key] = records
|
||||
}
|
||||
records.Add(record)
|
||||
|
||||
this.Keys[record.Id] = key
|
||||
}
|
||||
|
||||
func (this *DomainRecords) Find(routeCodes []string, recordName string, recordType string, config *dnsconfigs.NSAnswerConfig, strictMode bool) (record []*NSRecord, routeCode string) {
|
||||
// NAME.example.com
|
||||
var key = NewRecordKey(recordName, recordType)
|
||||
records, ok := this.RecordsMap[key]
|
||||
if ok {
|
||||
return records.Find(routeCodes, config, strictMode)
|
||||
}
|
||||
|
||||
// @.example.com
|
||||
if len(recordName) == 0 {
|
||||
records, ok = this.RecordsMap[NewRecordKey("@", recordType)]
|
||||
if ok {
|
||||
return records.Find(routeCodes, config, strictMode)
|
||||
}
|
||||
|
||||
return nil, ""
|
||||
}
|
||||
|
||||
// *.NAME.example.com
|
||||
var dotIndex = strings.Index(recordName, ".")
|
||||
var wildcardNames = []string{}
|
||||
if dotIndex > 0 {
|
||||
wildcardNames = append(wildcardNames, "*."+recordName[dotIndex+1:])
|
||||
}
|
||||
wildcardNames = append(wildcardNames, "*")
|
||||
for _, wildcardName := range wildcardNames {
|
||||
records, ok = this.RecordsMap[NewRecordKey(wildcardName, recordType)]
|
||||
if ok {
|
||||
return records.Find(routeCodes, config, strictMode)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, ""
|
||||
}
|
||||
|
||||
func (this *DomainRecords) Remove(recordId int64) {
|
||||
key, ok := this.Keys[recordId]
|
||||
if ok {
|
||||
var recordsMap = this.RecordsMap[key]
|
||||
recordsMap.Remove(recordId)
|
||||
if recordsMap.IsEmpty() {
|
||||
delete(this.RecordsMap, key)
|
||||
}
|
||||
delete(this.Keys, recordId)
|
||||
}
|
||||
}
|
||||
57
EdgeDNS/internal/models/records_domains_test.go
Normal file
57
EdgeDNS/internal/models/records_domains_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build plus
|
||||
|
||||
package models
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDomainRecords_Find(t *testing.T) {
|
||||
var records = NewDomainRecords()
|
||||
records.Add(&NSRecord{Id: 1, Name: "", Value: "1", Type: dnsconfigs.RecordTypeA})
|
||||
records.Add(&NSRecord{Id: 2, Name: "@", Value: "@", Type: dnsconfigs.RecordTypeA})
|
||||
records.Add(&NSRecord{Id: 3, Name: "*", Value: "*", Type: dnsconfigs.RecordTypeA})
|
||||
records.Add(&NSRecord{Id: 4, Name: "hello", Value: "HELLO", Type: dnsconfigs.RecordTypeA})
|
||||
records.Add(&NSRecord{Id: 5, Name: "*.world", Value: "*.world", Type: dnsconfigs.RecordTypeA})
|
||||
for _, name := range []string{"", "hello", "world", "hello.world", "hello.world2"} {
|
||||
record, routeCode := records.Find([]string{}, name, dnsconfigs.RecordTypeA, nil, false)
|
||||
t.Log(name, record, routeCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDomainRecords_Find_RouteIds(t *testing.T) {
|
||||
var records = NewDomainRecords()
|
||||
records.Add(&NSRecord{Id: 1, Name: "", Value: "1", Type: dnsconfigs.RecordTypeA, RouteIds: []string{RouteIdString(11), RouteIdString(22)}})
|
||||
records.Add(&NSRecord{Id: 2, Name: "@", Value: "@", Type: dnsconfigs.RecordTypeA})
|
||||
records.Add(&NSRecord{Id: 3, Name: "*", Value: "*", Type: dnsconfigs.RecordTypeA})
|
||||
records.Add(&NSRecord{Id: 4, Name: "hello", Value: "HELLO", Type: dnsconfigs.RecordTypeA})
|
||||
records.Add(&NSRecord{Id: 41, Name: "hello", Value: "HELLO1", Type: dnsconfigs.RecordTypeA})
|
||||
records.Add(&NSRecord{Id: 42, Name: "hello", Value: "HELLO2", Type: dnsconfigs.RecordTypeA})
|
||||
records.Add(&NSRecord{Id: 43, Name: "hello", Value: "HELLO3", Type: dnsconfigs.RecordTypeA})
|
||||
records.Add(&NSRecord{Id: 5, Name: "*.world", Value: "*.world", Type: dnsconfigs.RecordTypeA})
|
||||
for _, name := range []string{"", "hello", "world", "hello.world", "hello.world2"} {
|
||||
record, _ := records.Find([]string{RouteIdString(11), RouteIdString(22)}, name, dnsconfigs.RecordTypeA, nil, false)
|
||||
if record == nil {
|
||||
t.Fatal("'" + name + "' record should not be nil")
|
||||
}
|
||||
t.Log(name, record)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDomainRecords_Remove(t *testing.T) {
|
||||
var records = NewDomainRecords()
|
||||
records.Add(&NSRecord{Id: 1, Name: "", Value: "1", Type: dnsconfigs.RecordTypeA})
|
||||
records.Add(&NSRecord{Id: 2, Name: "@", Value: "@", Type: dnsconfigs.RecordTypeA})
|
||||
records.Add(&NSRecord{Id: 3, Name: "*", Value: "*", Type: dnsconfigs.RecordTypeA})
|
||||
records.Add(&NSRecord{Id: 4, Name: "hello", Value: "HELLO", Type: dnsconfigs.RecordTypeA})
|
||||
records.Add(&NSRecord{Id: 5, Name: "*.world", Value: "*.world", Type: dnsconfigs.RecordTypeA})
|
||||
records.Remove(1)
|
||||
records.Remove(2)
|
||||
records.Remove(3)
|
||||
records.Remove(4)
|
||||
//records.Remove(5)
|
||||
logs.PrintAsJSON(records, t)
|
||||
}
|
||||
134
EdgeDNS/internal/models/records_route.go
Normal file
134
EdgeDNS/internal/models/records_route.go
Normal file
@@ -0,0 +1,134 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build plus
|
||||
|
||||
package models
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
)
|
||||
|
||||
func RouteIdString(routeId int64) string {
|
||||
return "id:" + types.String(routeId)
|
||||
}
|
||||
|
||||
type RouteRecords struct {
|
||||
routeRecordsMap map[string]*RecordIds // routeCode => { recordId1, recordId2, ... }
|
||||
recordsMap map[int64]*NSRecord // recordId => *NSRecord
|
||||
}
|
||||
|
||||
func NewRouteRecords() *RouteRecords {
|
||||
return &RouteRecords{
|
||||
routeRecordsMap: map[string]*RecordIds{},
|
||||
recordsMap: map[int64]*NSRecord{},
|
||||
}
|
||||
}
|
||||
|
||||
func (this *RouteRecords) Add(record *NSRecord) {
|
||||
// 先删除
|
||||
this.remove(record.Id)
|
||||
|
||||
// 添加记录
|
||||
this.recordsMap[record.Id] = record
|
||||
|
||||
// 添加线路
|
||||
var routeIds = record.RouteIds
|
||||
if len(routeIds) == 0 || (len(routeIds) == 1 && routeIds[0] == "") {
|
||||
routeIds = []string{"default"}
|
||||
}
|
||||
|
||||
for _, routeId := range routeIds {
|
||||
recordIds, ok := this.routeRecordsMap[routeId]
|
||||
if !ok {
|
||||
recordIds = NewRecordIds()
|
||||
this.routeRecordsMap[routeId] = recordIds
|
||||
}
|
||||
recordIds.Add(record.Id, record.Weight)
|
||||
}
|
||||
}
|
||||
|
||||
// Find 查找与线路匹配的记录
|
||||
// strictMode 表示是否严格匹配线路
|
||||
func (this *RouteRecords) Find(routeCodes []string, config *dnsconfigs.NSAnswerConfig, strictMode bool) (records []*NSRecord, routeCode string) {
|
||||
if config == nil {
|
||||
config = dnsconfigs.DefaultNSAnswerConfig()
|
||||
}
|
||||
|
||||
var maxSize = int(config.MaxSize)
|
||||
if maxSize <= 0 {
|
||||
maxSize = dnsconfigs.NSAnswerDefaultSize
|
||||
}
|
||||
|
||||
// 查找匹配的线路
|
||||
for _, routeId := range routeCodes {
|
||||
recordIds, ok := this.routeRecordsMap[routeId]
|
||||
if ok && !recordIds.IsEmpty() {
|
||||
return this.recordsWithIds(recordIds, config.Mode, maxSize), routeId
|
||||
}
|
||||
}
|
||||
|
||||
// 查找默认线路
|
||||
recordIds, ok := this.routeRecordsMap["default"]
|
||||
if ok && !recordIds.IsEmpty() {
|
||||
return this.recordsWithIds(recordIds, config.Mode, maxSize), "default"
|
||||
}
|
||||
|
||||
// 随机一个
|
||||
if !strictMode {
|
||||
for _, record := range this.recordsMap {
|
||||
return []*NSRecord{record}, "default"
|
||||
}
|
||||
}
|
||||
|
||||
return nil, ""
|
||||
}
|
||||
|
||||
func (this *RouteRecords) Remove(recordId int64) {
|
||||
this.remove(recordId)
|
||||
}
|
||||
|
||||
func (this *RouteRecords) IsEmpty() bool {
|
||||
return len(this.recordsMap) == 0
|
||||
}
|
||||
|
||||
func (this *RouteRecords) remove(recordId int64) {
|
||||
oldRecord, ok := this.recordsMap[recordId]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
delete(this.recordsMap, recordId)
|
||||
|
||||
var oldRouteIds = oldRecord.RouteIds
|
||||
if len(oldRouteIds) == 0 || (len(oldRouteIds) == 1 && oldRouteIds[0] == "") {
|
||||
oldRouteIds = []string{"default"}
|
||||
}
|
||||
|
||||
for _, routeId := range oldRouteIds {
|
||||
recordIds, ok := this.routeRecordsMap[routeId]
|
||||
if ok {
|
||||
recordIds.Remove(recordId)
|
||||
if recordIds.IsEmpty() {
|
||||
delete(this.routeRecordsMap, routeId)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (this *RouteRecords) recordsWithIds(recordIds *RecordIds, mode dnsconfigs.NSAnswerMode, maxSize int) (records []*NSRecord) {
|
||||
// round-robin
|
||||
if mode == dnsconfigs.NSAnswerModeRoundRobin {
|
||||
var recordId = recordIds.NextId()
|
||||
if recordId > 0 {
|
||||
return []*NSRecord{this.recordsMap[recordId]}
|
||||
}
|
||||
}
|
||||
|
||||
// random
|
||||
var randomIds = recordIds.RandomIds(maxSize)
|
||||
for _, randomId := range randomIds {
|
||||
records = append(records, this.recordsMap[randomId])
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
118
EdgeDNS/internal/models/records_route_test.go
Normal file
118
EdgeDNS/internal/models/records_route_test.go
Normal file
@@ -0,0 +1,118 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build plus
|
||||
|
||||
package models
|
||||
|
||||
import (
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"github.com/iwind/TeaGo/rands"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRouteRecords_Add(t *testing.T) {
|
||||
var records = NewRouteRecords()
|
||||
{
|
||||
records.Add(&NSRecord{Id: 1, RouteIds: []string{"CN"}})
|
||||
records.Add(&NSRecord{Id: 2})
|
||||
logs.PrintAsJSON(records.routeRecordsMap, t)
|
||||
logs.PrintAsJSON(records.recordsMap, t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteRecords_Add2(t *testing.T) {
|
||||
var records = NewRouteRecords()
|
||||
{
|
||||
records.Add(&NSRecord{Id: 1, RouteIds: []string{RouteIdString(11), RouteIdString(22)}})
|
||||
records.Add(&NSRecord{Id: 2, RouteIds: []string{RouteIdString(11), RouteIdString(22)}})
|
||||
logs.PrintAsJSON(records.routeRecordsMap, t)
|
||||
logs.PrintAsJSON(records.recordsMap, t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteRecords_Add3(t *testing.T) {
|
||||
var records = NewRouteRecords()
|
||||
{
|
||||
records.Add(&NSRecord{Id: 1, RouteIds: []string{RouteIdString(11), RouteIdString(22)}})
|
||||
records.Add(&NSRecord{Id: 2, RouteIds: []string{RouteIdString(11), RouteIdString(22)}})
|
||||
records.Add(&NSRecord{Id: 2, RouteIds: []string{RouteIdString(33), RouteIdString(44)}}) // duplicated
|
||||
logs.PrintAsJSON(records.routeRecordsMap, t)
|
||||
logs.PrintAsJSON(records.recordsMap, t)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteRecords_Remove(t *testing.T) {
|
||||
var records = NewRouteRecords()
|
||||
records.Add(&NSRecord{Id: 1})
|
||||
records.Add(&NSRecord{Id: 2, RouteIds: []string{RouteIdString(11), RouteIdString(22)}})
|
||||
records.Add(&NSRecord{Id: 3, RouteIds: []string{RouteIdString(11), RouteIdString(22)}})
|
||||
records.Add(&NSRecord{Id: 4, RouteIds: []string{RouteIdString(11)}})
|
||||
t.Log("===before===")
|
||||
logs.PrintAsJSON(records.routeRecordsMap, t)
|
||||
logs.PrintAsJSON(records.recordsMap, t)
|
||||
t.Log("===after===")
|
||||
//records.Remove(1)
|
||||
records.Remove(2)
|
||||
logs.PrintAsJSON(records.routeRecordsMap, t)
|
||||
logs.PrintAsJSON(records.recordsMap, t)
|
||||
}
|
||||
|
||||
func TestRouteRecords_Find(t *testing.T) {
|
||||
var records = NewRouteRecords()
|
||||
records.Add(&NSRecord{Id: 1})
|
||||
records.Add(&NSRecord{Id: 2, RouteIds: []string{RouteIdString(11), RouteIdString(22)}})
|
||||
records.Add(&NSRecord{Id: 3, RouteIds: []string{RouteIdString(11), RouteIdString(22)}})
|
||||
records.Add(&NSRecord{Id: 4, RouteIds: []string{RouteIdString(11)}})
|
||||
for _, routeIds := range [][]string{
|
||||
{},
|
||||
{RouteIdString(11)},
|
||||
{RouteIdString(22)},
|
||||
{RouteIdString(100)},
|
||||
} {
|
||||
record, routeId := records.Find(routeIds, nil, false)
|
||||
t.Logf("routeIds: %v, record: %#v, route: %s", routeIds, record, routeId)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteRecords_Find_Balance(t *testing.T) {
|
||||
var records = NewRouteRecords()
|
||||
records.Add(&NSRecord{Id: 1, RouteIds: []string{"aa"}})
|
||||
records.Add(&NSRecord{Id: 2, RouteIds: []string{"aa", "bb", "default"}})
|
||||
records.Add(&NSRecord{Id: 3})
|
||||
records.Add(&NSRecord{Id: 4})
|
||||
|
||||
for _, route := range []string{"", "default", "aa", "bb", "cc"} {
|
||||
var m = map[int64]int{} // id => count
|
||||
for i := 0; i < 1_000_000; i++ {
|
||||
var records, _ = records.Find([]string{route}, nil, false)
|
||||
for _, record := range records {
|
||||
m[record.Id]++
|
||||
}
|
||||
}
|
||||
t.Logf("%s: %+v", route, m)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRouteRecords_Remove(b *testing.B) {
|
||||
var records = NewRouteRecords()
|
||||
for i := 0; i < 1_000_000; i++ {
|
||||
records.Add(&NSRecord{Id: int64(i), RouteIds: []string{RouteIdString(int64(i % 100))}})
|
||||
}
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
records.Remove(int64(rands.Int(0, 100)))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRouteRecords_Find(b *testing.B) {
|
||||
var records = NewRouteRecords()
|
||||
for i := 0; i < 1_000_000; i++ {
|
||||
records.Add(&NSRecord{Id: int64(i), RouteIds: []string{RouteIdString(int64(i % 100))}})
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = records.Find([]string{RouteIdString(int64(i % 200))}, nil, false)
|
||||
}
|
||||
}
|
||||
24
EdgeDNS/internal/models/region_resolver.go
Normal file
24
EdgeDNS/internal/models/region_resolver.go
Normal file
@@ -0,0 +1,24 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package models
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/iplibrary"
|
||||
"net"
|
||||
)
|
||||
|
||||
var DefaultRegionResolver = &RegionResolver{}
|
||||
|
||||
type RegionResolver struct {
|
||||
}
|
||||
|
||||
func (this *RegionResolver) Resolve(ip net.IP) (countryId int64, provinceId int64, cityId int64, providerId int64) {
|
||||
var result = iplibrary.Lookup(ip)
|
||||
if result != nil && result.IsOk() {
|
||||
countryId = result.CountryId()
|
||||
provinceId = result.ProvinceId()
|
||||
cityId = result.CityId()
|
||||
providerId = result.ProviderId()
|
||||
}
|
||||
return
|
||||
}
|
||||
10
EdgeDNS/internal/monitor/value.go
Normal file
10
EdgeDNS/internal/monitor/value.go
Normal file
@@ -0,0 +1,10 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package monitor
|
||||
|
||||
// ItemValue 数据值定义
|
||||
type ItemValue struct {
|
||||
Item string
|
||||
ValueJSON []byte
|
||||
CreatedAt int64
|
||||
}
|
||||
89
EdgeDNS/internal/monitor/value_queue.go
Normal file
89
EdgeDNS/internal/monitor/value_queue.go
Normal file
@@ -0,0 +1,89 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package monitor
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/events"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"time"
|
||||
)
|
||||
|
||||
var SharedValueQueue = NewValueQueue()
|
||||
|
||||
func init() {
|
||||
if !teaconst.IsMain {
|
||||
return
|
||||
}
|
||||
|
||||
events.On(events.EventStart, func() {
|
||||
go SharedValueQueue.Start()
|
||||
})
|
||||
}
|
||||
|
||||
// ValueQueue 数据记录队列
|
||||
type ValueQueue struct {
|
||||
valuesChan chan *ItemValue
|
||||
}
|
||||
|
||||
func NewValueQueue() *ValueQueue {
|
||||
return &ValueQueue{
|
||||
valuesChan: make(chan *ItemValue, 1024),
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动队列
|
||||
func (this *ValueQueue) Start() {
|
||||
// 这里单次循环就行,因为Loop里已经使用了Range通道
|
||||
err := this.Loop()
|
||||
if err != nil {
|
||||
remotelogs.Error("MONITOR_QUEUE", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Add 添加数据
|
||||
func (this *ValueQueue) Add(item string, value maps.Map) {
|
||||
valueJSON, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
remotelogs.Error("MONITOR_QUEUE", "marshal value error: "+err.Error())
|
||||
return
|
||||
}
|
||||
select {
|
||||
case this.valuesChan <- &ItemValue{
|
||||
Item: item,
|
||||
ValueJSON: valueJSON,
|
||||
CreatedAt: time.Now().Unix(),
|
||||
}:
|
||||
default:
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// Loop 单次循环
|
||||
func (this *ValueQueue) Loop() error {
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for value := range this.valuesChan {
|
||||
_, err = rpcClient.NodeValueRPC.CreateNodeValue(rpcClient.Context(), &pb.CreateNodeValueRequest{
|
||||
Item: value.Item,
|
||||
ValueJSON: value.ValueJSON,
|
||||
CreatedAt: value.CreatedAt,
|
||||
})
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("MONITOR", err.Error())
|
||||
} else {
|
||||
remotelogs.Error("MONITOR", err.Error())
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
256
EdgeDNS/internal/nodes/api_stream.go
Normal file
256
EdgeDNS/internal/nodes/api_stream.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/messageconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/events"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/firewalls"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/utils"
|
||||
executils "github.com/TeaOSLab/EdgeDNS/internal/utils/exec"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
type APIStream struct {
|
||||
stream pb.NSNodeService_NsNodeStreamClient
|
||||
|
||||
isQuiting bool
|
||||
cancelFunc context.CancelFunc
|
||||
}
|
||||
|
||||
func NewAPIStream() *APIStream {
|
||||
return &APIStream{}
|
||||
}
|
||||
|
||||
func (this *APIStream) Start() {
|
||||
events.On(events.EventQuit, func() {
|
||||
this.isQuiting = true
|
||||
if this.cancelFunc != nil {
|
||||
this.cancelFunc()
|
||||
}
|
||||
})
|
||||
for {
|
||||
if this.isQuiting {
|
||||
return
|
||||
}
|
||||
err := this.loop()
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("API_STREAM", err.Error())
|
||||
} else {
|
||||
remotelogs.Error("API_STREAM", err.Error())
|
||||
}
|
||||
time.Sleep(10 * time.Second)
|
||||
continue
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func (this *APIStream) loop() error {
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return errors.Wrap(err)
|
||||
}
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(rpcClient.Context())
|
||||
this.cancelFunc = cancelFunc
|
||||
|
||||
defer func() {
|
||||
cancelFunc()
|
||||
}()
|
||||
|
||||
nodeStream, err := rpcClient.NSNodeRPC.NsNodeStream(ctx)
|
||||
if err != nil {
|
||||
if this.isQuiting {
|
||||
return nil
|
||||
}
|
||||
return errors.Wrap(err)
|
||||
}
|
||||
|
||||
this.stream = nodeStream
|
||||
|
||||
for {
|
||||
if this.isQuiting {
|
||||
logs.Println("API_STREAM", "quit")
|
||||
break
|
||||
}
|
||||
|
||||
message, err := nodeStream.Recv()
|
||||
if err != nil {
|
||||
if this.isQuiting {
|
||||
remotelogs.Println("API_STREAM", "quit")
|
||||
return nil
|
||||
}
|
||||
return errors.Wrap(err)
|
||||
}
|
||||
|
||||
// 处理消息
|
||||
switch message.Code {
|
||||
case messageconfigs.NSMessageCodeConnectedAPINode: // 连接API节点成功
|
||||
err = this.handleConnectedAPINode(message)
|
||||
case messageconfigs.NSMessageCodeNewNodeTask: // 有新的任务
|
||||
err = this.handleNewNodeTask(message)
|
||||
case messageconfigs.NSMessageCodeCheckSystemdService: // 检查Systemd服务
|
||||
err = this.handleCheckSystemdService(message)
|
||||
case messageconfigs.MessageCodeCheckLocalFirewall: // 检查本地防火墙
|
||||
err = this.handleCheckLocalFirewall(message)
|
||||
default:
|
||||
err = this.handleUnknownMessage(message)
|
||||
}
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("API_STREAM", "handle message failed: "+err.Error())
|
||||
} else {
|
||||
remotelogs.Error("API_STREAM", "handle message failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 连接API节点成功
|
||||
func (this *APIStream) handleConnectedAPINode(message *pb.NSNodeStreamMessage) error {
|
||||
// 更改连接的APINode信息
|
||||
if len(message.DataJSON) == 0 {
|
||||
return nil
|
||||
}
|
||||
msg := &messageconfigs.ConnectedAPINodeMessage{}
|
||||
err := json.Unmarshal(message.DataJSON, msg)
|
||||
if err != nil {
|
||||
return errors.Wrap(err)
|
||||
}
|
||||
|
||||
_, err = rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return errors.Wrap(err)
|
||||
}
|
||||
|
||||
remotelogs.Println("API_STREAM", "connected to api node '"+strconv.FormatInt(msg.APINodeId, 10)+"'")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 处理配置变化
|
||||
func (this *APIStream) handleNewNodeTask(message *pb.NSNodeStreamMessage) error {
|
||||
select {
|
||||
case nodeTaskNotify <- true:
|
||||
default:
|
||||
|
||||
}
|
||||
this.replyOk(message.RequestId, "ok")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查Systemd服务
|
||||
func (this *APIStream) handleCheckSystemdService(message *pb.NSNodeStreamMessage) error {
|
||||
systemctl, err := executils.LookPath("systemctl")
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "'systemctl' not found")
|
||||
return nil
|
||||
}
|
||||
if len(systemctl) == 0 {
|
||||
this.replyFail(message.RequestId, "'systemctl' not found")
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd := utils.NewCommandExecutor()
|
||||
shortName := teaconst.SystemdServiceName
|
||||
cmd.Add(systemctl, "is-enabled", shortName)
|
||||
output, err := cmd.Run()
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "'systemctl' command error: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
if output == "enabled" {
|
||||
this.replyOk(message.RequestId, "ok")
|
||||
} else {
|
||||
this.replyFail(message.RequestId, "not installed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查本地防火墙
|
||||
func (this *APIStream) handleCheckLocalFirewall(message *pb.NSNodeStreamMessage) error {
|
||||
var dataMessage = &messageconfigs.CheckLocalFirewallMessage{}
|
||||
err := json.Unmarshal(message.DataJSON, dataMessage)
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "decode message data failed: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
// nft
|
||||
if dataMessage.Name == "nftables" {
|
||||
if runtime.GOOS != "linux" {
|
||||
this.replyFail(message.RequestId, "not Linux system")
|
||||
return nil
|
||||
}
|
||||
|
||||
nft, err := executils.LookPath("nft")
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "'nft' not found: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
var cmd = exec.Command(nft, "--version")
|
||||
var output = &bytes.Buffer{}
|
||||
cmd.Stdout = output
|
||||
err = cmd.Run()
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "get version failed: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
var outputString = output.String()
|
||||
var versionMatches = regexp.MustCompile(`nftables v([\d.]+)`).FindStringSubmatch(outputString)
|
||||
if len(versionMatches) <= 1 {
|
||||
this.replyFail(message.RequestId, "can not get nft version")
|
||||
return nil
|
||||
}
|
||||
var version = versionMatches[1]
|
||||
|
||||
var result = maps.Map{
|
||||
"version": version,
|
||||
}
|
||||
|
||||
var protectionConfig = sharedNodeConfig.DDoSProtection
|
||||
err = firewalls.SharedDDoSProtectionManager.Apply(protectionConfig)
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, dataMessage.Name+"was installed, but apply DDoS protection config failed: "+err.Error())
|
||||
} else {
|
||||
this.replyOk(message.RequestId, string(result.AsJSON()))
|
||||
}
|
||||
} else {
|
||||
this.replyFail(message.RequestId, "invalid firewall name '"+dataMessage.Name+"'")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 处理未知消息
|
||||
func (this *APIStream) handleUnknownMessage(message *pb.NSNodeStreamMessage) error {
|
||||
this.replyFail(message.RequestId, "unknown message code '"+message.Code+"'")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 回复失败
|
||||
func (this *APIStream) replyFail(requestId int64, message string) {
|
||||
_ = this.stream.Send(&pb.NSNodeStreamMessage{RequestId: requestId, IsOk: false, Message: message})
|
||||
}
|
||||
|
||||
// 回复成功
|
||||
func (this *APIStream) replyOk(requestId int64, message string) {
|
||||
_ = this.stream.Send(&pb.NSNodeStreamMessage{RequestId: requestId, IsOk: true, Message: message})
|
||||
}
|
||||
455
EdgeDNS/internal/nodes/dns_node.go
Normal file
455
EdgeDNS/internal/nodes/dns_node.go
Normal file
@@ -0,0 +1,455 @@
|
||||
//go:build plus
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/agents"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/configs"
|
||||
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/dbs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/events"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/firewalls"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/utils"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"github.com/iwind/gosock/pkg/gosock"
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
var DaemonIsOn = false
|
||||
var DaemonPid = 0
|
||||
var nodeTaskNotify = make(chan bool, 8)
|
||||
|
||||
var sharedDomainManager *DomainManager
|
||||
var sharedRecordManager *RecordManager
|
||||
var sharedRouteManager *RouteManager
|
||||
var sharedKeyManager *KeyManager
|
||||
var sharedNodeConfig = &dnsconfigs.NSNodeConfig{}
|
||||
|
||||
func NewDNSNode() *DNSNode {
|
||||
return &DNSNode{
|
||||
sock: gosock.NewTmpSock(teaconst.ProcessName),
|
||||
}
|
||||
}
|
||||
|
||||
type DNSNode struct {
|
||||
sock *gosock.Sock
|
||||
|
||||
RPC *rpc.RPCClient
|
||||
}
|
||||
|
||||
func (this *DNSNode) Start() {
|
||||
// 设置netdns
|
||||
// 这个需要放在所有网络访问的最前面
|
||||
_ = os.Setenv("GODEBUG", "netdns=go")
|
||||
|
||||
// 判断是否在守护进程下
|
||||
_, ok := os.LookupEnv("EdgeDaemon")
|
||||
if ok {
|
||||
remotelogs.Println("DNS_NODE", "start from daemon")
|
||||
DaemonIsOn = true
|
||||
DaemonPid = os.Getppid()
|
||||
}
|
||||
|
||||
// 设置DNS解析库
|
||||
err := os.Setenv("GODEBUG", "netdns=go")
|
||||
if err != nil {
|
||||
remotelogs.Error("DNS_NODE", "[DNS_RESOLVER]set env failed: "+err.Error())
|
||||
}
|
||||
|
||||
// 处理异常
|
||||
this.handlePanic()
|
||||
|
||||
// 监听signal
|
||||
this.listenSignals()
|
||||
|
||||
// 本地Sock
|
||||
err = this.listenSock()
|
||||
if err != nil {
|
||||
logs.Println("[DNS_NODE]" + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 启动IP库
|
||||
remotelogs.Println("DNS_NODE", "initializing ip library ...")
|
||||
err = iplibrary.InitPlus()
|
||||
if err != nil {
|
||||
remotelogs.Error("DNS_NODE", "initialize ip library failed: "+err.Error())
|
||||
}
|
||||
runtime.GC()
|
||||
debug.FreeOSMemory()
|
||||
|
||||
// 触发事件
|
||||
events.Notify(events.EventStart)
|
||||
|
||||
// 监控状态
|
||||
go NewNodeStatusExecutor().Listen()
|
||||
|
||||
// 连接API
|
||||
go NewAPIStream().Start()
|
||||
|
||||
// 启动
|
||||
go this.start()
|
||||
|
||||
// Hold住进程
|
||||
logs.Println("[DNS_NODE]started")
|
||||
select {}
|
||||
}
|
||||
|
||||
// Daemon 实现守护进程
|
||||
func (this *DNSNode) Daemon() {
|
||||
var isDebug = lists.ContainsString(os.Args, "debug")
|
||||
for {
|
||||
conn, err := this.sock.Dial()
|
||||
if err != nil {
|
||||
if isDebug {
|
||||
log.Println("[DAEMON]starting ...")
|
||||
}
|
||||
|
||||
// 尝试启动
|
||||
err = func() error {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 可以标记当前是从守护进程启动的
|
||||
_ = os.Setenv("EdgeDaemon", "on")
|
||||
_ = os.Setenv("EdgeBackground", "on")
|
||||
|
||||
var cmd = exec.Command(exe)
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = cmd.Wait()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
if isDebug {
|
||||
log.Println("[DAEMON]", err)
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
} else {
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
||||
} else {
|
||||
_ = conn.Close()
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test 测试配置
|
||||
func (this *DNSNode) Test() error {
|
||||
// 检查是否能连接API
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return fmt.Errorf("test rpc failed: %w", err)
|
||||
}
|
||||
_, err = rpcClient.APINodeRPC.FindCurrentAPINodeVersion(rpcClient.Context(), &pb.FindCurrentAPINodeVersionRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("test rpc failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// InstallSystemService 安装系统服务
|
||||
func (this *DNSNode) InstallSystemService() error {
|
||||
shortName := teaconst.SystemdServiceName
|
||||
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
manager := utils.NewServiceManager(shortName, teaconst.ProductName)
|
||||
err = manager.Install(exe, []string{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 监听一些信号
|
||||
func (this *DNSNode) listenSignals() {
|
||||
var queue = make(chan os.Signal, 8)
|
||||
signal.Notify(queue, syscall.SIGTERM, syscall.SIGINT, syscall.SIGKILL, syscall.SIGQUIT)
|
||||
goman.New(func() {
|
||||
for range queue {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
utils.Exit()
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// 监听本地sock
|
||||
func (this *DNSNode) listenSock() error {
|
||||
// 检查是否在运行
|
||||
if this.sock.IsListening() {
|
||||
reply, err := this.sock.Send(&gosock.Command{Code: "pid"})
|
||||
if err == nil {
|
||||
return errors.New("error: the process is already running, pid: " + maps.NewMap(reply.Params).GetString("pid"))
|
||||
} else {
|
||||
return errors.New("error: the process is already running")
|
||||
}
|
||||
}
|
||||
|
||||
// 启动监听
|
||||
go func() {
|
||||
this.sock.OnCommand(func(cmd *gosock.Command) {
|
||||
switch cmd.Code {
|
||||
case "pid":
|
||||
_ = cmd.Reply(&gosock.Command{
|
||||
Code: "pid",
|
||||
Params: map[string]interface{}{
|
||||
"pid": os.Getpid(),
|
||||
},
|
||||
})
|
||||
case "info":
|
||||
exePath, _ := os.Executable()
|
||||
_ = cmd.Reply(&gosock.Command{
|
||||
Code: "info",
|
||||
Params: map[string]interface{}{
|
||||
"pid": os.Getpid(),
|
||||
"version": teaconst.Version,
|
||||
"path": exePath,
|
||||
},
|
||||
})
|
||||
case "stop":
|
||||
_ = cmd.ReplyOk()
|
||||
|
||||
// 退出主进程
|
||||
events.Notify(events.EventQuit)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
os.Exit(0)
|
||||
case "gc":
|
||||
runtime.GC()
|
||||
debug.FreeOSMemory()
|
||||
_ = cmd.ReplyOk()
|
||||
}
|
||||
})
|
||||
|
||||
err := this.sock.Listen()
|
||||
if err != nil {
|
||||
logs.Println("NODE", err.Error())
|
||||
}
|
||||
}()
|
||||
|
||||
events.On(events.EventQuit, func() {
|
||||
logs.Println("[DNS_NODE]", "quit unix sock")
|
||||
_ = this.sock.Close()
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 启动
|
||||
func (this *DNSNode) start() {
|
||||
client, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
remotelogs.Error("DNS_NODE", err.Error())
|
||||
return
|
||||
}
|
||||
this.RPC = client
|
||||
tryTimes := 0
|
||||
var configJSON []byte
|
||||
for {
|
||||
resp, err := client.NSNodeRPC.FindCurrentNSNodeConfig(client.Context(), &pb.FindCurrentNSNodeConfigRequest{})
|
||||
if err != nil {
|
||||
tryTimes++
|
||||
if tryTimes%10 == 0 {
|
||||
remotelogs.Error("NODE", "read config from API failed: "+err.Error())
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
// 不做长时间的无意义的重试
|
||||
if tryTimes > 1000 {
|
||||
remotelogs.Error("NODE", "load failed: unable to read config from API")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
configJSON = resp.NsNodeJSON
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(configJSON) == 0 {
|
||||
remotelogs.Error("NODE", "can not find node config")
|
||||
return
|
||||
}
|
||||
var config = &dnsconfigs.NSNodeConfig{}
|
||||
err = json.Unmarshal(configJSON, config)
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "decode config failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
err = config.Init(context.TODO())
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "init config failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sharedNodeConfig = config
|
||||
|
||||
configs.SharedNodeConfig = config
|
||||
events.Notify(events.EventReload)
|
||||
|
||||
sharedNodeConfigManager.reload(config)
|
||||
|
||||
apiConfig, _ := configs.SharedAPIConfig()
|
||||
if apiConfig != nil {
|
||||
apiConfig.NumberId = config.Id
|
||||
}
|
||||
|
||||
var db = dbs.NewDB(Tea.Root + "/data/data-" + types.String(config.Id) + "-" + config.NodeId + "-v0.1.0.db")
|
||||
err = db.Init()
|
||||
if err != nil {
|
||||
remotelogs.Error("NODE", "init database failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
go sharedNodeConfigManager.Start()
|
||||
|
||||
sharedDomainManager = NewDomainManager(db, config.ClusterId)
|
||||
go sharedDomainManager.Start()
|
||||
|
||||
sharedRecordManager = NewRecordManager(db)
|
||||
go sharedRecordManager.Start()
|
||||
|
||||
sharedRouteManager = NewRouteManager(db)
|
||||
go sharedRouteManager.Start()
|
||||
|
||||
sharedKeyManager = NewKeyManager(db)
|
||||
go sharedKeyManager.Start()
|
||||
|
||||
agents.SharedManager = agents.NewManager(db)
|
||||
go agents.SharedManager.Start()
|
||||
|
||||
// 发送通知,这里发送通知需要在DomainManager、RecordeManager等加载完成之后
|
||||
time.Sleep(1 * time.Second)
|
||||
events.Notify(events.EventLoaded)
|
||||
|
||||
// 启动循环
|
||||
go this.loop()
|
||||
}
|
||||
|
||||
// 更新配置Loop
|
||||
func (this *DNSNode) loop() {
|
||||
var ticker = time.NewTicker(60 * time.Second)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
case <-nodeTaskNotify:
|
||||
}
|
||||
|
||||
err := this.processTasks()
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("DNS_NODE", "process tasks: "+err.Error())
|
||||
} else {
|
||||
remotelogs.Error("DNS_NODE", "process tasks: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理任务
|
||||
func (this *DNSNode) processTasks() error {
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 所有的任务
|
||||
tasksResp, err := rpcClient.NodeTaskRPC.FindNodeTasks(rpcClient.Context(), &pb.FindNodeTasksRequest{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, task := range tasksResp.NodeTasks {
|
||||
switch task.Type {
|
||||
case "nsConfigChanged":
|
||||
sharedNodeConfigManager.NotifyChange()
|
||||
case "nsDomainChanged":
|
||||
sharedDomainManager.NotifyUpdate()
|
||||
case "nsRecordChanged":
|
||||
sharedRecordManager.NotifyUpdate()
|
||||
case "nsRouteChanged":
|
||||
sharedRouteManager.NotifyUpdate()
|
||||
case "nsKeyChanged":
|
||||
sharedKeyManager.NotifyUpdate()
|
||||
case "nsDDoSProtectionChanged":
|
||||
err := this.updateDDoS(rpcClient)
|
||||
if err != nil {
|
||||
remotelogs.Error("DNS_NODE", "apply DDoS config failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
_, err = rpcClient.NodeTaskRPC.ReportNodeTaskDone(rpcClient.Context(), &pb.ReportNodeTaskDoneRequest{
|
||||
NodeTaskId: task.Id,
|
||||
IsOk: true,
|
||||
Error: "",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *DNSNode) updateDDoS(rpcClient *rpc.RPCClient) error {
|
||||
resp, err := rpcClient.NSNodeRPC.FindNSNodeDDoSProtection(rpcClient.Context(), &pb.FindNSNodeDDoSProtectionRequest{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(resp.DdosProtectionJSON) == 0 {
|
||||
if sharedNodeConfig != nil {
|
||||
sharedNodeConfig.DDoSProtection = nil
|
||||
}
|
||||
} else {
|
||||
var ddosProtectionConfig = &ddosconfigs.ProtectionConfig{}
|
||||
err = json.Unmarshal(resp.DdosProtectionJSON, ddosProtectionConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decode DDoS protection config failed: %w", err)
|
||||
}
|
||||
|
||||
if sharedNodeConfig != nil {
|
||||
sharedNodeConfig.DDoSProtection = ddosProtectionConfig
|
||||
}
|
||||
|
||||
err = firewalls.SharedDDoSProtectionManager.Apply(ddosProtectionConfig)
|
||||
if err != nil {
|
||||
// 不阻塞
|
||||
remotelogs.Error("NODE", "apply DDoS protection failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
306
EdgeDNS/internal/nodes/dns_node_test.go
Normal file
306
EdgeDNS/internal/nodes/dns_node_test.go
Normal file
@@ -0,0 +1,306 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build plus
|
||||
|
||||
package nodes_test
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"github.com/miekg/dns"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDNS_Query_A_UDP(t *testing.T) {
|
||||
type query struct {
|
||||
Domain string
|
||||
Type uint16
|
||||
}
|
||||
|
||||
c := new(dns.Client)
|
||||
|
||||
for _, query := range []query{
|
||||
{"hello.world.teaos.cn", dns.TypeA},
|
||||
{"cdn.teaos.cn", dns.TypeA},
|
||||
{"hello.teaos.cn", dns.TypeCNAME},
|
||||
{"hello.teaos.cn", dns.TypeA},
|
||||
{"edgecdn.teaos.cn", dns.TypeA},
|
||||
} {
|
||||
var m = new(dns.Msg)
|
||||
m.SetQuestion(query.Domain+".", query.Type)
|
||||
r, _, err := c.Exchange(m, "127.0.0.1:53")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to exchange: %v", err)
|
||||
}
|
||||
for _, a := range r.Answer {
|
||||
t.Log(a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNS_Query_A_Many(t *testing.T) {
|
||||
type queryDef struct {
|
||||
Domain string
|
||||
Type uint16
|
||||
}
|
||||
|
||||
var c = new(dns.Client)
|
||||
|
||||
for i := 0; i < 10000; i++ {
|
||||
for _, query := range []queryDef{
|
||||
{"hello.goedge.cn", dns.TypeA},
|
||||
} {
|
||||
var m = new(dns.Msg)
|
||||
m.SetQuestion(query.Domain+".", query.Type)
|
||||
r, _, err := c.Exchange(m, "127.0.0.1:53")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to exchange: %v", err)
|
||||
}
|
||||
_ = r
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNS_Query_CNAME(t *testing.T) {
|
||||
type query struct {
|
||||
Domain string
|
||||
Type uint16
|
||||
}
|
||||
|
||||
c := new(dns.Client)
|
||||
|
||||
for _, query := range []query{
|
||||
{"hello.teaos.cn", dns.TypeA},
|
||||
} {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(query.Domain+".", query.Type)
|
||||
r, _, err := c.Exchange(m, "127.0.0.1:53")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to exchange: %v", err)
|
||||
}
|
||||
for _, a := range r.Answer {
|
||||
t.Log(a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNS_Query_A_TCP(t *testing.T) {
|
||||
type query struct {
|
||||
Domain string
|
||||
Type uint16
|
||||
}
|
||||
|
||||
c := new(dns.Client)
|
||||
|
||||
for _, query := range []query{
|
||||
{"hello.world.teaos.cn", dns.TypeA},
|
||||
{"cdn.teaos.cn", dns.TypeA},
|
||||
{"hello.teaos.cn", dns.TypeCNAME},
|
||||
{"goedge.cn", dns.TypeA},
|
||||
{"www.goedge.cn", dns.TypeA},
|
||||
} {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(query.Domain+".", query.Type)
|
||||
//r, _, err := c.Exchange(m, "127.0.0.1:54")
|
||||
conn, err := dns.Dial("tcp", "127.0.0.1:53")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r, _, err := c.ExchangeWithConn(m, conn)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to exchange: %v", err)
|
||||
}
|
||||
for _, a := range r.Answer {
|
||||
t.Log(a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNS_Query_A_TLS(t *testing.T) {
|
||||
type query struct {
|
||||
Domain string
|
||||
Type uint16
|
||||
}
|
||||
|
||||
c := new(dns.Client)
|
||||
|
||||
for _, query := range []query{
|
||||
{"hello.world.teaos.cn", dns.TypeA},
|
||||
{"cdn.teaos.cn", dns.TypeA},
|
||||
{"hello.teaos.cn", dns.TypeCNAME},
|
||||
} {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(query.Domain+".", query.Type)
|
||||
//r, _, err := c.Exchange(m, "127.0.0.1:54")
|
||||
conn, err := dns.DialWithTLS("tcp", "127.0.0.1:853", &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r, _, err := c.ExchangeWithConn(m, conn)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to exchange: %v", err)
|
||||
}
|
||||
for _, a := range r.Answer {
|
||||
t.Log(a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNS_Query_Internet(t *testing.T) {
|
||||
type query struct {
|
||||
Domain string
|
||||
Type uint16
|
||||
}
|
||||
|
||||
c := new(dns.Client)
|
||||
for _, query := range []query{
|
||||
{"1.goedge.cn", dns.TypeA},
|
||||
} {
|
||||
m := new(dns.Msg)
|
||||
//m.RecursionDesired = true
|
||||
m.SetQuestion(query.Domain+".", query.Type)
|
||||
|
||||
conn, err := dns.Dial("udp", "ns1.teaos.cn:53")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r, _, err := c.ExchangeWithConn(m, conn)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to exchange: %v", err)
|
||||
}
|
||||
for _, a := range r.Answer {
|
||||
t.Log(a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNS_Query_TSIG(t *testing.T) {
|
||||
type query struct {
|
||||
Domain string
|
||||
Type uint16
|
||||
}
|
||||
|
||||
c := new(dns.Client)
|
||||
c.TsigSecret = map[string]string{"teaos.": "NzhhZDExMzM5NWMwN2Q5OWM5YTFhMzgxZWNkZGMwMDA2ODUzODdiYTM2ODA5N2I2YjYwZWRlNmNlNjlhMzdmM2JmNjcxZmQ4NzVjMjI1Y2QwOTQ2Njk5OWY0MzRkMTJkNTczNjFlZDgwYmQxZWZjZDM4ZjAxNDNmM2Y2NTU1YjE="}
|
||||
|
||||
for _, query := range []query{
|
||||
{"hello.cdn.teaos.cn", dns.TypeA},
|
||||
{"cdn.teaos.cn", dns.TypeA},
|
||||
{"hello.teaos.cn", dns.TypeCNAME},
|
||||
{"hello.teaos.cn", dns.TypeA},
|
||||
{"edgecdn.teaos.cn", dns.TypeA},
|
||||
} {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(query.Domain+".", query.Type)
|
||||
m.SetTsig("teaos.", dns.HmacSHA512, 300, time.Now().Unix())
|
||||
r, _, err := c.Exchange(m, "127.0.0.1:53")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to exchange: %v", err)
|
||||
}
|
||||
for _, a := range r.Answer {
|
||||
t.Log(a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNS_Query_A_Route(t *testing.T) {
|
||||
type query struct {
|
||||
Domain string
|
||||
Type uint16
|
||||
}
|
||||
|
||||
c := new(dns.Client)
|
||||
|
||||
for _, query := range []query{
|
||||
{"route.com", dns.TypeA},
|
||||
} {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(query.Domain+".", query.Type)
|
||||
r, _, err := c.Exchange(m, "127.0.0.1:53")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to exchange: %v", err)
|
||||
}
|
||||
for _, a := range r.Answer {
|
||||
t.Log(a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNS_Query_Recursion(t *testing.T) {
|
||||
type query struct {
|
||||
Domain string
|
||||
Type uint16
|
||||
}
|
||||
|
||||
c := new(dns.Client)
|
||||
|
||||
for _, query := range []query{
|
||||
{"example.org", dns.TypeA},
|
||||
} {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(query.Domain+".", query.Type)
|
||||
r, _, err := c.Exchange(m, "127.0.0.1:53")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to exchange: %v", err)
|
||||
}
|
||||
for _, a := range r.Answer {
|
||||
t.Log(a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNS_Flood(t *testing.T) {
|
||||
type query struct {
|
||||
Domain string
|
||||
Type uint16
|
||||
}
|
||||
|
||||
var c = new(dns.Client)
|
||||
|
||||
for i := 0; i < 1_000_000; i++ {
|
||||
for _, query := range []query{
|
||||
{"hello.world.teaos.cn", dns.TypeA},
|
||||
{"cdn.teaos.cn", dns.TypeA},
|
||||
{"hello.teaos.cn", dns.TypeCNAME},
|
||||
{"hello.teaos.cn", dns.TypeA},
|
||||
{"edgecdn.teaos.cn", dns.TypeA},
|
||||
} {
|
||||
var m = new(dns.Msg)
|
||||
m.SetQuestion(query.Domain+".", query.Type)
|
||||
r, _, err := c.Exchange(m, "192.168.2.60:58")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to exchange: %v", err)
|
||||
}
|
||||
_ = r
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDNSNode(b *testing.B) {
|
||||
var c = new(dns.Client)
|
||||
conn, err := c.Dial("192.168.2.60:58")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
type query struct {
|
||||
Domain string
|
||||
Type uint16
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, query := range []query{
|
||||
{"cdn.teaos.cn", dns.TypeA},
|
||||
} {
|
||||
var m = new(dns.Msg)
|
||||
m.SetQuestion(query.Domain+".", query.Type)
|
||||
_, _, err := c.ExchangeWithConn(m, conn)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
179
EdgeDNS/internal/nodes/http_writer.go
Normal file
179
EdgeDNS/internal/nodes/http_writer.go
Normal file
@@ -0,0 +1,179 @@
|
||||
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"github.com/miekg/dns"
|
||||
"net"
|
||||
"net/http"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type HTTPWriter struct {
|
||||
rawConn net.Conn
|
||||
rawWriter http.ResponseWriter
|
||||
contentType string
|
||||
}
|
||||
|
||||
func NewHTTPWriter(rawWriter http.ResponseWriter, rawConn net.Conn, contentType string) *HTTPWriter {
|
||||
return &HTTPWriter{
|
||||
rawWriter: rawWriter,
|
||||
rawConn: rawConn,
|
||||
contentType: contentType,
|
||||
}
|
||||
}
|
||||
|
||||
func (this *HTTPWriter) LocalAddr() net.Addr {
|
||||
return this.rawConn.LocalAddr()
|
||||
}
|
||||
|
||||
func (this *HTTPWriter) RemoteAddr() net.Addr {
|
||||
return this.rawConn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (this *HTTPWriter) WriteMsg(msg *dns.Msg) error {
|
||||
if msg == nil {
|
||||
return errors.New("'msg' should not be nil")
|
||||
}
|
||||
|
||||
msgData, err := this.encodeMsg(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
this.rawWriter.Header().Set("Content-Length", types.String(len(msgData)))
|
||||
this.rawWriter.Header().Set("Content-Type", this.contentType)
|
||||
|
||||
// cache-control
|
||||
if len(msg.Answer) > 0 {
|
||||
var minTtl uint32
|
||||
for _, answer := range msg.Answer {
|
||||
var header = answer.Header()
|
||||
if header != nil && header.Ttl > 0 && (minTtl == 0 || header.Ttl < minTtl) {
|
||||
minTtl = header.Ttl
|
||||
}
|
||||
}
|
||||
if minTtl > 0 {
|
||||
this.rawWriter.Header().Set("Cache-Control", "max-age="+types.String(minTtl))
|
||||
}
|
||||
}
|
||||
|
||||
this.rawWriter.WriteHeader(http.StatusOK)
|
||||
_, err = this.rawWriter.Write(msgData)
|
||||
return err
|
||||
}
|
||||
|
||||
func (this *HTTPWriter) Write(p []byte) (int, error) {
|
||||
this.rawWriter.Header().Set("Content-Length", types.String(len(p)))
|
||||
this.rawWriter.WriteHeader(http.StatusOK)
|
||||
return this.rawWriter.Write(p)
|
||||
}
|
||||
|
||||
func (this *HTTPWriter) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *HTTPWriter) TsigStatus() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *HTTPWriter) TsigTimersOnly(timersOnly bool) {
|
||||
}
|
||||
|
||||
func (this *HTTPWriter) Hijack() {
|
||||
hijacker, ok := this.rawWriter.(http.Hijacker)
|
||||
if ok {
|
||||
_, _, _ = hijacker.Hijack()
|
||||
}
|
||||
}
|
||||
|
||||
func (this *HTTPWriter) encodeMsg(msg *dns.Msg) ([]byte, error) {
|
||||
if this.contentType == "application/x-javascript" || this.contentType == "application/json" {
|
||||
var result = map[string]any{
|
||||
"Status": 0,
|
||||
"TC": msg.Truncated,
|
||||
"RD": msg.RecursionDesired,
|
||||
"RA": msg.RecursionAvailable,
|
||||
"AD": msg.AuthenticatedData,
|
||||
"CD": msg.CheckingDisabled,
|
||||
}
|
||||
|
||||
// questions
|
||||
var questionMaps = []map[string]any{}
|
||||
for _, question := range msg.Question {
|
||||
questionMaps = append(questionMaps, map[string]any{
|
||||
"name": question.Name,
|
||||
"type": question.Qtype,
|
||||
})
|
||||
}
|
||||
result["Question"] = questionMaps
|
||||
|
||||
// answers
|
||||
var answerMaps = []map[string]any{}
|
||||
for _, answer := range msg.Answer {
|
||||
var answerMap = map[string]any{
|
||||
"name": answer.Header().Name,
|
||||
"type": answer.Header().Rrtype,
|
||||
"TTL": answer.Header().Ttl,
|
||||
}
|
||||
|
||||
switch x := answer.(type) {
|
||||
case *dns.A:
|
||||
answerMap["data"] = x.A.String()
|
||||
case *dns.AAAA:
|
||||
answerMap["data"] = x.AAAA.String()
|
||||
case *dns.CNAME:
|
||||
answerMap["data"] = x.Target
|
||||
case *dns.TXT:
|
||||
answerMap["data"] = x.Txt
|
||||
case *dns.NS:
|
||||
answerMap["data"] = x.Ns
|
||||
case *dns.MX:
|
||||
answerMap["data"] = x.Mx
|
||||
answerMap["preference"] = x.Preference
|
||||
default:
|
||||
var answerValue = reflect.ValueOf(answer).Elem()
|
||||
var answerType = answerValue.Type()
|
||||
|
||||
var countFields = answerType.NumField()
|
||||
for i := 0; i < countFields; i++ {
|
||||
var fieldName = answerType.Field(i).Name
|
||||
var fieldValue = answerValue.FieldByName(fieldName)
|
||||
if !fieldValue.IsValid() {
|
||||
continue
|
||||
}
|
||||
var fieldInterface = fieldValue.Interface()
|
||||
if fieldInterface == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
_, ok := fieldInterface.(dns.RR_Header)
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if countFields == 2 {
|
||||
answerMap["data"] = fieldValue.Interface()
|
||||
} else {
|
||||
answerMap[fieldName] = fieldValue.Interface()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
answerMaps = append(answerMaps, answerMap)
|
||||
}
|
||||
result["Answer"] = answerMaps
|
||||
|
||||
if Tea.IsTesting() {
|
||||
return json.MarshalIndent(result, "", " ")
|
||||
} else {
|
||||
return json.Marshal(result)
|
||||
}
|
||||
} else {
|
||||
return msg.Pack()
|
||||
}
|
||||
}
|
||||
306
EdgeDNS/internal/nodes/listen_manager.go
Normal file
306
EdgeDNS/internal/nodes/listen_manager.go
Normal file
@@ -0,0 +1,306 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
teaconst "github.com/TeaOSLab/EdgeDNS/internal/const"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/events"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/firewalls"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/utils"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var sharedListenManager *ListenManager = nil
|
||||
|
||||
func init() {
|
||||
if !teaconst.IsMain {
|
||||
return
|
||||
}
|
||||
|
||||
sharedListenManager = NewListenManager()
|
||||
|
||||
events.On(events.EventReload, func() {
|
||||
sharedListenManager.Update(sharedNodeConfig)
|
||||
})
|
||||
events.On(events.EventQuit, func() {
|
||||
_ = sharedListenManager.ShutdownAll()
|
||||
})
|
||||
}
|
||||
|
||||
// ListenManager 端口监听管理器
|
||||
type ListenManager struct {
|
||||
locker sync.Mutex
|
||||
serverMap map[string]*Server // addr => *Server
|
||||
|
||||
firewalld *firewalls.Firewalld
|
||||
lastPortStrings string
|
||||
lastTCPPortRanges [][2]int
|
||||
lastUDPPortRanges [][2]int
|
||||
}
|
||||
|
||||
// NewListenManager 获取新对象
|
||||
func NewListenManager() *ListenManager {
|
||||
return &ListenManager{
|
||||
serverMap: map[string]*Server{},
|
||||
firewalld: firewalls.NewFirewalld(),
|
||||
}
|
||||
}
|
||||
|
||||
// Update 修改配置
|
||||
func (this *ListenManager) Update(config *dnsconfigs.NSNodeConfig) {
|
||||
// 构造服务配置
|
||||
var serverConfigs = []*ServerConfig{}
|
||||
var serverAddrs = []string{}
|
||||
|
||||
// 如果没有配置,则配置一些默认的端口
|
||||
if config.TCP == nil && config.TLS == nil && config.UDP == nil {
|
||||
config.TCP = &serverconfigs.TCPProtocolConfig{}
|
||||
config.TCP.IsOn = true
|
||||
config.TCP.Listen = []*serverconfigs.NetworkAddressConfig{
|
||||
{
|
||||
Protocol: serverconfigs.ProtocolTCP,
|
||||
MinPort: 53,
|
||||
MaxPort: 53,
|
||||
},
|
||||
}
|
||||
|
||||
config.UDP = &serverconfigs.UDPProtocolConfig{}
|
||||
config.UDP.IsOn = true
|
||||
config.UDP.Listen = []*serverconfigs.NetworkAddressConfig{
|
||||
{
|
||||
Protocol: serverconfigs.ProtocolUDP,
|
||||
MinPort: 53,
|
||||
MaxPort: 53,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// 读取配置
|
||||
if config.TCP != nil && config.TCP.IsOn {
|
||||
for _, listen := range config.TCP.Listen {
|
||||
for port := listen.MinPort; port <= listen.MaxPort; port++ {
|
||||
serverConfigs = append(serverConfigs, &ServerConfig{
|
||||
Protocol: listen.Protocol,
|
||||
Host: listen.Host,
|
||||
Port: port,
|
||||
SSLPolicy: nil,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
if config.TLS != nil && config.TLS.IsOn {
|
||||
for _, listen := range config.TLS.Listen {
|
||||
if config.TLS.SSLPolicy == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for port := listen.MinPort; port <= listen.MaxPort; port++ {
|
||||
serverConfigs = append(serverConfigs, &ServerConfig{
|
||||
Protocol: listen.Protocol,
|
||||
Host: listen.Host,
|
||||
Port: port,
|
||||
SSLPolicy: config.TLS.SSLPolicy,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
if config.DoH != nil && config.DoH.IsOn {
|
||||
for _, listen := range config.DoH.Listen {
|
||||
if config.DoH.SSLPolicy == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for port := listen.MinPort; port <= listen.MaxPort; port++ {
|
||||
serverConfigs = append(serverConfigs, &ServerConfig{
|
||||
Protocol: listen.Protocol,
|
||||
Host: listen.Host,
|
||||
Port: port,
|
||||
SSLPolicy: config.DoH.SSLPolicy,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
if config.UDP != nil && config.UDP.IsOn {
|
||||
for _, listen := range config.UDP.Listen {
|
||||
for port := listen.MinPort; port <= listen.MaxPort; port++ {
|
||||
serverConfigs = append(serverConfigs, &ServerConfig{
|
||||
Protocol: listen.Protocol,
|
||||
Host: listen.Host,
|
||||
Port: port,
|
||||
SSLPolicy: nil,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 启动新的
|
||||
var addrMap = map[string]bool{} // addr => bool
|
||||
for _, serverConfig := range serverConfigs {
|
||||
var fullAddr = serverConfig.FullAddr()
|
||||
serverAddrs = append(serverAddrs, fullAddr)
|
||||
addrMap[fullAddr] = true
|
||||
this.locker.Lock()
|
||||
server, ok := this.serverMap[fullAddr]
|
||||
this.locker.Unlock()
|
||||
if !ok {
|
||||
// 启动新的
|
||||
var err error
|
||||
server, err = NewServer(serverConfig)
|
||||
if err != nil {
|
||||
remotelogs.Error("LISTEN_MANAGER", "create listener '"+fullAddr+"' failed: "+err.Error())
|
||||
continue
|
||||
}
|
||||
this.locker.Lock()
|
||||
this.serverMap[fullAddr] = server
|
||||
this.locker.Unlock()
|
||||
go func() {
|
||||
remotelogs.Println("LISTEN_MANAGER", "listen '"+fullAddr+"'")
|
||||
err = server.ListenAndServe()
|
||||
if err != nil {
|
||||
this.locker.Lock()
|
||||
delete(this.serverMap, fullAddr)
|
||||
this.locker.Unlock()
|
||||
remotelogs.Error("LISTEN_MANAGER", "listen '"+fullAddr+"' failed: "+err.Error())
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
// 更新配置
|
||||
server.Reload(serverConfig)
|
||||
}
|
||||
}
|
||||
|
||||
// 停止老的
|
||||
this.locker.Lock()
|
||||
for fullAddr, server := range this.serverMap {
|
||||
_, ok := addrMap[fullAddr]
|
||||
if !ok {
|
||||
delete(this.serverMap, fullAddr)
|
||||
|
||||
remotelogs.Println("LISTEN_MANAGER", "shutdown "+fullAddr)
|
||||
err := server.Shutdown()
|
||||
if err != nil {
|
||||
remotelogs.Error("LISTEN_MANAGER", "shutdown listener '"+fullAddr+"' failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
this.locker.Unlock()
|
||||
|
||||
// 添加端口到firewalld
|
||||
go func() {
|
||||
this.addToFirewalld(serverAddrs)
|
||||
}()
|
||||
}
|
||||
|
||||
// ShutdownAll 关闭所有的监听端口
|
||||
func (this *ListenManager) ShutdownAll() error {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
var lastErr error
|
||||
for _, server := range this.serverMap {
|
||||
err := server.Shutdown()
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
|
||||
func (this *ListenManager) addToFirewalld(serverAddrs []string) {
|
||||
if runtime.GOOS != "linux" {
|
||||
return
|
||||
}
|
||||
|
||||
if this.firewalld == nil || !this.firewalld.IsReady() {
|
||||
return
|
||||
}
|
||||
|
||||
// 组合端口号
|
||||
var portStrings = []string{}
|
||||
var udpPorts = []int{}
|
||||
var tcpPorts = []int{}
|
||||
for _, addr := range serverAddrs {
|
||||
var protocol = "tcp"
|
||||
if strings.HasPrefix(addr, "udp") {
|
||||
protocol = "udp"
|
||||
}
|
||||
|
||||
var lastIndex = strings.LastIndex(addr, ":")
|
||||
if lastIndex > 0 {
|
||||
var portString = addr[lastIndex+1:]
|
||||
portStrings = append(portStrings, portString+"/"+protocol)
|
||||
|
||||
switch protocol {
|
||||
case "tcp":
|
||||
tcpPorts = append(tcpPorts, types.Int(portString))
|
||||
case "udp":
|
||||
udpPorts = append(udpPorts, types.Int(portString))
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(portStrings) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否有变化
|
||||
sort.Strings(portStrings)
|
||||
var newPortStrings = strings.Join(portStrings, ",")
|
||||
if newPortStrings == this.lastPortStrings {
|
||||
return
|
||||
}
|
||||
this.lastPortStrings = newPortStrings
|
||||
|
||||
remotelogs.Println("FIREWALLD", "opening ports automatically ...")
|
||||
defer func() {
|
||||
remotelogs.Println("FIREWALLD", "open ports successfully")
|
||||
}()
|
||||
|
||||
// 合并端口
|
||||
var tcpPortRanges = utils.MergePorts(tcpPorts)
|
||||
var udpPortRanges = utils.MergePorts(udpPorts)
|
||||
|
||||
defer func() {
|
||||
this.lastTCPPortRanges = tcpPortRanges
|
||||
this.lastUDPPortRanges = udpPortRanges
|
||||
}()
|
||||
|
||||
// 删除老的不存在的端口
|
||||
var tcpPortRangesMap = map[string]bool{}
|
||||
var udpPortRangesMap = map[string]bool{}
|
||||
for _, portRange := range tcpPortRanges {
|
||||
tcpPortRangesMap[this.firewalld.PortRangeString(portRange, "tcp")] = true
|
||||
}
|
||||
for _, portRange := range udpPortRanges {
|
||||
udpPortRangesMap[this.firewalld.PortRangeString(portRange, "udp")] = true
|
||||
}
|
||||
|
||||
for _, portRange := range this.lastTCPPortRanges {
|
||||
var s = this.firewalld.PortRangeString(portRange, "tcp")
|
||||
_, ok := tcpPortRangesMap[s]
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
remotelogs.Println("FIREWALLD", "remove port '"+s+"'")
|
||||
_ = this.firewalld.RemovePortRangePermanently(portRange, "tcp")
|
||||
}
|
||||
for _, portRange := range this.lastUDPPortRanges {
|
||||
var s = this.firewalld.PortRangeString(portRange, "udp")
|
||||
_, ok := udpPortRangesMap[s]
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
remotelogs.Println("FIREWALLD", "remove port '"+s+"'")
|
||||
_ = this.firewalld.RemovePortRangePermanently(portRange, "udp")
|
||||
}
|
||||
|
||||
// 添加新的
|
||||
_ = this.firewalld.AllowPortRangesPermanently(tcpPortRanges, "tcp")
|
||||
_ = this.firewalld.AllowPortRangesPermanently(udpPortRanges, "udp")
|
||||
}
|
||||
319
EdgeDNS/internal/nodes/manager_domain.go
Normal file
319
EdgeDNS/internal/nodes/manager_domain.go
Normal file
@@ -0,0 +1,319 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/dbs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/models"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DomainManager 域名管理器
|
||||
type DomainManager struct {
|
||||
domainMap map[int64]*models.NSDomain // domainId => domain
|
||||
namesMap map[string]map[int64]*models.NSDomain // domain name => { domainId => domain }
|
||||
clusterId int64
|
||||
|
||||
db *dbs.DB
|
||||
version int64
|
||||
locker *sync.RWMutex
|
||||
|
||||
notifier chan bool
|
||||
}
|
||||
|
||||
// NewDomainManager 获取域名管理器对象
|
||||
func NewDomainManager(db *dbs.DB, clusterId int64) *DomainManager {
|
||||
return &DomainManager{
|
||||
db: db,
|
||||
domainMap: map[int64]*models.NSDomain{},
|
||||
namesMap: map[string]map[int64]*models.NSDomain{},
|
||||
clusterId: clusterId,
|
||||
notifier: make(chan bool, 8),
|
||||
locker: &sync.RWMutex{},
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动自动任务
|
||||
func (this *DomainManager) Start() {
|
||||
remotelogs.Println("DOMAIN_MANAGER", "starting ...")
|
||||
|
||||
// 从本地数据库中加载数据
|
||||
err := this.Load()
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("DOMAIN_MANAGER", "load failed: "+err.Error())
|
||||
} else {
|
||||
remotelogs.Error("DOMAIN_MANAGER", "load failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// 初始化运行
|
||||
err = this.LoopAll()
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("DOMAIN_MANAGER", "loop failed: "+err.Error())
|
||||
} else {
|
||||
remotelogs.Error("DOMAIN_MANAGER", "loop failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// 更新
|
||||
var ticker = time.NewTicker(20 * time.Second)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
case <-this.notifier:
|
||||
}
|
||||
|
||||
err = this.LoopAll()
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("DOMAIN_MANAGER", "loop failed: "+err.Error())
|
||||
} else {
|
||||
remotelogs.Error("DOMAIN_MANAGER", "loop failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (this *DomainManager) LoopAll() error {
|
||||
for {
|
||||
hasNext, err := this.Loop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !hasNext {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load 从数据库中加载数据
|
||||
func (this *DomainManager) Load() error {
|
||||
var offset = 0
|
||||
var size = 10000
|
||||
for {
|
||||
domains, err := this.db.ListDomains(this.clusterId, offset, size)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(domains) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
this.locker.Lock()
|
||||
for _, domain := range domains {
|
||||
this.domainMap[domain.Id] = domain
|
||||
|
||||
nameMap, ok := this.namesMap[domain.Name]
|
||||
if ok {
|
||||
nameMap[domain.Id] = domain
|
||||
} else {
|
||||
this.namesMap[domain.Name] = map[int64]*models.NSDomain{
|
||||
domain.Id: domain,
|
||||
}
|
||||
}
|
||||
|
||||
if domain.Version > this.version {
|
||||
this.version = domain.Version
|
||||
}
|
||||
}
|
||||
this.locker.Unlock()
|
||||
|
||||
offset += size
|
||||
}
|
||||
|
||||
if this.version > 0 {
|
||||
this.version++
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Loop 单次循环任务
|
||||
func (this *DomainManager) Loop() (hasNext bool, err error) {
|
||||
client, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
resp, err := client.NSDomainRPC.ListNSDomainsAfterVersion(client.Context(), &pb.ListNSDomainsAfterVersionRequest{
|
||||
Version: this.version,
|
||||
Size: 20000,
|
||||
})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
var domains = resp.NsDomains
|
||||
if len(domains) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
for _, domain := range domains {
|
||||
this.processDomain(domain)
|
||||
if domain.Version > this.version {
|
||||
this.version = domain.Version
|
||||
}
|
||||
}
|
||||
this.version++
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// FindDomain 根据名称查找域名
|
||||
func (this *DomainManager) FindDomain(name string) (domain *models.NSDomain, ok bool) {
|
||||
this.locker.RLock()
|
||||
defer this.locker.RUnlock()
|
||||
|
||||
nameMap, ok := this.namesMap[name]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
for _, domain2 := range nameMap {
|
||||
return domain2, true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// FindDomainWithId 根据域名ID查询域名
|
||||
func (this *DomainManager) FindDomainWithId(domainId int64) (domain *models.NSDomain) {
|
||||
this.locker.RLock()
|
||||
defer this.locker.RUnlock()
|
||||
|
||||
return this.domainMap[domainId]
|
||||
}
|
||||
|
||||
// NotifyUpdate 通知更新
|
||||
func (this *DomainManager) NotifyUpdate() {
|
||||
select {
|
||||
case this.notifier <- true:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// SplitDomain 分解域名
|
||||
func (this *DomainManager) SplitDomain(fullDomainName string) (rootDomain *models.NSDomain, recordName string) {
|
||||
if len(fullDomainName) == 0 {
|
||||
return
|
||||
}
|
||||
fullDomainName = strings.TrimSuffix(fullDomainName, ".") // 去除尾部的点(.)
|
||||
fullDomainName = strings.ToLower(fullDomainName) // 转换为小写
|
||||
var domainName = fullDomainName
|
||||
var domain, ok = this.FindDomain(domainName)
|
||||
if !ok {
|
||||
for {
|
||||
var index = strings.Index(domainName, ".")
|
||||
if index < 0 {
|
||||
break
|
||||
}
|
||||
domainName = domainName[index+1:]
|
||||
domain, ok = this.FindDomain(domainName)
|
||||
if ok {
|
||||
recordName = fullDomainName[:len(fullDomainName)-len(domainName)-1]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return domain, recordName
|
||||
}
|
||||
|
||||
// 处理域名
|
||||
func (this *DomainManager) processDomain(domain *pb.NSDomain) {
|
||||
if !domain.IsOn || domain.IsDeleted || domain.Status != dnsconfigs.NSDomainStatusVerified {
|
||||
this.locker.Lock()
|
||||
delete(this.domainMap, domain.Id)
|
||||
|
||||
nameMap, ok := this.namesMap[domain.Name]
|
||||
if ok {
|
||||
delete(nameMap, domain.Id)
|
||||
if len(nameMap) == 0 {
|
||||
delete(this.namesMap, domain.Name)
|
||||
}
|
||||
}
|
||||
|
||||
this.locker.Unlock()
|
||||
|
||||
// 从数据库中删除
|
||||
if this.db != nil {
|
||||
err := this.db.DeleteDomain(domain.Id)
|
||||
if err != nil {
|
||||
remotelogs.Error("DOMAIN_MANAGER", "delete domain from db failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// 存入数据库
|
||||
if this.db != nil {
|
||||
exists, err := this.db.ExistsDomain(domain.Id)
|
||||
if err != nil {
|
||||
remotelogs.Error("DOMAIN_MANAGER", "query failed: "+err.Error())
|
||||
} else {
|
||||
if exists {
|
||||
err = this.db.UpdateDomain(domain.Id, domain.NsCluster.Id, domain.UserId, domain.Name, domain.TsigJSON, domain.Version)
|
||||
if err != nil {
|
||||
remotelogs.Error("DOMAIN_MANAGER", "update failed: "+err.Error())
|
||||
}
|
||||
} else {
|
||||
err = this.db.InsertDomain(domain.Id, domain.NsCluster.Id, domain.UserId, domain.Name, domain.TsigJSON, domain.Version)
|
||||
if err != nil {
|
||||
remotelogs.Error("DOMAIN_MANAGER", "insert failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 同集群的才需要加载
|
||||
if this.clusterId == domain.NsCluster.Id {
|
||||
this.locker.Lock()
|
||||
var tsigConfig = &dnsconfigs.NSTSIGConfig{}
|
||||
if len(domain.TsigJSON) > 0 {
|
||||
err := json.Unmarshal(domain.TsigJSON, tsigConfig)
|
||||
if err != nil {
|
||||
remotelogs.Error("DOMAIN_MANAGER", "decode TSIG json failed: "+err.Error()+", domain: "+domain.Name+", domainId: "+types.String(domain.Id)+", JSON: "+string(domain.TsigJSON))
|
||||
}
|
||||
}
|
||||
|
||||
var nsDomain = &models.NSDomain{
|
||||
Id: domain.Id,
|
||||
ClusterId: domain.NsCluster.Id,
|
||||
UserId: domain.UserId,
|
||||
Name: domain.Name,
|
||||
TSIG: tsigConfig,
|
||||
Version: domain.Version,
|
||||
}
|
||||
this.domainMap[domain.Id] = nsDomain
|
||||
|
||||
nameMap, ok := this.namesMap[domain.Name]
|
||||
if ok {
|
||||
nameMap[nsDomain.Id] = nsDomain
|
||||
} else {
|
||||
this.namesMap[domain.Name] = map[int64]*models.NSDomain{
|
||||
nsDomain.Id: nsDomain,
|
||||
}
|
||||
}
|
||||
this.locker.Unlock()
|
||||
} else {
|
||||
// 不同集群的删除域名
|
||||
this.locker.Lock()
|
||||
delete(this.domainMap, domain.Id)
|
||||
|
||||
nameMap, ok := this.namesMap[domain.Name]
|
||||
if ok {
|
||||
delete(nameMap, domain.Id)
|
||||
if len(nameMap) == 0 {
|
||||
delete(this.namesMap, domain.Name)
|
||||
}
|
||||
}
|
||||
|
||||
this.locker.Unlock()
|
||||
}
|
||||
}
|
||||
63
EdgeDNS/internal/nodes/manager_domain_test.go
Normal file
63
EdgeDNS/internal/nodes/manager_domain_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/dbs"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDomainManager_Loop(t *testing.T) {
|
||||
var db = dbs.NewDB(Tea.Root + "/data/data.db")
|
||||
err := db.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var manager = NewDomainManager(db, 1)
|
||||
for i := 0; i < 10; i++ {
|
||||
_, err := manager.Loop()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
logs.PrintAsJSON(manager.domainMap, t)
|
||||
}
|
||||
|
||||
func TestDomainManager_Load(t *testing.T) {
|
||||
var db = dbs.NewDB(Tea.Root + "/data/data.db")
|
||||
err := db.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
manager := NewDomainManager(db, 2)
|
||||
err = manager.Load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
logs.PrintAsJSON(manager.domainMap, t)
|
||||
t.Log("version:", manager.version)
|
||||
}
|
||||
|
||||
func TestDomainManager_FindDomain(t *testing.T) {
|
||||
var db = dbs.NewDB(Tea.Root + "/data/data.db")
|
||||
err := db.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var manager = NewDomainManager(db, 2)
|
||||
err = manager.Load()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, name := range []string{"hello.com", "teaos.cn"} {
|
||||
domain, ok := manager.FindDomain(name)
|
||||
t.Log(name, ok, domain)
|
||||
}
|
||||
}
|
||||
284
EdgeDNS/internal/nodes/manager_key.go
Normal file
284
EdgeDNS/internal/nodes/manager_key.go
Normal file
@@ -0,0 +1,284 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/dbs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/models"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// KeyManager 密钥管理器
|
||||
type KeyManager struct {
|
||||
domainKeyMap map[int64]*models.NSKeys // domainId => *NSKeys
|
||||
zoneKeyMap map[int64]*models.NSKeys // zoneId => *NSKeys
|
||||
|
||||
db *dbs.DB
|
||||
locker sync.RWMutex
|
||||
version int64
|
||||
|
||||
notifier chan bool
|
||||
}
|
||||
|
||||
// NewKeyManager 获取密钥管理器
|
||||
func NewKeyManager(db *dbs.DB) *KeyManager {
|
||||
return &KeyManager{
|
||||
domainKeyMap: map[int64]*models.NSKeys{},
|
||||
zoneKeyMap: map[int64]*models.NSKeys{},
|
||||
db: db,
|
||||
notifier: make(chan bool, 8),
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动自动任务
|
||||
func (this *KeyManager) Start() {
|
||||
remotelogs.Println("KEY_MANAGER", "starting ...")
|
||||
|
||||
// 从本地数据库中加载数据
|
||||
err := this.Load()
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("KEY_MANAGER", "load failed: "+err.Error())
|
||||
} else {
|
||||
remotelogs.Error("KEY_MANAGER", "load failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// 初始化运行
|
||||
err = this.LoopAll()
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("KEY_MANAGER", "loop failed: "+err.Error())
|
||||
} else {
|
||||
remotelogs.Error("KEY_MANAGER", "loop failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// 更新
|
||||
var ticker = time.NewTicker(1 * time.Minute)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
case <-this.notifier:
|
||||
}
|
||||
|
||||
err := this.LoopAll()
|
||||
if err != nil {
|
||||
if rpc.IsConnError(err) {
|
||||
remotelogs.Debug("KEY_MANAGER", "loop failed: "+err.Error())
|
||||
} else {
|
||||
remotelogs.Error("KEY_MANAGER", "loop failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Load 从数据库中加载数据
|
||||
func (this *KeyManager) Load() error {
|
||||
var offset = 0
|
||||
var size = 10000
|
||||
for {
|
||||
keys, err := this.db.ListKeys(offset, size)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(keys) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
this.locker.Lock()
|
||||
for _, key := range keys {
|
||||
if key.ZoneId > 0 {
|
||||
keyList, ok := this.zoneKeyMap[key.ZoneId]
|
||||
if ok {
|
||||
keyList.Add(key)
|
||||
} else {
|
||||
keyList = models.NewNSKeys()
|
||||
keyList.Add(key)
|
||||
this.zoneKeyMap[key.ZoneId] = keyList
|
||||
}
|
||||
} else if key.DomainId > 0 {
|
||||
keyList, ok := this.domainKeyMap[key.DomainId]
|
||||
if ok {
|
||||
keyList.Add(key)
|
||||
} else {
|
||||
keyList = models.NewNSKeys()
|
||||
keyList.Add(key)
|
||||
this.domainKeyMap[key.DomainId] = keyList
|
||||
}
|
||||
}
|
||||
|
||||
if key.Version > this.version {
|
||||
this.version = key.Version
|
||||
}
|
||||
}
|
||||
this.locker.Unlock()
|
||||
|
||||
offset += size
|
||||
}
|
||||
|
||||
if this.version > 0 {
|
||||
this.version++
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *KeyManager) LoopAll() error {
|
||||
for {
|
||||
hasNext, err := this.Loop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !hasNext {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Loop 单次循环任务
|
||||
func (this *KeyManager) Loop() (hasNext bool, err error) {
|
||||
client, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
resp, err := client.NSKeyRPC.ListNSKeysAfterVersion(client.Context(), &pb.ListNSKeysAfterVersionRequest{
|
||||
Version: this.version,
|
||||
Size: 20000,
|
||||
})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
var keys = resp.NsKeys
|
||||
if len(keys) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
for _, key := range keys {
|
||||
this.processKey(key)
|
||||
|
||||
if key.Version > this.version {
|
||||
this.version = key.Version
|
||||
}
|
||||
}
|
||||
this.version++
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (this *KeyManager) FindKeysWithDomain(domainId int64) []*models.NSKey {
|
||||
this.locker.RLock()
|
||||
defer this.locker.RUnlock()
|
||||
|
||||
keys, ok := this.domainKeyMap[domainId]
|
||||
if ok {
|
||||
return keys.All()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NotifyUpdate 通知更新
|
||||
func (this *KeyManager) NotifyUpdate() {
|
||||
select {
|
||||
case this.notifier <- true:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// 处理Key
|
||||
func (this *KeyManager) processKey(key *pb.NSKey) {
|
||||
if key.NsDomain == nil && key.NsZone == nil {
|
||||
return
|
||||
}
|
||||
if !key.IsOn || key.IsDeleted {
|
||||
this.locker.Lock()
|
||||
if key.NsDomain != nil {
|
||||
list, ok := this.domainKeyMap[key.NsDomain.Id]
|
||||
if ok {
|
||||
list.Remove(key.Id)
|
||||
}
|
||||
}
|
||||
if key.NsZone != nil {
|
||||
list, ok := this.zoneKeyMap[key.NsZone.Id]
|
||||
if ok {
|
||||
list.Remove(key.Id)
|
||||
}
|
||||
}
|
||||
this.locker.Unlock()
|
||||
|
||||
// 从数据库中删除
|
||||
if this.db != nil {
|
||||
err := this.db.DeleteKey(key.Id)
|
||||
if err != nil {
|
||||
remotelogs.Error("KEY_MANAGER", "delete key from db failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var domainId int64
|
||||
var zoneId int64
|
||||
if key.NsDomain != nil {
|
||||
domainId = key.NsDomain.Id
|
||||
}
|
||||
if key.NsZone != nil {
|
||||
zoneId = key.NsZone.Id
|
||||
}
|
||||
|
||||
// 存入数据库
|
||||
if this.db != nil {
|
||||
exists, err := this.db.ExistsKey(key.Id)
|
||||
if err != nil {
|
||||
remotelogs.Error("KEY_MANAGER", "query failed: "+err.Error())
|
||||
} else {
|
||||
if exists {
|
||||
err = this.db.UpdateKey(key.Id, domainId, zoneId, key.Algo, key.Secret, key.SecretType, key.Version)
|
||||
if err != nil {
|
||||
remotelogs.Error("KEY_MANAGER", "update failed: "+err.Error())
|
||||
}
|
||||
} else {
|
||||
err = this.db.InsertKey(key.Id, domainId, zoneId, key.Algo, key.Secret, key.SecretType, key.Version)
|
||||
if err != nil {
|
||||
remotelogs.Error("KEY_MANAGER", "insert failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 加入缓存Map
|
||||
this.locker.Lock()
|
||||
var nsKey = &models.NSKey{
|
||||
Id: key.Id,
|
||||
DomainId: domainId,
|
||||
ZoneId: zoneId,
|
||||
Algo: key.Algo,
|
||||
Secret: key.Secret,
|
||||
SecretType: key.SecretType,
|
||||
Version: key.Version,
|
||||
}
|
||||
|
||||
if zoneId > 0 {
|
||||
keyList, ok := this.zoneKeyMap[zoneId]
|
||||
if ok {
|
||||
keyList.Add(nsKey)
|
||||
} else {
|
||||
keyList = models.NewNSKeys()
|
||||
keyList.Add(nsKey)
|
||||
this.zoneKeyMap[zoneId] = keyList
|
||||
}
|
||||
} else if domainId > 0 {
|
||||
keyList, ok := this.domainKeyMap[domainId]
|
||||
if ok {
|
||||
keyList.Add(nsKey)
|
||||
} else {
|
||||
keyList = models.NewNSKeys()
|
||||
keyList.Add(nsKey)
|
||||
this.domainKeyMap[domainId] = keyList
|
||||
}
|
||||
}
|
||||
this.locker.Unlock()
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user