Files
waf-platform/EdgeHttpDNS/internal/nodes/resolve_server.go
2026-03-02 20:07:53 +08:00

1473 lines
39 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package nodes
import (
"context"
"crypto/hmac"
"crypto/sha256"
"crypto/tls"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"log"
"net"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/TeaOSLab/EdgeCommon/pkg/iplibrary"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs"
"github.com/TeaOSLab/EdgeHttpDNS/internal/accesslogs"
"github.com/TeaOSLab/EdgeHttpDNS/internal/configs"
"github.com/iwind/TeaGo/rands"
"github.com/miekg/dns"
)
// sharedRecursionDNSClient 共享的递归DNS客户端对齐 EdgeDNS 实现)
var sharedRecursionDNSClient = &dns.Client{
Timeout: 3 * time.Second,
}
const (
httpdnsCodeSuccess = "SUCCESS"
httpdnsCodeAppInvalid = "APP_NOT_FOUND_OR_DISABLED"
httpdnsCodeDomainNotBound = "DOMAIN_NOT_BOUND"
httpdnsCodeSignInvalid = "SIGN_INVALID"
httpdnsCodeNoRecords = "NO_RECORDS"
httpdnsCodeInternalError = "RESOLVE_TIMEOUT_OR_INTERNAL"
httpdnsCodeMethodNotAllow = "METHOD_NOT_ALLOWED"
httpdnsCodeInvalidArgument = "INVALID_ARGUMENT"
)
type resolveClientInfo struct {
IP string `json:"ip"`
Region string `json:"region"`
Carrier string `json:"carrier"`
Country string `json:"country"`
}
type resolveRecord struct {
Type string `json:"type"`
IP string `json:"ip"`
Weight int32 `json:"weight,omitempty"`
Line string `json:"line,omitempty"`
Region string `json:"region,omitempty"`
}
type resolveData struct {
Domain string `json:"domain"`
QType string `json:"qtype"`
TTL int32 `json:"ttl"`
Records []*resolveRecord `json:"records"`
Client *resolveClientInfo `json:"client"`
Summary string `json:"summary"`
}
type resolveResponse struct {
Code string `json:"code"`
Message string `json:"message"`
RequestID string `json:"requestId"`
Data *resolveData `json:"data,omitempty"`
}
type clientRouteProfile struct {
IP string
Country string
Province string
Carrier string
ProviderRaw string
Region string
Continent string
RegionText string
}
type tlsListener struct {
addr string // e.g. ":443"
listener net.Listener
server *http.Server
}
type ResolveServer struct {
quitCh <-chan struct{}
snapshotManager *SnapshotManager
// Local config fallback
fallbackAddr string
certFile string
keyFile string
handler http.Handler // shared mux
tlsConfig *tls.Config // shared TLS config (with GetCertificate)
logWriter *accesslogs.HTTPDNSFileWriter
logQueue chan *pb.HTTPDNSAccessLog
// TLS certificate hot-reload
certMu sync.RWMutex
currentCert *tls.Certificate
certSnapshotAt int64
// Listener hot-reload
listenerMu sync.Mutex
listeners map[string]*tlsListener // key: addr (e.g. ":443")
}
func NewResolveServer(quitCh <-chan struct{}, snapshotManager *SnapshotManager) *ResolveServer {
fallbackAddr := ":443"
certFile := ""
keyFile := ""
if apiConfig, err := configs.SharedAPIConfig(); err == nil && apiConfig != nil {
if len(apiConfig.HTTPSListenAddr) > 0 {
fallbackAddr = apiConfig.HTTPSListenAddr
}
certFile = apiConfig.HTTPSCert
keyFile = apiConfig.HTTPSKey
}
logWriter := accesslogs.SharedHTTPDNSFileWriter()
if apiConfig, err := configs.SharedAPIConfig(); err == nil && apiConfig != nil {
if len(strings.TrimSpace(apiConfig.LogDir)) > 0 {
logWriter.SetDir(strings.TrimSpace(apiConfig.LogDir))
}
}
if err := logWriter.EnsureInit(); err != nil {
log.Println("[HTTPDNS_NODE][resolve]init access log file writer failed:", err.Error())
}
instance := &ResolveServer{
quitCh: quitCh,
snapshotManager: snapshotManager,
fallbackAddr: fallbackAddr,
certFile: certFile,
keyFile: keyFile,
logWriter: logWriter,
logQueue: make(chan *pb.HTTPDNSAccessLog, 8192),
listeners: make(map[string]*tlsListener),
}
mux := http.NewServeMux()
mux.HandleFunc("/resolve", instance.handleResolve)
mux.HandleFunc("/healthz", instance.handleHealth)
instance.handler = mux
instance.tlsConfig = &tls.Config{
MinVersion: tls.VersionTLS11,
NextProtos: []string{"http/1.1"},
GetCertificate: instance.getCertificate,
}
return instance
}
func (s *ResolveServer) Start() {
go s.startAccessLogFlusher()
// 1. Load initial certificate from file (fallback)
if len(s.certFile) > 0 && len(s.keyFile) > 0 {
cert, err := tls.LoadX509KeyPair(s.certFile, s.keyFile)
if err != nil {
log.Println("[HTTPDNS_NODE][resolve]load cert file failed:", err.Error())
} else {
s.currentCert = &cert
log.Println("[HTTPDNS_NODE][resolve]loaded initial TLS cert from file")
}
}
// 2. Try loading certificate from cluster snapshot (takes priority over file)
if snapshot := s.snapshotManager.Current(); snapshot != nil {
s.reloadCertFromSnapshot(snapshot)
}
if s.currentCert == nil {
log.Println("[HTTPDNS_NODE][resolve]WARNING: no TLS certificate available, HTTPS will fail")
}
// 3. Parse initial listen addresses and start listeners
if snapshot := s.snapshotManager.Current(); snapshot != nil {
addrs := s.desiredAddrs(snapshot)
s.syncListeners(addrs)
} else {
s.syncListeners([]string{s.fallbackAddr})
}
// 4. Watch for changes (blocks until quit)
s.watchLoop()
}
func (s *ResolveServer) getCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
s.certMu.RLock()
cert := s.currentCert
s.certMu.RUnlock()
if cert != nil {
return cert, nil
}
return nil, errors.New("no TLS certificate available")
}
type snapshotTLSConfig struct {
Listen []*serverconfigs.NetworkAddressConfig `json:"listen"`
SSLPolicy *sslconfigs.SSLPolicy `json:"sslPolicy"`
}
func (s *ResolveServer) parseTLSConfig(snapshot *LoadedSnapshot) *snapshotTLSConfig {
if snapshot.ClusterID <= 0 {
return nil
}
cluster := snapshot.Clusters[snapshot.ClusterID]
if cluster == nil {
return nil
}
raw := cluster.GetTlsPolicyJSON()
if len(raw) == 0 {
return nil
}
var cfg snapshotTLSConfig
if err := json.Unmarshal(raw, &cfg); err != nil {
log.Println("[HTTPDNS_NODE][resolve]parse tlsPolicyJSON failed:", err.Error())
return nil
}
return &cfg
}
func (s *ResolveServer) desiredAddrs(snapshot *LoadedSnapshot) []string {
cfg := s.parseTLSConfig(snapshot)
if cfg == nil || len(cfg.Listen) == 0 {
return []string{s.fallbackAddr}
}
seen := make(map[string]struct{})
var addrs []string
for _, listenCfg := range cfg.Listen {
if listenCfg == nil {
continue
}
if err := listenCfg.Init(); err != nil {
log.Println("[HTTPDNS_NODE][resolve]init listen config failed:", err.Error())
continue
}
for _, addr := range listenCfg.Addresses() {
if _, ok := seen[addr]; !ok {
seen[addr] = struct{}{}
addrs = append(addrs, addr)
}
}
}
if len(addrs) == 0 {
return []string{s.fallbackAddr}
}
sort.Strings(addrs)
return addrs
}
func (s *ResolveServer) reloadCertFromSnapshot(snapshot *LoadedSnapshot) {
cfg := s.parseTLSConfig(snapshot)
if cfg == nil || cfg.SSLPolicy == nil || len(cfg.SSLPolicy.Certs) == 0 {
s.certMu.Lock()
s.certSnapshotAt = snapshot.LoadedAt
s.certMu.Unlock()
return
}
if err := cfg.SSLPolicy.Init(context.Background()); err != nil {
log.Println("[HTTPDNS_NODE][resolve]init SSLPolicy failed:", err.Error())
s.certMu.Lock()
s.certSnapshotAt = snapshot.LoadedAt
s.certMu.Unlock()
return
}
cert := cfg.SSLPolicy.FirstCert()
if cert == nil {
s.certMu.Lock()
s.certSnapshotAt = snapshot.LoadedAt
s.certMu.Unlock()
return
}
s.certMu.Lock()
s.currentCert = cert
s.certSnapshotAt = snapshot.LoadedAt
s.certMu.Unlock()
log.Println("[HTTPDNS_NODE][resolve]TLS certificate reloaded from snapshot")
}
func (s *ResolveServer) startListener(addr string) error {
ln, err := net.Listen("tcp", addr)
if err != nil {
return fmt.Errorf("listen on %s: %w", addr, err)
}
tlsLn := tls.NewListener(ln, s.tlsConfig)
srv := &http.Server{
Handler: s.handler,
ReadTimeout: 5 * time.Second,
ReadHeaderTimeout: 3 * time.Second,
WriteTimeout: 5 * time.Second,
IdleTimeout: 75 * time.Second,
MaxHeaderBytes: 8 * 1024,
TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){},
}
s.listeners[addr] = &tlsListener{
addr: addr,
listener: tlsLn,
server: srv,
}
go func() {
if err := srv.Serve(tlsLn); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Println("[HTTPDNS_NODE][resolve]serve failed on", addr, ":", err.Error())
}
}()
log.Println("[HTTPDNS_NODE][resolve]listening HTTPS on", addr)
return nil
}
func (s *ResolveServer) stopListener(addr string) {
tl, ok := s.listeners[addr]
if !ok {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
_ = tl.server.Shutdown(ctx)
delete(s.listeners, addr)
log.Println("[HTTPDNS_NODE][resolve]stopped listener on", addr)
}
func (s *ResolveServer) syncListeners(desired []string) {
s.listenerMu.Lock()
defer s.listenerMu.Unlock()
desiredSet := make(map[string]struct{}, len(desired))
for _, addr := range desired {
desiredSet[addr] = struct{}{}
}
// Stop listeners that are no longer desired
for addr := range s.listeners {
if _, ok := desiredSet[addr]; !ok {
s.stopListener(addr)
}
}
// Start new listeners
for _, addr := range desired {
if _, ok := s.listeners[addr]; !ok {
if err := s.startListener(addr); err != nil {
log.Println("[HTTPDNS_NODE][resolve]start listener failed:", err.Error())
}
}
}
}
func (s *ResolveServer) watchLoop() {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
snapshot := s.snapshotManager.Current()
if snapshot == nil {
continue
}
s.certMu.RLock()
lastAt := s.certSnapshotAt
s.certMu.RUnlock()
if snapshot.LoadedAt == lastAt {
continue
}
// Snapshot changed — sync listeners and reload cert
addrs := s.desiredAddrs(snapshot)
s.syncListeners(addrs)
s.reloadCertFromSnapshot(snapshot)
case <-s.quitCh:
s.shutdownAll()
return
}
}
}
func (s *ResolveServer) shutdownAll() {
s.listenerMu.Lock()
defer s.listenerMu.Unlock()
for addr := range s.listeners {
s.stopListener(addr)
}
}
func (s *ResolveServer) handleHealth(writer http.ResponseWriter, _ *http.Request) {
writer.WriteHeader(http.StatusOK)
_, _ = writer.Write([]byte("ok"))
}
func (s *ResolveServer) handleResolve(writer http.ResponseWriter, request *http.Request) {
startAt := time.Now()
requestID := "rid-" + rands.HexString(16)
if request.Method != http.MethodGet {
s.writeResolveJSON(writer, http.StatusMethodNotAllowed, &resolveResponse{
Code: httpdnsCodeMethodNotAllow,
Message: "只允许使用 GET 方法",
RequestID: requestID,
})
return
}
query := request.URL.Query()
appID := strings.TrimSpace(query.Get("appId"))
domain := strings.TrimSuffix(strings.ToLower(strings.TrimSpace(query.Get("dn"))), ".")
qtype := strings.ToUpper(strings.TrimSpace(query.Get("qtype")))
if len(qtype) == 0 {
qtype = "A"
}
if qtype != "A" && qtype != "AAAA" {
s.writeResolveJSON(writer, http.StatusBadRequest, &resolveResponse{
Code: httpdnsCodeInvalidArgument,
Message: "qtype 参数仅支持 A 或 AAAA",
RequestID: requestID,
})
return
}
if len(appID) == 0 || len(domain) == 0 {
s.writeResolveJSON(writer, http.StatusBadRequest, &resolveResponse{
Code: httpdnsCodeInvalidArgument,
Message: "缺少必填参数: appId 和 dn",
RequestID: requestID,
})
return
}
snapshot := s.snapshotManager.Current()
if snapshot == nil {
s.writeResolveJSON(writer, http.StatusServiceUnavailable, &resolveResponse{
Code: httpdnsCodeInternalError,
Message: "服务节点尚未准备就绪,请稍后再试",
RequestID: requestID,
})
return
}
loadedApp := snapshot.Apps[strings.ToLower(appID)]
if loadedApp == nil || loadedApp.App == nil || !loadedApp.App.GetIsOn() {
s.writeFailedResolve(writer, requestID, snapshot, nil, domain, qtype, httpdnsCodeAppInvalid, "找不到指定的应用,或该应用已下线", startAt, request, query)
return
}
if snapshot.ClusterID > 0 {
var appClusterIds []int64
if len(loadedApp.App.GetClusterIdsJSON()) > 0 {
_ = json.Unmarshal(loadedApp.App.GetClusterIdsJSON(), &appClusterIds)
}
var clusterBound bool
for _, cid := range appClusterIds {
if cid == snapshot.ClusterID {
clusterBound = true
break
}
}
if !clusterBound {
s.writeFailedResolve(writer, requestID, snapshot, loadedApp.App, domain, qtype, httpdnsCodeAppInvalid, "当前应用未绑定到该解析集群", startAt, request, query)
return
}
}
loadedDomain := loadedApp.Domains[domain]
if loadedDomain == nil || loadedDomain.Domain == nil || !loadedDomain.Domain.GetIsOn() {
s.writeFailedResolve(writer, requestID, snapshot, loadedApp.App, domain, qtype, httpdnsCodeDomainNotBound, "应用尚未绑定该域名,或域名解析已暂停", startAt, request, query)
return
}
if loadedApp.App.GetSignEnabled() {
if !validateResolveSign(loadedApp.App.GetSignSecret(), loadedApp.App.GetAppId(), domain, qtype, query.Get("nonce"), query.Get("exp"), query.Get("sign")) {
s.writeFailedResolve(writer, requestID, snapshot, loadedApp.App, domain, qtype, httpdnsCodeSignInvalid, "请求鉴权失败:签名无效或已过期", startAt, request, query)
return
}
}
clientIP := detectClientIP(request, query.Get("cip"))
clientProfile := buildClientRouteProfile(clientIP)
clusterTTL := pickDefaultTTL(snapshot, loadedApp.App)
rule, records, ttl := pickRuleRecords(loadedDomain.Rules, qtype, clientProfile, clusterTTL)
if len(records) == 0 {
// Fallback回源上游 DNS 查询真实记录
fallbackRecords, fallbackTTL, fallbackErr := fallbackResolve(domain, qtype)
if fallbackErr != nil || len(fallbackRecords) == 0 {
errMsg := "未找到解析记录"
if fallbackErr != nil {
errMsg = "未找到解析记录 (上游回源失败: " + fallbackErr.Error() + ")"
}
s.writeFailedResolve(writer, requestID, snapshot, loadedApp.App, domain, qtype, httpdnsCodeNoRecords, errMsg, startAt, request, query)
return
}
records = fallbackRecords
if fallbackTTL > 0 {
ttl = fallbackTTL
} else {
ttl = clusterTTL
}
}
if ttl <= 0 {
ttl = clusterTTL
}
if ttl <= 0 {
ttl = 30
}
resultIPs := make([]string, 0, len(records))
for _, record := range records {
resultIPs = append(resultIPs, record.IP)
}
summary := fmt.Sprintf("%s|%s(%s)|%s|%s %s -> %s|success|%dms",
time.Now().Format("2006-01-02 15:04:05"),
loadedApp.App.GetName(),
loadedApp.App.GetAppId(),
clientProfile.IP,
qtype,
domain,
strings.Join(resultIPs, ", "),
time.Since(startAt).Milliseconds(),
)
if rule != nil && len(strings.TrimSpace(rule.GetRuleName())) > 0 {
summary += "|rule:" + strings.TrimSpace(rule.GetRuleName())
}
s.writeResolveJSON(writer, http.StatusOK, &resolveResponse{
Code: httpdnsCodeSuccess,
Message: "ok",
RequestID: requestID,
Data: &resolveData{
Domain: domain,
QType: qtype,
TTL: ttl,
Records: records,
Client: &resolveClientInfo{
IP: clientProfile.IP,
Region: clientProfile.RegionText,
Carrier: clientProfile.Carrier,
Country: clientProfile.Country,
},
Summary: summary,
},
})
s.enqueueAccessLog(&pb.HTTPDNSAccessLog{
RequestId: requestID,
ClusterId: snapshot.ClusterID,
NodeId: snapshot.NodeID,
AppId: loadedApp.App.GetAppId(),
AppName: loadedApp.App.GetName(),
Domain: domain,
Qtype: qtype,
ClientIP: clientProfile.IP,
ClientRegion: clientProfile.RegionText,
Carrier: clientProfile.Carrier,
SdkVersion: strings.TrimSpace(query.Get("sdk_version")),
Os: strings.TrimSpace(query.Get("os")),
ResultIPs: strings.Join(resultIPs, ","),
Status: "success",
ErrorCode: "none",
CostMs: int32(time.Since(startAt).Milliseconds()),
CreatedAt: time.Now().Unix(),
Day: time.Now().Format("20060102"),
Summary: summary,
})
}
func pickDefaultTTL(snapshot *LoadedSnapshot, app *pb.HTTPDNSApp) int32 {
if snapshot == nil {
return 30
}
if snapshot.ClusterID > 0 {
if cluster := snapshot.Clusters[snapshot.ClusterID]; cluster != nil && cluster.GetDefaultTTL() > 0 {
return cluster.GetDefaultTTL()
}
}
if app != nil {
var appClusterIds []int64
if len(app.GetClusterIdsJSON()) > 0 {
_ = json.Unmarshal(app.GetClusterIdsJSON(), &appClusterIds)
}
for _, cid := range appClusterIds {
if cluster := snapshot.Clusters[cid]; cluster != nil && cluster.GetDefaultTTL() > 0 {
return cluster.GetDefaultTTL()
}
}
}
return 30
}
func (s *ResolveServer) writeFailedResolve(
writer http.ResponseWriter,
requestID string,
snapshot *LoadedSnapshot,
app *pb.HTTPDNSApp,
domain string,
qtype string,
errorCode string,
message string,
startAt time.Time,
request *http.Request,
query url.Values,
) {
clientIP := detectClientIP(request, query.Get("cip"))
clientProfile := buildClientRouteProfile(clientIP)
appID := ""
appName := ""
if app != nil {
appID = app.GetAppId()
appName = app.GetName()
}
summary := fmt.Sprintf("%s|%s(%s)|%s|%s %s -> [none]|failed(%s)|%dms",
time.Now().Format("2006-01-02 15:04:05"),
appName,
appID,
clientProfile.IP,
qtype,
domain,
errorCode,
time.Since(startAt).Milliseconds(),
)
s.writeResolveJSON(writer, http.StatusOK, &resolveResponse{
Code: errorCode,
Message: message,
RequestID: requestID,
})
clusterID := int64(0)
nodeID := int64(0)
if snapshot != nil {
clusterID = snapshot.ClusterID
nodeID = snapshot.NodeID
}
s.enqueueAccessLog(&pb.HTTPDNSAccessLog{
RequestId: requestID,
ClusterId: clusterID,
NodeId: nodeID,
AppId: appID,
AppName: appName,
Domain: domain,
Qtype: qtype,
ClientIP: clientProfile.IP,
ClientRegion: clientProfile.RegionText,
Carrier: clientProfile.Carrier,
SdkVersion: strings.TrimSpace(query.Get("sdk_version")),
Os: strings.TrimSpace(query.Get("os")),
ResultIPs: "",
Status: "failed",
ErrorCode: errorCode,
CostMs: int32(time.Since(startAt).Milliseconds()),
CreatedAt: time.Now().Unix(),
Day: time.Now().Format("20060102"),
Summary: summary,
})
}
func (s *ResolveServer) writeResolveJSON(writer http.ResponseWriter, status int, resp *resolveResponse) {
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(status)
data, err := json.Marshal(resp)
if err != nil {
_, _ = writer.Write([]byte(`{"code":"RESOLVE_TIMEOUT_OR_INTERNAL","message":"encode response failed"}`))
return
}
_, _ = writer.Write(data)
}
func detectClientIP(request *http.Request, cip string) string {
if candidate := normalizeIPCandidate(cip); len(candidate) > 0 {
return candidate
}
xff := strings.TrimSpace(request.Header.Get("X-Forwarded-For"))
if len(xff) > 0 {
for _, item := range strings.Split(xff, ",") {
if candidate := normalizeIPCandidate(item); len(candidate) > 0 {
return candidate
}
}
}
headerKeys := []string{"X-Real-IP", "X-Client-IP", "CF-Connecting-IP", "True-Client-IP"}
for _, key := range headerKeys {
if candidate := normalizeIPCandidate(request.Header.Get(key)); len(candidate) > 0 {
return candidate
}
}
return normalizeIPCandidate(request.RemoteAddr)
}
func validateResolveSign(signSecret string, appID string, domain string, qtype string, nonce string, exp string, sign string) bool {
signSecret = strings.TrimSpace(signSecret)
nonce = strings.TrimSpace(nonce)
exp = strings.TrimSpace(exp)
sign = strings.TrimSpace(sign)
if len(signSecret) == 0 || len(nonce) == 0 || len(exp) == 0 || len(sign) == 0 {
return false
}
expireAt, err := strconv.ParseInt(exp, 10, 64)
if err != nil {
return false
}
now := time.Now().Unix()
if expireAt <= now-30 || expireAt > now+86400 {
return false
}
raw := appID + "|" + strings.ToLower(domain) + "|" + strings.ToUpper(qtype) + "|" + exp + "|" + nonce
mac := hmac.New(sha256.New, []byte(signSecret))
_, _ = mac.Write([]byte(raw))
expected := hex.EncodeToString(mac.Sum(nil))
return strings.EqualFold(expected, sign)
}
func buildClientRouteProfile(ip string) *clientRouteProfile {
profile := &clientRouteProfile{
IP: normalizeIPCandidate(ip),
}
if net.ParseIP(profile.IP) == nil {
return profile
}
result := iplibrary.LookupIP(profile.IP)
if result == nil || !result.IsOk() {
return profile
}
profile.Country = normalizeCountryName(strings.TrimSpace(result.CountryName()))
profile.Province = normalizeProvinceName(strings.TrimSpace(result.ProvinceName()))
profile.ProviderRaw = strings.TrimSpace(result.ProviderName())
profile.Carrier = normalizeCarrier(profile.ProviderRaw, profile.Country)
if len(profile.Carrier) == 0 {
if isMainlandChinaCountry(profile.Country) {
profile.Carrier = "默认"
} else {
if len(profile.ProviderRaw) > 0 {
profile.Carrier = profile.ProviderRaw
} else {
profile.Carrier = "默认"
}
}
}
profile.Region = normalizeChinaRegion(profile.Province)
profile.Continent = normalizeContinent(profile.Country)
profile.RegionText = strings.TrimSpace(result.RegionSummary())
if len(profile.RegionText) == 0 {
pieces := make([]string, 0, 4)
if len(profile.Country) > 0 {
pieces = append(pieces, profile.Country)
}
if len(profile.Province) > 0 {
pieces = append(pieces, profile.Province)
}
if len(profile.Region) > 0 {
pieces = append(pieces, profile.Region)
}
if len(profile.Carrier) > 0 {
pieces = append(pieces, profile.Carrier)
}
profile.RegionText = strings.Join(pieces, " ")
}
return profile
}
func normalizeCarrier(provider string, country string) string {
value := strings.TrimSpace(provider)
if len(value) == 0 {
return ""
}
lower := strings.ToLower(value)
switch {
case strings.Contains(value, "电信"), strings.Contains(value, "天翼"),
strings.Contains(lower, "telecom"), strings.Contains(lower, "chinanet"),
strings.Contains(lower, "chinatelecom"), strings.Contains(lower, "ctnet"), strings.Contains(lower, "cn2"):
return "电信"
case strings.Contains(value, "联通"), strings.Contains(value, "网通"),
strings.Contains(lower, "unicom"), strings.Contains(lower, "chinaunicom"),
strings.Contains(lower, "cucc"), strings.Contains(lower, "china169"), strings.Contains(lower, "cnc"):
return "联通"
case strings.Contains(value, "移动"),
strings.Contains(lower, "mobile"), strings.Contains(lower, "chinamobile"),
strings.Contains(lower, "cmcc"), strings.Contains(lower, "cmnet"):
return "移动"
case strings.Contains(value, "教育"),
strings.Contains(lower, "cernet"), strings.Contains(lower, "edu"), strings.Contains(lower, "education"):
return "教育网"
case strings.Contains(value, "鹏博士"),
strings.Contains(lower, "drpeng"), strings.Contains(lower, "dr.peng"), strings.Contains(lower, "dr_peng"):
return "鹏博士"
case strings.Contains(value, "广电"),
strings.Contains(lower, "broadcast"), strings.Contains(lower, "cable"), strings.Contains(lower, "radio"):
return "广电"
default:
if isMainlandChinaCountry(country) {
return ""
}
return value
}
}
func normalizeChinaRegion(province string) string {
switch normalizeProvinceName(province) {
case "辽宁", "吉林", "黑龙江":
return "东北"
case "北京", "天津", "河北", "山西", "内蒙古":
return "华北"
case "上海", "江苏", "浙江", "安徽", "福建", "江西", "山东":
return "华东"
case "广东", "广西", "海南":
return "华南"
case "河南", "湖北", "湖南":
return "华中"
case "陕西", "甘肃", "青海", "宁夏", "新疆":
return "西北"
case "重庆", "四川", "贵州", "云南", "西藏":
return "西南"
default:
return ""
}
}
func normalizeContinent(country string) string {
switch normalizeCountryName(country) {
case "中国", "中国香港", "中国澳门", "中国台湾", "日本", "韩国", "新加坡", "印度", "泰国", "越南",
"印度尼西亚", "马来西亚", "菲律宾", "柬埔寨", "缅甸", "老挝", "斯里兰卡", "孟加拉国", "巴基斯坦", "尼泊尔",
"阿联酋", "沙特阿拉伯", "土耳其", "以色列", "伊朗", "伊拉克", "卡塔尔", "科威特", "蒙古":
return "亚洲"
case "美国", "加拿大", "墨西哥", "巴拿马", "哥斯达黎加", "古巴":
return "北美洲"
case "巴西", "阿根廷", "智利", "哥伦比亚", "秘鲁", "委内瑞拉", "厄瓜多尔":
return "南美洲"
case "德国", "英国", "法国", "荷兰", "西班牙", "意大利", "俄罗斯",
"波兰", "瑞典", "瑞士", "挪威", "芬兰", "丹麦", "葡萄牙", "爱尔兰", "比利时", "奥地利",
"乌克兰", "捷克", "罗马尼亚", "匈牙利", "希腊":
return "欧洲"
case "南非", "埃及", "尼日利亚", "肯尼亚", "摩洛哥", "阿尔及利亚", "坦桑尼亚", "埃塞俄比亚", "加纳", "突尼斯":
return "非洲"
case "澳大利亚", "新西兰":
return "大洋洲"
default:
return ""
}
}
func pickRuleRecords(rules []*pb.HTTPDNSCustomRule, qtype string, profile *clientRouteProfile, defaultTTL int32) (*pb.HTTPDNSCustomRule, []*resolveRecord, int32) {
bestScore := -1
var bestRule *pb.HTTPDNSCustomRule
var bestRecords []*resolveRecord
bestTTL := defaultTTL
for _, rule := range rules {
if rule == nil || !rule.GetIsOn() {
continue
}
score, ok := matchRuleLine(rule, profile)
if !ok {
continue
}
records := make([]*resolveRecord, 0)
for _, item := range rule.GetRecords() {
if item == nil {
continue
}
if strings.ToUpper(strings.TrimSpace(item.GetRecordType())) != qtype {
continue
}
value := strings.TrimSpace(item.GetRecordValue())
if len(value) == 0 {
continue
}
records = append(records, &resolveRecord{
Type: qtype,
IP: value,
Weight: item.GetWeight(),
Line: lookupIPLineLabel(value),
Region: lookupIPRegionSummary(value),
})
}
if len(records) == 0 {
continue
}
if score > bestScore {
bestScore = score
bestRule = rule
bestRecords = records
if rule.GetTtl() > 0 {
bestTTL = rule.GetTtl()
} else {
bestTTL = defaultTTL
}
}
}
return bestRule, bestRecords, bestTTL
}
// fallbackResolve 当无自定义规则命中时,回源上游 DNS 查询真实记录(对齐 EdgeDNS 做法)
func fallbackResolve(domain string, qtype string) ([]*resolveRecord, int32, error) {
var dnsType uint16
switch qtype {
case "A":
dnsType = dns.TypeA
case "AAAA":
dnsType = dns.TypeAAAA
default:
return nil, 0, nil // 仅降级处理 A 和 AAAA
}
m := new(dns.Msg)
m.SetQuestion(dns.Fqdn(domain), dnsType)
m.RecursionDesired = true
// 优先使用本机 /etc/resolv.conf 中的 DNS 服务器(对齐 EdgeDNS
var upstream = "223.5.5.5:53"
resolveConfig, confErr := dns.ClientConfigFromFile("/etc/resolv.conf")
if confErr == nil && len(resolveConfig.Servers) > 0 {
port := resolveConfig.Port
if len(port) == 0 {
port = "53"
}
server := resolveConfig.Servers[rands.Int(0, len(resolveConfig.Servers)-1)]
upstream = server + ":" + port
}
r, _, err := sharedRecursionDNSClient.Exchange(m, upstream)
if err != nil {
return nil, 0, err
}
if r.Rcode != dns.RcodeSuccess {
return nil, 0, fmt.Errorf("upstream rcode: %d", r.Rcode)
}
var records []*resolveRecord
var responseTTL int32
for _, ans := range r.Answer {
switch t := ans.(type) {
case *dns.A:
if qtype == "A" {
if t.Hdr.Ttl > 0 {
ttl := int32(t.Hdr.Ttl)
if responseTTL == 0 || ttl < responseTTL {
responseTTL = ttl
}
}
ip := t.A.String()
records = append(records, &resolveRecord{
Type: "A",
IP: ip,
Line: lookupIPLineLabel(ip),
Region: lookupIPRegionSummary(ip),
})
}
case *dns.AAAA:
if qtype == "AAAA" {
if t.Hdr.Ttl > 0 {
ttl := int32(t.Hdr.Ttl)
if responseTTL == 0 || ttl < responseTTL {
responseTTL = ttl
}
}
ip := t.AAAA.String()
records = append(records, &resolveRecord{
Type: "AAAA",
IP: ip,
Line: lookupIPLineLabel(ip),
Region: lookupIPRegionSummary(ip),
})
}
}
}
return records, responseTTL, nil
}
func lookupIPRegionSummary(ip string) string {
address := strings.TrimSpace(ip)
if len(address) == 0 || net.ParseIP(address) == nil {
return ""
}
result := iplibrary.LookupIP(address)
if result == nil || !result.IsOk() {
return ""
}
return strings.TrimSpace(result.RegionSummary())
}
func lookupIPLineLabel(ip string) string {
address := strings.TrimSpace(ip)
if len(address) == 0 || net.ParseIP(address) == nil {
return "上游DNS"
}
result := iplibrary.LookupIP(address)
if result == nil || !result.IsOk() {
return "上游DNS"
}
provider := strings.TrimSpace(result.ProviderName())
country := normalizeCountryName(strings.TrimSpace(result.CountryName()))
carrier := strings.TrimSpace(normalizeCarrier(provider, country))
if isMainlandChinaCountry(country) {
if len(carrier) > 0 {
return carrier
}
return "默认"
}
if len(carrier) > 0 {
return carrier
}
if len(provider) > 0 {
return provider
}
return "上游DNS"
}
func isMainlandChinaCountry(country string) bool {
switch normalizeCountryName(country) {
case "中国":
return true
}
return false
}
func matchRuleLine(rule *pb.HTTPDNSCustomRule, profile *clientRouteProfile) (int, bool) {
scope := strings.ToLower(strings.TrimSpace(rule.GetLineScope()))
score := 0
if scope == "overseas" {
// 境外规则只匹配非中国大陆来源
if isMainlandChinaCountry(profile.Country) {
return 0, false
}
fieldScore, ok := matchRuleField(rule.GetLineContinent(), profile.Continent)
if !ok {
return 0, false
}
score += fieldScore
fieldScore, ok = matchRuleField(rule.GetLineCountry(), profile.Country)
if !ok {
return 0, false
}
score += fieldScore
return score, true
}
// 中国地区规则只匹配中国大陆来源
if !isMainlandChinaCountry(profile.Country) {
return 0, false
}
fieldScore, ok := matchRuleField(rule.GetLineCountry(), profile.Country)
if !ok {
return 0, false
}
score += fieldScore
fieldScore, ok = matchRuleField(rule.GetLineContinent(), profile.Continent)
if !ok {
return 0, false
}
score += fieldScore
fieldScore, ok = matchRuleField(rule.GetLineCarrier(), profile.Carrier, profile.ProviderRaw)
if !ok {
return 0, false
}
score += fieldScore
fieldScore, ok = matchRuleField(rule.GetLineRegion(), profile.Region)
if !ok {
return 0, false
}
score += fieldScore
fieldScore, ok = matchRuleField(rule.GetLineProvince(), profile.Province)
if !ok {
return 0, false
}
score += fieldScore
return score, true
}
func matchRuleField(ruleValue string, candidates ...string) (int, bool) {
if isDefaultLineValue(ruleValue) {
return 0, true
}
want := normalizeLineValue(ruleValue)
if len(want) == 0 {
return 0, true
}
for _, candidate := range candidates {
got := normalizeLineValue(candidate)
if len(got) == 0 {
// 如果规则设了具体值want但客户端信息got为空则不能匹配
// 否则 strings.Contains("xxx", "") 永远为 true
continue
}
if want == got || strings.Contains(got, want) || strings.Contains(want, got) {
return 1, true
}
}
return 0, false
}
func isDefaultLineValue(value string) bool {
switch normalizeLineValue(value) {
case "", "default", "all", "*", "any", "默认", "全部", "不限":
return true
}
return false
}
func normalizeLineValue(value string) string {
v := strings.ToLower(strings.TrimSpace(value))
v = strings.ReplaceAll(v, " ", "")
v = strings.ReplaceAll(v, "-", "")
v = strings.ReplaceAll(v, "_", "")
v = strings.ReplaceAll(v, "/", "")
return v
}
func normalizeIPCandidate(ip string) string {
value := strings.TrimSpace(ip)
if len(value) == 0 {
return ""
}
if host, _, err := net.SplitHostPort(value); err == nil {
value = strings.TrimSpace(host)
}
if parsed := net.ParseIP(value); parsed != nil {
return parsed.String()
}
return ""
}
func normalizeCountryName(country string) string {
value := strings.TrimSpace(country)
if len(value) == 0 {
return ""
}
normalized := normalizeLineValue(value)
switch normalized {
case "中国香港", "香港", "hongkong":
return "中国香港"
case "中国澳门", "澳门", "macao", "macau":
return "中国澳门"
case "中国台湾", "台湾", "taiwan":
return "中国台湾"
}
switch normalized {
case "中国", "中国大陆", "中国内地", "中华人民共和国", "prc", "cn", "china", "mainlandchina", "peoplesrepublicofchina", "thepeoplesrepublicofchina":
return "中国"
case "美国", "usa", "unitedstates", "unitedstatesofamerica":
return "美国"
case "加拿大", "canada":
return "加拿大"
case "墨西哥", "mexico":
return "墨西哥"
case "日本", "japan":
return "日本"
case "韩国", "southkorea", "korea":
return "韩国"
case "新加坡", "singapore":
return "新加坡"
case "印度", "india":
return "印度"
case "泰国", "thailand":
return "泰国"
case "越南", "vietnam":
return "越南"
case "德国", "germany":
return "德国"
case "英国", "uk", "unitedkingdom", "greatbritain", "britain":
return "英国"
case "法国", "france":
return "法国"
case "荷兰", "netherlands":
return "荷兰"
case "西班牙", "spain":
return "西班牙"
case "意大利", "italy":
return "意大利"
case "俄罗斯", "russia":
return "俄罗斯"
case "巴西", "brazil":
return "巴西"
case "阿根廷", "argentina":
return "阿根廷"
case "智利", "chile":
return "智利"
case "哥伦比亚", "colombia":
return "哥伦比亚"
// --- 亚洲(新增)---
case "印度尼西亚", "indonesia":
return "印度尼西亚"
case "马来西亚", "malaysia":
return "马来西亚"
case "菲律宾", "philippines":
return "菲律宾"
case "柬埔寨", "cambodia":
return "柬埔寨"
case "缅甸", "myanmar", "burma":
return "缅甸"
case "老挝", "laos":
return "老挝"
case "斯里兰卡", "srilanka":
return "斯里兰卡"
case "孟加拉国", "孟加拉", "bangladesh":
return "孟加拉国"
case "巴基斯坦", "pakistan":
return "巴基斯坦"
case "尼泊尔", "nepal":
return "尼泊尔"
case "阿联酋", "阿拉伯联合酋长国", "uae", "unitedarabemirates":
return "阿联酋"
case "沙特阿拉伯", "沙特", "saudiarabia", "saudi":
return "沙特阿拉伯"
case "土耳其", "turkey", "türkiye", "turkiye":
return "土耳其"
case "以色列", "israel":
return "以色列"
case "伊朗", "iran":
return "伊朗"
case "伊拉克", "iraq":
return "伊拉克"
case "卡塔尔", "qatar":
return "卡塔尔"
case "科威特", "kuwait":
return "科威特"
case "蒙古", "mongolia":
return "蒙古"
// --- 欧洲(新增)---
case "波兰", "poland":
return "波兰"
case "瑞典", "sweden":
return "瑞典"
case "瑞士", "switzerland":
return "瑞士"
case "挪威", "norway":
return "挪威"
case "芬兰", "finland":
return "芬兰"
case "丹麦", "denmark":
return "丹麦"
case "葡萄牙", "portugal":
return "葡萄牙"
case "爱尔兰", "ireland":
return "爱尔兰"
case "比利时", "belgium":
return "比利时"
case "奥地利", "austria":
return "奥地利"
case "乌克兰", "ukraine":
return "乌克兰"
case "捷克", "czech", "czechrepublic", "czechia":
return "捷克"
case "罗马尼亚", "romania":
return "罗马尼亚"
case "匈牙利", "hungary":
return "匈牙利"
case "希腊", "greece":
return "希腊"
// --- 北美洲(新增)---
case "巴拿马", "panama":
return "巴拿马"
case "哥斯达黎加", "costarica":
return "哥斯达黎加"
case "古巴", "cuba":
return "古巴"
// --- 南美洲(新增)---
case "秘鲁", "peru":
return "秘鲁"
case "委内瑞拉", "venezuela":
return "委内瑞拉"
case "厄瓜多尔", "ecuador":
return "厄瓜多尔"
// --- 非洲 ---
case "南非", "southafrica":
return "南非"
case "埃及", "egypt":
return "埃及"
case "尼日利亚", "nigeria":
return "尼日利亚"
case "肯尼亚", "kenya":
return "肯尼亚"
case "摩洛哥", "morocco":
return "摩洛哥"
case "阿尔及利亚", "algeria":
return "阿尔及利亚"
case "坦桑尼亚", "tanzania":
return "坦桑尼亚"
case "埃塞俄比亚", "ethiopia":
return "埃塞俄比亚"
case "加纳", "ghana":
return "加纳"
case "突尼斯", "tunisia":
return "突尼斯"
// --- 大洋洲 ---
case "澳大利亚", "australia":
return "澳大利亚"
case "新西兰", "newzealand":
return "新西兰"
default:
return value
}
}
func normalizeProvinceName(province string) string {
value := strings.TrimSpace(province)
if len(value) == 0 {
return ""
}
switch value {
case "内蒙古自治区":
return "内蒙古"
case "广西壮族自治区":
return "广西"
case "宁夏回族自治区":
return "宁夏"
case "新疆维吾尔自治区":
return "新疆"
case "西藏自治区":
return "西藏"
case "香港特别行政区":
return "香港"
case "澳门特别行政区":
return "澳门"
}
for _, suffix := range []string{"维吾尔自治区", "回族自治区", "壮族自治区", "自治区", "特别行政区", "省", "市"} {
value = strings.TrimSuffix(value, suffix)
}
return value
}
func ruleLineSummary(rule *pb.HTTPDNSCustomRule) string {
if rule == nil {
return ""
}
scope := strings.ToLower(strings.TrimSpace(rule.GetLineScope()))
if scope == "overseas" {
pieces := make([]string, 0, 2)
if !isDefaultLineValue(rule.GetLineContinent()) {
pieces = append(pieces, strings.TrimSpace(rule.GetLineContinent()))
}
if !isDefaultLineValue(rule.GetLineCountry()) {
pieces = append(pieces, strings.TrimSpace(rule.GetLineCountry()))
}
return strings.Join(pieces, "/")
}
pieces := make([]string, 0, 3)
if !isDefaultLineValue(rule.GetLineCarrier()) {
pieces = append(pieces, strings.TrimSpace(rule.GetLineCarrier()))
}
if !isDefaultLineValue(rule.GetLineRegion()) {
pieces = append(pieces, strings.TrimSpace(rule.GetLineRegion()))
}
if !isDefaultLineValue(rule.GetLineProvince()) {
pieces = append(pieces, strings.TrimSpace(rule.GetLineProvince()))
}
return strings.Join(pieces, "/")
}
func ruleRegionSummary(rule *pb.HTTPDNSCustomRule) string {
if rule == nil {
return ""
}
if !isDefaultLineValue(rule.GetLineProvince()) {
return strings.TrimSpace(rule.GetLineProvince())
}
if !isDefaultLineValue(rule.GetLineCountry()) {
return strings.TrimSpace(rule.GetLineCountry())
}
if !isDefaultLineValue(rule.GetLineRegion()) {
return strings.TrimSpace(rule.GetLineRegion())
}
if !isDefaultLineValue(rule.GetLineContinent()) {
return strings.TrimSpace(rule.GetLineContinent())
}
return ""
}
func (s *ResolveServer) enqueueAccessLog(item *pb.HTTPDNSAccessLog) {
if item == nil {
return
}
select {
case s.logQueue <- item:
default:
log.Println("[HTTPDNS_NODE][resolve]access log queue is full, drop request:", item.GetRequestId())
}
}
func (s *ResolveServer) startAccessLogFlusher() {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
batch := make([]*pb.HTTPDNSAccessLog, 0, 128)
flush := func() {
if len(batch) == 0 {
return
}
s.logWriter.WriteBatch(batch)
batch = batch[:0]
}
for {
select {
case item := <-s.logQueue:
if item != nil {
batch = append(batch, item)
}
if len(batch) > 4096 {
log.Println("[HTTPDNS_NODE][resolve]access log flush backlog too large, trim:", len(batch))
batch = batch[len(batch)-2048:]
}
if len(batch) >= 128 {
flush()
}
case <-ticker.C:
flush()
case <-s.quitCh:
flush()
_ = s.logWriter.Close()
return
}
}
}