1120 lines
30 KiB
Go
1120 lines
30 KiB
Go
package nodes
|
||
|
||
import (
|
||
"context"
|
||
"crypto/hmac"
|
||
"crypto/sha256"
|
||
"crypto/tls"
|
||
"encoding/hex"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"log"
|
||
"net"
|
||
"net/http"
|
||
"net/url"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/TeaOSLab/EdgeCommon/pkg/iplibrary"
|
||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||
"github.com/TeaOSLab/EdgeHttpDNS/internal/configs"
|
||
"github.com/TeaOSLab/EdgeHttpDNS/internal/rpc"
|
||
"github.com/iwind/TeaGo/rands"
|
||
"github.com/miekg/dns"
|
||
)
|
||
|
||
// sharedRecursionDNSClient 共享的递归DNS客户端(对齐 EdgeDNS 实现)
|
||
var sharedRecursionDNSClient = &dns.Client{
|
||
Timeout: 3 * time.Second,
|
||
}
|
||
|
||
const (
|
||
httpdnsCodeSuccess = "SUCCESS"
|
||
httpdnsCodeAppInvalid = "APP_NOT_FOUND_OR_DISABLED"
|
||
httpdnsCodeDomainNotBound = "DOMAIN_NOT_BOUND"
|
||
httpdnsCodeSignInvalid = "SIGN_INVALID"
|
||
httpdnsCodeNoRecords = "NO_RECORDS"
|
||
httpdnsCodeInternalError = "RESOLVE_TIMEOUT_OR_INTERNAL"
|
||
httpdnsCodeMethodNotAllow = "METHOD_NOT_ALLOWED"
|
||
httpdnsCodeInvalidArgument = "INVALID_ARGUMENT"
|
||
)
|
||
|
||
type resolveClientInfo struct {
|
||
IP string `json:"ip"`
|
||
Region string `json:"region"`
|
||
Carrier string `json:"carrier"`
|
||
Country string `json:"country"`
|
||
}
|
||
|
||
type resolveRecord struct {
|
||
Type string `json:"type"`
|
||
IP string `json:"ip"`
|
||
Weight int32 `json:"weight,omitempty"`
|
||
Line string `json:"line,omitempty"`
|
||
Region string `json:"region,omitempty"`
|
||
}
|
||
|
||
type resolveData struct {
|
||
Domain string `json:"domain"`
|
||
QType string `json:"qtype"`
|
||
TTL int32 `json:"ttl"`
|
||
Records []*resolveRecord `json:"records"`
|
||
Client *resolveClientInfo `json:"client"`
|
||
Summary string `json:"summary"`
|
||
}
|
||
|
||
type resolveResponse struct {
|
||
Code string `json:"code"`
|
||
Message string `json:"message"`
|
||
RequestID string `json:"requestId"`
|
||
Data *resolveData `json:"data,omitempty"`
|
||
}
|
||
|
||
type clientRouteProfile struct {
|
||
IP string
|
||
Country string
|
||
Province string
|
||
Carrier string
|
||
ProviderRaw string
|
||
Region string
|
||
Continent string
|
||
RegionText string
|
||
}
|
||
|
||
type ResolveServer struct {
|
||
quitCh <-chan struct{}
|
||
|
||
snapshotManager *SnapshotManager
|
||
|
||
listenAddr string
|
||
certFile string
|
||
keyFile string
|
||
server *http.Server
|
||
|
||
logQueue chan *pb.HTTPDNSAccessLog
|
||
}
|
||
|
||
func NewResolveServer(quitCh <-chan struct{}, snapshotManager *SnapshotManager) *ResolveServer {
|
||
listenAddr := ":443"
|
||
certFile := ""
|
||
keyFile := ""
|
||
if apiConfig, err := configs.SharedAPIConfig(); err == nil && apiConfig != nil {
|
||
if len(apiConfig.HTTPSListenAddr) > 0 {
|
||
listenAddr = apiConfig.HTTPSListenAddr
|
||
}
|
||
certFile = apiConfig.HTTPSCert
|
||
keyFile = apiConfig.HTTPSKey
|
||
}
|
||
|
||
instance := &ResolveServer{
|
||
quitCh: quitCh,
|
||
snapshotManager: snapshotManager,
|
||
listenAddr: listenAddr,
|
||
certFile: certFile,
|
||
keyFile: keyFile,
|
||
logQueue: make(chan *pb.HTTPDNSAccessLog, 8192),
|
||
}
|
||
|
||
mux := http.NewServeMux()
|
||
mux.HandleFunc("/resolve", instance.handleResolve)
|
||
mux.HandleFunc("/healthz", instance.handleHealth)
|
||
|
||
instance.server = &http.Server{
|
||
Addr: instance.listenAddr,
|
||
Handler: mux,
|
||
ReadTimeout: 5 * time.Second,
|
||
ReadHeaderTimeout: 3 * time.Second,
|
||
WriteTimeout: 5 * time.Second,
|
||
IdleTimeout: 75 * time.Second,
|
||
MaxHeaderBytes: 8 * 1024,
|
||
TLSConfig: &tls.Config{
|
||
MinVersion: tls.VersionTLS12,
|
||
},
|
||
}
|
||
|
||
return instance
|
||
}
|
||
|
||
func (s *ResolveServer) Start() {
|
||
go s.startAccessLogFlusher()
|
||
go s.waitForShutdown()
|
||
|
||
log.Println("[HTTPDNS_NODE][resolve]listening HTTPS on", s.listenAddr)
|
||
if err := s.server.ListenAndServeTLS(s.certFile, s.keyFile); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||
log.Println("[HTTPDNS_NODE][resolve]listen failed:", err.Error())
|
||
}
|
||
}
|
||
|
||
func (s *ResolveServer) waitForShutdown() {
|
||
<-s.quitCh
|
||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||
defer cancel()
|
||
_ = s.server.Shutdown(ctx)
|
||
}
|
||
|
||
func (s *ResolveServer) handleHealth(writer http.ResponseWriter, _ *http.Request) {
|
||
writer.WriteHeader(http.StatusOK)
|
||
_, _ = writer.Write([]byte("ok"))
|
||
}
|
||
|
||
func (s *ResolveServer) handleResolve(writer http.ResponseWriter, request *http.Request) {
|
||
startAt := time.Now()
|
||
requestID := "rid-" + rands.HexString(16)
|
||
|
||
if request.Method != http.MethodGet {
|
||
s.writeResolveJSON(writer, http.StatusMethodNotAllowed, &resolveResponse{
|
||
Code: httpdnsCodeMethodNotAllow,
|
||
Message: "只允许使用 GET 方法",
|
||
RequestID: requestID,
|
||
})
|
||
return
|
||
}
|
||
|
||
query := request.URL.Query()
|
||
appID := strings.TrimSpace(query.Get("appId"))
|
||
domain := strings.TrimSuffix(strings.ToLower(strings.TrimSpace(query.Get("dn"))), ".")
|
||
qtype := strings.ToUpper(strings.TrimSpace(query.Get("qtype")))
|
||
if len(qtype) == 0 {
|
||
qtype = "A"
|
||
}
|
||
if qtype != "A" && qtype != "AAAA" {
|
||
s.writeResolveJSON(writer, http.StatusBadRequest, &resolveResponse{
|
||
Code: httpdnsCodeInvalidArgument,
|
||
Message: "qtype 参数仅支持 A 或 AAAA",
|
||
RequestID: requestID,
|
||
})
|
||
return
|
||
}
|
||
|
||
if len(appID) == 0 || len(domain) == 0 {
|
||
s.writeResolveJSON(writer, http.StatusBadRequest, &resolveResponse{
|
||
Code: httpdnsCodeInvalidArgument,
|
||
Message: "缺少必填参数: appId 和 dn",
|
||
RequestID: requestID,
|
||
})
|
||
return
|
||
}
|
||
|
||
snapshot := s.snapshotManager.Current()
|
||
if snapshot == nil {
|
||
s.writeResolveJSON(writer, http.StatusServiceUnavailable, &resolveResponse{
|
||
Code: httpdnsCodeInternalError,
|
||
Message: "服务节点尚未准备就绪,请稍后再试",
|
||
RequestID: requestID,
|
||
})
|
||
return
|
||
}
|
||
|
||
loadedApp := snapshot.Apps[strings.ToLower(appID)]
|
||
if loadedApp == nil || loadedApp.App == nil || !loadedApp.App.GetIsOn() {
|
||
s.writeFailedResolve(writer, requestID, snapshot, nil, domain, qtype, httpdnsCodeAppInvalid, "找不到指定的应用,或该应用已下线", startAt, request, query)
|
||
return
|
||
}
|
||
|
||
if snapshot.ClusterID > 0 &&
|
||
loadedApp.App.GetPrimaryClusterId() != snapshot.ClusterID &&
|
||
loadedApp.App.GetBackupClusterId() != snapshot.ClusterID {
|
||
s.writeFailedResolve(writer, requestID, snapshot, loadedApp.App, domain, qtype, httpdnsCodeAppInvalid, "当前应用未绑定到该解析集群", startAt, request, query)
|
||
return
|
||
}
|
||
|
||
loadedDomain := loadedApp.Domains[domain]
|
||
if loadedDomain == nil || loadedDomain.Domain == nil || !loadedDomain.Domain.GetIsOn() {
|
||
s.writeFailedResolve(writer, requestID, snapshot, loadedApp.App, domain, qtype, httpdnsCodeDomainNotBound, "应用尚未绑定该域名,或域名解析已暂停", startAt, request, query)
|
||
return
|
||
}
|
||
|
||
if loadedApp.App.GetSignEnabled() {
|
||
if !validateResolveSign(loadedApp.App.GetSignSecret(), loadedApp.App.GetAppId(), domain, qtype, query.Get("nonce"), query.Get("exp"), query.Get("sign")) {
|
||
s.writeFailedResolve(writer, requestID, snapshot, loadedApp.App, domain, qtype, httpdnsCodeSignInvalid, "请求鉴权失败:签名无效或已过期", startAt, request, query)
|
||
return
|
||
}
|
||
}
|
||
|
||
clientIP := detectClientIP(request, query.Get("cip"))
|
||
clientProfile := buildClientRouteProfile(clientIP)
|
||
|
||
clusterTTL := pickDefaultTTL(snapshot, loadedApp.App)
|
||
rule, records, ttl := pickRuleRecords(loadedDomain.Rules, qtype, clientProfile, clusterTTL)
|
||
if len(records) == 0 {
|
||
// Fallback:回源上游 DNS 查询真实记录
|
||
fallbackRecords, fallbackTTL, fallbackErr := fallbackResolve(domain, qtype)
|
||
if fallbackErr != nil || len(fallbackRecords) == 0 {
|
||
errMsg := "未找到解析记录"
|
||
if fallbackErr != nil {
|
||
errMsg = "未找到解析记录 (上游回源失败: " + fallbackErr.Error() + ")"
|
||
}
|
||
s.writeFailedResolve(writer, requestID, snapshot, loadedApp.App, domain, qtype, httpdnsCodeNoRecords, errMsg, startAt, request, query)
|
||
return
|
||
}
|
||
records = fallbackRecords
|
||
if fallbackTTL > 0 {
|
||
ttl = fallbackTTL
|
||
} else {
|
||
ttl = clusterTTL
|
||
}
|
||
}
|
||
if ttl <= 0 {
|
||
ttl = clusterTTL
|
||
}
|
||
if ttl <= 0 {
|
||
ttl = 30
|
||
}
|
||
|
||
resultIPs := make([]string, 0, len(records))
|
||
for _, record := range records {
|
||
resultIPs = append(resultIPs, record.IP)
|
||
}
|
||
|
||
summary := fmt.Sprintf("%s|%s(%s)|%s|%s %s -> %s|success|%dms",
|
||
time.Now().Format("2006-01-02 15:04:05"),
|
||
loadedApp.App.GetName(),
|
||
loadedApp.App.GetAppId(),
|
||
clientProfile.IP,
|
||
qtype,
|
||
domain,
|
||
strings.Join(resultIPs, ", "),
|
||
time.Since(startAt).Milliseconds(),
|
||
)
|
||
if rule != nil && len(strings.TrimSpace(rule.GetRuleName())) > 0 {
|
||
summary += "|rule:" + strings.TrimSpace(rule.GetRuleName())
|
||
}
|
||
|
||
s.writeResolveJSON(writer, http.StatusOK, &resolveResponse{
|
||
Code: httpdnsCodeSuccess,
|
||
Message: "ok",
|
||
RequestID: requestID,
|
||
Data: &resolveData{
|
||
Domain: domain,
|
||
QType: qtype,
|
||
TTL: ttl,
|
||
Records: records,
|
||
Client: &resolveClientInfo{
|
||
IP: clientProfile.IP,
|
||
Region: clientProfile.RegionText,
|
||
Carrier: clientProfile.Carrier,
|
||
Country: clientProfile.Country,
|
||
},
|
||
Summary: summary,
|
||
},
|
||
})
|
||
|
||
s.enqueueAccessLog(&pb.HTTPDNSAccessLog{
|
||
RequestId: requestID,
|
||
ClusterId: snapshot.ClusterID,
|
||
NodeId: snapshot.NodeID,
|
||
AppId: loadedApp.App.GetAppId(),
|
||
AppName: loadedApp.App.GetName(),
|
||
Domain: domain,
|
||
Qtype: qtype,
|
||
ClientIP: clientProfile.IP,
|
||
ClientRegion: clientProfile.RegionText,
|
||
Carrier: clientProfile.Carrier,
|
||
SdkVersion: strings.TrimSpace(query.Get("sdk_version")),
|
||
Os: strings.TrimSpace(query.Get("os")),
|
||
ResultIPs: strings.Join(resultIPs, ","),
|
||
Status: "success",
|
||
ErrorCode: "none",
|
||
CostMs: int32(time.Since(startAt).Milliseconds()),
|
||
CreatedAt: time.Now().Unix(),
|
||
Day: time.Now().Format("20060102"),
|
||
Summary: summary,
|
||
})
|
||
}
|
||
|
||
func pickDefaultTTL(snapshot *LoadedSnapshot, app *pb.HTTPDNSApp) int32 {
|
||
if snapshot == nil {
|
||
return 30
|
||
}
|
||
|
||
if snapshot.ClusterID > 0 {
|
||
if cluster := snapshot.Clusters[snapshot.ClusterID]; cluster != nil && cluster.GetDefaultTTL() > 0 {
|
||
return cluster.GetDefaultTTL()
|
||
}
|
||
}
|
||
if app != nil {
|
||
if cluster := snapshot.Clusters[app.GetPrimaryClusterId()]; cluster != nil && cluster.GetDefaultTTL() > 0 {
|
||
return cluster.GetDefaultTTL()
|
||
}
|
||
if cluster := snapshot.Clusters[app.GetBackupClusterId()]; cluster != nil && cluster.GetDefaultTTL() > 0 {
|
||
return cluster.GetDefaultTTL()
|
||
}
|
||
}
|
||
return 30
|
||
}
|
||
|
||
func (s *ResolveServer) writeFailedResolve(
|
||
writer http.ResponseWriter,
|
||
requestID string,
|
||
snapshot *LoadedSnapshot,
|
||
app *pb.HTTPDNSApp,
|
||
domain string,
|
||
qtype string,
|
||
errorCode string,
|
||
message string,
|
||
startAt time.Time,
|
||
request *http.Request,
|
||
query url.Values,
|
||
) {
|
||
clientIP := detectClientIP(request, query.Get("cip"))
|
||
clientProfile := buildClientRouteProfile(clientIP)
|
||
|
||
appID := ""
|
||
appName := ""
|
||
if app != nil {
|
||
appID = app.GetAppId()
|
||
appName = app.GetName()
|
||
}
|
||
|
||
summary := fmt.Sprintf("%s|%s(%s)|%s|%s %s -> [none]|failed(%s)|%dms",
|
||
time.Now().Format("2006-01-02 15:04:05"),
|
||
appName,
|
||
appID,
|
||
clientProfile.IP,
|
||
qtype,
|
||
domain,
|
||
errorCode,
|
||
time.Since(startAt).Milliseconds(),
|
||
)
|
||
|
||
s.writeResolveJSON(writer, http.StatusOK, &resolveResponse{
|
||
Code: errorCode,
|
||
Message: message,
|
||
RequestID: requestID,
|
||
})
|
||
|
||
clusterID := int64(0)
|
||
nodeID := int64(0)
|
||
if snapshot != nil {
|
||
clusterID = snapshot.ClusterID
|
||
nodeID = snapshot.NodeID
|
||
}
|
||
|
||
s.enqueueAccessLog(&pb.HTTPDNSAccessLog{
|
||
RequestId: requestID,
|
||
ClusterId: clusterID,
|
||
NodeId: nodeID,
|
||
AppId: appID,
|
||
AppName: appName,
|
||
Domain: domain,
|
||
Qtype: qtype,
|
||
ClientIP: clientProfile.IP,
|
||
ClientRegion: clientProfile.RegionText,
|
||
Carrier: clientProfile.Carrier,
|
||
SdkVersion: strings.TrimSpace(query.Get("sdk_version")),
|
||
Os: strings.TrimSpace(query.Get("os")),
|
||
ResultIPs: "",
|
||
Status: "failed",
|
||
ErrorCode: errorCode,
|
||
CostMs: int32(time.Since(startAt).Milliseconds()),
|
||
CreatedAt: time.Now().Unix(),
|
||
Day: time.Now().Format("20060102"),
|
||
Summary: summary,
|
||
})
|
||
}
|
||
|
||
func (s *ResolveServer) writeResolveJSON(writer http.ResponseWriter, status int, resp *resolveResponse) {
|
||
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||
writer.WriteHeader(status)
|
||
data, err := json.Marshal(resp)
|
||
if err != nil {
|
||
_, _ = writer.Write([]byte(`{"code":"RESOLVE_TIMEOUT_OR_INTERNAL","message":"encode response failed"}`))
|
||
return
|
||
}
|
||
_, _ = writer.Write(data)
|
||
}
|
||
|
||
func detectClientIP(request *http.Request, cip string) string {
|
||
if candidate := normalizeIPCandidate(cip); len(candidate) > 0 {
|
||
return candidate
|
||
}
|
||
|
||
xff := strings.TrimSpace(request.Header.Get("X-Forwarded-For"))
|
||
if len(xff) > 0 {
|
||
for _, item := range strings.Split(xff, ",") {
|
||
if candidate := normalizeIPCandidate(item); len(candidate) > 0 {
|
||
return candidate
|
||
}
|
||
}
|
||
}
|
||
|
||
headerKeys := []string{"X-Real-IP", "X-Client-IP", "CF-Connecting-IP", "True-Client-IP"}
|
||
for _, key := range headerKeys {
|
||
if candidate := normalizeIPCandidate(request.Header.Get(key)); len(candidate) > 0 {
|
||
return candidate
|
||
}
|
||
}
|
||
|
||
return normalizeIPCandidate(request.RemoteAddr)
|
||
}
|
||
|
||
func validateResolveSign(signSecret string, appID string, domain string, qtype string, nonce string, exp string, sign string) bool {
|
||
signSecret = strings.TrimSpace(signSecret)
|
||
nonce = strings.TrimSpace(nonce)
|
||
exp = strings.TrimSpace(exp)
|
||
sign = strings.TrimSpace(sign)
|
||
if len(signSecret) == 0 || len(nonce) == 0 || len(exp) == 0 || len(sign) == 0 {
|
||
return false
|
||
}
|
||
|
||
expireAt, err := strconv.ParseInt(exp, 10, 64)
|
||
if err != nil {
|
||
return false
|
||
}
|
||
now := time.Now().Unix()
|
||
if expireAt <= now-30 || expireAt > now+86400 {
|
||
return false
|
||
}
|
||
|
||
raw := appID + "|" + strings.ToLower(domain) + "|" + strings.ToUpper(qtype) + "|" + exp + "|" + nonce
|
||
mac := hmac.New(sha256.New, []byte(signSecret))
|
||
_, _ = mac.Write([]byte(raw))
|
||
expected := hex.EncodeToString(mac.Sum(nil))
|
||
return strings.EqualFold(expected, sign)
|
||
}
|
||
|
||
func buildClientRouteProfile(ip string) *clientRouteProfile {
|
||
profile := &clientRouteProfile{
|
||
IP: normalizeIPCandidate(ip),
|
||
}
|
||
if net.ParseIP(profile.IP) == nil {
|
||
return profile
|
||
}
|
||
|
||
result := iplibrary.LookupIP(profile.IP)
|
||
if result == nil || !result.IsOk() {
|
||
return profile
|
||
}
|
||
|
||
profile.Country = normalizeCountryName(strings.TrimSpace(result.CountryName()))
|
||
profile.Province = normalizeProvinceName(strings.TrimSpace(result.ProvinceName()))
|
||
profile.ProviderRaw = strings.TrimSpace(result.ProviderName())
|
||
profile.Carrier = normalizeCarrier(profile.ProviderRaw, profile.Country)
|
||
if len(profile.Carrier) == 0 {
|
||
if isMainlandChinaCountry(profile.Country) {
|
||
profile.Carrier = "默认"
|
||
} else {
|
||
if len(profile.ProviderRaw) > 0 {
|
||
profile.Carrier = profile.ProviderRaw
|
||
} else {
|
||
profile.Carrier = "默认"
|
||
}
|
||
}
|
||
}
|
||
profile.Region = normalizeChinaRegion(profile.Province)
|
||
profile.Continent = normalizeContinent(profile.Country)
|
||
profile.RegionText = strings.TrimSpace(result.RegionSummary())
|
||
if len(profile.RegionText) == 0 {
|
||
pieces := make([]string, 0, 4)
|
||
if len(profile.Country) > 0 {
|
||
pieces = append(pieces, profile.Country)
|
||
}
|
||
if len(profile.Province) > 0 {
|
||
pieces = append(pieces, profile.Province)
|
||
}
|
||
if len(profile.Region) > 0 {
|
||
pieces = append(pieces, profile.Region)
|
||
}
|
||
if len(profile.Carrier) > 0 {
|
||
pieces = append(pieces, profile.Carrier)
|
||
}
|
||
profile.RegionText = strings.Join(pieces, " ")
|
||
}
|
||
|
||
return profile
|
||
}
|
||
|
||
func normalizeCarrier(provider string, country string) string {
|
||
value := strings.TrimSpace(provider)
|
||
if len(value) == 0 {
|
||
return ""
|
||
}
|
||
lower := strings.ToLower(value)
|
||
switch {
|
||
case strings.Contains(value, "电信"), strings.Contains(value, "天翼"),
|
||
strings.Contains(lower, "telecom"), strings.Contains(lower, "chinanet"),
|
||
strings.Contains(lower, "chinatelecom"), strings.Contains(lower, "ctnet"), strings.Contains(lower, "cn2"):
|
||
return "电信"
|
||
case strings.Contains(value, "联通"), strings.Contains(value, "网通"),
|
||
strings.Contains(lower, "unicom"), strings.Contains(lower, "chinaunicom"),
|
||
strings.Contains(lower, "cucc"), strings.Contains(lower, "china169"), strings.Contains(lower, "cnc"):
|
||
return "联通"
|
||
case strings.Contains(value, "移动"),
|
||
strings.Contains(lower, "mobile"), strings.Contains(lower, "chinamobile"),
|
||
strings.Contains(lower, "cmcc"), strings.Contains(lower, "cmnet"):
|
||
return "移动"
|
||
case strings.Contains(value, "教育"),
|
||
strings.Contains(lower, "cernet"), strings.Contains(lower, "edu"), strings.Contains(lower, "education"):
|
||
return "教育网"
|
||
case strings.Contains(value, "鹏博士"),
|
||
strings.Contains(lower, "drpeng"), strings.Contains(lower, "dr.peng"), strings.Contains(lower, "dr_peng"):
|
||
return "鹏博士"
|
||
case strings.Contains(value, "广电"),
|
||
strings.Contains(lower, "broadcast"), strings.Contains(lower, "cable"), strings.Contains(lower, "radio"):
|
||
return "广电"
|
||
default:
|
||
if isMainlandChinaCountry(country) {
|
||
return ""
|
||
}
|
||
return value
|
||
}
|
||
}
|
||
|
||
func normalizeChinaRegion(province string) string {
|
||
switch normalizeProvinceName(province) {
|
||
case "辽宁", "吉林", "黑龙江":
|
||
return "东北"
|
||
case "北京", "天津", "河北", "山西", "内蒙古":
|
||
return "华北"
|
||
case "上海", "江苏", "浙江", "安徽", "福建", "江西", "山东":
|
||
return "华东"
|
||
case "广东", "广西", "海南":
|
||
return "华南"
|
||
case "河南", "湖北", "湖南":
|
||
return "华中"
|
||
case "陕西", "甘肃", "青海", "宁夏", "新疆":
|
||
return "西北"
|
||
case "重庆", "四川", "贵州", "云南", "西藏":
|
||
return "西南"
|
||
default:
|
||
return ""
|
||
}
|
||
}
|
||
|
||
func normalizeContinent(country string) string {
|
||
switch normalizeCountryName(country) {
|
||
case "中国", "中国香港", "中国澳门", "中国台湾", "日本", "韩国", "新加坡", "印度", "泰国", "越南":
|
||
return "亚洲"
|
||
case "美国", "加拿大", "墨西哥":
|
||
return "北美洲"
|
||
case "巴西", "阿根廷", "智利", "哥伦比亚":
|
||
return "南美洲"
|
||
case "德国", "英国", "法国", "荷兰", "西班牙", "意大利", "俄罗斯":
|
||
return "欧洲"
|
||
case "南非", "埃及", "尼日利亚", "肯尼亚", "摩洛哥":
|
||
return "非洲"
|
||
case "澳大利亚", "新西兰":
|
||
return "大洋洲"
|
||
default:
|
||
return ""
|
||
}
|
||
}
|
||
func pickRuleRecords(rules []*pb.HTTPDNSCustomRule, qtype string, profile *clientRouteProfile, defaultTTL int32) (*pb.HTTPDNSCustomRule, []*resolveRecord, int32) {
|
||
bestScore := -1
|
||
var bestRule *pb.HTTPDNSCustomRule
|
||
var bestRecords []*resolveRecord
|
||
bestTTL := defaultTTL
|
||
|
||
for _, rule := range rules {
|
||
if rule == nil || !rule.GetIsOn() {
|
||
continue
|
||
}
|
||
|
||
score, ok := matchRuleLine(rule, profile)
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
records := make([]*resolveRecord, 0)
|
||
for _, item := range rule.GetRecords() {
|
||
if item == nil {
|
||
continue
|
||
}
|
||
if strings.ToUpper(strings.TrimSpace(item.GetRecordType())) != qtype {
|
||
continue
|
||
}
|
||
value := strings.TrimSpace(item.GetRecordValue())
|
||
if len(value) == 0 {
|
||
continue
|
||
}
|
||
records = append(records, &resolveRecord{
|
||
Type: qtype,
|
||
IP: value,
|
||
Weight: item.GetWeight(),
|
||
Line: ruleLineSummary(rule),
|
||
Region: ruleRegionSummary(rule),
|
||
})
|
||
}
|
||
if len(records) == 0 {
|
||
continue
|
||
}
|
||
|
||
if score > bestScore {
|
||
bestScore = score
|
||
bestRule = rule
|
||
bestRecords = records
|
||
if rule.GetTtl() > 0 {
|
||
bestTTL = rule.GetTtl()
|
||
} else {
|
||
bestTTL = defaultTTL
|
||
}
|
||
}
|
||
}
|
||
|
||
return bestRule, bestRecords, bestTTL
|
||
}
|
||
|
||
// fallbackResolve 当无自定义规则命中时,回源上游 DNS 查询真实记录(对齐 EdgeDNS 做法)
|
||
func fallbackResolve(domain string, qtype string) ([]*resolveRecord, int32, error) {
|
||
var dnsType uint16
|
||
switch qtype {
|
||
case "A":
|
||
dnsType = dns.TypeA
|
||
case "AAAA":
|
||
dnsType = dns.TypeAAAA
|
||
default:
|
||
return nil, 0, nil // 仅降级处理 A 和 AAAA
|
||
}
|
||
|
||
m := new(dns.Msg)
|
||
m.SetQuestion(dns.Fqdn(domain), dnsType)
|
||
m.RecursionDesired = true
|
||
|
||
// 优先使用本机 /etc/resolv.conf 中的 DNS 服务器(对齐 EdgeDNS)
|
||
var upstream = "223.5.5.5:53"
|
||
resolveConfig, confErr := dns.ClientConfigFromFile("/etc/resolv.conf")
|
||
if confErr == nil && len(resolveConfig.Servers) > 0 {
|
||
port := resolveConfig.Port
|
||
if len(port) == 0 {
|
||
port = "53"
|
||
}
|
||
server := resolveConfig.Servers[rands.Int(0, len(resolveConfig.Servers)-1)]
|
||
upstream = server + ":" + port
|
||
}
|
||
|
||
r, _, err := sharedRecursionDNSClient.Exchange(m, upstream)
|
||
if err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
if r.Rcode != dns.RcodeSuccess {
|
||
return nil, 0, fmt.Errorf("upstream rcode: %d", r.Rcode)
|
||
}
|
||
|
||
var records []*resolveRecord
|
||
var responseTTL int32
|
||
for _, ans := range r.Answer {
|
||
switch t := ans.(type) {
|
||
case *dns.A:
|
||
if qtype == "A" {
|
||
if t.Hdr.Ttl > 0 {
|
||
ttl := int32(t.Hdr.Ttl)
|
||
if responseTTL == 0 || ttl < responseTTL {
|
||
responseTTL = ttl
|
||
}
|
||
}
|
||
ip := t.A.String()
|
||
records = append(records, &resolveRecord{
|
||
Type: "A",
|
||
IP: ip,
|
||
Line: lookupIPLineLabel(ip),
|
||
Region: lookupIPRegionSummary(ip),
|
||
})
|
||
}
|
||
case *dns.AAAA:
|
||
if qtype == "AAAA" {
|
||
if t.Hdr.Ttl > 0 {
|
||
ttl := int32(t.Hdr.Ttl)
|
||
if responseTTL == 0 || ttl < responseTTL {
|
||
responseTTL = ttl
|
||
}
|
||
}
|
||
ip := t.AAAA.String()
|
||
records = append(records, &resolveRecord{
|
||
Type: "AAAA",
|
||
IP: ip,
|
||
Line: lookupIPLineLabel(ip),
|
||
Region: lookupIPRegionSummary(ip),
|
||
})
|
||
}
|
||
}
|
||
}
|
||
|
||
return records, responseTTL, nil
|
||
}
|
||
|
||
func lookupIPRegionSummary(ip string) string {
|
||
address := strings.TrimSpace(ip)
|
||
if len(address) == 0 || net.ParseIP(address) == nil {
|
||
return ""
|
||
}
|
||
|
||
result := iplibrary.LookupIP(address)
|
||
if result == nil || !result.IsOk() {
|
||
return ""
|
||
}
|
||
|
||
return strings.TrimSpace(result.RegionSummary())
|
||
}
|
||
|
||
func lookupIPLineLabel(ip string) string {
|
||
address := strings.TrimSpace(ip)
|
||
if len(address) == 0 || net.ParseIP(address) == nil {
|
||
return "上游DNS"
|
||
}
|
||
|
||
result := iplibrary.LookupIP(address)
|
||
if result == nil || !result.IsOk() {
|
||
return "上游DNS"
|
||
}
|
||
|
||
provider := strings.TrimSpace(result.ProviderName())
|
||
country := normalizeCountryName(strings.TrimSpace(result.CountryName()))
|
||
carrier := strings.TrimSpace(normalizeCarrier(provider, country))
|
||
if isMainlandChinaCountry(country) {
|
||
if len(carrier) > 0 {
|
||
return carrier
|
||
}
|
||
return "默认"
|
||
}
|
||
|
||
if len(carrier) > 0 {
|
||
return carrier
|
||
}
|
||
if len(provider) > 0 {
|
||
return provider
|
||
}
|
||
return "上游DNS"
|
||
}
|
||
|
||
func isMainlandChinaCountry(country string) bool {
|
||
switch normalizeCountryName(country) {
|
||
case "中国":
|
||
return true
|
||
}
|
||
return false
|
||
}
|
||
|
||
func matchRuleLine(rule *pb.HTTPDNSCustomRule, profile *clientRouteProfile) (int, bool) {
|
||
scope := strings.ToLower(strings.TrimSpace(rule.GetLineScope()))
|
||
score := 0
|
||
|
||
if scope == "overseas" {
|
||
// 境外规则只匹配非中国大陆来源
|
||
if isMainlandChinaCountry(profile.Country) {
|
||
return 0, false
|
||
}
|
||
|
||
fieldScore, ok := matchRuleField(rule.GetLineContinent(), profile.Continent)
|
||
if !ok {
|
||
return 0, false
|
||
}
|
||
score += fieldScore
|
||
|
||
fieldScore, ok = matchRuleField(rule.GetLineCountry(), profile.Country)
|
||
if !ok {
|
||
return 0, false
|
||
}
|
||
score += fieldScore
|
||
return score, true
|
||
}
|
||
|
||
// 中国地区规则只匹配中国大陆来源
|
||
if !isMainlandChinaCountry(profile.Country) {
|
||
return 0, false
|
||
}
|
||
|
||
fieldScore, ok := matchRuleField(rule.GetLineCountry(), profile.Country)
|
||
if !ok {
|
||
return 0, false
|
||
}
|
||
score += fieldScore
|
||
|
||
fieldScore, ok = matchRuleField(rule.GetLineContinent(), profile.Continent)
|
||
if !ok {
|
||
return 0, false
|
||
}
|
||
score += fieldScore
|
||
|
||
fieldScore, ok = matchRuleField(rule.GetLineCarrier(), profile.Carrier, profile.ProviderRaw)
|
||
if !ok {
|
||
return 0, false
|
||
}
|
||
score += fieldScore
|
||
|
||
fieldScore, ok = matchRuleField(rule.GetLineRegion(), profile.Region)
|
||
if !ok {
|
||
return 0, false
|
||
}
|
||
score += fieldScore
|
||
|
||
fieldScore, ok = matchRuleField(rule.GetLineProvince(), profile.Province)
|
||
if !ok {
|
||
return 0, false
|
||
}
|
||
score += fieldScore
|
||
|
||
return score, true
|
||
}
|
||
|
||
func matchRuleField(ruleValue string, candidates ...string) (int, bool) {
|
||
if isDefaultLineValue(ruleValue) {
|
||
return 0, true
|
||
}
|
||
|
||
want := normalizeLineValue(ruleValue)
|
||
if len(want) == 0 {
|
||
return 0, true
|
||
}
|
||
|
||
for _, candidate := range candidates {
|
||
got := normalizeLineValue(candidate)
|
||
if len(got) == 0 {
|
||
// 如果规则设了具体值(want),但客户端信息(got)为空,则不能匹配
|
||
// 否则 strings.Contains("xxx", "") 永远为 true
|
||
continue
|
||
}
|
||
if want == got || strings.Contains(got, want) || strings.Contains(want, got) {
|
||
return 1, true
|
||
}
|
||
}
|
||
return 0, false
|
||
}
|
||
|
||
func isDefaultLineValue(value string) bool {
|
||
switch normalizeLineValue(value) {
|
||
case "", "default", "all", "*", "any", "默认", "全部", "不限":
|
||
return true
|
||
}
|
||
return false
|
||
}
|
||
|
||
func normalizeLineValue(value string) string {
|
||
v := strings.ToLower(strings.TrimSpace(value))
|
||
v = strings.ReplaceAll(v, " ", "")
|
||
v = strings.ReplaceAll(v, "-", "")
|
||
v = strings.ReplaceAll(v, "_", "")
|
||
v = strings.ReplaceAll(v, "/", "")
|
||
return v
|
||
}
|
||
|
||
func normalizeIPCandidate(ip string) string {
|
||
value := strings.TrimSpace(ip)
|
||
if len(value) == 0 {
|
||
return ""
|
||
}
|
||
if host, _, err := net.SplitHostPort(value); err == nil {
|
||
value = strings.TrimSpace(host)
|
||
}
|
||
if parsed := net.ParseIP(value); parsed != nil {
|
||
return parsed.String()
|
||
}
|
||
return ""
|
||
}
|
||
|
||
func normalizeCountryName(country string) string {
|
||
value := strings.TrimSpace(country)
|
||
if len(value) == 0 {
|
||
return ""
|
||
}
|
||
|
||
normalized := normalizeLineValue(value)
|
||
switch normalized {
|
||
case "中国香港", "香港", "hongkong":
|
||
return "中国香港"
|
||
case "中国澳门", "澳门", "macao", "macau":
|
||
return "中国澳门"
|
||
case "中国台湾", "台湾", "taiwan":
|
||
return "中国台湾"
|
||
}
|
||
|
||
switch normalized {
|
||
case "中国", "中国大陆", "中国内地", "中华人民共和国", "prc", "cn", "china", "mainlandchina", "peoplesrepublicofchina", "thepeoplesrepublicofchina":
|
||
return "中国"
|
||
case "美国", "usa", "unitedstates", "unitedstatesofamerica":
|
||
return "美国"
|
||
case "加拿大", "canada":
|
||
return "加拿大"
|
||
case "墨西哥", "mexico":
|
||
return "墨西哥"
|
||
case "日本", "japan":
|
||
return "日本"
|
||
case "韩国", "southkorea", "korea":
|
||
return "韩国"
|
||
case "新加坡", "singapore":
|
||
return "新加坡"
|
||
case "印度", "india":
|
||
return "印度"
|
||
case "泰国", "thailand":
|
||
return "泰国"
|
||
case "越南", "vietnam":
|
||
return "越南"
|
||
case "德国", "germany":
|
||
return "德国"
|
||
case "英国", "uk", "unitedkingdom", "greatbritain", "britain":
|
||
return "英国"
|
||
case "法国", "france":
|
||
return "法国"
|
||
case "荷兰", "netherlands":
|
||
return "荷兰"
|
||
case "西班牙", "spain":
|
||
return "西班牙"
|
||
case "意大利", "italy":
|
||
return "意大利"
|
||
case "俄罗斯", "russia":
|
||
return "俄罗斯"
|
||
case "巴西", "brazil":
|
||
return "巴西"
|
||
case "阿根廷", "argentina":
|
||
return "阿根廷"
|
||
case "智利", "chile":
|
||
return "智利"
|
||
case "哥伦比亚", "colombia":
|
||
return "哥伦比亚"
|
||
case "南非", "southafrica":
|
||
return "南非"
|
||
case "埃及", "egypt":
|
||
return "埃及"
|
||
case "尼日利亚", "nigeria":
|
||
return "尼日利亚"
|
||
case "肯尼亚", "kenya":
|
||
return "肯尼亚"
|
||
case "摩洛哥", "morocco":
|
||
return "摩洛哥"
|
||
case "澳大利亚", "australia":
|
||
return "澳大利亚"
|
||
case "新西兰", "newzealand":
|
||
return "新西兰"
|
||
default:
|
||
return value
|
||
}
|
||
}
|
||
|
||
func normalizeProvinceName(province string) string {
|
||
value := strings.TrimSpace(province)
|
||
if len(value) == 0 {
|
||
return ""
|
||
}
|
||
|
||
switch value {
|
||
case "内蒙古自治区":
|
||
return "内蒙古"
|
||
case "广西壮族自治区":
|
||
return "广西"
|
||
case "宁夏回族自治区":
|
||
return "宁夏"
|
||
case "新疆维吾尔自治区":
|
||
return "新疆"
|
||
case "西藏自治区":
|
||
return "西藏"
|
||
case "香港特别行政区":
|
||
return "香港"
|
||
case "澳门特别行政区":
|
||
return "澳门"
|
||
}
|
||
|
||
for _, suffix := range []string{"维吾尔自治区", "回族自治区", "壮族自治区", "自治区", "特别行政区", "省", "市"} {
|
||
value = strings.TrimSuffix(value, suffix)
|
||
}
|
||
|
||
return value
|
||
}
|
||
|
||
func ruleLineSummary(rule *pb.HTTPDNSCustomRule) string {
|
||
if rule == nil {
|
||
return ""
|
||
}
|
||
|
||
scope := strings.ToLower(strings.TrimSpace(rule.GetLineScope()))
|
||
if scope == "overseas" {
|
||
pieces := make([]string, 0, 2)
|
||
if !isDefaultLineValue(rule.GetLineContinent()) {
|
||
pieces = append(pieces, strings.TrimSpace(rule.GetLineContinent()))
|
||
}
|
||
if !isDefaultLineValue(rule.GetLineCountry()) {
|
||
pieces = append(pieces, strings.TrimSpace(rule.GetLineCountry()))
|
||
}
|
||
return strings.Join(pieces, "/")
|
||
}
|
||
|
||
pieces := make([]string, 0, 3)
|
||
if !isDefaultLineValue(rule.GetLineCarrier()) {
|
||
pieces = append(pieces, strings.TrimSpace(rule.GetLineCarrier()))
|
||
}
|
||
if !isDefaultLineValue(rule.GetLineRegion()) {
|
||
pieces = append(pieces, strings.TrimSpace(rule.GetLineRegion()))
|
||
}
|
||
if !isDefaultLineValue(rule.GetLineProvince()) {
|
||
pieces = append(pieces, strings.TrimSpace(rule.GetLineProvince()))
|
||
}
|
||
return strings.Join(pieces, "/")
|
||
}
|
||
|
||
func ruleRegionSummary(rule *pb.HTTPDNSCustomRule) string {
|
||
if rule == nil {
|
||
return ""
|
||
}
|
||
if !isDefaultLineValue(rule.GetLineProvince()) {
|
||
return strings.TrimSpace(rule.GetLineProvince())
|
||
}
|
||
if !isDefaultLineValue(rule.GetLineCountry()) {
|
||
return strings.TrimSpace(rule.GetLineCountry())
|
||
}
|
||
if !isDefaultLineValue(rule.GetLineRegion()) {
|
||
return strings.TrimSpace(rule.GetLineRegion())
|
||
}
|
||
if !isDefaultLineValue(rule.GetLineContinent()) {
|
||
return strings.TrimSpace(rule.GetLineContinent())
|
||
}
|
||
return ""
|
||
}
|
||
|
||
func (s *ResolveServer) enqueueAccessLog(item *pb.HTTPDNSAccessLog) {
|
||
if item == nil {
|
||
return
|
||
}
|
||
select {
|
||
case s.logQueue <- item:
|
||
default:
|
||
log.Println("[HTTPDNS_NODE][resolve]access log queue is full, drop request:", item.GetRequestId())
|
||
}
|
||
}
|
||
|
||
func (s *ResolveServer) startAccessLogFlusher() {
|
||
ticker := time.NewTicker(1 * time.Second)
|
||
defer ticker.Stop()
|
||
|
||
batch := make([]*pb.HTTPDNSAccessLog, 0, 128)
|
||
flush := func() {
|
||
if len(batch) == 0 {
|
||
return
|
||
}
|
||
rpcClient, err := rpc.SharedRPC()
|
||
if err != nil {
|
||
log.Println("[HTTPDNS_NODE][resolve]access-log rpc unavailable:", err.Error())
|
||
return
|
||
}
|
||
_, err = rpcClient.HTTPDNSAccessLogRPC.CreateHTTPDNSAccessLogs(rpcClient.Context(), &pb.CreateHTTPDNSAccessLogsRequest{
|
||
Logs: batch,
|
||
})
|
||
if err != nil {
|
||
log.Println("[HTTPDNS_NODE][resolve]flush access logs failed:", err.Error())
|
||
return
|
||
}
|
||
batch = batch[:0]
|
||
}
|
||
|
||
for {
|
||
select {
|
||
case item := <-s.logQueue:
|
||
if item != nil {
|
||
batch = append(batch, item)
|
||
}
|
||
if len(batch) > 4096 {
|
||
log.Println("[HTTPDNS_NODE][resolve]access log flush backlog too large, trim:", len(batch))
|
||
batch = batch[len(batch)-2048:]
|
||
}
|
||
if len(batch) >= 128 {
|
||
flush()
|
||
}
|
||
case <-ticker.C:
|
||
flush()
|
||
case <-s.quitCh:
|
||
flush()
|
||
return
|
||
}
|
||
}
|
||
}
|