This commit is contained in:
unknown
2026-02-04 20:27:13 +08:00
commit 3b042d1dad
9410 changed files with 1488147 additions and 0 deletions

View File

@@ -0,0 +1,5 @@
# IPList
List Check Order:
~~~
Global List --> Node List--> Server List --> WAF List --> Bind List
~~~

View File

@@ -0,0 +1,31 @@
package iplibrary
import (
"encoding/json"
"github.com/iwind/TeaGo/maps"
"net/http"
)
type BaseAction struct {
}
func (this *BaseAction) Close() error {
return nil
}
// DoHTTP 处理HTTP请求
func (this *BaseAction) DoHTTP(req *http.Request, resp http.ResponseWriter) (goNext bool, err error) {
return true, nil
}
func (this *BaseAction) convertParams(params maps.Map, ptr interface{}) error {
data, err := json.Marshal(params)
if err != nil {
return err
}
err = json.Unmarshal(data, ptr)
if err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,22 @@
package iplibrary
// FataError 是否是致命错误
type FataError struct {
err string
}
func (this *FataError) Error() string {
return this.err
}
func NewFataError(err string) error {
return &FataError{err: err}
}
func IsFatalError(err error) bool {
if err == nil {
return false
}
_, ok := err.(*FataError)
return ok
}

View File

@@ -0,0 +1,153 @@
package iplibrary
import (
"errors"
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
"runtime"
"time"
)
// FirewalldAction Firewalld动作管理
// 常用命令:
// - 查询列表: firewall-cmd --list-all
// - 添加IPfirewall-cmd --add-rich-rule="rule family='ipv4' source address='192.168.2.32' reject" --timeout=30s
// - 删除IPfirewall-cmd --remove-rich-rule="rule family='ipv4' source address='192.168.2.32' reject" --timeout=30s
type FirewalldAction struct {
BaseAction
config *firewallconfigs.FirewallActionFirewalldConfig
firewalldNotFound bool
}
func NewFirewalldAction() *FirewalldAction {
return &FirewalldAction{}
}
func (this *FirewalldAction) Init(config *firewallconfigs.FirewallActionConfig) error {
this.config = &firewallconfigs.FirewallActionFirewalldConfig{}
err := this.convertParams(config.Params, this.config)
if err != nil {
return err
}
return nil
}
func (this *FirewalldAction) AddItem(listType IPListType, item *pb.IPItem) error {
return this.runAction("addItem", listType, item)
}
func (this *FirewalldAction) DeleteItem(listType IPListType, item *pb.IPItem) error {
return this.runAction("deleteItem", listType, item)
}
func (this *FirewalldAction) runAction(action string, listType IPListType, item *pb.IPItem) error {
if item.Type == "all" {
return nil
}
if len(item.IpTo) == 0 {
return this.runActionSingleIP(action, listType, item)
}
cidrList, err := iPv4RangeToCIDRRange(item.IpFrom, item.IpTo)
if err != nil {
// 不合法的范围不予处理即可
return nil
}
if len(cidrList) == 0 {
return nil
}
for _, cidr := range cidrList {
item.IpFrom = cidr
item.IpTo = ""
err := this.runActionSingleIP(action, listType, item)
if err != nil {
return err
}
}
return nil
}
func (this *FirewalldAction) runActionSingleIP(action string, listType IPListType, item *pb.IPItem) error {
timestamp := time.Now().Unix()
if item.ExpiredAt > 0 && timestamp > item.ExpiredAt {
return nil
}
path := this.config.Path
var err error
if len(path) == 0 {
path, err = executils.LookPath("firewall-cmd")
if err != nil {
if this.firewalldNotFound {
return nil
}
this.firewalldNotFound = true
return err
}
}
if len(path) == 0 {
return errors.New("can not find 'firewall-cmd'")
}
opt := ""
switch action {
case "addItem":
opt = "--add-rich-rule"
case "deleteItem":
opt = "--remove-rich-rule"
default:
return errors.New("invalid action '" + action + "'")
}
opt += "=rule family='"
switch item.Type {
case "ipv4":
opt += "ipv4"
case "ipv6":
opt += "ipv6"
default:
// 我们忽略不能识别的Family
return nil
}
opt += "' source address='"
if len(item.IpFrom) == 0 {
return errors.New("invalid ip from")
}
opt += item.IpFrom + "' "
switch listType {
case IPListTypeWhite:
opt += " accept"
case IPListTypeBlack:
opt += " reject"
default:
// 我们忽略不能识别的列表类型
return nil
}
args := []string{opt}
if action == "addItem" {
if item.ExpiredAt > timestamp {
args = append(args, "--timeout="+fmt.Sprintf("%d", item.ExpiredAt-timestamp)+"s")
} else {
// TODO 思考是否需要permanent不然--reload之后会丢失
}
}
if runtime.GOOS == "darwin" {
// MAC OS直接返回
return nil
}
cmd := executils.NewTimeoutCmd(30*time.Second, path, args...)
cmd.WithStderr()
err = cmd.Run()
if err != nil {
return fmt.Errorf("%w, output: %s", err, cmd.Stderr())
}
return nil
}

View File

@@ -0,0 +1,79 @@
package iplibrary
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"testing"
"time"
)
func TestFirewalldAction_AddItem(t *testing.T) {
{
action := NewFirewalldAction()
action.config = &firewallconfigs.FirewallActionFirewalldConfig{
Path: "/usr/bin/firewalld",
}
err := action.AddItem(IPListTypeWhite, &pb.IPItem{
Type: "ipv4",
Id: 1,
IpFrom: "192.168.1.100",
ExpiredAt: time.Now().Unix() + 30,
})
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
{
action := NewFirewalldAction()
action.config = &firewallconfigs.FirewallActionFirewalldConfig{
Path: "/usr/bin/firewalld",
}
err := action.AddItem(IPListTypeBlack, &pb.IPItem{
Type: "ipv4",
Id: 1,
IpFrom: "192.168.1.101",
ExpiredAt: time.Now().Unix() + 30,
})
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
}
func TestFirewalldAction_DeleteItem(t *testing.T) {
action := NewFirewalldAction()
action.config = &firewallconfigs.FirewallActionFirewalldConfig{
Path: "/usr/bin/firewalld",
}
err := action.DeleteItem(IPListTypeWhite, &pb.IPItem{
Type: "ipv4",
Id: 1,
IpFrom: "192.168.1.100",
ExpiredAt: time.Now().Unix() + 30,
})
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestFirewalldAction_MultipleItem(t *testing.T) {
action := NewFirewalldAction()
action.config = &firewallconfigs.FirewallActionFirewalldConfig{
Path: "/usr/bin/firewalld",
}
err := action.AddItem(IPListTypeBlack, &pb.IPItem{
Type: "ipv4",
Id: 1,
IpFrom: "192.168.1.30",
IpTo: "192.168.1.200",
ExpiredAt: time.Now().Unix() + 30,
})
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}

View File

@@ -0,0 +1,55 @@
package iplibrary
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"net/http"
)
// HTMLAction HTML动作
type HTMLAction struct {
BaseAction
config *firewallconfigs.FirewallActionHTMLConfig
}
// NewHTMLAction 获取新对象
func NewHTMLAction() *HTMLAction {
return &HTMLAction{}
}
// Init 初始化
func (this *HTMLAction) Init(config *firewallconfigs.FirewallActionConfig) error {
this.config = &firewallconfigs.FirewallActionHTMLConfig{}
err := this.convertParams(config.Params, this.config)
if err != nil {
return err
}
return nil
}
// AddItem 添加
func (this *HTMLAction) AddItem(listType IPListType, item *pb.IPItem) error {
return nil
}
// DeleteItem 删除
func (this *HTMLAction) DeleteItem(listType IPListType, item *pb.IPItem) error {
return nil
}
// Close 关闭
func (this *HTMLAction) Close() error {
return nil
}
// DoHTTP 处理HTTP请求
func (this *HTMLAction) DoHTTP(req *http.Request, resp http.ResponseWriter) (goNext bool, err error) {
if this.config == nil {
goNext = true
return
}
resp.WriteHeader(http.StatusForbidden) // TODO改成可以配置
_, _ = resp.Write([]byte(this.config.Content))
return false, nil
}

View File

@@ -0,0 +1,81 @@
package iplibrary
import (
"bytes"
"encoding/json"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/iwind/TeaGo/maps"
"net/http"
"time"
)
var httpAPIClient = &http.Client{
Timeout: 5 * time.Second,
}
type HTTPAPIAction struct {
BaseAction
config *firewallconfigs.FirewallActionHTTPAPIConfig
}
func NewHTTPAPIAction() *HTTPAPIAction {
return &HTTPAPIAction{}
}
func (this *HTTPAPIAction) Init(config *firewallconfigs.FirewallActionConfig) error {
this.config = &firewallconfigs.FirewallActionHTTPAPIConfig{}
err := this.convertParams(config.Params, this.config)
if err != nil {
return err
}
if len(this.config.URL) == 0 {
return NewFataError("'url' should not be empty")
}
return nil
}
func (this *HTTPAPIAction) AddItem(listType IPListType, item *pb.IPItem) error {
return this.runAction("addItem", listType, item)
}
func (this *HTTPAPIAction) DeleteItem(listType IPListType, item *pb.IPItem) error {
return this.runAction("deleteItem", listType, item)
}
func (this *HTTPAPIAction) runAction(action string, listType IPListType, item *pb.IPItem) error {
if item == nil {
return nil
}
// TODO 增加节点ID等信息
m := maps.Map{
"action": action,
"listType": listType,
"item": maps.Map{
"type": item.Type,
"ipFrom": item.IpFrom,
"ipTo": item.IpTo,
"expiredAt": item.ExpiredAt,
},
}
mJSON, err := json.Marshal(m)
if err != nil {
return err
}
req, err := http.NewRequest(http.MethodPost, this.config.URL, bytes.NewReader(mJSON))
if err != nil {
return err
}
req.Header.Set("User-Agent", teaconst.GlobalProductName+"-Node/"+teaconst.Version)
resp, err := httpAPIClient.Do(req)
if err != nil {
return err
}
_ = resp.Body.Close()
return nil
}

View File

@@ -0,0 +1,50 @@
package iplibrary
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"testing"
)
func TestHTTPAPIAction_AddItem(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var action = NewHTTPAPIAction()
action.config = &firewallconfigs.FirewallActionHTTPAPIConfig{
URL: "http://127.0.0.1:2345/post",
TimeoutSeconds: 0,
}
err := action.AddItem(IPListTypeBlack, &pb.IPItem{
Type: "ipv4",
Id: 1,
IpFrom: "192.168.1.100",
})
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestHTTPAPIAction_DeleteItem(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var action = NewHTTPAPIAction()
action.config = &firewallconfigs.FirewallActionHTTPAPIConfig{
URL: "http://127.0.0.1:2345/post",
TimeoutSeconds: 0,
}
err := action.DeleteItem(IPListTypeBlack, &pb.IPItem{
Type: "ipv4",
Id: 1,
IpFrom: "192.168.1.100",
})
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}

View File

@@ -0,0 +1,24 @@
package iplibrary
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"net/http"
)
type ActionInterface interface {
// Init 初始化
Init(config *firewallconfigs.FirewallActionConfig) error
// AddItem 添加
AddItem(listType IPListType, item *pb.IPItem) error
// DeleteItem 删除
DeleteItem(listType IPListType, item *pb.IPItem) error
// Close 关闭
Close() error
// DoHTTP 处理HTTP请求
DoHTTP(req *http.Request, resp http.ResponseWriter) (goNext bool, err error)
}

View File

@@ -0,0 +1,358 @@
package iplibrary
import (
"errors"
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
"github.com/iwind/TeaGo/types"
"runtime"
"strconv"
"strings"
"time"
)
// IPSetAction IPSet动作
// 相关命令:
// - 利用Firewalld管理set
// - 添加firewall-cmd --permanent --new-ipset=edge_ip_list --type=hash:ip --option="timeout=0"
// - 删除firewall-cmd --permanent --delete-ipset=edge_ip_list
// - 重载firewall-cmd --reload
// - firewalld+ipset: firewall-cmd --permanent --add-rich-rule="rule source ipset='edge_ip_list' reject"
// - 利用IPTables管理set
// - 添加iptables -A INPUT -m set --match-set edge_ip_list src -j REJECT
// - 添加Itemipset add edge_ip_list 192.168.2.32 timeout 30
// - 删除Item: ipset del edge_ip_list 192.168.2.32
// - 创建setipset create edge_ip_list hash:ip timeout 0
// - 查看统计ipset -t list edge_black_list
// - 删除setipset destroy edge_black_list
type IPSetAction struct {
BaseAction
config *firewallconfigs.FirewallActionIPSetConfig
ipsetNotfound bool
}
func NewIPSetAction() *IPSetAction {
return &IPSetAction{}
}
func (this *IPSetAction) Init(config *firewallconfigs.FirewallActionConfig) error {
this.config = &firewallconfigs.FirewallActionIPSetConfig{}
err := this.convertParams(config.Params, this.config)
if err != nil {
return err
}
if len(this.config.WhiteName) == 0 {
return NewFataError("white list name should not be empty")
}
if len(this.config.BlackName) == 0 {
return NewFataError("black list name should not be empty")
}
// 创建ipset
{
path, err := executils.LookPath("ipset")
if err != nil {
return err
}
// ipv4
for _, listName := range []string{this.config.WhiteName, this.config.BlackName} {
if len(listName) == 0 {
continue
}
var cmd = executils.NewTimeoutCmd(30*time.Second, path, "create", listName, "hash:ip", "timeout", "0", "maxelem", "1000000")
cmd.WithStderr()
err := cmd.Run()
if err != nil {
var output = cmd.Stderr()
if !strings.Contains(output, "already exists") {
return fmt.Errorf("create ipset '%s': %w, output: %s", listName, err, output)
} else {
err = nil
}
}
}
// ipv6
for _, listName := range []string{this.config.WhiteNameIPv6, this.config.BlackNameIPv6} {
if len(listName) == 0 {
continue
}
var cmd = executils.NewTimeoutCmd(30*time.Second, path, "create", listName, "hash:ip", "family", "inet6", "timeout", "0", "maxelem", "1000000")
cmd.WithStderr()
err := cmd.Run()
if err != nil {
var output = cmd.Stderr()
if !strings.Contains(output, "already exists") {
return fmt.Errorf("create ipset '%s': %w, output: %s", listName, err, output)
} else {
err = nil
}
}
}
}
// firewalld
if this.config.AutoAddToFirewalld {
path, err := executils.LookPath("firewall-cmd")
if err != nil {
return err
}
// ipv4
for _, listName := range []string{this.config.WhiteName, this.config.BlackName} {
if len(listName) == 0 {
continue
}
var cmd = executils.NewTimeoutCmd(30*time.Second, path, "--permanent", "--new-ipset="+listName, "--type=hash:ip", "--option=timeout=0", "--option=maxelem=1000000")
cmd.WithStderr()
err := cmd.Run()
if err != nil {
var output = cmd.Stderr()
if strings.Contains(output, "NAME_CONFLICT") {
err = nil
} else {
return fmt.Errorf("firewall-cmd add ipset '%s': %w, output: %s", listName, err, output)
}
}
}
// ipv6
for _, listName := range []string{this.config.WhiteNameIPv6, this.config.BlackNameIPv6} {
if len(listName) == 0 {
continue
}
var cmd = executils.NewTimeoutCmd(30*time.Second, path, "--permanent", "--new-ipset="+listName, "--type=hash:ip", "--option=family=inet6", "--option=timeout=0", "--option=maxelem=1000000")
cmd.WithStderr()
err := cmd.Run()
if err != nil {
var output = cmd.Stderr()
if strings.Contains(output, "NAME_CONFLICT") {
err = nil
} else {
return fmt.Errorf("firewall-cmd add ipset '%s': %w, output: %s", listName, err, output)
}
}
}
// accept
for _, listName := range []string{this.config.WhiteName, this.config.WhiteNameIPv6} {
if len(listName) == 0 {
continue
}
var cmd = executils.NewTimeoutCmd(30*time.Second, path, "--permanent", "--add-rich-rule=rule source ipset='"+listName+"' accept")
cmd.WithStderr()
err := cmd.Run()
if err != nil {
return fmt.Errorf("firewall-cmd add rich rule '%s': %w, output: %s", listName, err, cmd.Stderr())
}
}
// reject
for _, listName := range []string{this.config.BlackName, this.config.BlackNameIPv6} {
if len(listName) == 0 {
continue
}
var cmd = executils.NewTimeoutCmd(30*time.Second, path, "--permanent", "--add-rich-rule=rule source ipset='"+listName+"' reject")
cmd.WithStderr()
err := cmd.Run()
if err != nil {
return fmt.Errorf("firewall-cmd add rich rule '%s': %w, output: %s", listName, err, cmd.Stderr())
}
}
// reload
{
var cmd = executils.NewTimeoutCmd(30*time.Second, path, "--reload")
cmd.WithStderr()
err := cmd.Run()
if err != nil {
return fmt.Errorf("firewall-cmd reload: %w, output: %s", err, cmd.Stderr())
}
}
}
// iptables
if this.config.AutoAddToIPTables {
path, err := executils.LookPath("iptables")
if err != nil {
return err
}
// accept
for _, listName := range []string{this.config.WhiteName, this.config.WhiteNameIPv6} {
if len(listName) == 0 {
continue
}
// 检查规则是否存在
var cmd = executils.NewTimeoutCmd(30*time.Second, path, "-C", "INPUT", "-m", "set", "--match-set", listName, "src", "-j", "ACCEPT")
err := cmd.Run()
var exists = err == nil
// 添加规则
if !exists {
var cmd = executils.NewTimeoutCmd(30*time.Second, path, "-A", "INPUT", "-m", "set", "--match-set", listName, "src", "-j", "ACCEPT")
cmd.WithStderr()
err := cmd.Run()
if err != nil {
return fmt.Errorf("iptables add rule: %w, output: %s", err, cmd.Stderr())
}
}
}
// reject
for _, listName := range []string{this.config.BlackName, this.config.BlackNameIPv6} {
if len(listName) == 0 {
continue
}
// 检查规则是否存在
var cmd = executils.NewTimeoutCmd(30*time.Second, path, "-C", "INPUT", "-m", "set", "--match-set", listName, "src", "-j", "REJECT")
err := cmd.Run()
var exists = err == nil
if !exists {
var cmd = executils.NewTimeoutCmd(30*time.Second, path, "-A", "INPUT", "-m", "set", "--match-set", listName, "src", "-j", "REJECT")
cmd.WithStderr()
err := cmd.Run()
if err != nil {
return fmt.Errorf("iptables add rule: %w, output: %s", err, cmd.Stderr())
}
}
}
}
return nil
}
func (this *IPSetAction) AddItem(listType IPListType, item *pb.IPItem) error {
return this.runAction("addItem", listType, item)
}
func (this *IPSetAction) DeleteItem(listType IPListType, item *pb.IPItem) error {
return this.runAction("deleteItem", listType, item)
}
func (this *IPSetAction) runAction(action string, listType IPListType, item *pb.IPItem) error {
if item.Type == "all" {
return nil
}
if len(item.IpTo) == 0 {
return this.runActionSingleIP(action, listType, item)
}
cidrList, err := iPv4RangeToCIDRRange(item.IpFrom, item.IpTo)
if err != nil {
// 不合法的范围不予处理即可
return nil
}
if len(cidrList) == 0 {
return nil
}
for _, cidr := range cidrList {
var index = strings.Index(cidr, "/")
if index <= 0 {
continue
}
// 只支持/24以下的
if types.Int(cidr[index+1:]) < 24 {
continue
}
item.IpFrom = cidr
item.IpTo = ""
err := this.runActionSingleIP(action, listType, item)
if err != nil {
return err
}
}
return nil
}
func (this *IPSetAction) SetConfig(config *firewallconfigs.FirewallActionIPSetConfig) {
this.config = config
}
func (this *IPSetAction) runActionSingleIP(action string, listType IPListType, item *pb.IPItem) error {
if item.Type == "all" {
return nil
}
var listName string
var isIPv6 = strings.Contains(item.IpFrom, ":")
switch listType {
case IPListTypeWhite:
if isIPv6 {
listName = this.config.WhiteNameIPv6
} else {
listName = this.config.WhiteName
}
case IPListTypeBlack:
if isIPv6 {
listName = this.config.BlackNameIPv6
} else {
listName = this.config.BlackName
}
default:
// 不支持的类型
return nil
}
if len(listName) == 0 {
return nil
}
var path = this.config.Path
var err error
if len(path) == 0 {
path, err = executils.LookPath("ipset")
if err != nil {
// 找不到ipset命令错误只提示一次
if this.ipsetNotfound {
return nil
}
this.ipsetNotfound = true
return err
}
}
// ipset add edge_ip_list 192.168.2.32 timeout 30
var args = []string{}
switch action {
case "addItem":
args = append(args, "add")
case "deleteItem":
args = append(args, "del")
}
args = append(args, listName, item.IpFrom)
if action == "addItem" {
var timestamp = time.Now().Unix()
if item.ExpiredAt > timestamp {
args = append(args, "timeout", strconv.FormatInt(item.ExpiredAt-timestamp, 10))
}
}
if runtime.GOOS == "darwin" {
// MAC OS直接返回
return nil
}
var cmd = executils.NewTimeoutCmd(30*time.Second, path, args...)
cmd.WithStderr()
err = cmd.Run()
if err != nil {
var errString = cmd.Stderr()
if action == "deleteItem" && strings.Contains(errString, "not added") {
return nil
}
return errors.New(strings.TrimSpace(errString))
}
return nil
}

View File

@@ -0,0 +1,123 @@
package iplibrary_test
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
"github.com/iwind/TeaGo/maps"
"testing"
"time"
)
func TestIPSetAction_Init(t *testing.T) {
_, lookupErr := executils.LookPath("iptables")
if lookupErr != nil {
return
}
var action = iplibrary.NewIPSetAction()
err := action.Init(&firewallconfigs.FirewallActionConfig{
Params: maps.Map{
"path": "/usr/bin/iptables",
"whiteName": "white-list",
"blackName": "black-list",
},
})
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestIPSetAction_AddItem(t *testing.T) {
_, lookupErr := executils.LookPath("iptables")
if lookupErr != nil {
return
}
var action = iplibrary.NewIPSetAction()
action.SetConfig(&firewallconfigs.FirewallActionIPSetConfig{
Path: "/usr/bin/iptables",
WhiteName: "white-list",
BlackName: "black-list",
WhiteNameIPv6: "white-list-ipv6",
BlackNameIPv6: "black-list-ipv6",
})
{
err := action.AddItem(iplibrary.IPListTypeWhite, &pb.IPItem{
Type: "ipv4",
Id: 1,
IpFrom: "192.168.1.100",
ExpiredAt: time.Now().Unix() + 30,
})
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
{
err := action.AddItem(iplibrary.IPListTypeWhite, &pb.IPItem{
Type: "ipv4",
Id: 1,
IpFrom: "1:2:3:4",
ExpiredAt: time.Now().Unix() + 30,
})
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
{
err := action.AddItem(iplibrary.IPListTypeBlack, &pb.IPItem{
Type: "ipv4",
Id: 1,
IpFrom: "192.168.1.100",
ExpiredAt: time.Now().Unix() + 30,
})
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
{
err := action.AddItem(iplibrary.IPListTypeBlack, &pb.IPItem{
Type: "ipv4",
Id: 1,
IpFrom: "1:2:3:4",
ExpiredAt: time.Now().Unix() + 30,
})
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
}
func TestIPSetAction_DeleteItem(t *testing.T) {
_, lookupErr := executils.LookPath("firewalld")
if lookupErr != nil {
return
}
var action = iplibrary.NewIPSetAction()
err := action.Init(&firewallconfigs.FirewallActionConfig{
Params: maps.Map{
"path": "/usr/bin/firewalld",
"whiteName": "white-list",
},
})
if err != nil {
t.Fatal(err)
}
err = action.DeleteItem(iplibrary.IPListTypeWhite, &pb.IPItem{
Type: "ipv4",
Id: 1,
IpFrom: "192.168.1.100",
ExpiredAt: time.Now().Unix() + 30,
})
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}

View File

@@ -0,0 +1,135 @@
package iplibrary
import (
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
"runtime"
"strings"
"time"
)
// IPTablesAction IPTables动作
// 相关命令:
//
// iptables -A INPUT -s "192.168.2.32" -j ACCEPT
// iptables -A INPUT -s "192.168.2.32" -j REJECT
// iptables -D INPUT ...
// iptables -F INPUT
type IPTablesAction struct {
BaseAction
config *firewallconfigs.FirewallActionIPTablesConfig
iptablesNotFound bool
}
func NewIPTablesAction() *IPTablesAction {
return &IPTablesAction{}
}
func (this *IPTablesAction) Init(config *firewallconfigs.FirewallActionConfig) error {
this.config = &firewallconfigs.FirewallActionIPTablesConfig{}
err := this.convertParams(config.Params, this.config)
if err != nil {
return err
}
return nil
}
func (this *IPTablesAction) AddItem(listType IPListType, item *pb.IPItem) error {
return this.runAction("addItem", listType, item)
}
func (this *IPTablesAction) DeleteItem(listType IPListType, item *pb.IPItem) error {
return this.runAction("deleteItem", listType, item)
}
func (this *IPTablesAction) runAction(action string, listType IPListType, item *pb.IPItem) error {
if item.Type == "all" {
return nil
}
if len(item.IpTo) == 0 {
return this.runActionSingleIP(action, listType, item)
}
cidrList, err := iPv4RangeToCIDRRange(item.IpFrom, item.IpTo)
if err != nil {
// 不合法的范围不予处理即可
return nil
}
if len(cidrList) == 0 {
return nil
}
for _, cidr := range cidrList {
item.IpFrom = cidr
item.IpTo = ""
err := this.runActionSingleIP(action, listType, item)
if err != nil {
return err
}
}
return nil
}
func (this *IPTablesAction) runActionSingleIP(action string, listType IPListType, item *pb.IPItem) error {
// 暂时不支持ipv6
// TODO 将来支持ipv6
if utils.IsIPv6(item.IpFrom) {
return nil
}
if item.Type == "all" {
return nil
}
var path = this.config.Path
var err error
if len(path) == 0 {
path, err = executils.LookPath("iptables")
if err != nil {
if this.iptablesNotFound {
return nil
}
this.iptablesNotFound = true
return err
}
this.config.Path = path
}
iptablesAction := ""
switch action {
case "addItem":
iptablesAction = "-A"
case "deleteItem":
iptablesAction = "-D"
default:
return nil
}
args := []string{iptablesAction, "INPUT", "-s", item.IpFrom, "-j"}
switch listType {
case IPListTypeWhite:
args = append(args, "ACCEPT")
case IPListTypeBlack:
args = append(args, "REJECT")
default:
return nil
}
if runtime.GOOS == "darwin" {
// MAC OS直接返回
return nil
}
var cmd = executils.NewTimeoutCmd(30*time.Second, path, args...)
cmd.WithStderr()
err = cmd.Run()
if err != nil {
var output = cmd.Stderr()
if strings.Contains(output, "No chain/target/match") {
err = nil
} else {
return fmt.Errorf("%w, output: %s", err, output)
}
}
return nil
}

View File

@@ -0,0 +1,68 @@
package iplibrary
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
"testing"
"time"
)
func TestIPTablesAction_AddItem(t *testing.T) {
_, lookupErr := executils.LookPath("iptables")
if lookupErr != nil {
return
}
var action = NewIPTablesAction()
action.config = &firewallconfigs.FirewallActionIPTablesConfig{
Path: "/usr/bin/iptables",
}
{
err := action.AddItem(IPListTypeWhite, &pb.IPItem{
Type: "ipv4",
Id: 1,
IpFrom: "192.168.1.100",
ExpiredAt: time.Now().Unix() + 30,
})
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
{
err := action.AddItem(IPListTypeBlack, &pb.IPItem{
Type: "ipv4",
Id: 1,
IpFrom: "192.168.1.100",
ExpiredAt: time.Now().Unix() + 30,
})
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
}
func TestIPTablesAction_DeleteItem(t *testing.T) {
_, lookupErr := executils.LookPath("firewalld")
if lookupErr != nil {
return
}
var action = NewIPTablesAction()
action.config = &firewallconfigs.FirewallActionIPTablesConfig{
Path: "/usr/bin/firewalld",
}
err := action.DeleteItem(IPListTypeWhite, &pb.IPItem{
Type: "ipv4",
Id: 1,
IpFrom: "192.168.1.100",
ExpiredAt: time.Now().Unix() + 30,
})
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}

View File

@@ -0,0 +1,173 @@
package iplibrary
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"strconv"
"sync"
)
var SharedActionManager = NewActionManager()
// ActionManager 动作管理器定义
type ActionManager struct {
locker sync.Mutex
eventMap map[string][]ActionInterface // eventLevel => []instance
configMap map[int64]*firewallconfigs.FirewallActionConfig // id => config
instanceMap map[int64]ActionInterface // id => instance
}
// NewActionManager 获取动作管理对象
func NewActionManager() *ActionManager {
return &ActionManager{
configMap: map[int64]*firewallconfigs.FirewallActionConfig{},
instanceMap: map[int64]ActionInterface{},
}
}
// UpdateActions 更新配置
func (this *ActionManager) UpdateActions(actions []*firewallconfigs.FirewallActionConfig) {
this.locker.Lock()
defer this.locker.Unlock()
// 关闭不存在的
newActionsMap := map[int64]*firewallconfigs.FirewallActionConfig{}
for _, action := range actions {
newActionsMap[action.Id] = action
}
for _, oldAction := range this.configMap {
_, ok := newActionsMap[oldAction.Id]
if !ok {
instance, ok := this.instanceMap[oldAction.Id]
if ok {
_ = instance.Close()
delete(this.instanceMap, oldAction.Id)
remotelogs.Println("IPLIBRARY/ACTION_MANAGER", "close action "+strconv.FormatInt(oldAction.Id, 10))
}
}
}
// 添加新的或者更新老的
for _, newAction := range newActionsMap {
oldInstance, ok := this.instanceMap[newAction.Id]
if ok {
// 检查配置是否一致
oldConfigJSON, err := json.Marshal(this.configMap[newAction.Id])
if err != nil {
remotelogs.Error("IPLIBRARY/ACTION_MANAGER", "action "+strconv.FormatInt(newAction.Id, 10)+", type:"+newAction.Type+": "+err.Error())
continue
}
newConfigJSON, err := json.Marshal(newAction)
if err != nil {
remotelogs.Error("IPLIBRARY/ACTION_MANAGER", "action "+strconv.FormatInt(newAction.Id, 10)+", type:"+newAction.Type+": "+err.Error())
continue
}
if !bytes.Equal(newConfigJSON, oldConfigJSON) {
_ = oldInstance.Close()
// 重新创建
// 之所以要重新创建,是因为前后的动作类型可能有变化,完全重建可以避免不必要的麻烦
newInstance, err := this.createInstance(newAction)
if err != nil {
remotelogs.Error("IPLIBRARY/ACTION_MANAGER", "reload action "+strconv.FormatInt(newAction.Id, 10)+", type:"+newAction.Type+": "+err.Error())
continue
}
remotelogs.Println("IPLIBRARY/ACTION_MANAGER", "reloaded "+strconv.FormatInt(newAction.Id, 10)+", type:"+newAction.Type)
this.instanceMap[newAction.Id] = newInstance
}
} else {
// 创建
instance, err := this.createInstance(newAction)
if err != nil {
remotelogs.Error("IPLIBRARY/ACTION_MANAGER", "load new action "+strconv.FormatInt(newAction.Id, 10)+", type:"+newAction.Type+": "+err.Error())
continue
}
remotelogs.Println("IPLIBRARY/ACTION_MANAGER", "loaded action "+strconv.FormatInt(newAction.Id, 10)+", type:"+newAction.Type)
this.instanceMap[newAction.Id] = instance
}
}
// 更新配置
this.configMap = newActionsMap
this.eventMap = map[string][]ActionInterface{}
for _, action := range this.configMap {
instance, ok := this.instanceMap[action.Id]
if !ok {
continue
}
var instances = this.eventMap[action.EventLevel]
instances = append(instances, instance)
this.eventMap[action.EventLevel] = instances
}
}
// FindEventActions 查找事件对应的动作
func (this *ActionManager) FindEventActions(eventLevel string) []ActionInterface {
this.locker.Lock()
defer this.locker.Unlock()
return this.eventMap[eventLevel]
}
// AddItem 执行添加IP动作
func (this *ActionManager) AddItem(listType IPListType, item *pb.IPItem) {
instances, ok := this.eventMap[item.EventLevel]
if ok {
for _, instance := range instances {
err := instance.AddItem(listType, item)
if err != nil {
remotelogs.Error("IPLIBRARY/ACTION_MANAGER", "add item '"+fmt.Sprintf("%d", item.Id)+"': "+err.Error())
}
}
}
}
// DeleteItem 执行删除IP动作
func (this *ActionManager) DeleteItem(listType IPListType, item *pb.IPItem) {
instances, ok := this.eventMap[item.EventLevel]
if ok {
for _, instance := range instances {
err := instance.DeleteItem(listType, item)
if err != nil {
remotelogs.Error("IPLIBRARY/ACTION_MANAGER", "delete item '"+fmt.Sprintf("%d", item.Id)+"': "+err.Error())
}
}
}
}
func (this *ActionManager) createInstance(config *firewallconfigs.FirewallActionConfig) (ActionInterface, error) {
var instance ActionInterface
switch config.Type {
case firewallconfigs.FirewallActionTypeIPSet:
instance = NewIPSetAction()
case firewallconfigs.FirewallActionTypeFirewalld:
instance = NewFirewalldAction()
case firewallconfigs.FirewallActionTypeIPTables:
instance = NewIPTablesAction()
case firewallconfigs.FirewallActionTypeScript:
instance = NewScriptAction()
case firewallconfigs.FirewallActionTypeHTTPAPI:
instance = NewHTTPAPIAction()
case firewallconfigs.FirewallActionTypeHTML:
instance = NewHTMLAction()
}
if instance == nil {
return nil, errors.New("can not create instance for type '" + config.Type + "'")
}
err := instance.Init(config)
if err != nil {
// 如果是警告错误,我们只是提示
if !IsFatalError(err) {
remotelogs.Error("IPLIBRARY/ACTION_MANAGER/CREATE_INSTANCE", "init '"+config.Type+"' failed: "+err.Error())
} else {
return nil, err
}
}
return instance, nil
}

View File

@@ -0,0 +1,55 @@
package iplibrary
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/iwind/TeaGo/maps"
"testing"
)
func TestActionManager_UpdateActions(t *testing.T) {
var manager = NewActionManager()
manager.UpdateActions([]*firewallconfigs.FirewallActionConfig{
{
Id: 1,
Type: "ipset",
Params: maps.Map{
"whiteName": "edge-white-list",
"blackName": "edge-black-list",
},
},
})
t.Log("===config===")
for _, c := range manager.configMap {
t.Log(c.Id, c.Type)
}
t.Log("===instance===")
for id, c := range manager.instanceMap {
t.Log(id, c)
}
manager.UpdateActions([]*firewallconfigs.FirewallActionConfig{
{
Id: 1,
Type: "ipset",
Params: maps.Map{
"whiteName": "edge-white-list",
"blackName": "edge-black-list",
},
},
{
Id: 2,
Type: "iptables",
Params: maps.Map{},
},
})
t.Log("===config===")
for _, c := range manager.configMap {
t.Log(c.Id, c.Type)
}
t.Log("===instance===")
for id, c := range manager.instanceMap {
t.Logf("%d: %#v", id, c)
}
}

View File

@@ -0,0 +1,67 @@
package iplibrary
import (
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
"path/filepath"
"time"
)
// ScriptAction 脚本命令动作
type ScriptAction struct {
BaseAction
config *firewallconfigs.FirewallActionScriptConfig
}
func NewScriptAction() *ScriptAction {
return &ScriptAction{}
}
func (this *ScriptAction) Init(config *firewallconfigs.FirewallActionConfig) error {
this.config = &firewallconfigs.FirewallActionScriptConfig{}
err := this.convertParams(config.Params, this.config)
if err != nil {
return err
}
if len(this.config.Path) == 0 {
return NewFataError("'path' should not be empty")
}
return nil
}
func (this *ScriptAction) AddItem(listType IPListType, item *pb.IPItem) error {
return this.runAction("addItem", listType, item)
}
func (this *ScriptAction) DeleteItem(listType IPListType, item *pb.IPItem) error {
return this.runAction("deleteItem", listType, item)
}
func (this *ScriptAction) runAction(action string, listType IPListType, item *pb.IPItem) error {
// TODO 智能支持 .sh 脚本文件
var cmd = executils.NewTimeoutCmd(30*time.Second, this.config.Path)
cmd.WithEnv([]string{
"ACTION=" + action,
"TYPE=" + item.Type,
"IP_FROM=" + item.IpFrom,
"IP_TO=" + item.IpTo,
"EXPIRED_AT=" + fmt.Sprintf("%d", item.ExpiredAt),
"LIST_TYPE=" + listType,
})
if len(this.config.Cwd) > 0 {
cmd.WithDir(this.config.Cwd)
} else {
cmd.WithDir(filepath.Dir(this.config.Path))
}
cmd.WithStderr()
err := cmd.Run()
if err != nil {
return fmt.Errorf("%w, output: %s", err, cmd.Stderr())
}
return nil
}

View File

@@ -0,0 +1,55 @@
package iplibrary
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"testing"
"time"
)
func TestScriptAction_AddItem(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
action := NewScriptAction()
action.config = &firewallconfigs.FirewallActionScriptConfig{
Path: "/tmp/ip-item.sh",
Cwd: "",
Args: nil,
}
err := action.AddItem(IPListTypeBlack, &pb.IPItem{
Type: "ipv4",
Id: 1,
IpFrom: "192.168.1.100",
ExpiredAt: time.Now().Unix(),
})
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestScriptAction_DeleteItem(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
action := NewScriptAction()
action.config = &firewallconfigs.FirewallActionScriptConfig{
Path: "/tmp/ip-item.sh",
Cwd: "",
Args: nil,
}
err := action.DeleteItem(IPListTypeBlack, &pb.IPItem{
Type: "ipv4",
Id: 1,
IpFrom: "192.168.1.100",
ExpiredAt: time.Now().Unix(),
})
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}

View File

@@ -0,0 +1,85 @@
package iplibrary
import (
"fmt"
"log"
"math"
"strconv"
"strings"
)
// Convert IPv4 range into CIDR
// 来自https://gist.github.com/P-A-R-U-S/a090dd90c5104ce85a29c32669dac107
func iPv4RangeToCIDRRange(ipStart string, ipEnd string) (cidrs []string, err error) {
cidr2mask := []uint32{
0x00000000, 0x80000000, 0xC0000000,
0xE0000000, 0xF0000000, 0xF8000000,
0xFC000000, 0xFE000000, 0xFF000000,
0xFF800000, 0xFFC00000, 0xFFE00000,
0xFFF00000, 0xFFF80000, 0xFFFC0000,
0xFFFE0000, 0xFFFF0000, 0xFFFF8000,
0xFFFFC000, 0xFFFFE000, 0xFFFFF000,
0xFFFFF800, 0xFFFFFC00, 0xFFFFFE00,
0xFFFFFF00, 0xFFFFFF80, 0xFFFFFFC0,
0xFFFFFFE0, 0xFFFFFFF0, 0xFFFFFFF8,
0xFFFFFFFC, 0xFFFFFFFE, 0xFFFFFFFF,
}
ipStartUint32 := iPv4ToUint32(ipStart)
ipEndUint32 := iPv4ToUint32(ipEnd)
if ipStartUint32 > ipEndUint32 {
log.Fatalf("start IP:%s must be less than end IP:%s", ipStart, ipEnd)
}
for ipEndUint32 >= ipStartUint32 {
maxSize := 32
for maxSize > 0 {
maskedBase := ipStartUint32 & cidr2mask[maxSize-1]
if maskedBase != ipStartUint32 {
break
}
maxSize--
}
x := math.Log(float64(ipEndUint32-ipStartUint32+1)) / math.Log(2)
maxDiff := 32 - int(math.Floor(x))
if maxSize < maxDiff {
maxSize = maxDiff
}
cidrs = append(cidrs, uInt32ToIPv4(ipStartUint32)+"/"+strconv.Itoa(maxSize))
ipStartUint32 += uint32(math.Exp2(float64(32 - maxSize)))
}
return cidrs, err
}
// Convert IPv4 to uint32
func iPv4ToUint32(iPv4 string) uint32 {
ipOctets := [4]uint64{}
for i, v := range strings.SplitN(iPv4, ".", 4) {
ipOctets[i], _ = strconv.ParseUint(v, 10, 32)
}
result := (ipOctets[0] << 24) | (ipOctets[1] << 16) | (ipOctets[2] << 8) | ipOctets[3]
return uint32(result)
}
// Convert uint32 to IP
func uInt32ToIPv4(iPuInt32 uint32) (iP string) {
iP = fmt.Sprintf("%d.%d.%d.%d",
iPuInt32>>24,
(iPuInt32&0x00FFFFFF)>>16,
(iPuInt32&0x0000FFFF)>>8,
iPuInt32&0x000000FF)
return iP
}

View File

@@ -0,0 +1,7 @@
package iplibrary
import "testing"
func TestIPv4RangeToCIDRRange(t *testing.T) {
t.Log(iPv4RangeToCIDRRange("192.168.0.0", "192.168.255.255"))
}

View File

@@ -0,0 +1,5 @@
package iplibrary
func init() {
}

View File

@@ -0,0 +1,64 @@
package iplibrary
import (
"github.com/TeaOSLab/EdgeCommon/pkg/iputils"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
)
type IPItemType = string
const (
IPItemTypeIPv4 IPItemType = "ipv4" // IPv4
IPItemTypeIPv6 IPItemType = "ipv6" // IPv6
IPItemTypeAll IPItemType = "all" // 所有IP
)
// IPItem IP条目
type IPItem struct {
Type string `json:"type"`
Id uint64 `json:"id"`
IPFrom []byte `json:"ipFrom"`
IPTo []byte `json:"ipTo"`
ExpiredAt int64 `json:"expiredAt"`
EventLevel string `json:"eventLevel"`
}
// Contains 检查是否包含某个IP
func (this *IPItem) Contains(ipBytes []byte) bool {
switch this.Type {
case IPItemTypeIPv4:
return this.containsIP(ipBytes)
case IPItemTypeIPv6:
return this.containsIP(ipBytes)
case IPItemTypeAll:
return this.containsAll()
default:
return this.containsIP(ipBytes)
}
}
// 检查是否包含某个
func (this *IPItem) containsIP(ipBytes []byte) bool {
if IsZero(this.IPTo) {
if iputils.CompareBytes(this.IPFrom, ipBytes) != 0 {
return false
}
} else {
if iputils.CompareBytes(this.IPFrom, ipBytes) > 0 || iputils.CompareBytes(this.IPTo, ipBytes) < 0 {
return false
}
}
if this.ExpiredAt > 0 && this.ExpiredAt < fasttime.Now().Unix() {
return false
}
return true
}
// 检查是否包所有IP
func (this *IPItem) containsAll() bool {
if this.ExpiredAt > 0 && this.ExpiredAt < fasttime.Now().Unix() {
return false
}
return true
}

View File

@@ -0,0 +1,125 @@
package iplibrary_test
import (
"github.com/TeaOSLab/EdgeCommon/pkg/iputils"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"github.com/iwind/TeaGo/assert"
"math/rand"
"runtime"
"strconv"
"testing"
"time"
)
func TestIPItem_Contains(t *testing.T) {
var a = assert.NewAssertion(t)
{
var item = &iplibrary.IPItem{
IPFrom: iputils.ToBytes("192.168.1.100"),
IPTo: nil,
ExpiredAt: 0,
}
a.IsTrue(item.Contains(iputils.ToBytes("192.168.1.100")))
}
{
var item = &iplibrary.IPItem{
IPFrom: iputils.ToBytes("192.168.1.100"),
IPTo: nil,
ExpiredAt: time.Now().Unix() + 1,
}
a.IsTrue(item.Contains(iputils.ToBytes("192.168.1.100")))
}
{
var item = &iplibrary.IPItem{
IPFrom: iputils.ToBytes("192.168.1.100"),
IPTo: nil,
ExpiredAt: time.Now().Unix() - 1,
}
a.IsFalse(item.Contains(iputils.ToBytes("192.168.1.100")))
}
{
var item = &iplibrary.IPItem{
IPFrom: iputils.ToBytes("192.168.1.100"),
IPTo: nil,
ExpiredAt: 0,
}
a.IsFalse(item.Contains(iputils.ToBytes("192.168.1.101")))
}
{
var item = &iplibrary.IPItem{
IPFrom: iputils.ToBytes("192.168.1.1"),
IPTo: iputils.ToBytes("192.168.1.101"),
ExpiredAt: 0,
}
a.IsTrue(item.Contains(iputils.ToBytes("192.168.1.100")))
}
{
var item = &iplibrary.IPItem{
IPFrom: iputils.ToBytes("192.168.1.1"),
IPTo: iputils.ToBytes("192.168.1.100"),
ExpiredAt: 0,
}
a.IsTrue(item.Contains(iputils.ToBytes("192.168.1.100")))
}
{
var item = &iplibrary.IPItem{
IPFrom: iputils.ToBytes("192.168.1.1"),
IPTo: iputils.ToBytes("192.168.1.101"),
ExpiredAt: 0,
}
a.IsTrue(item.Contains(iputils.ToBytes("192.168.1.1")))
}
}
func TestIPItem_Memory(t *testing.T) {
var isSingleTest = testutils.IsSingleTesting()
var list = iplibrary.NewIPList()
var count = 100
if isSingleTest {
count = 2_000_000
}
for i := 0; i < count; i++ {
list.Add(&iplibrary.IPItem{
Type: "ip",
Id: uint64(i),
IPFrom: iputils.ToBytes("192.168.1.1"),
IPTo: nil,
ExpiredAt: time.Now().Unix(),
EventLevel: "",
})
}
runtime.GC()
t.Log("waiting")
if isSingleTest {
time.Sleep(10 * time.Second)
}
}
func BenchmarkIPItem_Contains(b *testing.B) {
runtime.GOMAXPROCS(1)
var item = &iplibrary.IPItem{
IPFrom: iputils.ToBytes("192.168.1.1"),
IPTo: iputils.ToBytes("192.168.1.101"),
ExpiredAt: 0,
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
var ip = iputils.ToBytes("192.168.1." + strconv.Itoa(rand.Int()%255))
item.Contains(ip)
}
})
}

View File

@@ -0,0 +1,350 @@
package iplibrary
import (
"github.com/TeaOSLab/EdgeCommon/pkg/iputils"
"github.com/TeaOSLab/EdgeNode/internal/utils/expires"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
"sort"
"sync"
)
var GlobalBlackIPList = NewIPList()
var GlobalWhiteIPList = NewIPList()
// IPList IP名单
// TODO 对ipMap进行分区
type IPList struct {
isDeleted bool
itemsMap map[uint64]*IPItem // id => item
sortedRangeItems []*IPItem
ipMap map[string]*IPItem // ipFrom => IPItem
bufferItemsMap map[uint64]*IPItem // id => IPItem
allItemsMap map[uint64]*IPItem // id => item
expireList *expires.List
mu sync.RWMutex
}
func NewIPList() *IPList {
var list = &IPList{
itemsMap: map[uint64]*IPItem{},
bufferItemsMap: map[uint64]*IPItem{},
allItemsMap: map[uint64]*IPItem{},
ipMap: map[string]*IPItem{},
}
var expireList = expires.NewList()
expireList.OnGC(func(itemId uint64) {
list.Delete(itemId)
})
list.expireList = expireList
return list
}
func (this *IPList) Add(item *IPItem) {
if this.isDeleted {
return
}
this.addItem(item, true, true)
}
func (this *IPList) AddDelay(item *IPItem) {
if this.isDeleted || item == nil {
return
}
if !IsZero(item.IPTo) {
this.mu.Lock()
this.bufferItemsMap[item.Id] = item
this.mu.Unlock()
} else {
this.addItem(item, true, true)
}
}
func (this *IPList) Sort() {
this.mu.Lock()
this.sortRangeItems(false)
this.mu.Unlock()
}
func (this *IPList) Delete(itemId uint64) {
this.mu.Lock()
this.deleteItem(itemId)
this.mu.Unlock()
}
// Contains 判断是否包含某个IP
func (this *IPList) Contains(ipBytes []byte) bool {
if this.isDeleted {
return false
}
this.mu.RLock()
defer this.mu.RUnlock()
if len(this.allItemsMap) > 0 {
return true
}
var item = this.lookupIP(ipBytes)
return item != nil
}
// ContainsExpires 判断是否包含某个IP
func (this *IPList) ContainsExpires(ipBytes []byte) (expiresAt int64, ok bool) {
if this.isDeleted {
return
}
this.mu.RLock()
defer this.mu.RUnlock()
if len(this.allItemsMap) > 0 {
return 0, true
}
var item = this.lookupIP(ipBytes)
if item == nil {
return
}
return item.ExpiredAt, true
}
// ContainsIPStrings 是否包含一组IP中的任意一个并返回匹配的第一个Item
func (this *IPList) ContainsIPStrings(ipStrings []string) (item *IPItem, found bool) {
if this.isDeleted {
return
}
if len(ipStrings) == 0 {
return
}
this.mu.RLock()
defer this.mu.RUnlock()
if len(this.allItemsMap) > 0 {
for _, allItem := range this.allItemsMap {
item = allItem
break
}
if item != nil {
found = true
return
}
}
for _, ipString := range ipStrings {
if len(ipString) == 0 {
continue
}
item = this.lookupIP(iputils.ToBytes(ipString))
if item != nil {
found = true
return
}
}
return
}
func (this *IPList) SetDeleted() {
this.isDeleted = true
}
func (this *IPList) SortedRangeItems() []*IPItem {
return this.sortedRangeItems
}
func (this *IPList) IPMap() map[string]*IPItem {
return this.ipMap
}
func (this *IPList) ItemsMap() map[uint64]*IPItem {
return this.itemsMap
}
func (this *IPList) AllItemsMap() map[uint64]*IPItem {
return this.allItemsMap
}
func (this *IPList) BufferItemsMap() map[uint64]*IPItem {
return this.bufferItemsMap
}
func (this *IPList) addItem(item *IPItem, lock bool, sortable bool) {
if item == nil {
return
}
if item.ExpiredAt > 0 && item.ExpiredAt < fasttime.Now().Unix() {
return
}
var shouldSort bool
if iputils.CompareBytes(item.IPFrom, item.IPTo) == 0 {
item.IPTo = nil
}
if IsZero(item.IPFrom) && IsZero(item.IPTo) {
if item.Type != IPItemTypeAll {
return
}
} else if !IsZero(item.IPTo) {
if iputils.CompareBytes(item.IPFrom, item.IPTo) > 0 {
item.IPFrom, item.IPTo = item.IPTo, item.IPFrom
} else if IsZero(item.IPFrom) {
item.IPFrom = item.IPTo
item.IPTo = nil
}
}
if lock {
this.mu.Lock()
defer this.mu.Unlock()
}
// 是否已经存在
_, ok := this.itemsMap[item.Id]
if ok {
this.deleteItem(item.Id)
}
this.itemsMap[item.Id] = item
// 展开
if item.Type == IPItemTypeAll {
this.allItemsMap[item.Id] = item
} else if !IsZero(item.IPFrom) {
if !IsZero(item.IPTo) {
this.sortedRangeItems = append(this.sortedRangeItems, item)
shouldSort = true
} else {
this.ipMap[ToHex(item.IPFrom)] = item
}
}
if item.ExpiredAt > 0 {
this.expireList.Add(item.Id, item.ExpiredAt)
}
if shouldSort && sortable {
this.sortRangeItems(true)
}
}
// 对列表进行排序
func (this *IPList) sortRangeItems(force bool) {
if len(this.bufferItemsMap) > 0 {
for _, item := range this.bufferItemsMap {
this.addItem(item, false, false)
}
this.bufferItemsMap = map[uint64]*IPItem{}
force = true
}
if force {
sort.Slice(this.sortedRangeItems, func(i, j int) bool {
var item1 = this.sortedRangeItems[i]
var item2 = this.sortedRangeItems[j]
if iputils.CompareBytes(item1.IPFrom, item2.IPFrom) == 0 {
return iputils.CompareBytes(item1.IPTo, item2.IPTo) < 0
}
return iputils.CompareBytes(item1.IPFrom, item2.IPFrom) < 0
})
}
}
// 不加锁的情况下查找Item
func (this *IPList) lookupIP(ipBytes []byte) *IPItem {
{
item, ok := this.ipMap[ToHex(ipBytes)]
if ok && (item.ExpiredAt == 0 || item.ExpiredAt > fasttime.Now().Unix()) {
return item
}
}
if len(this.sortedRangeItems) == 0 {
return nil
}
var count = len(this.sortedRangeItems)
var resultIndex = -1
sort.Search(count, func(i int) bool {
var item = this.sortedRangeItems[i]
var cmp = iputils.CompareBytes(item.IPFrom, ipBytes)
if cmp < 0 {
if iputils.CompareBytes(item.IPTo, ipBytes) >= 0 {
resultIndex = i
}
return false
} else if cmp == 0 {
resultIndex = i
return false
}
return true
})
if resultIndex < 0 || resultIndex >= count {
return nil
}
var item = this.sortedRangeItems[resultIndex]
if item.ExpiredAt == 0 || item.ExpiredAt > fasttime.Now().Unix() {
return item
}
return nil
}
// 在不加锁的情况下删除某个Item
// 将会被别的方法引用,切记不能加锁
func (this *IPList) deleteItem(itemId uint64) {
// 从buffer中删除
delete(this.bufferItemsMap, itemId)
// 从all items中删除
_, ok := this.allItemsMap[itemId]
if ok {
delete(this.allItemsMap, itemId)
}
// 检查是否存在
oldItem, existsOld := this.itemsMap[itemId]
if !existsOld {
return
}
// 从ipMap中删除
if IsZero(oldItem.IPTo) {
var ipHex = ToHex(oldItem.IPFrom)
ipItem, ok := this.ipMap[ipHex]
if ok && ipItem.Id == itemId {
delete(this.ipMap, ipHex)
}
}
delete(this.itemsMap, itemId)
// 删除排序中的Item
if !IsZero(oldItem.IPTo) {
var index = -1
for itemIndex, item := range this.sortedRangeItems {
if item.Id == itemId {
index = itemIndex
break
}
}
if index >= 0 {
copy(this.sortedRangeItems[index:], this.sortedRangeItems[index+1:])
this.sortedRangeItems = this.sortedRangeItems[:len(this.sortedRangeItems)-1]
}
}
}

View File

@@ -0,0 +1,14 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package iplibrary
import "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
type IPListDB interface {
Name() string
DeleteExpiredItems() error
ReadMaxVersion() (int64, error)
UpdateMaxVersion(version int64) error
ReadItems(offset int64, size int64) (items []*pb.IPItem, goNext bool, err error)
AddItem(item *pb.IPItem) error
}

View File

@@ -0,0 +1,233 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package iplibrary
import (
"encoding/binary"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
"github.com/TeaOSLab/EdgeNode/internal/utils/goman"
"github.com/TeaOSLab/EdgeNode/internal/utils/idles"
"github.com/TeaOSLab/EdgeNode/internal/utils/kvstore"
"testing"
"time"
)
type KVIPList struct {
ipTable *kvstore.Table[*pb.IPItem]
versionsTable *kvstore.Table[int64]
encoder *IPItemEncoder[*pb.IPItem]
cleanTicker *time.Ticker
isClosed bool
offsetItemKey string
}
func NewKVIPList() (*KVIPList, error) {
var db = &KVIPList{
cleanTicker: time.NewTicker(24 * time.Hour),
encoder: &IPItemEncoder[*pb.IPItem]{},
}
err := db.init()
return db, err
}
func (this *KVIPList) init() error {
store, storeErr := kvstore.DefaultStore()
if storeErr != nil {
return storeErr
}
db, dbErr := store.NewDB("ip_list")
if dbErr != nil {
return dbErr
}
{
table, err := kvstore.NewTable[*pb.IPItem]("ip_items", this.encoder)
if err != nil {
return err
}
this.ipTable = table
err = table.AddFields("expiresAt")
if err != nil {
return err
}
db.AddTable(table)
}
{
table, err := kvstore.NewTable[int64]("versions", kvstore.NewIntValueEncoder[int64]())
if err != nil {
return err
}
this.versionsTable = table
db.AddTable(table)
}
goman.New(func() {
events.OnClose(func() {
_ = this.Close()
this.cleanTicker.Stop()
})
idles.RunTicker(this.cleanTicker, func() {
if this.isClosed {
return
}
deleteErr := this.DeleteExpiredItems()
if deleteErr != nil {
remotelogs.Error("IP_LIST_DB", "clean expired items failed: "+deleteErr.Error())
}
})
})
return nil
}
// Name 数据库名称代号
func (this *KVIPList) Name() string {
return "kvstore"
}
// DeleteExpiredItems 删除过期的条目
func (this *KVIPList) DeleteExpiredItems() error {
if this.isClosed {
return nil
}
for {
var found bool
var currentTime = fasttime.Now().Unix()
err := this.ipTable.
Query().
FieldAsc("expiresAt").
ForUpdate().
Limit(1000).
FindAll(func(tx *kvstore.Tx[*pb.IPItem], item kvstore.Item[*pb.IPItem]) (goNext bool, err error) {
if !item.Value.IsDeleted && item.Value.ExpiredAt == 0 { // never expires
return kvstore.Skip()
}
if item.Value.ExpiredAt < currentTime-7*86400 /** keep for 7 days **/ {
err = tx.Delete(item.Key)
if err != nil {
return false, err
}
found = true
return true, nil
}
found = false
return false, nil
})
if err != nil {
return err
}
if !found {
break
}
}
return nil
}
func (this *KVIPList) AddItem(item *pb.IPItem) error {
if this.isClosed {
return nil
}
// 先删除
var key = this.encoder.EncodeKey(item)
err := this.ipTable.Delete(key)
if err != nil {
return err
}
// 如果是删除,则不再创建新记录
if item.IsDeleted {
return this.UpdateMaxVersion(item.Version)
}
err = this.ipTable.Set(key, item)
if err != nil {
return err
}
return this.UpdateMaxVersion(item.Version)
}
func (this *KVIPList) ReadItems(offset int64, size int64) (items []*pb.IPItem, goNextLoop bool, err error) {
if this.isClosed {
return
}
err = this.ipTable.
Query().
Offset(this.offsetItemKey).
Limit(int(size)).
FindAll(func(tx *kvstore.Tx[*pb.IPItem], item kvstore.Item[*pb.IPItem]) (goNext bool, err error) {
this.offsetItemKey = item.Key
goNextLoop = true
if !item.Value.IsDeleted {
items = append(items, item.Value)
}
return true, nil
})
return
}
// ReadMaxVersion 读取当前最大版本号
func (this *KVIPList) ReadMaxVersion() (int64, error) {
if this.isClosed {
return 0, errors.New("database has been closed")
}
version, err := this.versionsTable.Get("version")
if err != nil {
if kvstore.IsNotFound(err) {
return 0, nil
}
return 0, err
}
return version, nil
}
// UpdateMaxVersion 修改版本号
func (this *KVIPList) UpdateMaxVersion(version int64) error {
if this.isClosed {
return nil
}
return this.versionsTable.Set("version", version)
}
func (this *KVIPList) TestInspect(t *testing.T) error {
return this.ipTable.
Query().
FindAll(func(tx *kvstore.Tx[*pb.IPItem], item kvstore.Item[*pb.IPItem]) (goNext bool, err error) {
if len(item.Key) != 8 {
return false, errors.New("invalid key '" + item.Key + "'")
}
t.Log(binary.BigEndian.Uint64([]byte(item.Key)), "=>", item.Value)
return true, nil
})
}
// Flush to disk
func (this *KVIPList) Flush() error {
return this.ipTable.DB().Store().Flush()
}
func (this *KVIPList) Close() error {
this.isClosed = true
return nil
}

View File

@@ -0,0 +1,55 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package iplibrary
import (
"encoding/binary"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"google.golang.org/protobuf/proto"
"math"
)
type IPItemEncoder[T interface{ *pb.IPItem }] struct {
}
func NewIPItemEncoder[T interface{ *pb.IPItem }]() *IPItemEncoder[T] {
return &IPItemEncoder[T]{}
}
func (this *IPItemEncoder[T]) Encode(value T) ([]byte, error) {
return proto.Marshal(any(value).(*pb.IPItem))
}
func (this *IPItemEncoder[T]) EncodeField(value T, fieldName string) ([]byte, error) {
switch fieldName {
case "expiresAt":
var expiresAt = any(value).(*pb.IPItem).ExpiredAt
if expiresAt < 0 || expiresAt > int64(math.MaxUint32) {
expiresAt = 0
}
var b = make([]byte, 4)
binary.BigEndian.PutUint32(b, uint32(expiresAt))
return b, nil
}
return nil, errors.New("field '" + fieldName + "' not found")
}
func (this *IPItemEncoder[T]) Decode(valueBytes []byte) (value T, err error) {
var item = &pb.IPItem{}
err = proto.Unmarshal(valueBytes, item)
value = item
return
}
// EncodeKey generate key for ip item
func (this *IPItemEncoder[T]) EncodeKey(item *pb.IPItem) string {
var b = make([]byte, 8)
if item.Id < 0 {
item.Id = 0
}
binary.BigEndian.PutUint64(b, uint64(item.Id))
return string(b)
}

View File

@@ -0,0 +1,221 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package iplibrary_test
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"github.com/TeaOSLab/EdgeNode/internal/utils/zero"
"testing"
"time"
)
func TestKVIPList_AddItem(t *testing.T) {
kv, err := iplibrary.NewKVIPList()
if err != nil {
t.Fatal(err)
}
defer func() {
_ = kv.Flush()
}()
{
err = kv.AddItem(&pb.IPItem{
Id: 1,
IpFrom: "192.168.1.101",
IpTo: "",
Version: 1,
ExpiredAt: fasttime.NewFastTime().Unix() + 60,
ListId: 1,
IsDeleted: false,
ListType: "white",
})
if err != nil {
t.Fatal(err)
}
}
{
err = kv.AddItem(&pb.IPItem{
Id: 2,
IpFrom: "192.168.1.102",
IpTo: "",
Version: 2,
ExpiredAt: fasttime.NewFastTime().Unix() + 60,
ListId: 1,
IsDeleted: false,
ListType: "white",
})
if err != nil {
t.Fatal(err)
}
}
{
err = kv.AddItem(&pb.IPItem{
Id: 3,
IpFrom: "192.168.1.103",
IpTo: "",
Version: 3,
ExpiredAt: fasttime.NewFastTime().Unix() + 60,
ListId: 1,
IsDeleted: false,
ListType: "white",
})
if err != nil {
t.Fatal(err)
}
}
}
func TestKVIPList_AddItems_Many(t *testing.T) {
kv, err := iplibrary.NewKVIPList()
if err != nil {
t.Fatal(err)
}
defer func() {
_ = kv.Flush()
}()
var count = 2
var from = 1
if testutils.IsSingleTesting() {
count = 2_000_000
}
var before = time.Now()
defer func() {
t.Logf("cost: %.2f s", time.Since(before).Seconds())
}()
for i := from; i <= from+count; i++ {
err = kv.AddItem(&pb.IPItem{
Id: int64(i),
IpFrom: testutils.RandIP(),
IpTo: "",
Version: int64(i),
ExpiredAt: fasttime.NewFastTime().Unix() + 86400,
ListId: 1,
IsDeleted: false,
ListType: "white",
})
if err != nil {
t.Fatal(err)
}
}
}
func TestKVIPList_DeleteExpiredItems(t *testing.T) {
kv, err := iplibrary.NewKVIPList()
if err != nil {
t.Fatal(err)
}
defer func() {
_ = kv.Flush()
}()
err = kv.DeleteExpiredItems()
if err != nil {
t.Fatal(err)
}
}
func TestKVIPList_UpdateMaxVersion(t *testing.T) {
kv, err := iplibrary.NewKVIPList()
if err != nil {
t.Fatal(err)
}
defer func() {
_ = kv.Flush()
}()
err = kv.UpdateMaxVersion(101)
if err != nil {
t.Fatal(err)
}
maxVersion, err := kv.ReadMaxVersion()
if err != nil {
t.Fatal(err)
}
t.Log("version:", maxVersion)
}
func TestKVIPList_ReadMaxVersion(t *testing.T) {
kv, err := iplibrary.NewKVIPList()
if err != nil {
t.Fatal(err)
}
maxVersion, err := kv.ReadMaxVersion()
if err != nil {
t.Fatal(err)
}
t.Log("version:", maxVersion)
}
func TestKVIPList_ReadItems(t *testing.T) {
kv, err := iplibrary.NewKVIPList()
if err != nil {
t.Fatal(err)
}
for {
items, goNext, readErr := kv.ReadItems(0, 2)
if readErr != nil {
t.Fatal(readErr)
}
t.Log("====")
for _, item := range items {
t.Log(item.Id)
}
if !goNext {
break
}
}
}
func TestKVIPList_CountItems(t *testing.T) {
kv, err := iplibrary.NewKVIPList()
if err != nil {
t.Fatal(err)
}
var count int
var m = map[int64]zero.Zero{}
for {
items, goNext, readErr := kv.ReadItems(0, 1000)
if readErr != nil {
t.Fatal(readErr)
}
for _, item := range items {
count++
m[item.Id] = zero.Zero{}
}
if !goNext {
break
}
}
t.Log("count:", count, "len:", len(m))
}
func TestKVIPList_Inspect(t *testing.T) {
kv, err := iplibrary.NewKVIPList()
if err != nil {
t.Fatal(err)
}
err = kv.TestInspect(t)
if err != nil {
t.Fatal(err)
}
}

View File

@@ -0,0 +1,313 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package iplibrary
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils/dbs"
"github.com/TeaOSLab/EdgeNode/internal/utils/goman"
"github.com/TeaOSLab/EdgeNode/internal/utils/idles"
"github.com/iwind/TeaGo/Tea"
"os"
"path/filepath"
"time"
)
type SQLiteIPList struct {
db *dbs.DB
itemTableName string
versionTableName string
deleteExpiredItemsStmt *dbs.Stmt
deleteItemStmt *dbs.Stmt
insertItemStmt *dbs.Stmt
selectItemsStmt *dbs.Stmt
selectMaxItemVersionStmt *dbs.Stmt
selectVersionStmt *dbs.Stmt
updateVersionStmt *dbs.Stmt
cleanTicker *time.Ticker
dir string
isClosed bool
}
func NewSQLiteIPList() (*SQLiteIPList, error) {
var db = &SQLiteIPList{
itemTableName: "ipItems",
versionTableName: "versions",
dir: filepath.Clean(Tea.Root + "/data"),
cleanTicker: time.NewTicker(24 * time.Hour),
}
err := db.init()
return db, err
}
func (this *SQLiteIPList) init() error {
// 检查目录是否存在
_, err := os.Stat(this.dir)
if err != nil {
err = os.MkdirAll(this.dir, 0777)
if err != nil {
return err
}
remotelogs.Println("IP_LIST_DB", "create data dir '"+this.dir+"'")
}
var path = this.dir + "/ip_list.db"
db, err := dbs.OpenWriter("file:" + path + "?cache=shared&mode=rwc&_journal_mode=WAL&_sync=" + dbs.SyncMode + "&_locking_mode=EXCLUSIVE")
if err != nil {
return err
}
db.SetMaxOpenConns(1)
//_, err = db.Exec("VACUUM")
//if err != nil {
// return err
//}
this.db = db
// 恢复数据库
var recoverEnv, _ = os.LookupEnv("EdgeRecover")
if len(recoverEnv) > 0 {
for _, indexName := range []string{"ip_list_itemId", "ip_list_expiredAt"} {
_, _ = db.Exec(`REINDEX "` + indexName + `"`)
}
}
// 初始化数据库
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + this.itemTableName + `" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"listId" integer DEFAULT 0,
"listType" varchar(32),
"isGlobal" integer(1) DEFAULT 0,
"type" varchar(16),
"itemId" integer DEFAULT 0,
"ipFrom" varchar(64) DEFAULT 0,
"ipTo" varchar(64) DEFAULT 0,
"expiredAt" integer DEFAULT 0,
"eventLevel" varchar(32),
"isDeleted" integer(1) DEFAULT 0,
"version" integer DEFAULT 0,
"nodeId" integer DEFAULT 0,
"serverId" integer DEFAULT 0
);
CREATE INDEX IF NOT EXISTS "ip_list_itemId"
ON "` + this.itemTableName + `" (
"itemId" ASC
);
CREATE INDEX IF NOT EXISTS "ip_list_expiredAt"
ON "` + this.itemTableName + `" (
"expiredAt" ASC
);
`)
if err != nil {
return err
}
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS "` + this.versionTableName + `" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"version" integer DEFAULT 0
);
`)
if err != nil {
return err
}
// 初始化SQL语句
this.deleteExpiredItemsStmt, err = this.db.Prepare(`DELETE FROM "` + this.itemTableName + `" WHERE "expiredAt">0 AND "expiredAt"<?`)
if err != nil {
return err
}
this.deleteItemStmt, err = this.db.Prepare(`DELETE FROM "` + this.itemTableName + `" WHERE "itemId"=?`)
if err != nil {
return err
}
this.insertItemStmt, err = this.db.Prepare(`INSERT INTO "` + this.itemTableName + `" ("listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
if err != nil {
return err
}
this.selectItemsStmt, err = this.db.Prepare(`SELECT "listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId" FROM "` + this.itemTableName + `" WHERE isDeleted=0 ORDER BY "version" ASC, "itemId" ASC LIMIT ?, ?`)
if err != nil {
return err
}
this.selectMaxItemVersionStmt, err = this.db.Prepare(`SELECT "version" FROM "` + this.itemTableName + `" ORDER BY "id" DESC LIMIT 1`)
if err != nil {
return err
}
this.selectVersionStmt, err = this.db.Prepare(`SELECT "version" FROM "` + this.versionTableName + `" LIMIT 1`)
if err != nil {
return err
}
this.updateVersionStmt, err = this.db.Prepare(`REPLACE INTO "` + this.versionTableName + `" ("id", "version") VALUES (1, ?)`)
if err != nil {
return err
}
this.db = db
goman.New(func() {
events.OnClose(func() {
_ = this.Close()
this.cleanTicker.Stop()
})
idles.RunTicker(this.cleanTicker, func() {
deleteErr := this.DeleteExpiredItems()
if deleteErr != nil {
remotelogs.Error("IP_LIST_DB", "clean expired items failed: "+deleteErr.Error())
}
})
})
return nil
}
// Name 数据库名称代号
func (this *SQLiteIPList) Name() string {
return "sqlite"
}
// DeleteExpiredItems 删除过期的条目
func (this *SQLiteIPList) DeleteExpiredItems() error {
if this.isClosed {
return nil
}
_, err := this.deleteExpiredItemsStmt.Exec(time.Now().Unix() - 7*86400)
return err
}
func (this *SQLiteIPList) AddItem(item *pb.IPItem) error {
if this.isClosed {
return nil
}
_, err := this.deleteItemStmt.Exec(item.Id)
if err != nil {
return err
}
// 如果是删除,则不再创建新记录
if item.IsDeleted {
return this.UpdateMaxVersion(item.Version)
}
_, err = this.insertItemStmt.Exec(item.ListId, item.ListType, item.IsGlobal, item.Type, item.Id, item.IpFrom, item.IpTo, item.ExpiredAt, item.EventLevel, item.IsDeleted, item.Version, item.NodeId, item.ServerId)
if err != nil {
return err
}
return this.UpdateMaxVersion(item.Version)
}
func (this *SQLiteIPList) ReadItems(offset int64, size int64) (items []*pb.IPItem, goNext bool, err error) {
if this.isClosed {
return
}
rows, err := this.selectItemsStmt.Query(offset, size)
if err != nil {
return nil, false, err
}
defer func() {
_ = rows.Close()
}()
for rows.Next() {
// "listId", "listType", "isGlobal", "type", "itemId", "ipFrom", "ipTo", "expiredAt", "eventLevel", "isDeleted", "version", "nodeId", "serverId"
var pbItem = &pb.IPItem{}
err = rows.Scan(&pbItem.ListId, &pbItem.ListType, &pbItem.IsGlobal, &pbItem.Type, &pbItem.Id, &pbItem.IpFrom, &pbItem.IpTo, &pbItem.ExpiredAt, &pbItem.EventLevel, &pbItem.IsDeleted, &pbItem.Version, &pbItem.NodeId, &pbItem.ServerId)
if err != nil {
return nil, false, err
}
items = append(items, pbItem)
}
goNext = int64(len(items)) == size
return
}
// ReadMaxVersion 读取当前最大版本号
func (this *SQLiteIPList) ReadMaxVersion() (int64, error) {
if this.isClosed {
return 0, nil
}
// from version table
{
var row = this.selectVersionStmt.QueryRow()
if row == nil {
return 0, nil
}
var version int64
err := row.Scan(&version)
if err == nil {
return version, nil
}
}
// from items table
{
var row = this.selectMaxItemVersionStmt.QueryRow()
if row == nil {
return 0, nil
}
var version int64
err := row.Scan(&version)
if err != nil {
return 0, nil
}
return version, nil
}
}
// UpdateMaxVersion 修改版本号
func (this *SQLiteIPList) UpdateMaxVersion(version int64) error {
if this.isClosed {
return nil
}
_, err := this.updateVersionStmt.Exec(version)
return err
}
func (this *SQLiteIPList) Close() error {
this.isClosed = true
if this.db != nil {
for _, stmt := range []*dbs.Stmt{
this.deleteExpiredItemsStmt,
this.deleteItemStmt,
this.insertItemStmt,
this.selectItemsStmt,
this.selectMaxItemVersionStmt, // ipItems table
this.selectVersionStmt, // versions table
this.updateVersionStmt,
} {
if stmt != nil {
_ = stmt.Close()
}
}
return this.db.Close()
}
return nil
}

