换成单集群模式

This commit is contained in:
robin
2026-03-02 20:07:53 +08:00
parent 5d0b7c7e91
commit 2a76d1773d
432 changed files with 5681 additions and 5095 deletions

View File

@@ -13,14 +13,18 @@ import (
"net"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/TeaOSLab/EdgeCommon/pkg/iplibrary"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs"
"github.com/TeaOSLab/EdgeHttpDNS/internal/accesslogs"
"github.com/TeaOSLab/EdgeHttpDNS/internal/configs"
"github.com/TeaOSLab/EdgeHttpDNS/internal/rpc"
"github.com/iwind/TeaGo/rands"
"github.com/miekg/dns"
)
@@ -83,61 +87,79 @@ type clientRouteProfile struct {
RegionText string
}
type ResolveServer struct {
quitCh <-chan struct{}
type tlsListener struct {
addr string // e.g. ":443"
listener net.Listener
server *http.Server
}
type ResolveServer struct {
quitCh <-chan struct{}
snapshotManager *SnapshotManager
listenAddr string
certFile string
keyFile string
server *http.Server
// Local config fallback
fallbackAddr string
certFile string
keyFile string
logQueue chan *pb.HTTPDNSAccessLog
handler http.Handler // shared mux
tlsConfig *tls.Config // shared TLS config (with GetCertificate)
logWriter *accesslogs.HTTPDNSFileWriter
logQueue chan *pb.HTTPDNSAccessLog
// TLS certificate hot-reload
certMu sync.RWMutex
currentCert *tls.Certificate
certSnapshotAt int64
// Listener hot-reload
listenerMu sync.Mutex
listeners map[string]*tlsListener // key: addr (e.g. ":443")
}
func NewResolveServer(quitCh <-chan struct{}, snapshotManager *SnapshotManager) *ResolveServer {
listenAddr := ":443"
fallbackAddr := ":443"
certFile := ""
keyFile := ""
if apiConfig, err := configs.SharedAPIConfig(); err == nil && apiConfig != nil {
if len(apiConfig.HTTPSListenAddr) > 0 {
listenAddr = apiConfig.HTTPSListenAddr
fallbackAddr = apiConfig.HTTPSListenAddr
}
certFile = apiConfig.HTTPSCert
keyFile = apiConfig.HTTPSKey
}
logWriter := accesslogs.SharedHTTPDNSFileWriter()
if apiConfig, err := configs.SharedAPIConfig(); err == nil && apiConfig != nil {
if len(strings.TrimSpace(apiConfig.LogDir)) > 0 {
logWriter.SetDir(strings.TrimSpace(apiConfig.LogDir))
}
}
if err := logWriter.EnsureInit(); err != nil {
log.Println("[HTTPDNS_NODE][resolve]init access log file writer failed:", err.Error())
}
instance := &ResolveServer{
quitCh: quitCh,
snapshotManager: snapshotManager,
listenAddr: listenAddr,
fallbackAddr: fallbackAddr,
certFile: certFile,
keyFile: keyFile,
logWriter: logWriter,
logQueue: make(chan *pb.HTTPDNSAccessLog, 8192),
listeners: make(map[string]*tlsListener),
}
mux := http.NewServeMux()
mux.HandleFunc("/resolve", instance.handleResolve)
mux.HandleFunc("/healthz", instance.handleHealth)
instance.handler = mux
instance.server = &http.Server{
Addr: instance.listenAddr,
Handler: mux,
ReadTimeout: 5 * time.Second,
ReadHeaderTimeout: 3 * time.Second,
WriteTimeout: 5 * time.Second,
IdleTimeout: 75 * time.Second,
MaxHeaderBytes: 8 * 1024,
TLSConfig: &tls.Config{
MinVersion: tls.VersionTLS11,
// /resolve is a small JSON API; pin to HTTP/1.1 to avoid ALPN/h2 handshake variance
// across some clients and middleboxes.
NextProtos: []string{"http/1.1"},
},
// Disable automatic HTTP/2 upgrade on TLS listeners. This keeps handshake behavior
// deterministic for SDK resolve calls.
TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){},
instance.tlsConfig = &tls.Config{
MinVersion: tls.VersionTLS11,
NextProtos: []string{"http/1.1"},
GetCertificate: instance.getCertificate,
}
return instance
@@ -145,19 +167,240 @@ func NewResolveServer(quitCh <-chan struct{}, snapshotManager *SnapshotManager)
func (s *ResolveServer) Start() {
go s.startAccessLogFlusher()
go s.waitForShutdown()
log.Println("[HTTPDNS_NODE][resolve]listening HTTPS on", s.listenAddr)
if err := s.server.ListenAndServeTLS(s.certFile, s.keyFile); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Println("[HTTPDNS_NODE][resolve]listen failed:", err.Error())
// 1. Load initial certificate from file (fallback)
if len(s.certFile) > 0 && len(s.keyFile) > 0 {
cert, err := tls.LoadX509KeyPair(s.certFile, s.keyFile)
if err != nil {
log.Println("[HTTPDNS_NODE][resolve]load cert file failed:", err.Error())
} else {
s.currentCert = &cert
log.Println("[HTTPDNS_NODE][resolve]loaded initial TLS cert from file")
}
}
// 2. Try loading certificate from cluster snapshot (takes priority over file)
if snapshot := s.snapshotManager.Current(); snapshot != nil {
s.reloadCertFromSnapshot(snapshot)
}
if s.currentCert == nil {
log.Println("[HTTPDNS_NODE][resolve]WARNING: no TLS certificate available, HTTPS will fail")
}
// 3. Parse initial listen addresses and start listeners
if snapshot := s.snapshotManager.Current(); snapshot != nil {
addrs := s.desiredAddrs(snapshot)
s.syncListeners(addrs)
} else {
s.syncListeners([]string{s.fallbackAddr})
}
// 4. Watch for changes (blocks until quit)
s.watchLoop()
}
func (s *ResolveServer) getCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
s.certMu.RLock()
cert := s.currentCert
s.certMu.RUnlock()
if cert != nil {
return cert, nil
}
return nil, errors.New("no TLS certificate available")
}
type snapshotTLSConfig struct {
Listen []*serverconfigs.NetworkAddressConfig `json:"listen"`
SSLPolicy *sslconfigs.SSLPolicy `json:"sslPolicy"`
}
func (s *ResolveServer) parseTLSConfig(snapshot *LoadedSnapshot) *snapshotTLSConfig {
if snapshot.ClusterID <= 0 {
return nil
}
cluster := snapshot.Clusters[snapshot.ClusterID]
if cluster == nil {
return nil
}
raw := cluster.GetTlsPolicyJSON()
if len(raw) == 0 {
return nil
}
var cfg snapshotTLSConfig
if err := json.Unmarshal(raw, &cfg); err != nil {
log.Println("[HTTPDNS_NODE][resolve]parse tlsPolicyJSON failed:", err.Error())
return nil
}
return &cfg
}
func (s *ResolveServer) desiredAddrs(snapshot *LoadedSnapshot) []string {
cfg := s.parseTLSConfig(snapshot)
if cfg == nil || len(cfg.Listen) == 0 {
return []string{s.fallbackAddr}
}
seen := make(map[string]struct{})
var addrs []string
for _, listenCfg := range cfg.Listen {
if listenCfg == nil {
continue
}
if err := listenCfg.Init(); err != nil {
log.Println("[HTTPDNS_NODE][resolve]init listen config failed:", err.Error())
continue
}
for _, addr := range listenCfg.Addresses() {
if _, ok := seen[addr]; !ok {
seen[addr] = struct{}{}
addrs = append(addrs, addr)
}
}
}
if len(addrs) == 0 {
return []string{s.fallbackAddr}
}
sort.Strings(addrs)
return addrs
}
func (s *ResolveServer) reloadCertFromSnapshot(snapshot *LoadedSnapshot) {
cfg := s.parseTLSConfig(snapshot)
if cfg == nil || cfg.SSLPolicy == nil || len(cfg.SSLPolicy.Certs) == 0 {
s.certMu.Lock()
s.certSnapshotAt = snapshot.LoadedAt
s.certMu.Unlock()
return
}
if err := cfg.SSLPolicy.Init(context.Background()); err != nil {
log.Println("[HTTPDNS_NODE][resolve]init SSLPolicy failed:", err.Error())
s.certMu.Lock()
s.certSnapshotAt = snapshot.LoadedAt
s.certMu.Unlock()
return
}
cert := cfg.SSLPolicy.FirstCert()
if cert == nil {
s.certMu.Lock()
s.certSnapshotAt = snapshot.LoadedAt
s.certMu.Unlock()
return
}
s.certMu.Lock()
s.currentCert = cert
s.certSnapshotAt = snapshot.LoadedAt
s.certMu.Unlock()
log.Println("[HTTPDNS_NODE][resolve]TLS certificate reloaded from snapshot")
}
func (s *ResolveServer) startListener(addr string) error {
ln, err := net.Listen("tcp", addr)
if err != nil {
return fmt.Errorf("listen on %s: %w", addr, err)
}
tlsLn := tls.NewListener(ln, s.tlsConfig)
srv := &http.Server{
Handler: s.handler,
ReadTimeout: 5 * time.Second,
ReadHeaderTimeout: 3 * time.Second,
WriteTimeout: 5 * time.Second,
IdleTimeout: 75 * time.Second,
MaxHeaderBytes: 8 * 1024,
TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){},
}
s.listeners[addr] = &tlsListener{
addr: addr,
listener: tlsLn,
server: srv,
}
go func() {
if err := srv.Serve(tlsLn); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Println("[HTTPDNS_NODE][resolve]serve failed on", addr, ":", err.Error())
}
}()
log.Println("[HTTPDNS_NODE][resolve]listening HTTPS on", addr)
return nil
}
func (s *ResolveServer) stopListener(addr string) {
tl, ok := s.listeners[addr]
if !ok {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
_ = tl.server.Shutdown(ctx)
delete(s.listeners, addr)
log.Println("[HTTPDNS_NODE][resolve]stopped listener on", addr)
}
func (s *ResolveServer) syncListeners(desired []string) {
s.listenerMu.Lock()
defer s.listenerMu.Unlock()
desiredSet := make(map[string]struct{}, len(desired))
for _, addr := range desired {
desiredSet[addr] = struct{}{}
}
// Stop listeners that are no longer desired
for addr := range s.listeners {
if _, ok := desiredSet[addr]; !ok {
s.stopListener(addr)
}
}
// Start new listeners
for _, addr := range desired {
if _, ok := s.listeners[addr]; !ok {
if err := s.startListener(addr); err != nil {
log.Println("[HTTPDNS_NODE][resolve]start listener failed:", err.Error())
}
}
}
}
func (s *ResolveServer) waitForShutdown() {
<-s.quitCh
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
_ = s.server.Shutdown(ctx)
func (s *ResolveServer) watchLoop() {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
snapshot := s.snapshotManager.Current()
if snapshot == nil {
continue
}
s.certMu.RLock()
lastAt := s.certSnapshotAt
s.certMu.RUnlock()
if snapshot.LoadedAt == lastAt {
continue
}
// Snapshot changed — sync listeners and reload cert
addrs := s.desiredAddrs(snapshot)
s.syncListeners(addrs)
s.reloadCertFromSnapshot(snapshot)
case <-s.quitCh:
s.shutdownAll()
return
}
}
}
func (s *ResolveServer) shutdownAll() {
s.listenerMu.Lock()
defer s.listenerMu.Unlock()
for addr := range s.listeners {
s.stopListener(addr)
}
}
func (s *ResolveServer) handleHealth(writer http.ResponseWriter, _ *http.Request) {
@@ -219,11 +462,22 @@ func (s *ResolveServer) handleResolve(writer http.ResponseWriter, request *http.
return
}
if snapshot.ClusterID > 0 &&
loadedApp.App.GetPrimaryClusterId() != snapshot.ClusterID &&
loadedApp.App.GetBackupClusterId() != snapshot.ClusterID {
s.writeFailedResolve(writer, requestID, snapshot, loadedApp.App, domain, qtype, httpdnsCodeAppInvalid, "当前应用未绑定到该解析集群", startAt, request, query)
return
if snapshot.ClusterID > 0 {
var appClusterIds []int64
if len(loadedApp.App.GetClusterIdsJSON()) > 0 {
_ = json.Unmarshal(loadedApp.App.GetClusterIdsJSON(), &appClusterIds)
}
var clusterBound bool
for _, cid := range appClusterIds {
if cid == snapshot.ClusterID {
clusterBound = true
break
}
}
if !clusterBound {
s.writeFailedResolve(writer, requestID, snapshot, loadedApp.App, domain, qtype, httpdnsCodeAppInvalid, "当前应用未绑定到该解析集群", startAt, request, query)
return
}
}
loadedDomain := loadedApp.Domains[domain]
@@ -341,11 +595,14 @@ func pickDefaultTTL(snapshot *LoadedSnapshot, app *pb.HTTPDNSApp) int32 {
}
}
if app != nil {
if cluster := snapshot.Clusters[app.GetPrimaryClusterId()]; cluster != nil && cluster.GetDefaultTTL() > 0 {
return cluster.GetDefaultTTL()
var appClusterIds []int64
if len(app.GetClusterIdsJSON()) > 0 {
_ = json.Unmarshal(app.GetClusterIdsJSON(), &appClusterIds)
}
if cluster := snapshot.Clusters[app.GetBackupClusterId()]; cluster != nil && cluster.GetDefaultTTL() > 0 {
return cluster.GetDefaultTTL()
for _, cid := range appClusterIds {
if cluster := snapshot.Clusters[cid]; cluster != nil && cluster.GetDefaultTTL() > 0 {
return cluster.GetDefaultTTL()
}
}
}
return 30
@@ -591,15 +848,19 @@ func normalizeChinaRegion(province string) string {
func normalizeContinent(country string) string {
switch normalizeCountryName(country) {
case "中国", "中国香港", "中国澳门", "中国台湾", "日本", "韩国", "新加坡", "印度", "泰国", "越南":
case "中国", "中国香港", "中国澳门", "中国台湾", "日本", "韩国", "新加坡", "印度", "泰国", "越南",
"印度尼西亚", "马来西亚", "菲律宾", "柬埔寨", "缅甸", "老挝", "斯里兰卡", "孟加拉国", "巴基斯坦", "尼泊尔",
"阿联酋", "沙特阿拉伯", "土耳其", "以色列", "伊朗", "伊拉克", "卡塔尔", "科威特", "蒙古":
return "亚洲"
case "美国", "加拿大", "墨西哥":
case "美国", "加拿大", "墨西哥", "巴拿马", "哥斯达黎加", "古巴":
return "北美洲"
case "巴西", "阿根廷", "智利", "哥伦比亚":
case "巴西", "阿根廷", "智利", "哥伦比亚", "秘鲁", "委内瑞拉", "厄瓜多尔":
return "南美洲"
case "德国", "英国", "法国", "荷兰", "西班牙", "意大利", "俄罗斯":
case "德国", "英国", "法国", "荷兰", "西班牙", "意大利", "俄罗斯",
"波兰", "瑞典", "瑞士", "挪威", "芬兰", "丹麦", "葡萄牙", "爱尔兰", "比利时", "奥地利",
"乌克兰", "捷克", "罗马尼亚", "匈牙利", "希腊":
return "欧洲"
case "南非", "埃及", "尼日利亚", "肯尼亚", "摩洛哥":
case "南非", "埃及", "尼日利亚", "肯尼亚", "摩洛哥", "阿尔及利亚", "坦桑尼亚", "埃塞俄比亚", "加纳", "突尼斯":
return "非洲"
case "澳大利亚", "新西兰":
return "大洋洲"
@@ -639,8 +900,8 @@ func pickRuleRecords(rules []*pb.HTTPDNSCustomRule, qtype string, profile *clien
Type: qtype,
IP: value,
Weight: item.GetWeight(),
Line: ruleLineSummary(rule),
Region: ruleRegionSummary(rule),
Line: lookupIPLineLabel(value),
Region: lookupIPRegionSummary(value),
})
}
if len(records) == 0 {
@@ -969,6 +1230,91 @@ func normalizeCountryName(country string) string {
return "智利"
case "哥伦比亚", "colombia":
return "哥伦比亚"
// --- 亚洲(新增)---
case "印度尼西亚", "indonesia":
return "印度尼西亚"
case "马来西亚", "malaysia":
return "马来西亚"
case "菲律宾", "philippines":
return "菲律宾"
case "柬埔寨", "cambodia":
return "柬埔寨"
case "缅甸", "myanmar", "burma":
return "缅甸"
case "老挝", "laos":
return "老挝"
case "斯里兰卡", "srilanka":
return "斯里兰卡"
case "孟加拉国", "孟加拉", "bangladesh":
return "孟加拉国"
case "巴基斯坦", "pakistan":
return "巴基斯坦"
case "尼泊尔", "nepal":
return "尼泊尔"
case "阿联酋", "阿拉伯联合酋长国", "uae", "unitedarabemirates":
return "阿联酋"
case "沙特阿拉伯", "沙特", "saudiarabia", "saudi":
return "沙特阿拉伯"
case "土耳其", "turkey", "türkiye", "turkiye":
return "土耳其"
case "以色列", "israel":
return "以色列"
case "伊朗", "iran":
return "伊朗"
case "伊拉克", "iraq":
return "伊拉克"
case "卡塔尔", "qatar":
return "卡塔尔"
case "科威特", "kuwait":
return "科威特"
case "蒙古", "mongolia":
return "蒙古"
// --- 欧洲(新增)---
case "波兰", "poland":
return "波兰"
case "瑞典", "sweden":
return "瑞典"
case "瑞士", "switzerland":
return "瑞士"
case "挪威", "norway":
return "挪威"
case "芬兰", "finland":
return "芬兰"
case "丹麦", "denmark":
return "丹麦"
case "葡萄牙", "portugal":
return "葡萄牙"
case "爱尔兰", "ireland":
return "爱尔兰"
case "比利时", "belgium":
return "比利时"
case "奥地利", "austria":
return "奥地利"
case "乌克兰", "ukraine":
return "乌克兰"
case "捷克", "czech", "czechrepublic", "czechia":
return "捷克"
case "罗马尼亚", "romania":
return "罗马尼亚"
case "匈牙利", "hungary":
return "匈牙利"
case "希腊", "greece":
return "希腊"
// --- 北美洲(新增)---
case "巴拿马", "panama":
return "巴拿马"
case "哥斯达黎加", "costarica":
return "哥斯达黎加"
case "古巴", "cuba":
return "古巴"
// --- 南美洲(新增)---
case "秘鲁", "peru":
return "秘鲁"
case "委内瑞拉", "venezuela":
return "委内瑞拉"
case "厄瓜多尔", "ecuador":
return "厄瓜多尔"
// --- 非洲 ---
case "南非", "southafrica":
return "南非"
case "埃及", "egypt":
@@ -979,6 +1325,17 @@ func normalizeCountryName(country string) string {
return "肯尼亚"
case "摩洛哥", "morocco":
return "摩洛哥"
case "阿尔及利亚", "algeria":
return "阿尔及利亚"
case "坦桑尼亚", "tanzania":
return "坦桑尼亚"
case "埃塞俄比亚", "ethiopia":
return "埃塞俄比亚"
case "加纳", "ghana":
return "加纳"
case "突尼斯", "tunisia":
return "突尼斯"
// --- 大洋洲 ---
case "澳大利亚", "australia":
return "澳大利亚"
case "新西兰", "newzealand":
@@ -1087,18 +1444,7 @@ func (s *ResolveServer) startAccessLogFlusher() {
if len(batch) == 0 {
return
}
rpcClient, err := rpc.SharedRPC()
if err != nil {
log.Println("[HTTPDNS_NODE][resolve]access-log rpc unavailable:", err.Error())
return
}
_, err = rpcClient.HTTPDNSAccessLogRPC.CreateHTTPDNSAccessLogs(rpcClient.Context(), &pb.CreateHTTPDNSAccessLogsRequest{
Logs: batch,
})
if err != nil {
log.Println("[HTTPDNS_NODE][resolve]flush access logs failed:", err.Error())
return
}
s.logWriter.WriteBatch(batch)
batch = batch[:0]
}
@@ -1119,6 +1465,7 @@ func (s *ResolveServer) startAccessLogFlusher() {
flush()
case <-s.quitCh:
flush()
_ = s.logWriter.Close()
return
}
}