v1.5.1 增强程序稳定性

This commit is contained in:
robin
2026-03-22 17:37:40 +08:00
parent afbaaa869c
commit 17e182b413
652 changed files with 22949 additions and 34397 deletions

View File

@@ -31,7 +31,11 @@ import (
"time"
)
var sharedRecursionDNSClient = &dns.Client{}
var sharedRecursionDNSClient = &dns.Client{
Timeout: 5 * time.Second,
ReadTimeout: 3 * time.Second,
WriteTimeout: 2 * time.Second,
}
type httpContextKey struct {
key string
@@ -171,9 +175,70 @@ func (this *Server) init() error {
return nil
}
// addECSOption 向 DNS 请求中添加 EDNS Client Subnet (ECS) 信息
// addECSOption 向 DNS 请求中设置 EDNS Client Subnet (ECS)。
// 如果请求已携带 ECS 则覆盖(避免双 ECS 导致上游 malformed request
func addECSOption(req *dns.Msg, clientIP string) {
if len(clientIP) == 0 {
return
}
ip := net.ParseIP(clientIP)
if ip == nil {
return
}
var ecs = &dns.EDNS0_SUBNET{
Code: dns.EDNS0SUBNET,
}
if ip.To4() != nil {
ecs.Family = 1 // IPv4
ecs.SourceNetmask = 24
ecs.Address = ip.To4()
} else {
ecs.Family = 2 // IPv6
ecs.SourceNetmask = 56
ecs.Address = ip
}
// 查找或创建 OPT 记录
var opt = req.IsEdns0()
if opt == nil {
req.SetEdns0(4096, false)
opt = req.IsEdns0()
}
if opt != nil {
// 删除已有的 ECS option避免出现双 EDNS0_SUBNET
var filtered []dns.EDNS0
for _, o := range opt.Option {
if o.Option() != dns.EDNS0SUBNET {
filtered = append(filtered, o)
}
}
opt.Option = append(filtered, ecs)
}
}
// stripECSFromExtra 从 Extra section 中移除 OPT 记录里的 EDNS0_SUBNET
// 防止服务端注入的 ECS 信息回传给下游客户端(隐私泄露风险)。
func stripECSFromExtra(extra []dns.RR) []dns.RR {
for _, rr := range extra {
if opt, ok := rr.(*dns.OPT); ok {
var filtered []dns.EDNS0
for _, o := range opt.Option {
if o.Option() != dns.EDNS0SUBNET {
filtered = append(filtered, o)
}
}
opt.Option = filtered
}
}
return extra
}
// 查询递归DNS
func (this *Server) lookupRecursionDNS(req *dns.Msg, resp *dns.Msg) error {
var config = sharedNodeConfig.RecursionConfig
func (this *Server) lookupRecursionDNS(req *dns.Msg, resp *dns.Msg, clientIP string) error {
var config = dnsNodeConfig().RecursionConfig
if config == nil {
return nil
}
@@ -182,6 +247,9 @@ func (this *Server) lookupRecursionDNS(req *dns.Msg, resp *dns.Msg) error {
}
// 是否允许
if len(req.Question) == 0 {
return nil
}
var domain = strings.TrimSuffix(req.Question[0].Name, ".")
if len(config.DenyDomains) > 0 && configutils.MatchDomains(config.DenyDomains, domain) {
return nil
@@ -190,6 +258,9 @@ func (this *Server) lookupRecursionDNS(req *dns.Msg, resp *dns.Msg) error {
return nil
}
// 携带客户端真实 IPECS向上游查询
addECSOption(req, clientIP)
if config.UseLocalHosts {
// TODO 需要缓存文件内容
resolveConfig, err := dns.ClientConfigFromFile("/etc/resolv.conf")
@@ -206,7 +277,12 @@ func (this *Server) lookupRecursionDNS(req *dns.Msg, resp *dns.Msg) error {
if err != nil {
return err
}
resp.Answer = r.Answer
if r != nil {
resp.Rcode = r.Rcode
resp.Answer = r.Answer
resp.Ns = r.Ns
resp.Extra = stripECSFromExtra(r.Extra)
}
} else if len(config.Hosts) > 0 {
var host = config.Hosts[rands.Int(0, len(config.Hosts)-1)]
if host.Port <= 0 {
@@ -216,7 +292,12 @@ func (this *Server) lookupRecursionDNS(req *dns.Msg, resp *dns.Msg) error {
if err != nil {
return err
}
resp.Answer = r.Answer
if r != nil {
resp.Rcode = r.Rcode
resp.Answer = r.Answer
resp.Ns = r.Ns
resp.Extra = stripECSFromExtra(r.Extra)
}
}
return nil
@@ -270,7 +351,8 @@ func (this *Server) parseAction(questionName string, remoteAddr *string) (string
// 记录日志
func (this *Server) addLog(networking string, question dns.Question, domainId int64, routeCode string, record *models.NSRecord, isRecursive bool, writer dns.ResponseWriter, remoteAddr string, err error) {
// 访问日志
var accessLogRef = sharedNodeConfig.AccessLogRef
var nodeConfig = dnsNodeConfig()
var accessLogRef = nodeConfig.AccessLogRef
if accessLogRef != nil && accessLogRef.IsOn {
if domainId == 0 && !accessLogRef.LogMissingDomains {
return
@@ -282,7 +364,7 @@ func (this *Server) addLog(networking string, question dns.Question, domainId in
var now = time.Now()
var pbAccessLog = &pb.NSAccessLog{
NsNodeId: sharedNodeConfig.Id,
NsNodeId: nodeConfig.Id,
RemoteAddr: remoteAddr,
NsDomainId: domainId,
QuestionName: question.Name,
@@ -428,8 +510,14 @@ func (this *Server) handleDNSMessage(writer dns.ResponseWriter, req *dns.Msg) {
domain, recordName = sharedDomainManager.SplitDomain(fullName)
if domain == nil {
// 检查递归DNS
if sharedNodeConfig.RecursionConfig != nil && sharedNodeConfig.RecursionConfig.IsOn {
err := this.lookupRecursionDNS(req, resp)
var recursionConfig = dnsNodeConfig().RecursionConfig
if recursionConfig != nil && recursionConfig.IsOn {
// 提取客户端 IP 用于 ECS
var clientIP = remoteAddr
if clientHost, _, splitErr := net.SplitHostPort(clientIP); splitErr == nil && len(clientHost) > 0 {
clientIP = clientHost
}
err := this.lookupRecursionDNS(req, resp, clientIP)
if err != nil {
this.addLog(networking, question, 0, "", nil, true, writer, remoteAddr, err)
} else {
@@ -459,7 +547,7 @@ func (this *Server) handleDNSMessage(writer dns.ResponseWriter, req *dns.Msg) {
// 是否为NS记录用于验证域名所有权
if question.Qtype == dns.TypeNS {
var hosts = sharedNodeConfig.Hosts
var hosts = dnsNodeConfig().Hosts
var l = len(hosts)
var record = &models.NSRecord{
Id: 0,
@@ -518,7 +606,7 @@ func (this *Server) handleDNSMessage(writer dns.ResponseWriter, req *dns.Msg) {
}
// 解析Agent
if sharedNodeConfig.DetectAgents {
if dnsNodeConfig().DetectAgents {
agents.SharedQueue.Push(clientIP)
}
@@ -569,7 +657,7 @@ func (this *Server) handleDNSMessage(writer dns.ResponseWriter, req *dns.Msg) {
}
// 对 NS.example.com NS|SOA 处理
if (question.Qtype == dns.TypeNS || (question.Qtype == dns.TypeSOA && len(records) == 0)) && lists.ContainsString(sharedNodeConfig.Hosts, fullName) {
if (question.Qtype == dns.TypeNS || (question.Qtype == dns.TypeSOA && len(records) == 0)) && lists.ContainsString(dnsNodeConfig().Hosts, fullName) {
var recordDNSType string
switch question.Qtype {
case dns.TypeNS:
@@ -663,7 +751,7 @@ func (this *Server) handleDNSMessage(writer dns.ResponseWriter, req *dns.Msg) {
}
case dnsconfigs.RecordTypeNS:
if record.Id == 0 {
var hosts = sharedNodeConfig.Hosts
var hosts = dnsNodeConfig().Hosts
var l = len(hosts)
if l > 0 {
// 随机
@@ -900,8 +988,9 @@ func (this *Server) handleHTTPJSONAPI(writer http.ResponseWriter, req *http.Requ
// 组合SOA回复信息
func (this *Server) composeSOAAnswer(question dns.Question, record *models.NSRecord, resp *dns.Msg) {
var config = sharedNodeConfig.SOA
var serial = sharedNodeConfig.SOASerial
var nodeCfg = dnsNodeConfig()
var config = nodeCfg.SOA
var serial = nodeCfg.SOASerial
if config == nil {
config = dnsconfigs.DefaultNSSOAConfig()
@@ -909,7 +998,7 @@ func (this *Server) composeSOAAnswer(question dns.Question, record *models.NSRec
var mName = config.MName
if len(mName) == 0 {
var hosts = sharedNodeConfig.Hosts
var hosts = nodeCfg.Hosts
var l = len(hosts)
if l > 0 {
var index = rands.Int(0, l-1)
@@ -919,7 +1008,7 @@ func (this *Server) composeSOAAnswer(question dns.Question, record *models.NSRec
var rName = config.RName
if len(rName) == 0 {
rName = sharedNodeConfig.Email
rName = nodeCfg.Email
}
rName = strings.ReplaceAll(rName, "@", ".")