View File

@@ -0,0 +1,107 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package iplibrary_test
import (
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
_ "github.com/iwind/TeaGo/bootstrap"
"github.com/iwind/TeaGo/logs"
"testing"
"time"
)
func TestSQLiteIPList_AddItem(t *testing.T) {
db, err := iplibrary.NewSQLiteIPList()
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
err = db.AddItem(&pb.IPItem{
Id: 1,
IpFrom: "192.168.1.101",
IpTo: "",
Version: 1024,
ExpiredAt: time.Now().Unix() + 3600,
Reason: "",
ListId: 2,
IsDeleted: false,
Type: "ipv4",
EventLevel: "error",
ListType: "black",
IsGlobal: true,
CreatedAt: 0,
NodeId: 11,
ServerId: 22,
SourceNodeId: 0,
SourceServerId: 0,
SourceHTTPFirewallPolicyId: 0,
SourceHTTPFirewallRuleGroupId: 0,
SourceHTTPFirewallRuleSetId: 0,
SourceServer: nil,
SourceHTTPFirewallPolicy: nil,
SourceHTTPFirewallRuleGroup: nil,
SourceHTTPFirewallRuleSet: nil,
})
if err != nil {
t.Fatal(err)
}
err = db.Close()
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestSQLiteIPList_ReadItems(t *testing.T) {
db, err := iplibrary.NewSQLiteIPList()
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
defer func() {
_ = db.Close()
}()
items, goNext, err := db.ReadItems(0, 2)
if err != nil {
t.Fatal(err)
}
t.Log("goNext:", goNext)
logs.PrintAsJSON(items, t)
}
func TestSQLiteIPList_ReadMaxVersion(t *testing.T) {
db, err := iplibrary.NewSQLiteIPList()
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
t.Log(db.ReadMaxVersion())
}
func TestSQLiteIPList_UpdateMaxVersion(t *testing.T) {
db, err := iplibrary.NewSQLiteIPList()
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
err = db.UpdateMaxVersion(1027)
if err != nil {
t.Fatal(err)
}
t.Log(db.ReadMaxVersion())
}

View File

@@ -0,0 +1,495 @@
package iplibrary_test
import (
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/iputils"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"github.com/iwind/TeaGo/assert"
"github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/rands"
"math/rand"
"runtime"
"runtime/debug"
"strconv"
"sync"
"testing"
"time"
)
func TestIPList_Add_Empty(t *testing.T) {
var ipList = iplibrary.NewIPList()
ipList.Add(&iplibrary.IPItem{
Id: 1,
})
logs.PrintAsJSON(ipList.ItemsMap(), t)
logs.PrintAsJSON(ipList.AllItemsMap(), t)
logs.PrintAsJSON(ipList.IPMap(), t)
}
func TestIPList_Add_One(t *testing.T) {
var a = assert.NewAssertion(t)
var ipList = iplibrary.NewIPList()
ipList.Add(&iplibrary.IPItem{
Id: 1,
IPFrom: iputils.ToBytes("192.168.1.1"),
})
ipList.Add(&iplibrary.IPItem{
Id: 2,
IPTo: iputils.ToBytes("192.168.1.2"),
})
ipList.Add(&iplibrary.IPItem{
Id: 3,
IPFrom: iputils.ToBytes("192.168.0.2"),
})
ipList.Add(&iplibrary.IPItem{
Id: 4,
IPFrom: iputils.ToBytes("192.168.0.2"),
IPTo: iputils.ToBytes("192.168.0.1"),
})
ipList.Add(&iplibrary.IPItem{
Id: 5,
IPFrom: iputils.ToBytes("2001:db8:0:1::101"),
})
ipList.Add(&iplibrary.IPItem{
Id: 6,
IPFrom: nil,
Type: "all",
})
t.Log("===items===")
logs.PrintAsJSON(ipList.ItemsMap(), t)
t.Log("===sorted items===")
logs.PrintAsJSON(ipList.SortedRangeItems(), t)
t.Log("===all items===")
a.IsTrue(len(ipList.AllItemsMap()) == 1)
logs.PrintAsJSON(ipList.AllItemsMap(), t) // ip => items
t.Log("===ip items===")
logs.PrintAsJSON(ipList.IPMap())
}
func TestIPList_Update(t *testing.T) {
var ipList = iplibrary.NewIPList()
ipList.Add(&iplibrary.IPItem{
Id: 1,
IPFrom: iputils.ToBytes("192.168.1.1"),
})
t.Log("===before===")
logs.PrintAsJSON(ipList.ItemsMap(), t)
logs.PrintAsJSON(ipList.SortedRangeItems(), t)
logs.PrintAsJSON(ipList.IPMap(), t)
/**ipList.Add(&iplibrary.IPItem{
Id: 2,
IPFrom: iputils.ToBytes("192.168.1.1"),
})**/
ipList.Add(&iplibrary.IPItem{
Id: 1,
//IPFrom: 123,
IPTo: iputils.ToBytes("192.168.1.2"),
})
t.Log("===after===")
logs.PrintAsJSON(ipList.ItemsMap(), t)
logs.PrintAsJSON(ipList.SortedRangeItems(), t)
logs.PrintAsJSON(ipList.IPMap(), t)
}
func TestIPList_Update_AllItems(t *testing.T) {
var ipList = iplibrary.NewIPList()
ipList.Add(&iplibrary.IPItem{
Id: 1,
Type: iplibrary.IPItemTypeAll,
IPFrom: nil,
})
ipList.Add(&iplibrary.IPItem{
Id: 1,
IPTo: nil,
})
t.Log("===items map===")
logs.PrintAsJSON(ipList.ItemsMap(), t)
t.Log("===all items map===")
logs.PrintAsJSON(ipList.AllItemsMap(), t)
t.Log("===ip map===")
logs.PrintAsJSON(ipList.IPMap())
}
func TestIPList_Add_Range(t *testing.T) {
var a = assert.NewAssertion(t)
var ipList = iplibrary.NewIPList()
ipList.Add(&iplibrary.IPItem{
Id: 1,
IPFrom: iputils.ToBytes("192.168.1.1"),
IPTo: iputils.ToBytes("192.168.2.1"),
})
ipList.Add(&iplibrary.IPItem{
Id: 2,
IPTo: iputils.ToBytes("192.168.1.2"),
})
ipList.Add(&iplibrary.IPItem{
Id: 3,
IPFrom: iputils.ToBytes("192.168.0.1"),
IPTo: iputils.ToBytes("192.168.0.2"),
})
a.IsTrue(len(ipList.SortedRangeItems()) == 2)
t.Log(len(ipList.ItemsMap()), "ips")
t.Log("===items map===")
logs.PrintAsJSON(ipList.ItemsMap(), t)
t.Log("===sorted range items===")
logs.PrintAsJSON(ipList.SortedRangeItems())
t.Log("===all items map===")
logs.PrintAsJSON(ipList.AllItemsMap(), t)
t.Log("===ip map===")
logs.PrintAsJSON(ipList.IPMap(), t)
}
func TestNewIPList_Memory(t *testing.T) {
var list = iplibrary.NewIPList()
var count = 100
if testutils.IsSingleTesting() {
count = 2_000_000
}
var stat1 = testutils.ReadMemoryStat()
for i := 0; i < count; i++ {
list.AddDelay(&iplibrary.IPItem{
Id: uint64(i),
IPFrom: iputils.ToBytes(testutils.RandIP()),
IPTo: iputils.ToBytes(testutils.RandIP()),
ExpiredAt: time.Now().Unix(),
})
}
list.Sort()
runtime.GC()
var stat2 = testutils.ReadMemoryStat()
t.Log((stat2.HeapInuse-stat1.HeapInuse)>>20, "MB")
}
func TestIPList_Contains(t *testing.T) {
var a = assert.NewAssertion(t)
var list = iplibrary.NewIPList()
for i := 0; i < 255; i++ {
list.Add(&iplibrary.IPItem{
Id: uint64(i),
IPFrom: iputils.ToBytes(strconv.Itoa(i) + ".168.0.1"),
IPTo: iputils.ToBytes(strconv.Itoa(i) + ".168.255.1"),
ExpiredAt: 0,
})
}
for i := 0; i < 255; i++ {
list.Add(&iplibrary.IPItem{
Id: uint64(1000 + i),
IPFrom: iputils.ToBytes("192.167.2." + strconv.Itoa(i)),
})
}
list.Add(&iplibrary.IPItem{
Id: 10000,
IPFrom: iputils.ToBytes("::1"),
})
list.Add(&iplibrary.IPItem{
Id: 10001,
IPFrom: iputils.ToBytes("::2"),
IPTo: iputils.ToBytes("::5"),
})
t.Log(len(list.ItemsMap()), "ip")
var before = time.Now()
a.IsTrue(list.Contains(iputils.ToBytes("192.168.1.100")))
a.IsTrue(list.Contains(iputils.ToBytes("192.168.2.100")))
a.IsFalse(list.Contains(iputils.ToBytes("192.169.3.100")))
a.IsFalse(list.Contains(iputils.ToBytes("192.167.3.100")))
a.IsTrue(list.Contains(iputils.ToBytes("192.167.2.100")))
a.IsTrue(list.Contains(iputils.ToBytes("::1")))
a.IsTrue(list.Contains(iputils.ToBytes("::3")))
a.IsFalse(list.Contains(iputils.ToBytes("::8")))
t.Log(time.Since(before).Seconds()*1000, "ms")
}
func TestIPList_Contains_Many(t *testing.T) {
var list = iplibrary.NewIPList()
for i := 0; i < 1_000_000; i++ {
list.AddDelay(&iplibrary.IPItem{
Id: uint64(i),
IPFrom: iputils.ToBytes(strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255))),
IPTo: iputils.ToBytes(strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255))),
ExpiredAt: 0,
})
}
var before = time.Now()
list.Sort()
t.Log("sort cost:", time.Since(before).Seconds()*1000, "ms")
t.Log(len(list.ItemsMap()), "ip")
before = time.Now()
_ = list.Contains(iputils.ToBytes("192.168.1.100"))
t.Log("contains cost:", time.Since(before).Seconds()*1000, "ms")
}
func TestIPList_ContainsAll(t *testing.T) {
var a = assert.NewAssertion(t)
{
var list = iplibrary.NewIPList()
list.Add(&iplibrary.IPItem{
Id: 1,
Type: "all",
IPFrom: nil,
})
var b = list.Contains(iputils.ToBytes("192.168.1.1"))
a.IsTrue(b)
list.Delete(1)
b = list.Contains(iputils.ToBytes("192.168.1.1"))
a.IsFalse(b)
}
{
var list = iplibrary.NewIPList()
list.Add(&iplibrary.IPItem{
Id: 1,
Type: "all",
IPFrom: iputils.ToBytes("0.0.0.0"),
})
var b = list.Contains(iputils.ToBytes("192.168.1.1"))
a.IsTrue(b)
list.Delete(1)
b = list.Contains(iputils.ToBytes("192.168.1.1"))
a.IsFalse(b)
}
}
func TestIPList_ContainsIPStrings(t *testing.T) {
var a = assert.NewAssertion(t)
var list = iplibrary.NewIPList()
for i := 0; i < 255; i++ {
list.Add(&iplibrary.IPItem{
Id: uint64(i),
IPFrom: iputils.ToBytes(strconv.Itoa(i) + ".168.0.1"),
IPTo: iputils.ToBytes(strconv.Itoa(i) + ".168.255.1"),
ExpiredAt: 0,
})
}
t.Log(len(list.ItemsMap()), "ip")
{
item, ok := list.ContainsIPStrings([]string{"192.168.1.100"})
t.Log("item:", item)
a.IsTrue(ok)
}
{
item, ok := list.ContainsIPStrings([]string{"192.167.1.100"})
t.Log("item:", item)
a.IsFalse(ok)
}
}
func TestIPList_Delete(t *testing.T) {
var list = iplibrary.NewIPList()
list.Add(&iplibrary.IPItem{
Id: 1,
IPFrom: iputils.ToBytes("192.168.0.1"),
ExpiredAt: 0,
})
list.Add(&iplibrary.IPItem{
Id: 2,
IPFrom: iputils.ToBytes("192.168.0.1"),
ExpiredAt: 0,
})
list.Add(&iplibrary.IPItem{
Id: 3,
IPFrom: iputils.ToBytes("192.168.1.1"),
IPTo: iputils.ToBytes("192.168.2.1"),
ExpiredAt: 0,
})
t.Log("===before===")
logs.PrintAsJSON(list.ItemsMap(), t)
logs.PrintAsJSON(list.AllItemsMap(), t)
logs.PrintAsJSON(list.SortedRangeItems())
logs.PrintAsJSON(list.IPMap(), t)
{
var found bool
for _, item := range list.SortedRangeItems() {
if item.Id == 3 {
found = true
break
}
}
if !found {
t.Fatal("should be found")
}
}
list.Delete(1)
t.Log("===after===")
logs.PrintAsJSON(list.ItemsMap(), t)
logs.PrintAsJSON(list.AllItemsMap(), t)
logs.PrintAsJSON(list.SortedRangeItems())
logs.PrintAsJSON(list.IPMap(), t)
list.Delete(3)
{
var found bool
for _, item := range list.SortedRangeItems() {
if item.Id == 3 {
found = true
break
}
}
if found {
t.Fatal("should be not found")
}
}
}
func TestIPList_GC(t *testing.T) {
var a = assert.NewAssertion(t)
var list = iplibrary.NewIPList()
list.Add(&iplibrary.IPItem{
Id: 1,
IPFrom: iputils.ToBytes("192.168.1.100"),
IPTo: iputils.ToBytes("192.168.1.101"),
ExpiredAt: time.Now().Unix() + 1,
})
list.Add(&iplibrary.IPItem{
Id: 2,
IPFrom: iputils.ToBytes("192.168.1.102"),
IPTo: iputils.ToBytes("192.168.1.103"),
ExpiredAt: 0,
})
logs.PrintAsJSON(list.ItemsMap(), t)
logs.PrintAsJSON(list.AllItemsMap(), t)
time.Sleep(3 * time.Second)
t.Log("===AFTER GC===")
logs.PrintAsJSON(list.ItemsMap(), t)
logs.PrintAsJSON(list.SortedRangeItems(), t)
a.IsTrue(len(list.ItemsMap()) == 1)
a.IsTrue(len(list.SortedRangeItems()) == 1)
}
func TestManyLists(t *testing.T) {
debug.SetMaxThreads(20)
var lists = []*iplibrary.IPList{}
var locker = &sync.Mutex{}
for i := 0; i < 1000; i++ {
locker.Lock()
lists = append(lists, iplibrary.NewIPList())
locker.Unlock()
}
if testutils.IsSingleTesting() {
time.Sleep(3 * time.Second)
}
t.Log(runtime.NumGoroutine())
t.Log(len(lists), "lists")
}
func BenchmarkIPList_Add(b *testing.B) {
runtime.GOMAXPROCS(1)
var list = iplibrary.NewIPList()
for i := 1; i < 200_000; i++ {
list.AddDelay(&iplibrary.IPItem{
Id: uint64(i),
IPFrom: iputils.ToBytes(strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + ".0.1"),
IPTo: iputils.ToBytes(strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + ".0.1"),
ExpiredAt: time.Now().Unix() + 60,
})
}
list.Sort()
b.Log(len(list.ItemsMap()), "ip")
b.ResetTimer()
for i := 0; i < b.N; i++ {
var ip = fmt.Sprintf("%d.%d.%d.%d", rand.Int()%255, rand.Int()%255, rand.Int()%255, rand.Int()%255)
list.Add(&iplibrary.IPItem{
Type: "",
Id: uint64(i % 1_000_000),
IPFrom: iputils.ToBytes(ip),
IPTo: nil,
ExpiredAt: fasttime.Now().Unix() + 3600,
EventLevel: "",
})
}
}
func BenchmarkIPList_Contains(b *testing.B) {
runtime.GOMAXPROCS(1)
var list = iplibrary.NewIPList()
for i := 1; i < 1_000_000; i++ {
var item = &iplibrary.IPItem{
Id: uint64(i),
IPFrom: iputils.ToBytes(strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + ".0.1"),
ExpiredAt: time.Now().Unix() + 60,
}
if i%100 == 0 {
item.IPTo = iputils.ToBytes(strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + ".0.1")
}
list.Add(item)
}
//b.Log(len(list.ItemsMap()), "ip")
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = list.Contains(iputils.ToBytes(testutils.RandIP()))
}
})
}
func BenchmarkIPList_Sort(b *testing.B) {
var list = iplibrary.NewIPList()
for i := 0; i < 1_000_000; i++ {
var item = &iplibrary.IPItem{
Id: uint64(i),
IPFrom: iputils.ToBytes(strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + ".0.1"),
ExpiredAt: time.Now().Unix() + 60,
}
if i%100 == 0 {
item.IPTo = iputils.ToBytes(strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + ".0.1")
}
list.AddDelay(item)
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
list.Sort()
}
})
}

