Files
2026-03-22 17:37:40 +08:00

1031 lines
27 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes
import (
"context"
"crypto/tls"
"encoding/base64"
"encoding/json"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeDNS/internal/agents"
"github.com/TeaOSLab/EdgeDNS/internal/models"
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
"github.com/TeaOSLab/EdgeDNS/internal/stats"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types"
"github.com/miekg/dns"
"io"
"math/rand"
"net"
"net/http"
"regexp"
"strings"
"time"
)
var sharedRecursionDNSClient = &dns.Client{
Timeout: 5 * time.Second,
ReadTimeout: 3 * time.Second,
WriteTimeout: 2 * time.Second,
}
type httpContextKey struct {
key string
}
var HTTPConnContextKey = &httpContextKey{key: "http-conn"}
const PingDomain = "ping."
// Server 服务
type Server struct {
config *ServerConfig
rawServer *dns.Server
httpsServer *http.Server
}
// NewServer 构造新服务
func NewServer(config *ServerConfig) (*Server, error) {
var server = &Server{
config: config,
}
err := server.init()
if err != nil {
return nil, err
}
return server, nil
}
// ListenAndServe 监听
func (this *Server) ListenAndServe() error {
if this.rawServer != nil {
return this.rawServer.ListenAndServe()
}
if this.httpsServer != nil {
listener, err := net.Listen("tcp", this.httpsServer.Addr)
if err != nil {
return err
}
err = this.httpsServer.ServeTLS(listener, "", "")
if err == http.ErrServerClosed {
err = nil
}
return err
}
return errors.New("the server is not initialized")
}
// Shutdown 关闭
func (this *Server) Shutdown() error {
if this.rawServer != nil {
return this.rawServer.Shutdown()
}
if this.httpsServer != nil {
return this.httpsServer.Shutdown(context.Background())
}
return errors.New("the server is not initialized")
}
// Reload 重载配置
func (this *Server) Reload(config *ServerConfig) {
this.config = config
}
// 初始化
func (this *Server) init() error {
var rawServer = &dns.Server{
ReadTimeout: 3 * time.Second,
WriteTimeout: 3 * time.Second,
}
rawServer.Handler = dns.HandlerFunc(this.handleDNSMessage)
var addr = ""
if len(this.config.Host) > 0 {
addr += configutils.QuoteIP(this.config.Host)
}
addr += ":" + types.String(this.config.Port)
rawServer.Addr = addr
switch this.config.Protocol {
case serverconfigs.ProtocolTCP:
rawServer.Net = "tcp"
case serverconfigs.ProtocolTLS:
rawServer.Net = "tcp-tls"
rawServer.TLSConfig = &tls.Config{
Certificates: nil,
GetConfigForClient: func(clientInfo *tls.ClientHelloInfo) (config *tls.Config, e error) {
return this.config.SSLPolicy.TLSConfig(), nil
},
GetCertificate: func(clientInfo *tls.ClientHelloInfo) (certificate *tls.Certificate, e error) {
return this.config.SSLPolicy.FirstCert(), nil
},
}
case serverconfigs.ProtocolHTTPS: // DoH
rawServer = nil
this.httpsServer = &http.Server{
Addr: addr,
Handler: http.HandlerFunc(this.handleHTTP),
TLSConfig: &tls.Config{
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
if this.config == nil {
return nil, errors.New("invalid 'ServerConfig.config'")
}
if this.config.SSLPolicy == nil {
return nil, errors.New("invalid 'ServerConfig.config.SSLPolicy'")
}
return this.config.SSLPolicy.TLSConfig(), nil
},
GetCertificate: func(clientInfo *tls.ClientHelloInfo) (certificate *tls.Certificate, e error) {
if this.config == nil {
return nil, errors.New("invalid 'ServerConfig.config'")
}
if this.config.SSLPolicy == nil {
return nil, errors.New("invalid 'ServerConfig.config.SSLPolicy'")
}
return this.config.SSLPolicy.FirstCert(), nil
},
},
ConnContext: func(ctx context.Context, c net.Conn) context.Context {
return context.WithValue(ctx, HTTPConnContextKey, c)
},
ReadTimeout: 5 * time.Second,
ReadHeaderTimeout: 3 * time.Second,
WriteTimeout: 3 * time.Second,
IdleTimeout: 75 * time.Second,
MaxHeaderBytes: 4096,
}
case serverconfigs.ProtocolUDP:
rawServer.Net = "udp"
}
this.rawServer = rawServer
return nil
}
// addECSOption 向 DNS 请求中添加 EDNS Client Subnet (ECS) 信息
// addECSOption 向 DNS 请求中设置 EDNS Client Subnet (ECS)。
// 如果请求已携带 ECS 则覆盖(避免双 ECS 导致上游 malformed request
func addECSOption(req *dns.Msg, clientIP string) {
if len(clientIP) == 0 {
return
}
ip := net.ParseIP(clientIP)
if ip == nil {
return
}
var ecs = &dns.EDNS0_SUBNET{
Code: dns.EDNS0SUBNET,
}
if ip.To4() != nil {
ecs.Family = 1 // IPv4
ecs.SourceNetmask = 24
ecs.Address = ip.To4()
} else {
ecs.Family = 2 // IPv6
ecs.SourceNetmask = 56
ecs.Address = ip
}
// 查找或创建 OPT 记录
var opt = req.IsEdns0()
if opt == nil {
req.SetEdns0(4096, false)
opt = req.IsEdns0()
}
if opt != nil {
// 删除已有的 ECS option避免出现双 EDNS0_SUBNET
var filtered []dns.EDNS0
for _, o := range opt.Option {
if o.Option() != dns.EDNS0SUBNET {
filtered = append(filtered, o)
}
}
opt.Option = append(filtered, ecs)
}
}
// stripECSFromExtra 从 Extra section 中移除 OPT 记录里的 EDNS0_SUBNET
// 防止服务端注入的 ECS 信息回传给下游客户端(隐私泄露风险)。
func stripECSFromExtra(extra []dns.RR) []dns.RR {
for _, rr := range extra {
if opt, ok := rr.(*dns.OPT); ok {
var filtered []dns.EDNS0
for _, o := range opt.Option {
if o.Option() != dns.EDNS0SUBNET {
filtered = append(filtered, o)
}
}
opt.Option = filtered
}
}
return extra
}
// 查询递归DNS
func (this *Server) lookupRecursionDNS(req *dns.Msg, resp *dns.Msg, clientIP string) error {
var config = dnsNodeConfig().RecursionConfig
if config == nil {
return nil
}
if !config.IsOn {
return nil
}
// 是否允许
if len(req.Question) == 0 {
return nil
}
var domain = strings.TrimSuffix(req.Question[0].Name, ".")
if len(config.DenyDomains) > 0 && configutils.MatchDomains(config.DenyDomains, domain) {
return nil
}
if len(config.AllowDomains) > 0 && !configutils.MatchDomains(config.AllowDomains, domain) {
return nil
}
// 携带客户端真实 IPECS向上游查询
addECSOption(req, clientIP)
if config.UseLocalHosts {
// TODO 需要缓存文件内容
resolveConfig, err := dns.ClientConfigFromFile("/etc/resolv.conf")
if err != nil {
return err
}
if len(resolveConfig.Servers) == 0 {
return errors.New("no dns servers found in config file")
}
if len(resolveConfig.Port) == 0 {
resolveConfig.Port = "53"
}
r, _, err := sharedRecursionDNSClient.Exchange(req, configutils.QuoteIP(resolveConfig.Servers[rands.Int(0, len(resolveConfig.Servers)-1)])+":"+resolveConfig.Port)
if err != nil {
return err
}
if r != nil {
resp.Rcode = r.Rcode
resp.Answer = r.Answer
resp.Ns = r.Ns
resp.Extra = stripECSFromExtra(r.Extra)
}
} else if len(config.Hosts) > 0 {
var host = config.Hosts[rands.Int(0, len(config.Hosts)-1)]
if host.Port <= 0 {
host.Port = 53
}
r, _, err := sharedRecursionDNSClient.Exchange(req, configutils.QuoteIP(host.Host)+":"+types.String(host.Port))
if err != nil {
return err
}
if r != nil {
resp.Rcode = r.Rcode
resp.Answer = r.Answer
resp.Ns = r.Ns
resp.Extra = stripECSFromExtra(r.Extra)
}
}
return nil
}
// 分析查询中的动作
func (this *Server) parseAction(questionName string, remoteAddr *string) (string, error) {
rpcClient, err := rpc.SharedRPC()
if err != nil {
return "", err
}
// TODO 需要防止恶意攻击
var optionIndex = strings.Index(questionName, "-")
if optionIndex > 0 {
optionId := types.Int64(questionName[1:optionIndex])
optionResp, err := rpcClient.NSQuestionOptionRPC.FindNSQuestionOption(rpcClient.Context(), &pb.FindNSQuestionOptionRequest{NsQuestionOptionId: optionId})
if err != nil {
return "", errors.New("query question option failed: " + err.Error())
} else {
var option = optionResp.NsQuestionOption
if option != nil {
switch option.Name {
case "setRemoteAddr":
var m = maps.Map{}
err = json.Unmarshal(option.ValuesJSON, &m)
if err != nil {
return "", errors.New("decode question option failed: " + err.Error())
} else {
var ip = m.GetString("ip")
// 验证 IP 地址合法性,防止 IP 欺骗
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return "", errors.New("invalid IP address in setRemoteAddr: " + ip)
}
// 拒绝回环地址和未指定地址
if parsedIP.IsLoopback() || parsedIP.IsUnspecified() {
return "", errors.New("disallowed IP address in setRemoteAddr: " + ip)
}
*remoteAddr = ip
}
}
}
}
questionName = questionName[optionIndex+1:]
}
return questionName, nil
}
// 记录日志
func (this *Server) addLog(networking string, question dns.Question, domainId int64, routeCode string, record *models.NSRecord, isRecursive bool, writer dns.ResponseWriter, remoteAddr string, err error) {
// 访问日志
var nodeConfig = dnsNodeConfig()
var accessLogRef = nodeConfig.AccessLogRef
if accessLogRef != nil && accessLogRef.IsOn {
if domainId == 0 && !accessLogRef.LogMissingDomains {
return
}
if accessLogRef.MissingRecordsOnly && record != nil && len(record.Value) > 0 {
return
}
var now = time.Now()
var pbAccessLog = &pb.NSAccessLog{
NsNodeId: nodeConfig.Id,
RemoteAddr: remoteAddr,
NsDomainId: domainId,
QuestionName: question.Name,
QuestionType: dns.Type(question.Qtype).String(),
IsRecursive: isRecursive,
Networking: networking,
ServerAddr: writer.LocalAddr().String(),
Timestamp: now.Unix(),
TimeLocal: now.Format("2/Jan/2006:15:04:05 -0700"),
RequestId: "",
}
if record != nil {
pbAccessLog.NsRecordId = record.Id
if len(routeCode) > 0 {
pbAccessLog.NsRouteCodes = []string{routeCode}
}
pbAccessLog.RecordName = record.Name
pbAccessLog.RecordType = record.Type
pbAccessLog.RecordValue = record.Value
}
if err != nil {
pbAccessLog.Error = err.Error()
}
sharedNSAccessLogQueue.Push(pbAccessLog)
}
}
// 验证TSIG
func (this *Server) checkTSIG(msg *dns.Msg, domainId int64) error {
var tsig = msg.IsTsig()
if tsig == nil {
return errors.New("tsig: tsig required")
}
var keys = sharedKeyManager.FindKeysWithDomain(domainId)
if len(keys) == 0 {
return errors.New("tsig: no keys defined")
}
for _, key := range keys {
if key.Algo != tsig.Algorithm {
continue
}
// 需要重新Pack每次Pack结果只能校验一次
msgData, err := msg.Pack()
if err != nil {
return err
}
if len(msgData) == 0 {
return nil
}
var base64Secret = key.Secret
if key.SecretType == dnsconfigs.NSKeySecretTypeClear {
base64Secret = base64.StdEncoding.EncodeToString([]byte(key.Secret))
}
err = dns.TsigVerify(msgData, base64Secret, "", false)
if err != nil {
continue
} else {
return nil
}
}
return dns.ErrSig
}
// 处理DNS请求
func (this *Server) handleDNSMessage(writer dns.ResponseWriter, req *dns.Msg) {
if len(req.Question) == 0 {
return
}
if sharedDomainManager == nil {
return
}
var networking = ""
if this.config != nil {
networking = this.config.Protocol.String()
}
var resultDomainId int64
var resultRecordIds [][2]int64 // [] { domainId, recordId}
var resp = new(dns.Msg)
resp.RecursionDesired = true
resp.RecursionAvailable = true
resp.SetReply(req)
resp.Answer = []dns.RR{}
var tsigIsChecked = false
var remoteAddr = writer.RemoteAddr().String()
for _, question := range req.Question {
if len(question.Name) == 0 {
continue
}
// PING
if question.Name == PingDomain {
resp.Answer = append(resp.Answer, &dns.A{
Hdr: dns.RR_Header{
Name: question.Name,
Rrtype: dns.TypeA,
Class: question.Qclass,
Ttl: 60,
},
A: net.ParseIP("127.0.0.1"),
})
resp.Rcode = dns.RcodeSuccess
err := writer.WriteMsg(resp)
if err != nil {
return
}
return
}
// 查询选项
if question.Name[0] == '$' {
_, port, _ := net.SplitHostPort(remoteAddr)
questionName, err := this.parseAction(question.Name, &remoteAddr)
if err != nil {
remotelogs.Error("SERVER", "invalid query option '"+question.Name+"'")
continue
}
question.Name = questionName
if len(port) > 0 {
if strings.Contains(remoteAddr, ":") { // IPv6
remoteAddr = "[" + remoteAddr + "]:" + port
} else {
remoteAddr += ":" + port
}
}
}
var fullName = strings.TrimSuffix(question.Name, ".")
var recordName string
var recordType = dns.Type(question.Qtype).String()
var domain *models.NSDomain
domain, recordName = sharedDomainManager.SplitDomain(fullName)
if domain == nil {
// 检查递归DNS
var recursionConfig = dnsNodeConfig().RecursionConfig
if recursionConfig != nil && recursionConfig.IsOn {
// 提取客户端 IP 用于 ECS
var clientIP = remoteAddr
if clientHost, _, splitErr := net.SplitHostPort(clientIP); splitErr == nil && len(clientHost) > 0 {
clientIP = clientHost
}
err := this.lookupRecursionDNS(req, resp, clientIP)
if err != nil {
this.addLog(networking, question, 0, "", nil, true, writer, remoteAddr, err)
} else {
var recordValue = ""
if len(resp.Answer) > 0 {
pieces := regexp.MustCompile(`\s+`).Split(resp.Answer[0].String(), 6)
if len(pieces) >= 5 {
recordValue = pieces[4]
}
}
this.addLog(networking, question, 0, "", &models.NSRecord{
Id: 0,
Name: recordName,
Type: recordType,
Value: recordValue,
Ttl: 0,
}, true, writer, remoteAddr, nil)
}
err = writer.WriteMsg(resp)
if err != nil {
return
}
return
}
// 是否为NS记录用于验证域名所有权
if question.Qtype == dns.TypeNS {
var hosts = dnsNodeConfig().Hosts
var l = len(hosts)
var record = &models.NSRecord{
Id: 0,
Type: dnsconfigs.RecordTypeNS,
Ttl: 600, // TODO 可以设置
}
if l > 0 {
l = 1 // 目前只返回一个
// 随机
var indexes = []int{}
for i := 0; i < l; i++ {
indexes = append(indexes, i)
}
rand.Shuffle(l, func(i, j int) {
indexes[i], indexes[j] = indexes[j], indexes[i]
})
record.Value = hosts[0] + "."
for _, index := range indexes {
resp.Answer = append(resp.Answer, &dns.NS{
Hdr: record.ToRRHeader(question.Name, dns.TypeNS, question.Qclass),
Ns: hosts[index] + ".",
})
}
this.addLog(networking, question, 0, "", record, false, writer, remoteAddr, nil)
continue
}
this.addLog(networking, question, 0, "", nil, false, writer, remoteAddr, nil)
continue
}
this.addLog(networking, question, 0, "", nil, false, writer, remoteAddr, nil)
continue
}
// 检查TSIG
if domain.TSIG != nil && domain.TSIG.IsOn && !tsigIsChecked {
err := this.checkTSIG(req, domain.Id)
if err != nil {
this.addLog(networking, question, domain.Id, "", nil, false, writer, remoteAddr, err)
continue
}
tsigIsChecked = true
}
resultDomainId = domain.Id
var clientIP = remoteAddr
clientHost, _, err := net.SplitHostPort(clientIP)
if err == nil && len(clientHost) > 0 {
clientIP = clientHost
}
// 解析Agent
if dnsNodeConfig().DetectAgents {
agents.SharedQueue.Push(clientIP)
}
var routeCodes = sharedRouteManager.FindRouteCodes(clientIP, domain.UserId)
var records []*models.NSRecord
var matchedRouteCode string
if question.Qtype == dns.TypeSOA { // SOA
if len(recordName) == 0 { // 只有顶级域名才有SOA记录
records = []*models.NSRecord{
{
Id: 0,
Type: dnsconfigs.RecordTypeSOA,
Ttl: 600, // TODO 可以设置
},
}
}
} else if question.Qtype == dns.TypeNS { // NS
if len(recordName) == 0 { // 只有顶级域名才有NS记录
records = []*models.NSRecord{
{
Id: 0,
Type: dnsconfigs.RecordTypeNS,
Ttl: 600, // TODO 可以设置
},
}
}
} else if question.Qtype != dns.TypeCNAME {
// 是否有直接的设置
records, matchedRouteCode = sharedRecordManager.FindRecords(domain.Id, routeCodes, recordName, recordType, true)
// 检查CNAME
if len(records) == 0 {
records, matchedRouteCode = sharedRecordManager.FindRecords(domain.Id, routeCodes, recordName, dnsconfigs.RecordTypeCNAME, false)
if len(records) > 0 {
question.Qtype = dns.TypeCNAME
}
}
// 再次尝试查找默认设置
if len(records) == 0 {
records, matchedRouteCode = sharedRecordManager.FindRecords(domain.Id, routeCodes, recordName, recordType, false)
}
}
if len(records) == 0 {
records, matchedRouteCode = sharedRecordManager.FindRecords(domain.Id, routeCodes, recordName, recordType, false)
}
// 对 NS.example.com NS|SOA 处理
if (question.Qtype == dns.TypeNS || (question.Qtype == dns.TypeSOA && len(records) == 0)) && lists.ContainsString(dnsNodeConfig().Hosts, fullName) {
var recordDNSType string
switch question.Qtype {
case dns.TypeNS:
recordDNSType = dnsconfigs.RecordTypeNS
case dns.TypeSOA:
recordDNSType = dnsconfigs.RecordTypeSOA
}
this.composeSOAAnswer(question, &models.NSRecord{
Type: recordDNSType,
Ttl: 600,
}, resp)
}
if len(records) > 0 {
var firstRecord = records[0]
for _, record := range records {
resultRecordIds = append(resultRecordIds, [2]int64{record.DomainId, record.Id})
switch record.Type {
case dnsconfigs.RecordTypeA:
var answer = record.ToRRAnswer(question.Name, question.Qclass)
if answer != nil {
resp.Answer = append(resp.Answer, answer)
}
case dnsconfigs.RecordTypeCNAME:
var value = record.Value
if !strings.HasSuffix(value, ".") {
value += "."
}
var lastRecordValue = value
resp.Answer = append(resp.Answer, &dns.CNAME{
Hdr: record.ToRRHeader(question.Name, dns.TypeCNAME, question.Qclass),
Target: value,
})
// 继续查询CNAME
var allCNAMEValues = []string{lastRecordValue}
for {
// 限制最深32层
if len(allCNAMEValues) > 32 {
break
}
cnameDomain, cnameRecordName := sharedDomainManager.SplitDomain(lastRecordValue)
if cnameDomain == nil {
break
}
cnameRecords, _ := sharedRecordManager.FindRecords(cnameDomain.Id, sharedRouteManager.FindRouteCodes(clientIP, cnameDomain.UserId), cnameRecordName, dnsconfigs.RecordTypeCNAME, false)
if len(cnameRecords) == 0 {
break
}
var cnameRecord = cnameRecords[0]
if !lists.ContainsString(allCNAMEValues, cnameRecord.Value) {
resultRecordIds = append(resultRecordIds, [2]int64{cnameRecord.DomainId, cnameRecord.Id}) // 统计
var answer = cnameRecord.ToRRAnswer(lastRecordValue, question.Qclass)
if answer == nil {
break
}
resp.Answer = append(resp.Answer, answer)
lastRecordValue = cnameRecord.Value
allCNAMEValues = append(allCNAMEValues, lastRecordValue)
} else {
break
}
}
// 再次查询原始问题
if len(req.Question) > 0 {
var firstQuestion = req.Question[0]
if firstQuestion.Qtype != dns.TypeCNAME {
finalDomain, finalRecordName := sharedDomainManager.SplitDomain(lastRecordValue)
if finalDomain != nil {
var realRecords, _ = sharedRecordManager.FindRecords(finalDomain.Id, sharedRouteManager.FindRouteCodes(clientIP, finalDomain.UserId), finalRecordName, dns.Type(firstQuestion.Qtype).String(), false)
if len(realRecords) > 0 {
for _, realRecord := range realRecords {
resultRecordIds = append(resultRecordIds, [2]int64{realRecord.DomainId, realRecord.Id}) // 统计
var answer = realRecord.ToRRAnswer(lastRecordValue, question.Qclass)
if answer != nil {
resp.Answer = append(resp.Answer, answer)
}
}
}
}
}
}
case dnsconfigs.RecordTypeAAAA:
var answer = record.ToRRAnswer(question.Name, question.Qclass)
if answer != nil {
resp.Answer = append(resp.Answer, answer)
}
case dnsconfigs.RecordTypeNS:
if record.Id == 0 {
var hosts = dnsNodeConfig().Hosts
var l = len(hosts)
if l > 0 {
// 随机
var indexes = []int{}
for i := 0; i < l; i++ {
indexes = append(indexes, i)
}
rand.Shuffle(l, func(i, j int) {
indexes[i], indexes[j] = indexes[j], indexes[i]
})
record.Value = hosts[0] + "."
for _, index := range indexes {
resp.Answer = append(resp.Answer, &dns.NS{
Hdr: record.ToRRHeader(question.Name, dns.TypeNS, question.Qclass),
Ns: hosts[index] + ".",
})
}
}
} else {
var value = record.Value
if !strings.HasSuffix(value, ".") {
value += "."
}
resp.Answer = append(resp.Answer, &dns.NS{
Hdr: record.ToRRHeader(question.Name, dns.TypeNS, question.Qclass),
Ns: value,
})
}
case dnsconfigs.RecordTypeMX:
var answer = record.ToRRAnswer(question.Name, question.Qclass)
if answer != nil {
resp.Answer = append(resp.Answer, answer)
}
case dnsconfigs.RecordTypeSRV:
var answer = record.ToRRAnswer(question.Name, question.Qclass)
if answer != nil {
resp.Answer = append(resp.Answer, answer)
}
case dnsconfigs.RecordTypeTXT:
var answer = record.ToRRAnswer(question.Name, question.Qclass)
if answer != nil {
resp.Answer = append(resp.Answer, answer)
}
case dnsconfigs.RecordTypeCAA:
var answer = record.ToRRAnswer(question.Name, question.Qclass)
if answer != nil {
resp.Answer = append(resp.Answer, answer)
}
case dnsconfigs.RecordTypeSOA:
this.composeSOAAnswer(question, record, resp)
}
}
// 访问日志
this.addLog(networking, question, resultDomainId, matchedRouteCode, firstRecord, false, writer, remoteAddr, nil)
} else {
this.addLog(networking, question, resultDomainId, "", nil, false, writer, remoteAddr, nil)
}
}
resp.Rcode = dns.RcodeSuccess
err := writer.WriteMsg(resp)
if err != nil {
return
}
// 统计
for _, resultRecordId := range resultRecordIds {
stats.SharedManager.Add(resultRecordId[0], resultRecordId[1], int64(resp.Len()))
}
}
// 处理HTTP请求
// 参考https://datatracker.ietf.org/doc/html/rfc8484
// 参考https://developers.google.com/speed/public-dns/docs/doh
func (this *Server) handleHTTP(writer http.ResponseWriter, req *http.Request) {
if req.URL.Path == "/dns-query" {
this.handleHTTPDNSMessage(writer, req)
return
}
if req.URL.Path == "/resolve" {
this.handleHTTPJSONAPI(writer, req)
return
}
writer.WriteHeader(http.StatusNotFound)
}
func (this *Server) handleHTTPDNSMessage(writer http.ResponseWriter, req *http.Request) {
const maxMessageSize = 512
writer.Header().Set("Accept", "application/dns-message")
if req.Method != http.MethodGet && req.Method != http.MethodPost {
writer.WriteHeader(http.StatusNotImplemented)
return
}
if req.ContentLength > maxMessageSize {
writer.WriteHeader(http.StatusRequestEntityTooLarge)
return
}
var messageData []byte
switch req.Method {
case http.MethodGet:
if len(req.URL.RawQuery) > maxMessageSize {
writer.WriteHeader(http.StatusRequestURITooLong)
return
}
var encodedMessage = req.URL.Query().Get("dns")
var err error
messageData, err = base64.StdEncoding.DecodeString(encodedMessage)
if err != nil {
writer.WriteHeader(http.StatusBadRequest)
return
}
case http.MethodPost:
var contentType = req.Header.Get("Content-Type")
if contentType != "application/dns-message" {
writer.WriteHeader(http.StatusUnsupportedMediaType)
return
}
data, err := io.ReadAll(io.LimitReader(req.Body, maxMessageSize))
if err != nil {
writer.WriteHeader(http.StatusBadRequest)
return
}
messageData = data
}
if len(messageData) == 0 {
writer.WriteHeader(http.StatusBadRequest)
return
}
var msg = &dns.Msg{}
err := msg.Unpack(messageData)
if err != nil {
writer.WriteHeader(http.StatusBadRequest)
return
}
var connValue = req.Context().Value(HTTPConnContextKey)
if connValue == nil {
writer.WriteHeader(http.StatusInternalServerError)
return
}
conn, ok := connValue.(net.Conn)
if !ok {
writer.WriteHeader(http.StatusInternalServerError)
return
}
this.handleDNSMessage(NewHTTPWriter(writer, conn, "application/dns-message"), msg)
}
func (this *Server) handleHTTPJSONAPI(writer http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodGet {
writer.WriteHeader(http.StatusNotImplemented)
return
}
var query = req.URL.Query()
var name = strings.TrimSpace(query.Get("name"))
var recordTypeString = strings.ToUpper(strings.TrimSpace(query.Get("type")))
if len(name) == 0 {
writer.WriteHeader(http.StatusBadRequest)
_, _ = writer.Write([]byte("invalid 'name' parameter"))
return
}
// add '.' to name
if !strings.HasSuffix(name, ".") {
name += "."
}
if len(recordTypeString) == 0 {
writer.WriteHeader(http.StatusBadRequest)
_, _ = writer.Write([]byte("invalid 'type' parameter"))
return
}
var recordType uint16
if regexp.MustCompile(`^\d{1,4}$`).MatchString(recordTypeString) {
recordType = types.Uint16(recordTypeString)
} else {
recordType = dns.StringToType[recordTypeString]
}
if recordType == 0 {
writer.WriteHeader(http.StatusBadRequest)
_, _ = writer.Write([]byte("invalid 'type' parameter"))
return
}
_, ok := dns.TypeToString[recordType]
if !ok {
writer.WriteHeader(http.StatusBadRequest)
_, _ = writer.Write([]byte("invalid 'type' parameter"))
return
}
var msg = &dns.Msg{}
msg.Question = []dns.Question{
{
Name: name,
Qtype: recordType,
Qclass: dns.ClassINET,
},
}
var connValue = req.Context().Value(HTTPConnContextKey)
if connValue == nil {
writer.WriteHeader(http.StatusInternalServerError)
return
}
// conn
conn, ok := connValue.(net.Conn)
if !ok {
writer.WriteHeader(http.StatusInternalServerError)
return
}
this.handleDNSMessage(NewHTTPWriter(writer, conn, "application/x-javascript"), msg)
}
// 组合SOA回复信息
func (this *Server) composeSOAAnswer(question dns.Question, record *models.NSRecord, resp *dns.Msg) {
var nodeCfg = dnsNodeConfig()
var config = nodeCfg.SOA
var serial = nodeCfg.SOASerial
if config == nil {
config = dnsconfigs.DefaultNSSOAConfig()
}
var mName = config.MName
if len(mName) == 0 {
var hosts = nodeCfg.Hosts
var l = len(hosts)
if l > 0 {
var index = rands.Int(0, l-1)
mName = hosts[index]
}
}
var rName = config.RName
if len(rName) == 0 {
rName = nodeCfg.Email
}
rName = strings.ReplaceAll(rName, "@", ".")
if len(mName) > 0 && len(rName) > 0 {
// 设置记录值
record.Value = mName + "."
resp.Answer = append(resp.Answer, &dns.SOA{
Hdr: record.ToRRHeader(question.Name, dns.TypeSOA, question.Qclass),
Ns: mName + ".",
Mbox: rName + ".",
Serial: serial,
Refresh: config.RefreshSeconds,
Retry: config.RetrySeconds,
Expire: config.ExpireSeconds,
Minttl: config.MinimumTTL,
})
}
}