package nodes import ( "context" "crypto/hmac" "crypto/sha256" "crypto/tls" "encoding/hex" "encoding/json" "errors" "fmt" "log" "net" "net/http" "net/url" "sort" "strconv" "strings" "sync" "time" "github.com/TeaOSLab/EdgeCommon/pkg/iplibrary" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs" "github.com/TeaOSLab/EdgeHttpDNS/internal/accesslogs" "github.com/TeaOSLab/EdgeHttpDNS/internal/configs" "github.com/iwind/TeaGo/rands" "github.com/miekg/dns" ) // sharedRecursionDNSClient 共享的递归DNS客户端(对齐 EdgeDNS 实现) var sharedRecursionDNSClient = &dns.Client{ Timeout: 3 * time.Second, } const ( httpdnsCodeSuccess = "SUCCESS" httpdnsCodeAppInvalid = "APP_NOT_FOUND_OR_DISABLED" httpdnsCodeDomainNotBound = "DOMAIN_NOT_BOUND" httpdnsCodeSignInvalid = "SIGN_INVALID" httpdnsCodeNoRecords = "NO_RECORDS" httpdnsCodeInternalError = "RESOLVE_TIMEOUT_OR_INTERNAL" httpdnsCodeMethodNotAllow = "METHOD_NOT_ALLOWED" httpdnsCodeInvalidArgument = "INVALID_ARGUMENT" ) type resolveClientInfo struct { IP string `json:"ip"` Region string `json:"region"` Carrier string `json:"carrier"` Country string `json:"country"` } type resolveRecord struct { Type string `json:"type"` IP string `json:"ip"` Weight int32 `json:"weight,omitempty"` Line string `json:"line,omitempty"` Region string `json:"region,omitempty"` } type resolveData struct { Domain string `json:"domain"` QType string `json:"qtype"` TTL int32 `json:"ttl"` Records []*resolveRecord `json:"records"` Client *resolveClientInfo `json:"client"` Summary string `json:"summary"` } type resolveResponse struct { Code string `json:"code"` Message string `json:"message"` RequestID string `json:"requestId"` Data *resolveData `json:"data,omitempty"` } type clientRouteProfile struct { IP string Country string Province string Carrier string ProviderRaw string Region string Continent string RegionText string } type tlsListener struct { addr string // e.g. ":443" listener net.Listener server *http.Server } type ResolveServer struct { quitCh <-chan struct{} snapshotManager *SnapshotManager // Local config fallback fallbackAddr string certFile string keyFile string handler http.Handler // shared mux tlsConfig *tls.Config // shared TLS config (with GetCertificate) logWriter *accesslogs.HTTPDNSFileWriter logQueue chan *pb.HTTPDNSAccessLog // TLS certificate hot-reload certMu sync.RWMutex currentCert *tls.Certificate certSnapshotAt int64 // Listener hot-reload listenerMu sync.Mutex listeners map[string]*tlsListener // key: addr (e.g. ":443") } func NewResolveServer(quitCh <-chan struct{}, snapshotManager *SnapshotManager) *ResolveServer { fallbackAddr := ":443" certFile := "" keyFile := "" if apiConfig, err := configs.SharedAPIConfig(); err == nil && apiConfig != nil { if len(apiConfig.HTTPSListenAddr) > 0 { fallbackAddr = apiConfig.HTTPSListenAddr } certFile = apiConfig.HTTPSCert keyFile = apiConfig.HTTPSKey } logWriter := accesslogs.SharedHTTPDNSFileWriter() if apiConfig, err := configs.SharedAPIConfig(); err == nil && apiConfig != nil { if len(strings.TrimSpace(apiConfig.LogDir)) > 0 { logWriter.SetDir(strings.TrimSpace(apiConfig.LogDir)) } } if err := logWriter.EnsureInit(); err != nil { log.Println("[HTTPDNS_NODE][resolve]init access log file writer failed:", err.Error()) } instance := &ResolveServer{ quitCh: quitCh, snapshotManager: snapshotManager, fallbackAddr: fallbackAddr, certFile: certFile, keyFile: keyFile, logWriter: logWriter, logQueue: make(chan *pb.HTTPDNSAccessLog, 8192), listeners: make(map[string]*tlsListener), } mux := http.NewServeMux() mux.HandleFunc("/resolve", instance.handleResolve) mux.HandleFunc("/healthz", instance.handleHealth) instance.handler = mux instance.tlsConfig = &tls.Config{ MinVersion: tls.VersionTLS11, NextProtos: []string{"http/1.1"}, GetCertificate: instance.getCertificate, } return instance } func (s *ResolveServer) Start() { go s.startAccessLogFlusher() // 1. Load initial certificate from file (fallback) if len(s.certFile) > 0 && len(s.keyFile) > 0 { cert, err := tls.LoadX509KeyPair(s.certFile, s.keyFile) if err != nil { log.Println("[HTTPDNS_NODE][resolve]load cert file failed:", err.Error()) } else { s.currentCert = &cert log.Println("[HTTPDNS_NODE][resolve]loaded initial TLS cert from file") } } // 2. Try loading certificate from cluster snapshot (takes priority over file) if snapshot := s.snapshotManager.Current(); snapshot != nil { s.reloadCertFromSnapshot(snapshot) } if s.currentCert == nil { log.Println("[HTTPDNS_NODE][resolve]WARNING: no TLS certificate available, HTTPS will fail") } // 3. Parse initial listen addresses and start listeners if snapshot := s.snapshotManager.Current(); snapshot != nil { addrs := s.desiredAddrs(snapshot) s.syncListeners(addrs) } else { s.syncListeners([]string{s.fallbackAddr}) } // 4. Watch for changes (blocks until quit) s.watchLoop() } func (s *ResolveServer) getCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { s.certMu.RLock() cert := s.currentCert s.certMu.RUnlock() if cert != nil { return cert, nil } return nil, errors.New("no TLS certificate available") } type snapshotTLSConfig struct { Listen []*serverconfigs.NetworkAddressConfig `json:"listen"` SSLPolicy *sslconfigs.SSLPolicy `json:"sslPolicy"` } func (s *ResolveServer) parseTLSConfig(snapshot *LoadedSnapshot) *snapshotTLSConfig { if snapshot.ClusterID <= 0 { return nil } cluster := snapshot.Clusters[snapshot.ClusterID] if cluster == nil { return nil } raw := cluster.GetTlsPolicyJSON() if len(raw) == 0 { return nil } var cfg snapshotTLSConfig if err := json.Unmarshal(raw, &cfg); err != nil { log.Println("[HTTPDNS_NODE][resolve]parse tlsPolicyJSON failed:", err.Error()) return nil } return &cfg } func (s *ResolveServer) desiredAddrs(snapshot *LoadedSnapshot) []string { cfg := s.parseTLSConfig(snapshot) if cfg == nil || len(cfg.Listen) == 0 { return []string{s.fallbackAddr} } seen := make(map[string]struct{}) var addrs []string for _, listenCfg := range cfg.Listen { if listenCfg == nil { continue } if err := listenCfg.Init(); err != nil { log.Println("[HTTPDNS_NODE][resolve]init listen config failed:", err.Error()) continue } for _, addr := range listenCfg.Addresses() { if _, ok := seen[addr]; !ok { seen[addr] = struct{}{} addrs = append(addrs, addr) } } } if len(addrs) == 0 { return []string{s.fallbackAddr} } sort.Strings(addrs) return addrs } func (s *ResolveServer) reloadCertFromSnapshot(snapshot *LoadedSnapshot) { cfg := s.parseTLSConfig(snapshot) if cfg == nil || cfg.SSLPolicy == nil || len(cfg.SSLPolicy.Certs) == 0 { s.certMu.Lock() s.certSnapshotAt = snapshot.LoadedAt s.certMu.Unlock() return } if err := cfg.SSLPolicy.Init(context.Background()); err != nil { log.Println("[HTTPDNS_NODE][resolve]init SSLPolicy failed:", err.Error()) s.certMu.Lock() s.certSnapshotAt = snapshot.LoadedAt s.certMu.Unlock() return } cert := cfg.SSLPolicy.FirstCert() if cert == nil { s.certMu.Lock() s.certSnapshotAt = snapshot.LoadedAt s.certMu.Unlock() return } s.certMu.Lock() s.currentCert = cert s.certSnapshotAt = snapshot.LoadedAt s.certMu.Unlock() log.Println("[HTTPDNS_NODE][resolve]TLS certificate reloaded from snapshot") } func (s *ResolveServer) startListener(addr string) error { ln, err := net.Listen("tcp", addr) if err != nil { return fmt.Errorf("listen on %s: %w", addr, err) } tlsLn := tls.NewListener(ln, s.tlsConfig) srv := &http.Server{ Handler: s.handler, ReadTimeout: 5 * time.Second, ReadHeaderTimeout: 3 * time.Second, WriteTimeout: 5 * time.Second, IdleTimeout: 75 * time.Second, MaxHeaderBytes: 8 * 1024, TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){}, } s.listeners[addr] = &tlsListener{ addr: addr, listener: tlsLn, server: srv, } go func() { if err := srv.Serve(tlsLn); err != nil && !errors.Is(err, http.ErrServerClosed) { log.Println("[HTTPDNS_NODE][resolve]serve failed on", addr, ":", err.Error()) } }() log.Println("[HTTPDNS_NODE][resolve]listening HTTPS on", addr) return nil } func (s *ResolveServer) stopListener(addr string) { tl, ok := s.listeners[addr] if !ok { return } ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() _ = tl.server.Shutdown(ctx) delete(s.listeners, addr) log.Println("[HTTPDNS_NODE][resolve]stopped listener on", addr) } func (s *ResolveServer) syncListeners(desired []string) { s.listenerMu.Lock() defer s.listenerMu.Unlock() desiredSet := make(map[string]struct{}, len(desired)) for _, addr := range desired { desiredSet[addr] = struct{}{} } // Stop listeners that are no longer desired for addr := range s.listeners { if _, ok := desiredSet[addr]; !ok { s.stopListener(addr) } } // Start new listeners for _, addr := range desired { if _, ok := s.listeners[addr]; !ok { if err := s.startListener(addr); err != nil { log.Println("[HTTPDNS_NODE][resolve]start listener failed:", err.Error()) } } } } func (s *ResolveServer) watchLoop() { ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() for { select { case <-ticker.C: snapshot := s.snapshotManager.Current() if snapshot == nil { continue } s.certMu.RLock() lastAt := s.certSnapshotAt s.certMu.RUnlock() if snapshot.LoadedAt == lastAt { continue } // Snapshot changed — sync listeners and reload cert addrs := s.desiredAddrs(snapshot) s.syncListeners(addrs) s.reloadCertFromSnapshot(snapshot) case <-s.quitCh: s.shutdownAll() return } } } func (s *ResolveServer) shutdownAll() { s.listenerMu.Lock() defer s.listenerMu.Unlock() for addr := range s.listeners { s.stopListener(addr) } } func (s *ResolveServer) handleHealth(writer http.ResponseWriter, _ *http.Request) { writer.WriteHeader(http.StatusOK) _, _ = writer.Write([]byte("ok")) } func (s *ResolveServer) handleResolve(writer http.ResponseWriter, request *http.Request) { startAt := time.Now() requestID := "rid-" + rands.HexString(16) if request.Method != http.MethodGet { s.writeResolveJSON(writer, http.StatusMethodNotAllowed, &resolveResponse{ Code: httpdnsCodeMethodNotAllow, Message: "只允许使用 GET 方法", RequestID: requestID, }) return } query := request.URL.Query() appID := strings.TrimSpace(query.Get("appId")) domain := strings.TrimSuffix(strings.ToLower(strings.TrimSpace(query.Get("dn"))), ".") qtype := strings.ToUpper(strings.TrimSpace(query.Get("qtype"))) if len(qtype) == 0 { qtype = "A" } if qtype != "A" && qtype != "AAAA" { s.writeResolveJSON(writer, http.StatusBadRequest, &resolveResponse{ Code: httpdnsCodeInvalidArgument, Message: "qtype 参数仅支持 A 或 AAAA", RequestID: requestID, }) return } if len(appID) == 0 || len(domain) == 0 { s.writeResolveJSON(writer, http.StatusBadRequest, &resolveResponse{ Code: httpdnsCodeInvalidArgument, Message: "缺少必填参数: appId 和 dn", RequestID: requestID, }) return } snapshot := s.snapshotManager.Current() if snapshot == nil { s.writeResolveJSON(writer, http.StatusServiceUnavailable, &resolveResponse{ Code: httpdnsCodeInternalError, Message: "服务节点尚未准备就绪,请稍后再试", RequestID: requestID, }) return } loadedApp := snapshot.Apps[strings.ToLower(appID)] if loadedApp == nil || loadedApp.App == nil || !loadedApp.App.GetIsOn() { s.writeFailedResolve(writer, requestID, snapshot, nil, domain, qtype, httpdnsCodeAppInvalid, "找不到指定的应用,或该应用已下线", startAt, request, query) return } if snapshot.ClusterID > 0 { var appClusterIds []int64 if len(loadedApp.App.GetClusterIdsJSON()) > 0 { _ = json.Unmarshal(loadedApp.App.GetClusterIdsJSON(), &appClusterIds) } var clusterBound bool for _, cid := range appClusterIds { if cid == snapshot.ClusterID { clusterBound = true break } } if !clusterBound { s.writeFailedResolve(writer, requestID, snapshot, loadedApp.App, domain, qtype, httpdnsCodeAppInvalid, "当前应用未绑定到该解析集群", startAt, request, query) return } } loadedDomain := loadedApp.Domains[domain] if loadedDomain == nil || loadedDomain.Domain == nil || !loadedDomain.Domain.GetIsOn() { s.writeFailedResolve(writer, requestID, snapshot, loadedApp.App, domain, qtype, httpdnsCodeDomainNotBound, "应用尚未绑定该域名,或域名解析已暂停", startAt, request, query) return } if loadedApp.App.GetSignEnabled() { if !validateResolveSign(loadedApp.App.GetSignSecret(), loadedApp.App.GetAppId(), domain, qtype, query.Get("nonce"), query.Get("exp"), query.Get("sign")) { s.writeFailedResolve(writer, requestID, snapshot, loadedApp.App, domain, qtype, httpdnsCodeSignInvalid, "请求鉴权失败:签名无效或已过期", startAt, request, query) return } } clientIP := detectClientIP(request, query.Get("cip")) clientProfile := buildClientRouteProfile(clientIP) clusterTTL := pickDefaultTTL(snapshot, loadedApp.App) rule, records, ttl := pickRuleRecords(loadedDomain.Rules, qtype, clientProfile, clusterTTL) if len(records) == 0 { // Fallback:回源上游 DNS 查询真实记录 fallbackRecords, fallbackTTL, fallbackErr := fallbackResolve(domain, qtype) if fallbackErr != nil || len(fallbackRecords) == 0 { errMsg := "未找到解析记录" if fallbackErr != nil { errMsg = "未找到解析记录 (上游回源失败: " + fallbackErr.Error() + ")" } s.writeFailedResolve(writer, requestID, snapshot, loadedApp.App, domain, qtype, httpdnsCodeNoRecords, errMsg, startAt, request, query) return } records = fallbackRecords if fallbackTTL > 0 { ttl = fallbackTTL } else { ttl = clusterTTL } } if ttl <= 0 { ttl = clusterTTL } if ttl <= 0 { ttl = 30 } resultIPs := make([]string, 0, len(records)) for _, record := range records { resultIPs = append(resultIPs, record.IP) } summary := fmt.Sprintf("%s|%s(%s)|%s|%s %s -> %s|success|%dms", time.Now().Format("2006-01-02 15:04:05"), loadedApp.App.GetName(), loadedApp.App.GetAppId(), clientProfile.IP, qtype, domain, strings.Join(resultIPs, ", "), time.Since(startAt).Milliseconds(), ) if rule != nil && len(strings.TrimSpace(rule.GetRuleName())) > 0 { summary += "|rule:" + strings.TrimSpace(rule.GetRuleName()) } s.writeResolveJSON(writer, http.StatusOK, &resolveResponse{ Code: httpdnsCodeSuccess, Message: "ok", RequestID: requestID, Data: &resolveData{ Domain: domain, QType: qtype, TTL: ttl, Records: records, Client: &resolveClientInfo{ IP: clientProfile.IP, Region: clientProfile.RegionText, Carrier: clientProfile.Carrier, Country: clientProfile.Country, }, Summary: summary, }, }) s.enqueueAccessLog(&pb.HTTPDNSAccessLog{ RequestId: requestID, ClusterId: snapshot.ClusterID, NodeId: snapshot.NodeID, AppId: loadedApp.App.GetAppId(), AppName: loadedApp.App.GetName(), Domain: domain, Qtype: qtype, ClientIP: clientProfile.IP, ClientRegion: clientProfile.RegionText, Carrier: clientProfile.Carrier, SdkVersion: strings.TrimSpace(query.Get("sdk_version")), Os: strings.TrimSpace(query.Get("os")), ResultIPs: strings.Join(resultIPs, ","), Status: "success", ErrorCode: "none", CostMs: int32(time.Since(startAt).Milliseconds()), CreatedAt: time.Now().Unix(), Day: time.Now().Format("20060102"), Summary: summary, }) } func pickDefaultTTL(snapshot *LoadedSnapshot, app *pb.HTTPDNSApp) int32 { if snapshot == nil { return 30 } if snapshot.ClusterID > 0 { if cluster := snapshot.Clusters[snapshot.ClusterID]; cluster != nil && cluster.GetDefaultTTL() > 0 { return cluster.GetDefaultTTL() } } if app != nil { var appClusterIds []int64 if len(app.GetClusterIdsJSON()) > 0 { _ = json.Unmarshal(app.GetClusterIdsJSON(), &appClusterIds) } for _, cid := range appClusterIds { if cluster := snapshot.Clusters[cid]; cluster != nil && cluster.GetDefaultTTL() > 0 { return cluster.GetDefaultTTL() } } } return 30 } func (s *ResolveServer) writeFailedResolve( writer http.ResponseWriter, requestID string, snapshot *LoadedSnapshot, app *pb.HTTPDNSApp, domain string, qtype string, errorCode string, message string, startAt time.Time, request *http.Request, query url.Values, ) { clientIP := detectClientIP(request, query.Get("cip")) clientProfile := buildClientRouteProfile(clientIP) appID := "" appName := "" if app != nil { appID = app.GetAppId() appName = app.GetName() } summary := fmt.Sprintf("%s|%s(%s)|%s|%s %s -> [none]|failed(%s)|%dms", time.Now().Format("2006-01-02 15:04:05"), appName, appID, clientProfile.IP, qtype, domain, errorCode, time.Since(startAt).Milliseconds(), ) s.writeResolveJSON(writer, http.StatusOK, &resolveResponse{ Code: errorCode, Message: message, RequestID: requestID, }) clusterID := int64(0) nodeID := int64(0) if snapshot != nil { clusterID = snapshot.ClusterID nodeID = snapshot.NodeID } s.enqueueAccessLog(&pb.HTTPDNSAccessLog{ RequestId: requestID, ClusterId: clusterID, NodeId: nodeID, AppId: appID, AppName: appName, Domain: domain, Qtype: qtype, ClientIP: clientProfile.IP, ClientRegion: clientProfile.RegionText, Carrier: clientProfile.Carrier, SdkVersion: strings.TrimSpace(query.Get("sdk_version")), Os: strings.TrimSpace(query.Get("os")), ResultIPs: "", Status: "failed", ErrorCode: errorCode, CostMs: int32(time.Since(startAt).Milliseconds()), CreatedAt: time.Now().Unix(), Day: time.Now().Format("20060102"), Summary: summary, }) } func (s *ResolveServer) writeResolveJSON(writer http.ResponseWriter, status int, resp *resolveResponse) { writer.Header().Set("Content-Type", "application/json; charset=utf-8") writer.WriteHeader(status) data, err := json.Marshal(resp) if err != nil { _, _ = writer.Write([]byte(`{"code":"RESOLVE_TIMEOUT_OR_INTERNAL","message":"encode response failed"}`)) return } _, _ = writer.Write(data) } func detectClientIP(request *http.Request, cip string) string { if candidate := normalizeIPCandidate(cip); len(candidate) > 0 { return candidate } xff := strings.TrimSpace(request.Header.Get("X-Forwarded-For")) if len(xff) > 0 { for _, item := range strings.Split(xff, ",") { if candidate := normalizeIPCandidate(item); len(candidate) > 0 { return candidate } } } headerKeys := []string{"X-Real-IP", "X-Client-IP", "CF-Connecting-IP", "True-Client-IP"} for _, key := range headerKeys { if candidate := normalizeIPCandidate(request.Header.Get(key)); len(candidate) > 0 { return candidate } } return normalizeIPCandidate(request.RemoteAddr) } func validateResolveSign(signSecret string, appID string, domain string, qtype string, nonce string, exp string, sign string) bool { signSecret = strings.TrimSpace(signSecret) nonce = strings.TrimSpace(nonce) exp = strings.TrimSpace(exp) sign = strings.TrimSpace(sign) if len(signSecret) == 0 || len(nonce) == 0 || len(exp) == 0 || len(sign) == 0 { return false } expireAt, err := strconv.ParseInt(exp, 10, 64) if err != nil { return false } now := time.Now().Unix() if expireAt <= now-30 || expireAt > now+86400 { return false } raw := appID + "|" + strings.ToLower(domain) + "|" + strings.ToUpper(qtype) + "|" + exp + "|" + nonce mac := hmac.New(sha256.New, []byte(signSecret)) _, _ = mac.Write([]byte(raw)) expected := hex.EncodeToString(mac.Sum(nil)) return strings.EqualFold(expected, sign) } func buildClientRouteProfile(ip string) *clientRouteProfile { profile := &clientRouteProfile{ IP: normalizeIPCandidate(ip), } if net.ParseIP(profile.IP) == nil { return profile } result := iplibrary.LookupIP(profile.IP) if result == nil || !result.IsOk() { return profile } profile.Country = normalizeCountryName(strings.TrimSpace(result.CountryName())) profile.Province = normalizeProvinceName(strings.TrimSpace(result.ProvinceName())) profile.ProviderRaw = strings.TrimSpace(result.ProviderName()) profile.Carrier = normalizeCarrier(profile.ProviderRaw, profile.Country) if len(profile.Carrier) == 0 { if isMainlandChinaCountry(profile.Country) { profile.Carrier = "默认" } else { if len(profile.ProviderRaw) > 0 { profile.Carrier = profile.ProviderRaw } else { profile.Carrier = "默认" } } } profile.Region = normalizeChinaRegion(profile.Province) profile.Continent = normalizeContinent(profile.Country) profile.RegionText = strings.TrimSpace(result.RegionSummary()) if len(profile.RegionText) == 0 { pieces := make([]string, 0, 4) if len(profile.Country) > 0 { pieces = append(pieces, profile.Country) } if len(profile.Province) > 0 { pieces = append(pieces, profile.Province) } if len(profile.Region) > 0 { pieces = append(pieces, profile.Region) } if len(profile.Carrier) > 0 { pieces = append(pieces, profile.Carrier) } profile.RegionText = strings.Join(pieces, " ") } return profile } func normalizeCarrier(provider string, country string) string { value := strings.TrimSpace(provider) if len(value) == 0 { return "" } lower := strings.ToLower(value) switch { case strings.Contains(value, "电信"), strings.Contains(value, "天翼"), strings.Contains(lower, "telecom"), strings.Contains(lower, "chinanet"), strings.Contains(lower, "chinatelecom"), strings.Contains(lower, "ctnet"), strings.Contains(lower, "cn2"): return "电信" case strings.Contains(value, "联通"), strings.Contains(value, "网通"), strings.Contains(lower, "unicom"), strings.Contains(lower, "chinaunicom"), strings.Contains(lower, "cucc"), strings.Contains(lower, "china169"), strings.Contains(lower, "cnc"): return "联通" case strings.Contains(value, "移动"), strings.Contains(lower, "mobile"), strings.Contains(lower, "chinamobile"), strings.Contains(lower, "cmcc"), strings.Contains(lower, "cmnet"): return "移动" case strings.Contains(value, "教育"), strings.Contains(lower, "cernet"), strings.Contains(lower, "edu"), strings.Contains(lower, "education"): return "教育网" case strings.Contains(value, "鹏博士"), strings.Contains(lower, "drpeng"), strings.Contains(lower, "dr.peng"), strings.Contains(lower, "dr_peng"): return "鹏博士" case strings.Contains(value, "广电"), strings.Contains(lower, "broadcast"), strings.Contains(lower, "cable"), strings.Contains(lower, "radio"): return "广电" default: if isMainlandChinaCountry(country) { return "" } return value } } func normalizeChinaRegion(province string) string { switch normalizeProvinceName(province) { case "辽宁", "吉林", "黑龙江": return "东北" case "北京", "天津", "河北", "山西", "内蒙古": return "华北" case "上海", "江苏", "浙江", "安徽", "福建", "江西", "山东": return "华东" case "广东", "广西", "海南": return "华南" case "河南", "湖北", "湖南": return "华中" case "陕西", "甘肃", "青海", "宁夏", "新疆": return "西北" case "重庆", "四川", "贵州", "云南", "西藏": return "西南" default: return "" } } func normalizeContinent(country string) string { switch normalizeCountryName(country) { case "中国", "中国香港", "中国澳门", "中国台湾", "日本", "韩国", "新加坡", "印度", "泰国", "越南", "印度尼西亚", "马来西亚", "菲律宾", "柬埔寨", "缅甸", "老挝", "斯里兰卡", "孟加拉国", "巴基斯坦", "尼泊尔", "阿联酋", "沙特阿拉伯", "土耳其", "以色列", "伊朗", "伊拉克", "卡塔尔", "科威特", "蒙古": return "亚洲" case "美国", "加拿大", "墨西哥", "巴拿马", "哥斯达黎加", "古巴": return "北美洲" case "巴西", "阿根廷", "智利", "哥伦比亚", "秘鲁", "委内瑞拉", "厄瓜多尔": return "南美洲" case "德国", "英国", "法国", "荷兰", "西班牙", "意大利", "俄罗斯", "波兰", "瑞典", "瑞士", "挪威", "芬兰", "丹麦", "葡萄牙", "爱尔兰", "比利时", "奥地利", "乌克兰", "捷克", "罗马尼亚", "匈牙利", "希腊": return "欧洲" case "南非", "埃及", "尼日利亚", "肯尼亚", "摩洛哥", "阿尔及利亚", "坦桑尼亚", "埃塞俄比亚", "加纳", "突尼斯": return "非洲" case "澳大利亚", "新西兰": return "大洋洲" default: return "" } } func pickRuleRecords(rules []*pb.HTTPDNSCustomRule, qtype string, profile *clientRouteProfile, defaultTTL int32) (*pb.HTTPDNSCustomRule, []*resolveRecord, int32) { bestScore := -1 var bestRule *pb.HTTPDNSCustomRule var bestRecords []*resolveRecord bestTTL := defaultTTL for _, rule := range rules { if rule == nil || !rule.GetIsOn() { continue } score, ok := matchRuleLine(rule, profile) if !ok { continue } records := make([]*resolveRecord, 0) for _, item := range rule.GetRecords() { if item == nil { continue } if strings.ToUpper(strings.TrimSpace(item.GetRecordType())) != qtype { continue } value := strings.TrimSpace(item.GetRecordValue()) if len(value) == 0 { continue } records = append(records, &resolveRecord{ Type: qtype, IP: value, Weight: item.GetWeight(), Line: lookupIPLineLabel(value), Region: lookupIPRegionSummary(value), }) } if len(records) == 0 { continue } if score > bestScore { bestScore = score bestRule = rule bestRecords = records if rule.GetTtl() > 0 { bestTTL = rule.GetTtl() } else { bestTTL = defaultTTL } } } return bestRule, bestRecords, bestTTL } // fallbackResolve 当无自定义规则命中时,回源上游 DNS 查询真实记录(对齐 EdgeDNS 做法) func fallbackResolve(domain string, qtype string) ([]*resolveRecord, int32, error) { var dnsType uint16 switch qtype { case "A": dnsType = dns.TypeA case "AAAA": dnsType = dns.TypeAAAA default: return nil, 0, nil // 仅降级处理 A 和 AAAA } m := new(dns.Msg) m.SetQuestion(dns.Fqdn(domain), dnsType) m.RecursionDesired = true // 优先使用本机 /etc/resolv.conf 中的 DNS 服务器(对齐 EdgeDNS) var upstream = "223.5.5.5:53" resolveConfig, confErr := dns.ClientConfigFromFile("/etc/resolv.conf") if confErr == nil && len(resolveConfig.Servers) > 0 { port := resolveConfig.Port if len(port) == 0 { port = "53" } server := resolveConfig.Servers[rands.Int(0, len(resolveConfig.Servers)-1)] upstream = server + ":" + port } r, _, err := sharedRecursionDNSClient.Exchange(m, upstream) if err != nil { return nil, 0, err } if r.Rcode != dns.RcodeSuccess { return nil, 0, fmt.Errorf("upstream rcode: %d", r.Rcode) } var records []*resolveRecord var responseTTL int32 for _, ans := range r.Answer { switch t := ans.(type) { case *dns.A: if qtype == "A" { if t.Hdr.Ttl > 0 { ttl := int32(t.Hdr.Ttl) if responseTTL == 0 || ttl < responseTTL { responseTTL = ttl } } ip := t.A.String() records = append(records, &resolveRecord{ Type: "A", IP: ip, Line: lookupIPLineLabel(ip), Region: lookupIPRegionSummary(ip), }) } case *dns.AAAA: if qtype == "AAAA" { if t.Hdr.Ttl > 0 { ttl := int32(t.Hdr.Ttl) if responseTTL == 0 || ttl < responseTTL { responseTTL = ttl } } ip := t.AAAA.String() records = append(records, &resolveRecord{ Type: "AAAA", IP: ip, Line: lookupIPLineLabel(ip), Region: lookupIPRegionSummary(ip), }) } } } return records, responseTTL, nil } func lookupIPRegionSummary(ip string) string { address := strings.TrimSpace(ip) if len(address) == 0 || net.ParseIP(address) == nil { return "" } result := iplibrary.LookupIP(address) if result == nil || !result.IsOk() { return "" } return strings.TrimSpace(result.RegionSummary()) } func lookupIPLineLabel(ip string) string { address := strings.TrimSpace(ip) if len(address) == 0 || net.ParseIP(address) == nil { return "上游DNS" } result := iplibrary.LookupIP(address) if result == nil || !result.IsOk() { return "上游DNS" } provider := strings.TrimSpace(result.ProviderName()) country := normalizeCountryName(strings.TrimSpace(result.CountryName())) carrier := strings.TrimSpace(normalizeCarrier(provider, country)) if isMainlandChinaCountry(country) { if len(carrier) > 0 { return carrier } return "默认" } if len(carrier) > 0 { return carrier } if len(provider) > 0 { return provider } return "上游DNS" } func isMainlandChinaCountry(country string) bool { switch normalizeCountryName(country) { case "中国": return true } return false } func matchRuleLine(rule *pb.HTTPDNSCustomRule, profile *clientRouteProfile) (int, bool) { scope := strings.ToLower(strings.TrimSpace(rule.GetLineScope())) score := 0 if scope == "overseas" { // 境外规则只匹配非中国大陆来源 if isMainlandChinaCountry(profile.Country) { return 0, false } fieldScore, ok := matchRuleField(rule.GetLineContinent(), profile.Continent) if !ok { return 0, false } score += fieldScore fieldScore, ok = matchRuleField(rule.GetLineCountry(), profile.Country) if !ok { return 0, false } score += fieldScore return score, true } // 中国地区规则只匹配中国大陆来源 if !isMainlandChinaCountry(profile.Country) { return 0, false } fieldScore, ok := matchRuleField(rule.GetLineCountry(), profile.Country) if !ok { return 0, false } score += fieldScore fieldScore, ok = matchRuleField(rule.GetLineContinent(), profile.Continent) if !ok { return 0, false } score += fieldScore fieldScore, ok = matchRuleField(rule.GetLineCarrier(), profile.Carrier, profile.ProviderRaw) if !ok { return 0, false } score += fieldScore fieldScore, ok = matchRuleField(rule.GetLineRegion(), profile.Region) if !ok { return 0, false } score += fieldScore fieldScore, ok = matchRuleField(rule.GetLineProvince(), profile.Province) if !ok { return 0, false } score += fieldScore return score, true } func matchRuleField(ruleValue string, candidates ...string) (int, bool) { if isDefaultLineValue(ruleValue) { return 0, true } want := normalizeLineValue(ruleValue) if len(want) == 0 { return 0, true } for _, candidate := range candidates { got := normalizeLineValue(candidate) if len(got) == 0 { // 如果规则设了具体值(want),但客户端信息(got)为空,则不能匹配 // 否则 strings.Contains("xxx", "") 永远为 true continue } if want == got || strings.Contains(got, want) || strings.Contains(want, got) { return 1, true } } return 0, false } func isDefaultLineValue(value string) bool { switch normalizeLineValue(value) { case "", "default", "all", "*", "any", "默认", "全部", "不限": return true } return false } func normalizeLineValue(value string) string { v := strings.ToLower(strings.TrimSpace(value)) v = strings.ReplaceAll(v, " ", "") v = strings.ReplaceAll(v, "-", "") v = strings.ReplaceAll(v, "_", "") v = strings.ReplaceAll(v, "/", "") return v } func normalizeIPCandidate(ip string) string { value := strings.TrimSpace(ip) if len(value) == 0 { return "" } if host, _, err := net.SplitHostPort(value); err == nil { value = strings.TrimSpace(host) } if parsed := net.ParseIP(value); parsed != nil { return parsed.String() } return "" } func normalizeCountryName(country string) string { value := strings.TrimSpace(country) if len(value) == 0 { return "" } normalized := normalizeLineValue(value) switch normalized { case "中国香港", "香港", "hongkong": return "中国香港" case "中国澳门", "澳门", "macao", "macau": return "中国澳门" case "中国台湾", "台湾", "taiwan": return "中国台湾" } switch normalized { case "中国", "中国大陆", "中国内地", "中华人民共和国", "prc", "cn", "china", "mainlandchina", "peoplesrepublicofchina", "thepeoplesrepublicofchina": return "中国" case "美国", "usa", "unitedstates", "unitedstatesofamerica": return "美国" case "加拿大", "canada": return "加拿大" case "墨西哥", "mexico": return "墨西哥" case "日本", "japan": return "日本" case "韩国", "southkorea", "korea": return "韩国" case "新加坡", "singapore": return "新加坡" case "印度", "india": return "印度" case "泰国", "thailand": return "泰国" case "越南", "vietnam": return "越南" case "德国", "germany": return "德国" case "英国", "uk", "unitedkingdom", "greatbritain", "britain": return "英国" case "法国", "france": return "法国" case "荷兰", "netherlands": return "荷兰" case "西班牙", "spain": return "西班牙" case "意大利", "italy": return "意大利" case "俄罗斯", "russia": return "俄罗斯" case "巴西", "brazil": return "巴西" case "阿根廷", "argentina": return "阿根廷" case "智利", "chile": return "智利" case "哥伦比亚", "colombia": return "哥伦比亚" // --- 亚洲(新增)--- case "印度尼西亚", "indonesia": return "印度尼西亚" case "马来西亚", "malaysia": return "马来西亚" case "菲律宾", "philippines": return "菲律宾" case "柬埔寨", "cambodia": return "柬埔寨" case "缅甸", "myanmar", "burma": return "缅甸" case "老挝", "laos": return "老挝" case "斯里兰卡", "srilanka": return "斯里兰卡" case "孟加拉国", "孟加拉", "bangladesh": return "孟加拉国" case "巴基斯坦", "pakistan": return "巴基斯坦" case "尼泊尔", "nepal": return "尼泊尔" case "阿联酋", "阿拉伯联合酋长国", "uae", "unitedarabemirates": return "阿联酋" case "沙特阿拉伯", "沙特", "saudiarabia", "saudi": return "沙特阿拉伯" case "土耳其", "turkey", "türkiye", "turkiye": return "土耳其" case "以色列", "israel": return "以色列" case "伊朗", "iran": return "伊朗" case "伊拉克", "iraq": return "伊拉克" case "卡塔尔", "qatar": return "卡塔尔" case "科威特", "kuwait": return "科威特" case "蒙古", "mongolia": return "蒙古" // --- 欧洲(新增)--- case "波兰", "poland": return "波兰" case "瑞典", "sweden": return "瑞典" case "瑞士", "switzerland": return "瑞士" case "挪威", "norway": return "挪威" case "芬兰", "finland": return "芬兰" case "丹麦", "denmark": return "丹麦" case "葡萄牙", "portugal": return "葡萄牙" case "爱尔兰", "ireland": return "爱尔兰" case "比利时", "belgium": return "比利时" case "奥地利", "austria": return "奥地利" case "乌克兰", "ukraine": return "乌克兰" case "捷克", "czech", "czechrepublic", "czechia": return "捷克" case "罗马尼亚", "romania": return "罗马尼亚" case "匈牙利", "hungary": return "匈牙利" case "希腊", "greece": return "希腊" // --- 北美洲(新增)--- case "巴拿马", "panama": return "巴拿马" case "哥斯达黎加", "costarica": return "哥斯达黎加" case "古巴", "cuba": return "古巴" // --- 南美洲(新增)--- case "秘鲁", "peru": return "秘鲁" case "委内瑞拉", "venezuela": return "委内瑞拉" case "厄瓜多尔", "ecuador": return "厄瓜多尔" // --- 非洲 --- case "南非", "southafrica": return "南非" case "埃及", "egypt": return "埃及" case "尼日利亚", "nigeria": return "尼日利亚" case "肯尼亚", "kenya": return "肯尼亚" case "摩洛哥", "morocco": return "摩洛哥" case "阿尔及利亚", "algeria": return "阿尔及利亚" case "坦桑尼亚", "tanzania": return "坦桑尼亚" case "埃塞俄比亚", "ethiopia": return "埃塞俄比亚" case "加纳", "ghana": return "加纳" case "突尼斯", "tunisia": return "突尼斯" // --- 大洋洲 --- case "澳大利亚", "australia": return "澳大利亚" case "新西兰", "newzealand": return "新西兰" default: return value } } func normalizeProvinceName(province string) string { value := strings.TrimSpace(province) if len(value) == 0 { return "" } switch value { case "内蒙古自治区": return "内蒙古" case "广西壮族自治区": return "广西" case "宁夏回族自治区": return "宁夏" case "新疆维吾尔自治区": return "新疆" case "西藏自治区": return "西藏" case "香港特别行政区": return "香港" case "澳门特别行政区": return "澳门" } for _, suffix := range []string{"维吾尔自治区", "回族自治区", "壮族自治区", "自治区", "特别行政区", "省", "市"} { value = strings.TrimSuffix(value, suffix) } return value } func ruleLineSummary(rule *pb.HTTPDNSCustomRule) string { if rule == nil { return "" } scope := strings.ToLower(strings.TrimSpace(rule.GetLineScope())) if scope == "overseas" { pieces := make([]string, 0, 2) if !isDefaultLineValue(rule.GetLineContinent()) { pieces = append(pieces, strings.TrimSpace(rule.GetLineContinent())) } if !isDefaultLineValue(rule.GetLineCountry()) { pieces = append(pieces, strings.TrimSpace(rule.GetLineCountry())) } return strings.Join(pieces, "/") } pieces := make([]string, 0, 3) if !isDefaultLineValue(rule.GetLineCarrier()) { pieces = append(pieces, strings.TrimSpace(rule.GetLineCarrier())) } if !isDefaultLineValue(rule.GetLineRegion()) { pieces = append(pieces, strings.TrimSpace(rule.GetLineRegion())) } if !isDefaultLineValue(rule.GetLineProvince()) { pieces = append(pieces, strings.TrimSpace(rule.GetLineProvince())) } return strings.Join(pieces, "/") } func ruleRegionSummary(rule *pb.HTTPDNSCustomRule) string { if rule == nil { return "" } if !isDefaultLineValue(rule.GetLineProvince()) { return strings.TrimSpace(rule.GetLineProvince()) } if !isDefaultLineValue(rule.GetLineCountry()) { return strings.TrimSpace(rule.GetLineCountry()) } if !isDefaultLineValue(rule.GetLineRegion()) { return strings.TrimSpace(rule.GetLineRegion()) } if !isDefaultLineValue(rule.GetLineContinent()) { return strings.TrimSpace(rule.GetLineContinent()) } return "" } func (s *ResolveServer) enqueueAccessLog(item *pb.HTTPDNSAccessLog) { if item == nil { return } select { case s.logQueue <- item: default: log.Println("[HTTPDNS_NODE][resolve]access log queue is full, drop request:", item.GetRequestId()) } } func (s *ResolveServer) startAccessLogFlusher() { ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() batch := make([]*pb.HTTPDNSAccessLog, 0, 128) flush := func() { if len(batch) == 0 { return } s.logWriter.WriteBatch(batch) batch = batch[:0] } for { select { case item := <-s.logQueue: if item != nil { batch = append(batch, item) } if len(batch) > 4096 { log.Println("[HTTPDNS_NODE][resolve]access log flush backlog too large, trim:", len(batch)) batch = batch[len(batch)-2048:] } if len(batch) >= 128 { flush() } case <-ticker.C: flush() case <-s.quitCh: flush() _ = s.logWriter.Close() return } } }