View File

@@ -0,0 +1,8 @@
package iplibrary
type IPListType = string
const (
IPListTypeWhite IPListType = "white"
IPListTypeBlack IPListType = "black"
)

View File

@@ -0,0 +1,99 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package iplibrary
import (
"encoding/hex"
"github.com/TeaOSLab/EdgeCommon/pkg/iputils"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/iwind/TeaGo/Tea"
)
// AllowIP 检查IP是否被允许访问
// 如果一个IP不在任何名单中则允许访问
func AllowIP(ip string, serverId int64) (canGoNext bool, inAllowList bool, expiresAt int64) {
if !Tea.IsTesting() { // 如果在测试环境,我们不加入一些白名单,以便于可以在本地和局域网正常测试
// 放行lo
if ip == "127.0.0.1" || ip == "::1" {
return true, true, 0
}
// check node
nodeConfig, err := nodeconfigs.SharedNodeConfig()
if err == nil && nodeConfig.IPIsAutoAllowed(ip) {
return true, true, 0
}
}
var ipBytes = iputils.ToBytes(ip)
if IsZero(ipBytes) {
return false, false, 0
}
// check white lists
if GlobalWhiteIPList.Contains(ipBytes) {
return true, true, 0
}
if serverId > 0 {
var list = SharedServerListManager.FindWhiteList(serverId, false)
if list != nil && list.Contains(ipBytes) {
return true, true, 0
}
}
// check black lists
expiresAt, ok := GlobalBlackIPList.ContainsExpires(ipBytes)
if ok {
return false, false, expiresAt
}
if serverId > 0 {
var list = SharedServerListManager.FindBlackList(serverId, false)
if list != nil {
expiresAt, ok = list.ContainsExpires(ipBytes)
if ok {
return false, false, expiresAt
}
}
}
return true, false, 0
}
// IsInWhiteList 检查IP是否在白名单中
func IsInWhiteList(ip string) bool {
var ipBytes = iputils.ToBytes(ip)
if IsZero(ipBytes) {
return false
}
// check white lists
return GlobalWhiteIPList.Contains(ipBytes)
}
// AllowIPStrings 检查一组IP是否被允许访问
func AllowIPStrings(ipStrings []string, serverId int64) bool {
if len(ipStrings) == 0 {
return true
}
for _, ip := range ipStrings {
isAllowed, _, _ := AllowIP(ip, serverId)
if !isAllowed {
return false
}
}
return true
}
func IsZero(ipBytes []byte) bool {
return len(ipBytes) == 0
}
func ToHex(b []byte) string {
if len(b) == 0 {
return ""
}
return hex.EncodeToString(b)
}

