Files
waf-platform/EdgeHttpDNS/internal/nodes/resolve_server.go
2026-02-27 10:35:22 +08:00

1120 lines
30 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package nodes
import (
"context"
"crypto/hmac"
"crypto/sha256"
"crypto/tls"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"log"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/TeaOSLab/EdgeCommon/pkg/iplibrary"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeHttpDNS/internal/configs"
"github.com/TeaOSLab/EdgeHttpDNS/internal/rpc"
"github.com/iwind/TeaGo/rands"
"github.com/miekg/dns"
)
// sharedRecursionDNSClient 共享的递归DNS客户端对齐 EdgeDNS 实现)
var sharedRecursionDNSClient = &dns.Client{
Timeout: 3 * time.Second,
}
const (
httpdnsCodeSuccess = "SUCCESS"
httpdnsCodeAppInvalid = "APP_NOT_FOUND_OR_DISABLED"
httpdnsCodeDomainNotBound = "DOMAIN_NOT_BOUND"
httpdnsCodeSignInvalid = "SIGN_INVALID"
httpdnsCodeNoRecords = "NO_RECORDS"
httpdnsCodeInternalError = "RESOLVE_TIMEOUT_OR_INTERNAL"
httpdnsCodeMethodNotAllow = "METHOD_NOT_ALLOWED"
httpdnsCodeInvalidArgument = "INVALID_ARGUMENT"
)
type resolveClientInfo struct {
IP string `json:"ip"`
Region string `json:"region"`
Carrier string `json:"carrier"`
Country string `json:"country"`
}
type resolveRecord struct {
Type string `json:"type"`
IP string `json:"ip"`
Weight int32 `json:"weight,omitempty"`
Line string `json:"line,omitempty"`
Region string `json:"region,omitempty"`
}
type resolveData struct {
Domain string `json:"domain"`
QType string `json:"qtype"`
TTL int32 `json:"ttl"`
Records []*resolveRecord `json:"records"`
Client *resolveClientInfo `json:"client"`
Summary string `json:"summary"`
}
type resolveResponse struct {
Code string `json:"code"`
Message string `json:"message"`
RequestID string `json:"requestId"`
Data *resolveData `json:"data,omitempty"`
}
type clientRouteProfile struct {
IP string
Country string
Province string
Carrier string
ProviderRaw string
Region string
Continent string
RegionText string
}
type ResolveServer struct {
quitCh <-chan struct{}
snapshotManager *SnapshotManager
listenAddr string
certFile string
keyFile string
server *http.Server
logQueue chan *pb.HTTPDNSAccessLog
}
func NewResolveServer(quitCh <-chan struct{}, snapshotManager *SnapshotManager) *ResolveServer {
listenAddr := ":443"
certFile := ""
keyFile := ""
if apiConfig, err := configs.SharedAPIConfig(); err == nil && apiConfig != nil {
if len(apiConfig.HTTPSListenAddr) > 0 {
listenAddr = apiConfig.HTTPSListenAddr
}
certFile = apiConfig.HTTPSCert
keyFile = apiConfig.HTTPSKey
}
instance := &ResolveServer{
quitCh: quitCh,
snapshotManager: snapshotManager,
listenAddr: listenAddr,
certFile: certFile,
keyFile: keyFile,
logQueue: make(chan *pb.HTTPDNSAccessLog, 8192),
}
mux := http.NewServeMux()
mux.HandleFunc("/resolve", instance.handleResolve)
mux.HandleFunc("/healthz", instance.handleHealth)
instance.server = &http.Server{
Addr: instance.listenAddr,
Handler: mux,
ReadTimeout: 5 * time.Second,
ReadHeaderTimeout: 3 * time.Second,
WriteTimeout: 5 * time.Second,
IdleTimeout: 75 * time.Second,
MaxHeaderBytes: 8 * 1024,
TLSConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
},
}
return instance
}
func (s *ResolveServer) Start() {
go s.startAccessLogFlusher()
go s.waitForShutdown()
log.Println("[HTTPDNS_NODE][resolve]listening HTTPS on", s.listenAddr)
if err := s.server.ListenAndServeTLS(s.certFile, s.keyFile); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Println("[HTTPDNS_NODE][resolve]listen failed:", err.Error())
}
}
func (s *ResolveServer) waitForShutdown() {
<-s.quitCh
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
_ = s.server.Shutdown(ctx)
}
func (s *ResolveServer) handleHealth(writer http.ResponseWriter, _ *http.Request) {
writer.WriteHeader(http.StatusOK)
_, _ = writer.Write([]byte("ok"))
}
func (s *ResolveServer) handleResolve(writer http.ResponseWriter, request *http.Request) {
startAt := time.Now()
requestID := "rid-" + rands.HexString(16)
if request.Method != http.MethodGet {
s.writeResolveJSON(writer, http.StatusMethodNotAllowed, &resolveResponse{
Code: httpdnsCodeMethodNotAllow,
Message: "只允许使用 GET 方法",
RequestID: requestID,
})
return
}
query := request.URL.Query()
appID := strings.TrimSpace(query.Get("appId"))
domain := strings.TrimSuffix(strings.ToLower(strings.TrimSpace(query.Get("dn"))), ".")
qtype := strings.ToUpper(strings.TrimSpace(query.Get("qtype")))
if len(qtype) == 0 {
qtype = "A"
}
if qtype != "A" && qtype != "AAAA" {
s.writeResolveJSON(writer, http.StatusBadRequest, &resolveResponse{
Code: httpdnsCodeInvalidArgument,
Message: "qtype 参数仅支持 A 或 AAAA",
RequestID: requestID,
})
return
}
if len(appID) == 0 || len(domain) == 0 {
s.writeResolveJSON(writer, http.StatusBadRequest, &resolveResponse{
Code: httpdnsCodeInvalidArgument,
Message: "缺少必填参数: appId 和 dn",
RequestID: requestID,
})
return
}
snapshot := s.snapshotManager.Current()
if snapshot == nil {
s.writeResolveJSON(writer, http.StatusServiceUnavailable, &resolveResponse{
Code: httpdnsCodeInternalError,
Message: "服务节点尚未准备就绪,请稍后再试",
RequestID: requestID,
})
return
}
loadedApp := snapshot.Apps[strings.ToLower(appID)]
if loadedApp == nil || loadedApp.App == nil || !loadedApp.App.GetIsOn() {
s.writeFailedResolve(writer, requestID, snapshot, nil, domain, qtype, httpdnsCodeAppInvalid, "找不到指定的应用,或该应用已下线", startAt, request, query)
return
}
if snapshot.ClusterID > 0 &&
loadedApp.App.GetPrimaryClusterId() != snapshot.ClusterID &&
loadedApp.App.GetBackupClusterId() != snapshot.ClusterID {
s.writeFailedResolve(writer, requestID, snapshot, loadedApp.App, domain, qtype, httpdnsCodeAppInvalid, "当前应用未绑定到该解析集群", startAt, request, query)
return
}
loadedDomain := loadedApp.Domains[domain]
if loadedDomain == nil || loadedDomain.Domain == nil || !loadedDomain.Domain.GetIsOn() {
s.writeFailedResolve(writer, requestID, snapshot, loadedApp.App, domain, qtype, httpdnsCodeDomainNotBound, "应用尚未绑定该域名,或域名解析已暂停", startAt, request, query)
return
}
if loadedApp.App.GetSignEnabled() {
if !validateResolveSign(loadedApp.App.GetSignSecret(), loadedApp.App.GetAppId(), domain, qtype, query.Get("nonce"), query.Get("exp"), query.Get("sign")) {
s.writeFailedResolve(writer, requestID, snapshot, loadedApp.App, domain, qtype, httpdnsCodeSignInvalid, "请求鉴权失败:签名无效或已过期", startAt, request, query)
return
}
}
clientIP := detectClientIP(request, query.Get("cip"))
clientProfile := buildClientRouteProfile(clientIP)
clusterTTL := pickDefaultTTL(snapshot, loadedApp.App)
rule, records, ttl := pickRuleRecords(loadedDomain.Rules, qtype, clientProfile, clusterTTL)
if len(records) == 0 {
// Fallback回源上游 DNS 查询真实记录
fallbackRecords, fallbackTTL, fallbackErr := fallbackResolve(domain, qtype)
if fallbackErr != nil || len(fallbackRecords) == 0 {
errMsg := "未找到解析记录"
if fallbackErr != nil {
errMsg = "未找到解析记录 (上游回源失败: " + fallbackErr.Error() + ")"
}
s.writeFailedResolve(writer, requestID, snapshot, loadedApp.App, domain, qtype, httpdnsCodeNoRecords, errMsg, startAt, request, query)
return
}
records = fallbackRecords
if fallbackTTL > 0 {
ttl = fallbackTTL
} else {
ttl = clusterTTL
}
}
if ttl <= 0 {
ttl = clusterTTL
}
if ttl <= 0 {
ttl = 30
}
resultIPs := make([]string, 0, len(records))
for _, record := range records {
resultIPs = append(resultIPs, record.IP)
}
summary := fmt.Sprintf("%s|%s(%s)|%s|%s %s -> %s|success|%dms",
time.Now().Format("2006-01-02 15:04:05"),
loadedApp.App.GetName(),
loadedApp.App.GetAppId(),
clientProfile.IP,
qtype,
domain,
strings.Join(resultIPs, ", "),
time.Since(startAt).Milliseconds(),
)
if rule != nil && len(strings.TrimSpace(rule.GetRuleName())) > 0 {
summary += "|rule:" + strings.TrimSpace(rule.GetRuleName())
}
s.writeResolveJSON(writer, http.StatusOK, &resolveResponse{
Code: httpdnsCodeSuccess,
Message: "ok",
RequestID: requestID,
Data: &resolveData{
Domain: domain,
QType: qtype,
TTL: ttl,
Records: records,
Client: &resolveClientInfo{
IP: clientProfile.IP,
Region: clientProfile.RegionText,
Carrier: clientProfile.Carrier,
Country: clientProfile.Country,
},
Summary: summary,
},
})
s.enqueueAccessLog(&pb.HTTPDNSAccessLog{
RequestId: requestID,
ClusterId: snapshot.ClusterID,
NodeId: snapshot.NodeID,
AppId: loadedApp.App.GetAppId(),
AppName: loadedApp.App.GetName(),
Domain: domain,
Qtype: qtype,
ClientIP: clientProfile.IP,
ClientRegion: clientProfile.RegionText,
Carrier: clientProfile.Carrier,
SdkVersion: strings.TrimSpace(query.Get("sdk_version")),
Os: strings.TrimSpace(query.Get("os")),
ResultIPs: strings.Join(resultIPs, ","),
Status: "success",
ErrorCode: "none",
CostMs: int32(time.Since(startAt).Milliseconds()),
CreatedAt: time.Now().Unix(),
Day: time.Now().Format("20060102"),
Summary: summary,
})
}
func pickDefaultTTL(snapshot *LoadedSnapshot, app *pb.HTTPDNSApp) int32 {
if snapshot == nil {
return 30
}
if snapshot.ClusterID > 0 {
if cluster := snapshot.Clusters[snapshot.ClusterID]; cluster != nil && cluster.GetDefaultTTL() > 0 {
return cluster.GetDefaultTTL()
}
}
if app != nil {
if cluster := snapshot.Clusters[app.GetPrimaryClusterId()]; cluster != nil && cluster.GetDefaultTTL() > 0 {
return cluster.GetDefaultTTL()
}
if cluster := snapshot.Clusters[app.GetBackupClusterId()]; cluster != nil && cluster.GetDefaultTTL() > 0 {
return cluster.GetDefaultTTL()
}
}
return 30
}
func (s *ResolveServer) writeFailedResolve(
writer http.ResponseWriter,
requestID string,
snapshot *LoadedSnapshot,
app *pb.HTTPDNSApp,
domain string,
qtype string,
errorCode string,
message string,
startAt time.Time,
request *http.Request,
query url.Values,
) {
clientIP := detectClientIP(request, query.Get("cip"))
clientProfile := buildClientRouteProfile(clientIP)
appID := ""
appName := ""
if app != nil {
appID = app.GetAppId()
appName = app.GetName()
}
summary := fmt.Sprintf("%s|%s(%s)|%s|%s %s -> [none]|failed(%s)|%dms",
time.Now().Format("2006-01-02 15:04:05"),
appName,
appID,
clientProfile.IP,
qtype,
domain,
errorCode,
time.Since(startAt).Milliseconds(),
)
s.writeResolveJSON(writer, http.StatusOK, &resolveResponse{
Code: errorCode,
Message: message,
RequestID: requestID,
})
clusterID := int64(0)
nodeID := int64(0)
if snapshot != nil {
clusterID = snapshot.ClusterID
nodeID = snapshot.NodeID
}
s.enqueueAccessLog(&pb.HTTPDNSAccessLog{
RequestId: requestID,
ClusterId: clusterID,
NodeId: nodeID,
AppId: appID,
AppName: appName,
Domain: domain,
Qtype: qtype,
ClientIP: clientProfile.IP,
ClientRegion: clientProfile.RegionText,
Carrier: clientProfile.Carrier,
SdkVersion: strings.TrimSpace(query.Get("sdk_version")),
Os: strings.TrimSpace(query.Get("os")),
ResultIPs: "",
Status: "failed",
ErrorCode: errorCode,
CostMs: int32(time.Since(startAt).Milliseconds()),
CreatedAt: time.Now().Unix(),
Day: time.Now().Format("20060102"),
Summary: summary,
})
}
func (s *ResolveServer) writeResolveJSON(writer http.ResponseWriter, status int, resp *resolveResponse) {
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(status)
data, err := json.Marshal(resp)
if err != nil {
_, _ = writer.Write([]byte(`{"code":"RESOLVE_TIMEOUT_OR_INTERNAL","message":"encode response failed"}`))
return
}
_, _ = writer.Write(data)
}
func detectClientIP(request *http.Request, cip string) string {
if candidate := normalizeIPCandidate(cip); len(candidate) > 0 {
return candidate
}
xff := strings.TrimSpace(request.Header.Get("X-Forwarded-For"))
if len(xff) > 0 {
for _, item := range strings.Split(xff, ",") {
if candidate := normalizeIPCandidate(item); len(candidate) > 0 {
return candidate
}
}
}
headerKeys := []string{"X-Real-IP", "X-Client-IP", "CF-Connecting-IP", "True-Client-IP"}
for _, key := range headerKeys {
if candidate := normalizeIPCandidate(request.Header.Get(key)); len(candidate) > 0 {
return candidate
}
}
return normalizeIPCandidate(request.RemoteAddr)
}
func validateResolveSign(signSecret string, appID string, domain string, qtype string, nonce string, exp string, sign string) bool {
signSecret = strings.TrimSpace(signSecret)
nonce = strings.TrimSpace(nonce)
exp = strings.TrimSpace(exp)
sign = strings.TrimSpace(sign)
if len(signSecret) == 0 || len(nonce) == 0 || len(exp) == 0 || len(sign) == 0 {
return false
}
expireAt, err := strconv.ParseInt(exp, 10, 64)
if err != nil {
return false
}
now := time.Now().Unix()
if expireAt <= now-30 || expireAt > now+86400 {
return false
}
raw := appID + "|" + strings.ToLower(domain) + "|" + strings.ToUpper(qtype) + "|" + exp + "|" + nonce
mac := hmac.New(sha256.New, []byte(signSecret))
_, _ = mac.Write([]byte(raw))
expected := hex.EncodeToString(mac.Sum(nil))
return strings.EqualFold(expected, sign)
}
func buildClientRouteProfile(ip string) *clientRouteProfile {
profile := &clientRouteProfile{
IP: normalizeIPCandidate(ip),
}
if net.ParseIP(profile.IP) == nil {
return profile
}
result := iplibrary.LookupIP(profile.IP)
if result == nil || !result.IsOk() {
return profile
}
profile.Country = normalizeCountryName(strings.TrimSpace(result.CountryName()))
profile.Province = normalizeProvinceName(strings.TrimSpace(result.ProvinceName()))
profile.ProviderRaw = strings.TrimSpace(result.ProviderName())
profile.Carrier = normalizeCarrier(profile.ProviderRaw, profile.Country)
if len(profile.Carrier) == 0 {
if isMainlandChinaCountry(profile.Country) {
profile.Carrier = "默认"
} else {
if len(profile.ProviderRaw) > 0 {
profile.Carrier = profile.ProviderRaw
} else {
profile.Carrier = "默认"
}
}
}
profile.Region = normalizeChinaRegion(profile.Province)
profile.Continent = normalizeContinent(profile.Country)
profile.RegionText = strings.TrimSpace(result.RegionSummary())
if len(profile.RegionText) == 0 {
pieces := make([]string, 0, 4)
if len(profile.Country) > 0 {
pieces = append(pieces, profile.Country)
}
if len(profile.Province) > 0 {
pieces = append(pieces, profile.Province)
}
if len(profile.Region) > 0 {
pieces = append(pieces, profile.Region)
}
if len(profile.Carrier) > 0 {
pieces = append(pieces, profile.Carrier)
}
profile.RegionText = strings.Join(pieces, " ")
}
return profile
}
func normalizeCarrier(provider string, country string) string {
value := strings.TrimSpace(provider)
if len(value) == 0 {
return ""
}
lower := strings.ToLower(value)
switch {
case strings.Contains(value, "电信"), strings.Contains(value, "天翼"),
strings.Contains(lower, "telecom"), strings.Contains(lower, "chinanet"),
strings.Contains(lower, "chinatelecom"), strings.Contains(lower, "ctnet"), strings.Contains(lower, "cn2"):
return "电信"
case strings.Contains(value, "联通"), strings.Contains(value, "网通"),
strings.Contains(lower, "unicom"), strings.Contains(lower, "chinaunicom"),
strings.Contains(lower, "cucc"), strings.Contains(lower, "china169"), strings.Contains(lower, "cnc"):
return "联通"
case strings.Contains(value, "移动"),
strings.Contains(lower, "mobile"), strings.Contains(lower, "chinamobile"),
strings.Contains(lower, "cmcc"), strings.Contains(lower, "cmnet"):
return "移动"
case strings.Contains(value, "教育"),
strings.Contains(lower, "cernet"), strings.Contains(lower, "edu"), strings.Contains(lower, "education"):
return "教育网"
case strings.Contains(value, "鹏博士"),
strings.Contains(lower, "drpeng"), strings.Contains(lower, "dr.peng"), strings.Contains(lower, "dr_peng"):
return "鹏博士"
case strings.Contains(value, "广电"),
strings.Contains(lower, "broadcast"), strings.Contains(lower, "cable"), strings.Contains(lower, "radio"):
return "广电"
default:
if isMainlandChinaCountry(country) {
return ""
}
return value
}
}
func normalizeChinaRegion(province string) string {
switch normalizeProvinceName(province) {
case "辽宁", "吉林", "黑龙江":
return "东北"
case "北京", "天津", "河北", "山西", "内蒙古":
return "华北"
case "上海", "江苏", "浙江", "安徽", "福建", "江西", "山东":
return "华东"
case "广东", "广西", "海南":
return "华南"
case "河南", "湖北", "湖南":
return "华中"
case "陕西", "甘肃", "青海", "宁夏", "新疆":
return "西北"
case "重庆", "四川", "贵州", "云南", "西藏":
return "西南"
default:
return ""
}
}
func normalizeContinent(country string) string {
switch normalizeCountryName(country) {
case "中国", "中国香港", "中国澳门", "中国台湾", "日本", "韩国", "新加坡", "印度", "泰国", "越南":
return "亚洲"
case "美国", "加拿大", "墨西哥":
return "北美洲"
case "巴西", "阿根廷", "智利", "哥伦比亚":
return "南美洲"
case "德国", "英国", "法国", "荷兰", "西班牙", "意大利", "俄罗斯":
return "欧洲"
case "南非", "埃及", "尼日利亚", "肯尼亚", "摩洛哥":
return "非洲"
case "澳大利亚", "新西兰":
return "大洋洲"
default:
return ""
}
}
func pickRuleRecords(rules []*pb.HTTPDNSCustomRule, qtype string, profile *clientRouteProfile, defaultTTL int32) (*pb.HTTPDNSCustomRule, []*resolveRecord, int32) {
bestScore := -1
var bestRule *pb.HTTPDNSCustomRule
var bestRecords []*resolveRecord
bestTTL := defaultTTL
for _, rule := range rules {
if rule == nil || !rule.GetIsOn() {
continue
}
score, ok := matchRuleLine(rule, profile)
if !ok {
continue
}
records := make([]*resolveRecord, 0)
for _, item := range rule.GetRecords() {
if item == nil {
continue
}
if strings.ToUpper(strings.TrimSpace(item.GetRecordType())) != qtype {
continue
}
value := strings.TrimSpace(item.GetRecordValue())
if len(value) == 0 {
continue
}
records = append(records, &resolveRecord{
Type: qtype,
IP: value,
Weight: item.GetWeight(),
Line: ruleLineSummary(rule),
Region: ruleRegionSummary(rule),
})
}
if len(records) == 0 {
continue
}
if score > bestScore {
bestScore = score
bestRule = rule
bestRecords = records
if rule.GetTtl() > 0 {
bestTTL = rule.GetTtl()
} else {
bestTTL = defaultTTL
}
}
}
return bestRule, bestRecords, bestTTL
}
// fallbackResolve 当无自定义规则命中时,回源上游 DNS 查询真实记录(对齐 EdgeDNS 做法)
func fallbackResolve(domain string, qtype string) ([]*resolveRecord, int32, error) {
var dnsType uint16
switch qtype {
case "A":
dnsType = dns.TypeA
case "AAAA":
dnsType = dns.TypeAAAA
default:
return nil, 0, nil // 仅降级处理 A 和 AAAA
}
m := new(dns.Msg)
m.SetQuestion(dns.Fqdn(domain), dnsType)
m.RecursionDesired = true
// 优先使用本机 /etc/resolv.conf 中的 DNS 服务器(对齐 EdgeDNS
var upstream = "223.5.5.5:53"
resolveConfig, confErr := dns.ClientConfigFromFile("/etc/resolv.conf")
if confErr == nil && len(resolveConfig.Servers) > 0 {
port := resolveConfig.Port
if len(port) == 0 {
port = "53"
}
server := resolveConfig.Servers[rands.Int(0, len(resolveConfig.Servers)-1)]
upstream = server + ":" + port
}
r, _, err := sharedRecursionDNSClient.Exchange(m, upstream)
if err != nil {
return nil, 0, err
}
if r.Rcode != dns.RcodeSuccess {
return nil, 0, fmt.Errorf("upstream rcode: %d", r.Rcode)
}
var records []*resolveRecord
var responseTTL int32
for _, ans := range r.Answer {
switch t := ans.(type) {
case *dns.A:
if qtype == "A" {
if t.Hdr.Ttl > 0 {
ttl := int32(t.Hdr.Ttl)
if responseTTL == 0 || ttl < responseTTL {
responseTTL = ttl
}
}
ip := t.A.String()
records = append(records, &resolveRecord{
Type: "A",
IP: ip,
Line: lookupIPLineLabel(ip),
Region: lookupIPRegionSummary(ip),
})
}
case *dns.AAAA:
if qtype == "AAAA" {
if t.Hdr.Ttl > 0 {
ttl := int32(t.Hdr.Ttl)
if responseTTL == 0 || ttl < responseTTL {
responseTTL = ttl
}
}
ip := t.AAAA.String()
records = append(records, &resolveRecord{
Type: "AAAA",
IP: ip,
Line: lookupIPLineLabel(ip),
Region: lookupIPRegionSummary(ip),
})
}
}
}
return records, responseTTL, nil
}
func lookupIPRegionSummary(ip string) string {
address := strings.TrimSpace(ip)
if len(address) == 0 || net.ParseIP(address) == nil {
return ""
}
result := iplibrary.LookupIP(address)
if result == nil || !result.IsOk() {
return ""
}
return strings.TrimSpace(result.RegionSummary())
}
func lookupIPLineLabel(ip string) string {
address := strings.TrimSpace(ip)
if len(address) == 0 || net.ParseIP(address) == nil {
return "上游DNS"
}
result := iplibrary.LookupIP(address)
if result == nil || !result.IsOk() {
return "上游DNS"
}
provider := strings.TrimSpace(result.ProviderName())
country := normalizeCountryName(strings.TrimSpace(result.CountryName()))
carrier := strings.TrimSpace(normalizeCarrier(provider, country))
if isMainlandChinaCountry(country) {
if len(carrier) > 0 {
return carrier
}
return "默认"
}
if len(carrier) > 0 {
return carrier
}
if len(provider) > 0 {
return provider
}
return "上游DNS"
}
func isMainlandChinaCountry(country string) bool {
switch normalizeCountryName(country) {
case "中国":
return true
}
return false
}
func matchRuleLine(rule *pb.HTTPDNSCustomRule, profile *clientRouteProfile) (int, bool) {
scope := strings.ToLower(strings.TrimSpace(rule.GetLineScope()))
score := 0
if scope == "overseas" {
// 境外规则只匹配非中国大陆来源
if isMainlandChinaCountry(profile.Country) {
return 0, false
}
fieldScore, ok := matchRuleField(rule.GetLineContinent(), profile.Continent)
if !ok {
return 0, false
}
score += fieldScore
fieldScore, ok = matchRuleField(rule.GetLineCountry(), profile.Country)
if !ok {
return 0, false
}
score += fieldScore
return score, true
}
// 中国地区规则只匹配中国大陆来源
if !isMainlandChinaCountry(profile.Country) {
return 0, false
}
fieldScore, ok := matchRuleField(rule.GetLineCountry(), profile.Country)
if !ok {
return 0, false
}
score += fieldScore
fieldScore, ok = matchRuleField(rule.GetLineContinent(), profile.Continent)
if !ok {
return 0, false
}
score += fieldScore
fieldScore, ok = matchRuleField(rule.GetLineCarrier(), profile.Carrier, profile.ProviderRaw)
if !ok {
return 0, false
}
score += fieldScore
fieldScore, ok = matchRuleField(rule.GetLineRegion(), profile.Region)
if !ok {
return 0, false
}
score += fieldScore
fieldScore, ok = matchRuleField(rule.GetLineProvince(), profile.Province)
if !ok {
return 0, false
}
score += fieldScore
return score, true
}
func matchRuleField(ruleValue string, candidates ...string) (int, bool) {
if isDefaultLineValue(ruleValue) {
return 0, true
}
want := normalizeLineValue(ruleValue)
if len(want) == 0 {
return 0, true
}
for _, candidate := range candidates {
got := normalizeLineValue(candidate)
if len(got) == 0 {
// 如果规则设了具体值want但客户端信息got为空则不能匹配
// 否则 strings.Contains("xxx", "") 永远为 true
continue
}
if want == got || strings.Contains(got, want) || strings.Contains(want, got) {
return 1, true
}
}
return 0, false
}
func isDefaultLineValue(value string) bool {
switch normalizeLineValue(value) {
case "", "default", "all", "*", "any", "默认", "全部", "不限":
return true
}
return false
}
func normalizeLineValue(value string) string {
v := strings.ToLower(strings.TrimSpace(value))
v = strings.ReplaceAll(v, " ", "")
v = strings.ReplaceAll(v, "-", "")
v = strings.ReplaceAll(v, "_", "")
v = strings.ReplaceAll(v, "/", "")
return v
}
func normalizeIPCandidate(ip string) string {
value := strings.TrimSpace(ip)
if len(value) == 0 {
return ""
}
if host, _, err := net.SplitHostPort(value); err == nil {
value = strings.TrimSpace(host)
}
if parsed := net.ParseIP(value); parsed != nil {
return parsed.String()
}
return ""
}
func normalizeCountryName(country string) string {
value := strings.TrimSpace(country)
if len(value) == 0 {
return ""
}
normalized := normalizeLineValue(value)
switch normalized {
case "中国香港", "香港", "hongkong":
return "中国香港"
case "中国澳门", "澳门", "macao", "macau":
return "中国澳门"
case "中国台湾", "台湾", "taiwan":
return "中国台湾"
}
switch normalized {
case "中国", "中国大陆", "中国内地", "中华人民共和国", "prc", "cn", "china", "mainlandchina", "peoplesrepublicofchina", "thepeoplesrepublicofchina":
return "中国"
case "美国", "usa", "unitedstates", "unitedstatesofamerica":
return "美国"
case "加拿大", "canada":
return "加拿大"
case "墨西哥", "mexico":
return "墨西哥"
case "日本", "japan":
return "日本"
case "韩国", "southkorea", "korea":
return "韩国"
case "新加坡", "singapore":
return "新加坡"
case "印度", "india":
return "印度"
case "泰国", "thailand":
return "泰国"
case "越南", "vietnam":
return "越南"
case "德国", "germany":
return "德国"
case "英国", "uk", "unitedkingdom", "greatbritain", "britain":
return "英国"
case "法国", "france":
return "法国"
case "荷兰", "netherlands":
return "荷兰"
case "西班牙", "spain":
return "西班牙"
case "意大利", "italy":
return "意大利"
case "俄罗斯", "russia":
return "俄罗斯"
case "巴西", "brazil":
return "巴西"
case "阿根廷", "argentina":
return "阿根廷"
case "智利", "chile":
return "智利"
case "哥伦比亚", "colombia":
return "哥伦比亚"
case "南非", "southafrica":
return "南非"
case "埃及", "egypt":
return "埃及"
case "尼日利亚", "nigeria":
return "尼日利亚"
case "肯尼亚", "kenya":
return "肯尼亚"
case "摩洛哥", "morocco":
return "摩洛哥"
case "澳大利亚", "australia":
return "澳大利亚"
case "新西兰", "newzealand":
return "新西兰"
default:
return value
}
}
func normalizeProvinceName(province string) string {
value := strings.TrimSpace(province)
if len(value) == 0 {
return ""
}
switch value {
case "内蒙古自治区":
return "内蒙古"
case "广西壮族自治区":
return "广西"
case "宁夏回族自治区":
return "宁夏"
case "新疆维吾尔自治区":
return "新疆"
case "西藏自治区":
return "西藏"
case "香港特别行政区":
return "香港"
case "澳门特别行政区":
return "澳门"
}
for _, suffix := range []string{"维吾尔自治区", "回族自治区", "壮族自治区", "自治区", "特别行政区", "省", "市"} {
value = strings.TrimSuffix(value, suffix)
}
return value
}
func ruleLineSummary(rule *pb.HTTPDNSCustomRule) string {
if rule == nil {
return ""
}
scope := strings.ToLower(strings.TrimSpace(rule.GetLineScope()))
if scope == "overseas" {
pieces := make([]string, 0, 2)
if !isDefaultLineValue(rule.GetLineContinent()) {
pieces = append(pieces, strings.TrimSpace(rule.GetLineContinent()))
}
if !isDefaultLineValue(rule.GetLineCountry()) {
pieces = append(pieces, strings.TrimSpace(rule.GetLineCountry()))
}
return strings.Join(pieces, "/")
}
pieces := make([]string, 0, 3)
if !isDefaultLineValue(rule.GetLineCarrier()) {
pieces = append(pieces, strings.TrimSpace(rule.GetLineCarrier()))
}
if !isDefaultLineValue(rule.GetLineRegion()) {
pieces = append(pieces, strings.TrimSpace(rule.GetLineRegion()))
}
if !isDefaultLineValue(rule.GetLineProvince()) {
pieces = append(pieces, strings.TrimSpace(rule.GetLineProvince()))
}
return strings.Join(pieces, "/")
}
func ruleRegionSummary(rule *pb.HTTPDNSCustomRule) string {
if rule == nil {
return ""
}
if !isDefaultLineValue(rule.GetLineProvince()) {
return strings.TrimSpace(rule.GetLineProvince())
}
if !isDefaultLineValue(rule.GetLineCountry()) {
return strings.TrimSpace(rule.GetLineCountry())
}
if !isDefaultLineValue(rule.GetLineRegion()) {
return strings.TrimSpace(rule.GetLineRegion())
}
if !isDefaultLineValue(rule.GetLineContinent()) {
return strings.TrimSpace(rule.GetLineContinent())
}
return ""
}
func (s *ResolveServer) enqueueAccessLog(item *pb.HTTPDNSAccessLog) {
if item == nil {
return
}
select {
case s.logQueue <- item:
default:
log.Println("[HTTPDNS_NODE][resolve]access log queue is full, drop request:", item.GetRequestId())
}
}
func (s *ResolveServer) startAccessLogFlusher() {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
batch := make([]*pb.HTTPDNSAccessLog, 0, 128)
flush := func() {
if len(batch) == 0 {
return
}
rpcClient, err := rpc.SharedRPC()
if err != nil {
log.Println("[HTTPDNS_NODE][resolve]access-log rpc unavailable:", err.Error())
return
}
_, err = rpcClient.HTTPDNSAccessLogRPC.CreateHTTPDNSAccessLogs(rpcClient.Context(), &pb.CreateHTTPDNSAccessLogsRequest{
Logs: batch,
})
if err != nil {
log.Println("[HTTPDNS_NODE][resolve]flush access logs failed:", err.Error())
return
}
batch = batch[:0]
}
for {
select {
case item := <-s.logQueue:
if item != nil {
batch = append(batch, item)
}
if len(batch) > 4096 {
log.Println("[HTTPDNS_NODE][resolve]access log flush backlog too large, trim:", len(batch))
batch = batch[len(batch)-2048:]
}
if len(batch) >= 128 {
flush()
}
case <-ticker.C:
flush()
case <-s.quitCh:
flush()
return
}
}
}