482 lines
12 KiB
Go
482 lines
12 KiB
Go
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||
|
||
package nodes
|
||
|
||
import (
|
||
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
|
||
"github.com/TeaOSLab/EdgeCommon/pkg/iplibrary"
|
||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||
"github.com/TeaOSLab/EdgeDNS/internal/agents"
|
||
"github.com/TeaOSLab/EdgeDNS/internal/dbs"
|
||
"github.com/TeaOSLab/EdgeDNS/internal/models"
|
||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
|
||
"github.com/iwind/TeaGo/lists"
|
||
"github.com/iwind/TeaGo/types"
|
||
"net"
|
||
"sort"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
)
|
||
|
||
// RouteManager 线路管理器
|
||
type RouteManager struct {
|
||
allRouteMap map[int64]*models.NSRoute // routeId => route
|
||
userRouteMap map[int64][]int64 // userId => sorted routeIds
|
||
|
||
db *dbs.DB
|
||
version int64
|
||
locker sync.RWMutex
|
||
|
||
notifier chan bool
|
||
|
||
ispRouteMap map[string]string // name => code
|
||
chinaRouteMap map[string]string // name => code
|
||
worldRouteMap map[string]string // name => code
|
||
}
|
||
|
||
// NewRouteManager 获取新线路管理器对象
|
||
func NewRouteManager(db *dbs.DB) *RouteManager {
|
||
return &RouteManager{
|
||
db: db,
|
||
allRouteMap: map[int64]*models.NSRoute{},
|
||
userRouteMap: map[int64][]int64{},
|
||
|
||
notifier: make(chan bool, 8),
|
||
|
||
ispRouteMap: map[string]string{},
|
||
chinaRouteMap: map[string]string{},
|
||
worldRouteMap: map[string]string{},
|
||
}
|
||
}
|
||
|
||
// Start 启动自动任务
|
||
func (this *RouteManager) Start() {
|
||
remotelogs.Println("ROUTE_MANAGER", "starting ...")
|
||
|
||
// 初始化公共线路
|
||
this.loadDefaultRoutes()
|
||
|
||
// 从本地数据库中加载数据
|
||
err := this.Load()
|
||
if err != nil {
|
||
if rpc.IsConnError(err) {
|
||
remotelogs.Debug("ROUTE_MANAGER", "load failed: "+err.Error())
|
||
} else {
|
||
remotelogs.Error("ROUTE_MANAGER", "load failed: "+err.Error())
|
||
}
|
||
}
|
||
|
||
// 初始化运行
|
||
err = this.LoopAll()
|
||
if err != nil {
|
||
if rpc.IsConnError(err) {
|
||
remotelogs.Debug("ROUTE_MANAGER", "loop failed: "+err.Error())
|
||
} else {
|
||
remotelogs.Error("ROUTE_MANAGER", "loop failed: "+err.Error())
|
||
}
|
||
}
|
||
|
||
// 更新
|
||
var ticker = time.NewTicker(1 * time.Minute)
|
||
for {
|
||
select {
|
||
case <-ticker.C:
|
||
case <-this.notifier:
|
||
}
|
||
|
||
err := this.LoopAll()
|
||
if err != nil {
|
||
if rpc.IsConnError(err) {
|
||
remotelogs.Debug("ROUTE_MANAGER", "loop failed: "+err.Error())
|
||
} else {
|
||
remotelogs.Error("ROUTE_MANAGER", "loop failed: "+err.Error())
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// Load 从数据库中加载数据
|
||
func (this *RouteManager) Load() error {
|
||
var offset int64 = 0
|
||
var size int64 = 10000
|
||
for {
|
||
routes, err := this.db.ListRoutes(offset, size)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
if len(routes) == 0 {
|
||
break
|
||
}
|
||
|
||
this.locker.Lock()
|
||
for _, route := range routes {
|
||
this.addRoute(route)
|
||
|
||
if route.Version > this.version {
|
||
this.version = route.Version
|
||
}
|
||
}
|
||
this.locker.Unlock()
|
||
|
||
offset += size
|
||
}
|
||
|
||
if this.version > 0 {
|
||
this.version++
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (this *RouteManager) LoopAll() error {
|
||
for {
|
||
hasNext, err := this.Loop()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if !hasNext {
|
||
break
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// Loop 单次循环任务
|
||
func (this *RouteManager) Loop() (hasNext bool, err error) {
|
||
client, err := rpc.SharedRPC()
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
resp, err := client.NSRouteRPC.ListNSRoutesAfterVersion(client.Context(), &pb.ListNSRoutesAfterVersionRequest{
|
||
Version: this.version,
|
||
Size: 20000,
|
||
})
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
var routes = resp.NsRoutes
|
||
if len(routes) == 0 {
|
||
return false, nil
|
||
}
|
||
for _, route := range routes {
|
||
this.processRoute(route)
|
||
|
||
if route.Version > this.version {
|
||
this.version = route.Version
|
||
}
|
||
}
|
||
this.version++
|
||
return true, nil
|
||
}
|
||
|
||
// FindRouteCodes 查找一个地址对应的线路
|
||
func (this *RouteManager) FindRouteCodes(ip string, domainUserId int64) (result []string) {
|
||
var netIP = net.ParseIP(ip)
|
||
if len(netIP) == 0 {
|
||
return nil
|
||
}
|
||
|
||
// 自定义route
|
||
this.locker.RLock()
|
||
|
||
// 先查找用户自定义的
|
||
if domainUserId > 0 {
|
||
var userRouteIds = this.userRouteMap[domainUserId]
|
||
for _, routeId := range userRouteIds {
|
||
route, ok := this.allRouteMap[routeId]
|
||
if ok && route.Contains(netIP) {
|
||
result = append(result, route.RealCode())
|
||
}
|
||
}
|
||
}
|
||
|
||
// 再查找公共的
|
||
var publicRouteIds = this.userRouteMap[0]
|
||
for _, routeId := range publicRouteIds {
|
||
route, ok := this.allRouteMap[routeId]
|
||
if ok && route.Contains(netIP) {
|
||
result = append(result, route.RealCode())
|
||
}
|
||
}
|
||
|
||
this.locker.RUnlock()
|
||
|
||
// 解析公用线路
|
||
var ipResult = iplibrary.LookupIP(ip)
|
||
if ipResult != nil && ipResult.IsOk() {
|
||
// 运营商
|
||
for _, providerCode := range ipResult.ProviderCodes() {
|
||
code, ok := this.ispRouteMap[providerCode]
|
||
if ok {
|
||
result = append(result, code)
|
||
|
||
// 单次只能有一个匹配
|
||
break
|
||
}
|
||
}
|
||
|
||
// 省|州|城市
|
||
if ipResult.ProvinceId() > 0 {
|
||
result = append(result, "region:province:"+types.String(ipResult.ProvinceId()))
|
||
}
|
||
if ipResult.CityId() > 0 {
|
||
result = append(result, "region:city:"+types.String(ipResult.CityId()))
|
||
}
|
||
if ipResult.TownId() > 0 {
|
||
result = append(result, "region:town:"+types.String(ipResult.TownId()))
|
||
}
|
||
|
||
// 中国省市
|
||
for _, provinceCode := range ipResult.ProvinceCodes() {
|
||
// 中国
|
||
code, ok := this.chinaRouteMap[provinceCode]
|
||
if ok {
|
||
result = append(result, code)
|
||
|
||
// 兼容以前的拼写错误
|
||
switch code {
|
||
case "china:province:hebei":
|
||
result = append(result, "china:province:heibei")
|
||
case "china:province:heibei":
|
||
result = append(result, "china:province:hebei")
|
||
case "china:jilin":
|
||
result = append(result, "china:province:jilin")
|
||
case "china:province:jilin":
|
||
result = append(result, "china:jilin")
|
||
}
|
||
|
||
// 香港
|
||
switch code {
|
||
case dnsconfigs.ChinaProvinceCodeHK:
|
||
result = append(result, dnsconfigs.WorldRegionCodeHK, dnsconfigs.WorldRegionCodeChinaAbroad)
|
||
|
||
// 澳门
|
||
case dnsconfigs.ChinaProvinceCodeMO:
|
||
result = append(result, dnsconfigs.WorldRegionCodeMO, dnsconfigs.WorldRegionCodeChinaAbroad)
|
||
|
||
// 台湾
|
||
case dnsconfigs.ChinaProvinceCodeTW:
|
||
result = append(result, dnsconfigs.WorldRegionCodeTW, dnsconfigs.WorldRegionCodeChinaAbroad)
|
||
default:
|
||
result = append(result, dnsconfigs.WorldRegionCodeChinaMainland)
|
||
}
|
||
|
||
// 单次只能有一个匹配
|
||
break
|
||
}
|
||
}
|
||
|
||
// 国家/地区
|
||
for _, countryCode := range ipResult.CountryCodes() {
|
||
code, ok := this.worldRouteMap[countryCode]
|
||
if ok {
|
||
// 中国全境(world:CN)必须优先于「海外」匹配,否则前面中国省市若误判为 HK/MO/TW 已加入 world:CN:abroad 时会先命中海外线
|
||
if code == dnsconfigs.WorldRegionCodeChina {
|
||
result = append([]string{code}, result...)
|
||
} else {
|
||
result = append(result, code)
|
||
// 中国海外
|
||
if code != dnsconfigs.WorldRegionCodeChina {
|
||
result = append(result, dnsconfigs.WorldRegionCodeChinaAbroad)
|
||
}
|
||
}
|
||
|
||
// 单次只能有一个匹配
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
// 搜索引擎线路
|
||
if agents.SharedManager != nil {
|
||
var agentCode = agents.SharedManager.LookupIP(ip)
|
||
if len(agentCode) > 0 {
|
||
result = append(result, "agent:"+agentCode, "agent" /** 所有搜索引擎 **/)
|
||
}
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
// NotifyUpdate 通知更新
|
||
func (this *RouteManager) NotifyUpdate() {
|
||
select {
|
||
case this.notifier <- true:
|
||
default:
|
||
}
|
||
}
|
||
|
||
func (this *RouteManager) loadDefaultRoutes() {
|
||
for _, route := range dnsconfigs.AllDefaultISPRoutes {
|
||
for _, name := range route.AliasNames {
|
||
this.ispRouteMap[name] = route.Code
|
||
}
|
||
}
|
||
for _, route := range dnsconfigs.AllDefaultChinaProvinceRoutes {
|
||
for _, name := range route.AliasNames {
|
||
this.chinaRouteMap[name] = route.Code
|
||
}
|
||
}
|
||
for _, route := range dnsconfigs.AllDefaultWorldRegionRoutes {
|
||
for _, name := range route.AliasNames {
|
||
this.worldRouteMap[name] = route.Code
|
||
}
|
||
// 用线路 Code 中的国家/地区 ISO 码做映射,使 IP 库返回的 ISO 码(如 US、CN、HK)能命中对应线路
|
||
if strings.HasPrefix(route.Code, "world:") {
|
||
parts := strings.Split(route.Code, ":")
|
||
if len(parts) == 2 && len(parts[1]) == 2 {
|
||
this.worldRouteMap[parts[1]] = route.Code
|
||
}
|
||
if len(parts) == 3 && len(parts[2]) == 2 {
|
||
this.worldRouteMap[strings.ToUpper(parts[2])] = route.Code
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 添加线路
|
||
func (this *RouteManager) addRoute(route *models.NSRoute) {
|
||
// 不需要加锁,因为此函数均在锁内调用
|
||
|
||
// 从老的用户中删除
|
||
oldRoute, ok := this.allRouteMap[route.Id]
|
||
if ok {
|
||
var oldUserId = oldRoute.UserId
|
||
if oldUserId != route.UserId {
|
||
userRouteIds, ok := this.userRouteMap[oldUserId]
|
||
if ok {
|
||
this.userRouteMap[oldUserId] = this.removeId(userRouteIds, route.Id)
|
||
if len(userRouteIds) == 0 {
|
||
delete(this.userRouteMap, oldUserId)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 添加
|
||
this.allRouteMap[route.Id] = route
|
||
|
||
userRouteIds, ok := this.userRouteMap[route.UserId]
|
||
if ok {
|
||
// 重新按优先级、排序、ID排序
|
||
var userRoutes = []*models.NSRoute{}
|
||
for _, userRouteId := range userRouteIds {
|
||
userRoute, ok := this.allRouteMap[userRouteId]
|
||
if ok {
|
||
userRoutes = append(userRoutes, userRoute)
|
||
}
|
||
}
|
||
if !lists.ContainsInt64(userRouteIds, route.Id) {
|
||
userRoutes = append(userRoutes, route)
|
||
}
|
||
|
||
sort.Slice(userRoutes, func(i, j int) bool {
|
||
var userRoute1 = userRoutes[i]
|
||
var userRoute2 = userRoutes[j]
|
||
if userRoute1.Priority != userRoute2.Priority {
|
||
return userRoute1.Priority > userRoute2.Priority
|
||
}
|
||
if userRoute1.Order != userRoute2.Order {
|
||
return userRoute1.Order > userRoute2.Order
|
||
}
|
||
return userRoute1.Id < userRoute2.Id
|
||
})
|
||
|
||
var newUserRouteIds = []int64{}
|
||
for _, userRoute := range userRoutes {
|
||
newUserRouteIds = append(newUserRouteIds, userRoute.Id)
|
||
}
|
||
this.userRouteMap[route.UserId] = newUserRouteIds
|
||
} else {
|
||
this.userRouteMap[route.UserId] = []int64{route.Id}
|
||
}
|
||
}
|
||
|
||
// 删除线路
|
||
func (this *RouteManager) removePBRoute(route *pb.NSRoute) {
|
||
delete(this.allRouteMap, route.Id)
|
||
userRouteIds, ok := this.userRouteMap[route.UserId]
|
||
if ok {
|
||
userRouteIds = this.removeId(userRouteIds, route.Id)
|
||
if len(userRouteIds) == 0 {
|
||
delete(this.userRouteMap, route.UserId)
|
||
} else {
|
||
this.userRouteMap[route.UserId] = userRouteIds
|
||
}
|
||
}
|
||
}
|
||
|
||
// 处理线路
|
||
func (this *RouteManager) processRoute(route *pb.NSRoute) {
|
||
if !route.IsOn || route.IsDeleted {
|
||
this.locker.Lock()
|
||
this.removePBRoute(route)
|
||
this.locker.Unlock()
|
||
|
||
// 从数据库中删除
|
||
if this.db != nil {
|
||
err := this.db.DeleteRoute(route.Id)
|
||
if err != nil {
|
||
remotelogs.Error("ROUTE_MANAGER", "delete route from db failed: "+err.Error())
|
||
}
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
// 存入数据库
|
||
if this.db != nil {
|
||
exists, err := this.db.ExistsRoute(route.Id)
|
||
if err != nil {
|
||
remotelogs.Error("ROUTE_MANAGER", "query failed: "+err.Error())
|
||
} else {
|
||
if exists {
|
||
err = this.db.UpdateRoute(route.Id, route.UserId, route.RangesJSON, route.Order, route.Priority, route.Version)
|
||
if err != nil {
|
||
remotelogs.Error("ROUTE_MANAGER", "update failed: "+err.Error())
|
||
}
|
||
} else {
|
||
err = this.db.InsertRoute(route.Id, route.UserId, route.RangesJSON, route.Order, route.Priority, route.Version)
|
||
if err != nil {
|
||
remotelogs.Error("ROUTE_MANAGER", "insert failed: "+err.Error())
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
ranges, err := models.InitRangesFromJSON(route.RangesJSON)
|
||
if err != nil {
|
||
remotelogs.Error("ROUTE_MANAGER", "decode routes '"+strconv.FormatInt(route.Id, 10)+"' failed: "+err.Error())
|
||
return
|
||
}
|
||
|
||
var nsRoute = &models.NSRoute{
|
||
Id: route.Id,
|
||
Ranges: ranges,
|
||
Priority: route.Priority,
|
||
Order: route.Order,
|
||
UserId: route.UserId,
|
||
Version: route.Version,
|
||
}
|
||
|
||
this.locker.Lock()
|
||
this.addRoute(nsRoute)
|
||
this.locker.Unlock()
|
||
}
|
||
|
||
// 从一组ID中删除某个ID
|
||
func (this *RouteManager) removeId(ids []int64, id int64) []int64 {
|
||
var result = []int64{}
|
||
for _, id2 := range ids {
|
||
if id2 == id {
|
||
continue
|
||
}
|
||
result = append(result, id2)
|
||
}
|
||
return result
|
||
}
|