View File

@@ -0,0 +1,25 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package iplibrary
import (
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"testing"
"time"
)
func TestIPIsAllowed(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var manager = NewIPListManager()
manager.Init()
var before = time.Now()
defer func() {
t.Log(time.Since(before).Seconds()*1000, "ms")
}()
t.Log(AllowIP("127.0.0.1", 0))
t.Log(AllowIP("127.0.0.1", 23))
}

View File

@@ -0,0 +1,354 @@
package iplibrary
import (
"github.com/TeaOSLab/EdgeCommon/pkg/iputils"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/utils/goman"
"github.com/TeaOSLab/EdgeNode/internal/utils/idles"
"github.com/TeaOSLab/EdgeNode/internal/utils/trackers"
"github.com/TeaOSLab/EdgeNode/internal/utils/zero"
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/types"
"os"
"sync"
"time"
)
var SharedIPListManager = NewIPListManager()
var IPListUpdateNotify = make(chan bool, 1)
func init() {
if !teaconst.IsMain {
return
}
events.On(events.EventLoaded, func() {
goman.New(func() {
SharedIPListManager.Start()
})
})
events.OnClose(func() {
SharedIPListManager.Stop()
})
var ticker = time.NewTicker(24 * time.Hour)
goman.New(func() {
idles.RunTicker(ticker, func() {
SharedIPListManager.DeleteExpiredItems()
})
})
}
// IPListManager IP名单管理
type IPListManager struct {
ticker *time.Ticker
db IPListDB
lastVersion int64
fetchPageSize int64
listMap map[int64]*IPList
mu sync.RWMutex
isFirstTime bool
}
func NewIPListManager() *IPListManager {
return &IPListManager{
fetchPageSize: 5_000,
listMap: map[int64]*IPList{},
isFirstTime: true,
}
}
func (this *IPListManager) Start() {
this.Init()
// 第一次读取
err := this.Loop()
if err != nil {
remotelogs.ErrorObject("IP_LIST_MANAGER", err)
}
this.ticker = time.NewTicker(60 * time.Second)
if Tea.IsTesting() {
this.ticker = time.NewTicker(10 * time.Second)
}
var countErrors = 0
for {
select {
case <-this.ticker.C:
case <-IPListUpdateNotify:
}
err = this.Loop()
if err != nil {
countErrors++
remotelogs.ErrorObject("IP_LIST_MANAGER", err)
// 连续错误小于3次的我们立即重试
if countErrors <= 3 {
select {
case IPListUpdateNotify <- true:
default:
}
}
} else {
countErrors = 0
}
}
}
func (this *IPListManager) Stop() {
if this.ticker != nil {
this.ticker.Stop()
}
}
func (this *IPListManager) Init() {
// 从数据库中当中读取数据
// 检查sqlite文件是否存在以便决定使用sqlite还是kv
var sqlitePath = Tea.Root + "/data/ip_list.db"
_, sqliteErr := os.Stat(sqlitePath)
var db IPListDB
var err error
if sqliteErr == nil || !teaconst.EnableKVCacheStore {
db, err = NewSQLiteIPList()
} else {
db, err = NewKVIPList()
}
if err != nil {
remotelogs.Error("IP_LIST_MANAGER", "create ip list local database failed: "+err.Error())
} else {
this.db = db
// 删除本地数据库中过期的条目
_ = db.DeleteExpiredItems()
// 本地数据库中最大版本号
this.lastVersion, err = db.ReadMaxVersion()
if err != nil {
remotelogs.Error("IP_LIST_MANAGER", "find max version failed: "+err.Error())
this.lastVersion = 0
}
remotelogs.Println("IP_LIST_MANAGER", "starting from '"+db.Name()+"' version '"+types.String(this.lastVersion)+"' ...")
// 从本地数据库中加载
var offset int64 = 0
var size int64 = 2_000
var tr = trackers.Begin("IP_LIST_MANAGER:load")
defer tr.End()
for {
items, goNext, readErr := db.ReadItems(offset, size)
var l = len(items)
if readErr != nil {
remotelogs.Error("IP_LIST_MANAGER", "read ip list from local database failed: "+readErr.Error())
} else {
this.processItems(items, false)
if !goNext {
break
}
}
offset += int64(l)
}
}
}
func (this *IPListManager) Loop() error {
// 是否同步IP名单
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
if nodeConfig != nil && !nodeConfig.EnableIPLists {
return nil
}
// 第一次同步则打印信息
if this.isFirstTime {
remotelogs.Println("IP_LIST_MANAGER", "initializing ip items ...")
}
for {
hasNext, err := this.fetch()
if err != nil {
return err
}
if !hasNext {
break
}
time.Sleep(1 * time.Second)
}
// 第一次同步则打印信息
if this.isFirstTime {
this.isFirstTime = false
remotelogs.Println("IP_LIST_MANAGER", "finished initializing ip items")
}
return nil
}
func (this *IPListManager) fetch() (hasNext bool, err error) {
rpcClient, err := rpc.SharedRPC()
if err != nil {
return false, err
}
itemsResp, err := rpcClient.IPItemRPC.ListIPItemsAfterVersion(rpcClient.Context(), &pb.ListIPItemsAfterVersionRequest{
Version: this.lastVersion,
Size: this.fetchPageSize,
})
if err != nil {
if rpc.IsConnError(err) {
remotelogs.Debug("IP_LIST_MANAGER", "rpc connection error: "+err.Error())
return false, nil
}
return false, err
}
// 更新版本号
defer func() {
if itemsResp.Version > this.lastVersion {
this.lastVersion = itemsResp.Version
err = this.db.UpdateMaxVersion(itemsResp.Version)
if err != nil {
remotelogs.Error("IP_LIST_MANAGER", "update max version to database: "+err.Error())
}
}
}()
var items = itemsResp.IpItems
if len(items) == 0 {
return false, nil
}
// 保存到本地数据库
if this.db != nil {
for _, item := range items {
err = this.db.AddItem(item)
if err != nil {
remotelogs.Error("IP_LIST_MANAGER", "insert item to local database failed: "+err.Error())
}
}
}
this.processItems(items, true)
return true, nil
}
func (this *IPListManager) FindList(listId int64) *IPList {
this.mu.RLock()
var list = this.listMap[listId]
this.mu.RUnlock()
return list
}
func (this *IPListManager) DeleteExpiredItems() {
if this.db != nil {
_ = this.db.DeleteExpiredItems()
}
}
func (this *IPListManager) ListMap() map[int64]*IPList {
return this.listMap
}
// 处理IP条目
func (this *IPListManager) processItems(items []*pb.IPItem, fromRemote bool) {
var changedLists = map[*IPList]zero.Zero{}
for _, item := range items {
// 调试
if Tea.IsTesting() {
this.debugItem(item)
}
var list *IPList
// TODO 实现节点专有List
if item.ServerId > 0 { // 服务专有List
switch item.ListType {
case "black":
list = SharedServerListManager.FindBlackList(item.ServerId, true)
case "white":
list = SharedServerListManager.FindWhiteList(item.ServerId, true)
}
} else if item.IsGlobal { // 全局List
switch item.ListType {
case "black":
list = GlobalBlackIPList
case "white":
list = GlobalWhiteIPList
}
} else { // 其他List
this.mu.Lock()
list = this.listMap[item.ListId]
this.mu.Unlock()
}
if list == nil {
list = NewIPList()
this.mu.Lock()
this.listMap[item.ListId] = list
this.mu.Unlock()
}
changedLists[list] = zero.New()
if item.IsDeleted {
list.Delete(uint64(item.Id))
// 从WAF名单中删除
waf.SharedIPBlackList.RemoveIP(item.IpFrom, item.ServerId, fromRemote)
// 操作事件
if fromRemote {
SharedActionManager.DeleteItem(item.ListType, item)
}
continue
}
list.AddDelay(&IPItem{
Id: uint64(item.Id),
Type: item.Type,
IPFrom: iputils.ToBytes(item.IpFrom),
IPTo: iputils.ToBytes(item.IpTo),
ExpiredAt: item.ExpiredAt,
EventLevel: item.EventLevel,
})
// 事件操作
if fromRemote {
SharedActionManager.DeleteItem(item.ListType, item)
SharedActionManager.AddItem(item.ListType, item)
}
}
if len(changedLists) > 0 {
for changedList := range changedLists {
changedList.Sort()
}
}
}
// 调试IP信息
func (this *IPListManager) debugItem(item *pb.IPItem) {
var ipRange = item.IpFrom
if len(item.IpTo) > 0 {
ipRange += " - " + item.IpTo
}
if item.IsDeleted {
remotelogs.Debug("IP_ITEM_DEBUG", "delete '"+ipRange+"'")
} else {
remotelogs.Debug("IP_ITEM_DEBUG", "add '"+ipRange+"'")
}
}

View File

@@ -0,0 +1,51 @@
package iplibrary_test
import (
"github.com/TeaOSLab/EdgeCommon/pkg/iputils"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"github.com/TeaOSLab/EdgeNode/internal/utils/testutils"
"github.com/iwind/TeaGo/logs"
"testing"
"time"
)
func TestIPListManager_init(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var manager = iplibrary.NewIPListManager()
manager.Init()
t.Log(manager.ListMap())
t.Log(iplibrary.SharedServerListManager.BlackMap())
logs.PrintAsJSON(iplibrary.GlobalBlackIPList.SortedRangeItems(), t)
}
func TestIPListManager_check(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var manager = iplibrary.NewIPListManager()
manager.Init()
var before = time.Now()
defer func() {
t.Log(time.Since(before).Seconds()*1000, "ms")
}()
t.Log(iplibrary.SharedServerListManager.FindBlackList(23, true).Contains(iputils.ToBytes("127.0.0.2")))
t.Log(iplibrary.GlobalBlackIPList.Contains(iputils.ToBytes("127.0.0.6")))
}
func TestIPListManager_loop(t *testing.T) {
if !testutils.IsSingleTesting() {
return
}
var manager = iplibrary.NewIPListManager()
manager.Start()
err := manager.Loop()
if err != nil {
t.Fatal(err)
}
}

View File

@@ -0,0 +1,10 @@
package iplibrary
type Result struct {
CityId int64
Country string
Region string
Province string
City string
ISP string
}

View File

@@ -0,0 +1,65 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package iplibrary
import "sync"
var SharedServerListManager = NewServerListManager()
// ServerListManager 服务相关名单
type ServerListManager struct {
whiteMap map[int64]*IPList // serverId => *List
blackMap map[int64]*IPList // serverId => *List
locker sync.RWMutex
}
func NewServerListManager() *ServerListManager {
return &ServerListManager{
whiteMap: map[int64]*IPList{},
blackMap: map[int64]*IPList{},
}
}
func (this *ServerListManager) FindWhiteList(serverId int64, autoCreate bool) *IPList {
this.locker.RLock()
list, ok := this.whiteMap[serverId]
this.locker.RUnlock()
if ok {
return list
}
if autoCreate {
list = NewIPList()
this.locker.Lock()
this.whiteMap[serverId] = list
this.locker.Unlock()
return list
}
return nil
}
func (this *ServerListManager) FindBlackList(serverId int64, autoCreate bool) *IPList {
this.locker.RLock()
list, ok := this.blackMap[serverId]
this.locker.RUnlock()
if ok {
return list
}
if autoCreate {
list = NewIPList()
this.locker.Lock()
this.blackMap[serverId] = list
this.locker.Unlock()
return list
}
return nil
}
func (this *ServerListManager) BlackMap() map[int64]*IPList {
return this.blackMap
}