1031 lines
27 KiB
Go
1031 lines
27 KiB
Go
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||
|
||
package nodes
|
||
|
||
import (
|
||
"context"
|
||
"crypto/tls"
|
||
"encoding/base64"
|
||
"encoding/json"
|
||
"errors"
|
||
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
|
||
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
|
||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||
"github.com/TeaOSLab/EdgeDNS/internal/agents"
|
||
"github.com/TeaOSLab/EdgeDNS/internal/models"
|
||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
|
||
"github.com/TeaOSLab/EdgeDNS/internal/stats"
|
||
"github.com/iwind/TeaGo/lists"
|
||
"github.com/iwind/TeaGo/maps"
|
||
"github.com/iwind/TeaGo/rands"
|
||
"github.com/iwind/TeaGo/types"
|
||
"github.com/miekg/dns"
|
||
"io"
|
||
"math/rand"
|
||
"net"
|
||
"net/http"
|
||
"regexp"
|
||
"strings"
|
||
"time"
|
||
)
|
||
|
||
var sharedRecursionDNSClient = &dns.Client{
|
||
Timeout: 5 * time.Second,
|
||
ReadTimeout: 3 * time.Second,
|
||
WriteTimeout: 2 * time.Second,
|
||
}
|
||
|
||
type httpContextKey struct {
|
||
key string
|
||
}
|
||
|
||
var HTTPConnContextKey = &httpContextKey{key: "http-conn"}
|
||
|
||
const PingDomain = "ping."
|
||
|
||
// Server 服务
|
||
type Server struct {
|
||
config *ServerConfig
|
||
|
||
rawServer *dns.Server
|
||
httpsServer *http.Server
|
||
}
|
||
|
||
// NewServer 构造新服务
|
||
func NewServer(config *ServerConfig) (*Server, error) {
|
||
var server = &Server{
|
||
config: config,
|
||
}
|
||
err := server.init()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return server, nil
|
||
}
|
||
|
||
// ListenAndServe 监听
|
||
func (this *Server) ListenAndServe() error {
|
||
if this.rawServer != nil {
|
||
return this.rawServer.ListenAndServe()
|
||
}
|
||
if this.httpsServer != nil {
|
||
listener, err := net.Listen("tcp", this.httpsServer.Addr)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
err = this.httpsServer.ServeTLS(listener, "", "")
|
||
if err == http.ErrServerClosed {
|
||
err = nil
|
||
}
|
||
return err
|
||
}
|
||
|
||
return errors.New("the server is not initialized")
|
||
}
|
||
|
||
// Shutdown 关闭
|
||
func (this *Server) Shutdown() error {
|
||
if this.rawServer != nil {
|
||
return this.rawServer.Shutdown()
|
||
}
|
||
if this.httpsServer != nil {
|
||
return this.httpsServer.Shutdown(context.Background())
|
||
}
|
||
|
||
return errors.New("the server is not initialized")
|
||
}
|
||
|
||
// Reload 重载配置
|
||
func (this *Server) Reload(config *ServerConfig) {
|
||
this.config = config
|
||
}
|
||
|
||
// 初始化
|
||
func (this *Server) init() error {
|
||
var rawServer = &dns.Server{
|
||
ReadTimeout: 3 * time.Second,
|
||
WriteTimeout: 3 * time.Second,
|
||
}
|
||
|
||
rawServer.Handler = dns.HandlerFunc(this.handleDNSMessage)
|
||
var addr = ""
|
||
if len(this.config.Host) > 0 {
|
||
addr += configutils.QuoteIP(this.config.Host)
|
||
}
|
||
addr += ":" + types.String(this.config.Port)
|
||
rawServer.Addr = addr
|
||
|
||
switch this.config.Protocol {
|
||
case serverconfigs.ProtocolTCP:
|
||
rawServer.Net = "tcp"
|
||
case serverconfigs.ProtocolTLS:
|
||
rawServer.Net = "tcp-tls"
|
||
rawServer.TLSConfig = &tls.Config{
|
||
Certificates: nil,
|
||
GetConfigForClient: func(clientInfo *tls.ClientHelloInfo) (config *tls.Config, e error) {
|
||
return this.config.SSLPolicy.TLSConfig(), nil
|
||
},
|
||
GetCertificate: func(clientInfo *tls.ClientHelloInfo) (certificate *tls.Certificate, e error) {
|
||
return this.config.SSLPolicy.FirstCert(), nil
|
||
},
|
||
}
|
||
case serverconfigs.ProtocolHTTPS: // DoH
|
||
rawServer = nil
|
||
this.httpsServer = &http.Server{
|
||
Addr: addr,
|
||
Handler: http.HandlerFunc(this.handleHTTP),
|
||
TLSConfig: &tls.Config{
|
||
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
|
||
if this.config == nil {
|
||
return nil, errors.New("invalid 'ServerConfig.config'")
|
||
}
|
||
if this.config.SSLPolicy == nil {
|
||
return nil, errors.New("invalid 'ServerConfig.config.SSLPolicy'")
|
||
}
|
||
return this.config.SSLPolicy.TLSConfig(), nil
|
||
},
|
||
GetCertificate: func(clientInfo *tls.ClientHelloInfo) (certificate *tls.Certificate, e error) {
|
||
if this.config == nil {
|
||
return nil, errors.New("invalid 'ServerConfig.config'")
|
||
}
|
||
if this.config.SSLPolicy == nil {
|
||
return nil, errors.New("invalid 'ServerConfig.config.SSLPolicy'")
|
||
}
|
||
return this.config.SSLPolicy.FirstCert(), nil
|
||
},
|
||
},
|
||
ConnContext: func(ctx context.Context, c net.Conn) context.Context {
|
||
return context.WithValue(ctx, HTTPConnContextKey, c)
|
||
},
|
||
ReadTimeout: 5 * time.Second,
|
||
ReadHeaderTimeout: 3 * time.Second,
|
||
WriteTimeout: 3 * time.Second,
|
||
IdleTimeout: 75 * time.Second,
|
||
MaxHeaderBytes: 4096,
|
||
}
|
||
case serverconfigs.ProtocolUDP:
|
||
rawServer.Net = "udp"
|
||
}
|
||
|
||
this.rawServer = rawServer
|
||
|
||
return nil
|
||
}
|
||
|
||
// addECSOption 向 DNS 请求中添加 EDNS Client Subnet (ECS) 信息
|
||
// addECSOption 向 DNS 请求中设置 EDNS Client Subnet (ECS)。
|
||
// 如果请求已携带 ECS 则覆盖(避免双 ECS 导致上游 malformed request)。
|
||
func addECSOption(req *dns.Msg, clientIP string) {
|
||
if len(clientIP) == 0 {
|
||
return
|
||
}
|
||
|
||
ip := net.ParseIP(clientIP)
|
||
if ip == nil {
|
||
return
|
||
}
|
||
|
||
var ecs = &dns.EDNS0_SUBNET{
|
||
Code: dns.EDNS0SUBNET,
|
||
}
|
||
if ip.To4() != nil {
|
||
ecs.Family = 1 // IPv4
|
||
ecs.SourceNetmask = 24
|
||
ecs.Address = ip.To4()
|
||
} else {
|
||
ecs.Family = 2 // IPv6
|
||
ecs.SourceNetmask = 56
|
||
ecs.Address = ip
|
||
}
|
||
|
||
// 查找或创建 OPT 记录
|
||
var opt = req.IsEdns0()
|
||
if opt == nil {
|
||
req.SetEdns0(4096, false)
|
||
opt = req.IsEdns0()
|
||
}
|
||
if opt != nil {
|
||
// 删除已有的 ECS option,避免出现双 EDNS0_SUBNET
|
||
var filtered []dns.EDNS0
|
||
for _, o := range opt.Option {
|
||
if o.Option() != dns.EDNS0SUBNET {
|
||
filtered = append(filtered, o)
|
||
}
|
||
}
|
||
opt.Option = append(filtered, ecs)
|
||
}
|
||
}
|
||
|
||
// stripECSFromExtra 从 Extra section 中移除 OPT 记录里的 EDNS0_SUBNET,
|
||
// 防止服务端注入的 ECS 信息回传给下游客户端(隐私泄露风险)。
|
||
func stripECSFromExtra(extra []dns.RR) []dns.RR {
|
||
for _, rr := range extra {
|
||
if opt, ok := rr.(*dns.OPT); ok {
|
||
var filtered []dns.EDNS0
|
||
for _, o := range opt.Option {
|
||
if o.Option() != dns.EDNS0SUBNET {
|
||
filtered = append(filtered, o)
|
||
}
|
||
}
|
||
opt.Option = filtered
|
||
}
|
||
}
|
||
return extra
|
||
}
|
||
|
||
// 查询递归DNS
|
||
func (this *Server) lookupRecursionDNS(req *dns.Msg, resp *dns.Msg, clientIP string) error {
|
||
var config = dnsNodeConfig().RecursionConfig
|
||
if config == nil {
|
||
return nil
|
||
}
|
||
if !config.IsOn {
|
||
return nil
|
||
}
|
||
|
||
// 是否允许
|
||
if len(req.Question) == 0 {
|
||
return nil
|
||
}
|
||
var domain = strings.TrimSuffix(req.Question[0].Name, ".")
|
||
if len(config.DenyDomains) > 0 && configutils.MatchDomains(config.DenyDomains, domain) {
|
||
return nil
|
||
}
|
||
if len(config.AllowDomains) > 0 && !configutils.MatchDomains(config.AllowDomains, domain) {
|
||
return nil
|
||
}
|
||
|
||
// 携带客户端真实 IP(ECS)向上游查询
|
||
addECSOption(req, clientIP)
|
||
|
||
if config.UseLocalHosts {
|
||
// TODO 需要缓存文件内容
|
||
resolveConfig, err := dns.ClientConfigFromFile("/etc/resolv.conf")
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if len(resolveConfig.Servers) == 0 {
|
||
return errors.New("no dns servers found in config file")
|
||
}
|
||
if len(resolveConfig.Port) == 0 {
|
||
resolveConfig.Port = "53"
|
||
}
|
||
r, _, err := sharedRecursionDNSClient.Exchange(req, configutils.QuoteIP(resolveConfig.Servers[rands.Int(0, len(resolveConfig.Servers)-1)])+":"+resolveConfig.Port)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if r != nil {
|
||
resp.Rcode = r.Rcode
|
||
resp.Answer = r.Answer
|
||
resp.Ns = r.Ns
|
||
resp.Extra = stripECSFromExtra(r.Extra)
|
||
}
|
||
} else if len(config.Hosts) > 0 {
|
||
var host = config.Hosts[rands.Int(0, len(config.Hosts)-1)]
|
||
if host.Port <= 0 {
|
||
host.Port = 53
|
||
}
|
||
r, _, err := sharedRecursionDNSClient.Exchange(req, configutils.QuoteIP(host.Host)+":"+types.String(host.Port))
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if r != nil {
|
||
resp.Rcode = r.Rcode
|
||
resp.Answer = r.Answer
|
||
resp.Ns = r.Ns
|
||
resp.Extra = stripECSFromExtra(r.Extra)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// 分析查询中的动作
|
||
func (this *Server) parseAction(questionName string, remoteAddr *string) (string, error) {
|
||
rpcClient, err := rpc.SharedRPC()
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
// TODO 需要防止恶意攻击
|
||
var optionIndex = strings.Index(questionName, "-")
|
||
if optionIndex > 0 {
|
||
optionId := types.Int64(questionName[1:optionIndex])
|
||
optionResp, err := rpcClient.NSQuestionOptionRPC.FindNSQuestionOption(rpcClient.Context(), &pb.FindNSQuestionOptionRequest{NsQuestionOptionId: optionId})
|
||
if err != nil {
|
||
return "", errors.New("query question option failed: " + err.Error())
|
||
} else {
|
||
var option = optionResp.NsQuestionOption
|
||
if option != nil {
|
||
switch option.Name {
|
||
case "setRemoteAddr":
|
||
var m = maps.Map{}
|
||
err = json.Unmarshal(option.ValuesJSON, &m)
|
||
if err != nil {
|
||
return "", errors.New("decode question option failed: " + err.Error())
|
||
} else {
|
||
var ip = m.GetString("ip")
|
||
// 验证 IP 地址合法性,防止 IP 欺骗
|
||
parsedIP := net.ParseIP(ip)
|
||
if parsedIP == nil {
|
||
return "", errors.New("invalid IP address in setRemoteAddr: " + ip)
|
||
}
|
||
// 拒绝回环地址和未指定地址
|
||
if parsedIP.IsLoopback() || parsedIP.IsUnspecified() {
|
||
return "", errors.New("disallowed IP address in setRemoteAddr: " + ip)
|
||
}
|
||
*remoteAddr = ip
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
questionName = questionName[optionIndex+1:]
|
||
}
|
||
return questionName, nil
|
||
}
|
||
|
||
// 记录日志
|
||
func (this *Server) addLog(networking string, question dns.Question, domainId int64, routeCode string, record *models.NSRecord, isRecursive bool, writer dns.ResponseWriter, remoteAddr string, err error) {
|
||
// 访问日志
|
||
var nodeConfig = dnsNodeConfig()
|
||
var accessLogRef = nodeConfig.AccessLogRef
|
||
if accessLogRef != nil && accessLogRef.IsOn {
|
||
if domainId == 0 && !accessLogRef.LogMissingDomains {
|
||
return
|
||
}
|
||
|
||
if accessLogRef.MissingRecordsOnly && record != nil && len(record.Value) > 0 {
|
||
return
|
||
}
|
||
|
||
var now = time.Now()
|
||
var pbAccessLog = &pb.NSAccessLog{
|
||
NsNodeId: nodeConfig.Id,
|
||
RemoteAddr: remoteAddr,
|
||
NsDomainId: domainId,
|
||
QuestionName: question.Name,
|
||
QuestionType: dns.Type(question.Qtype).String(),
|
||
IsRecursive: isRecursive,
|
||
Networking: networking,
|
||
ServerAddr: writer.LocalAddr().String(),
|
||
Timestamp: now.Unix(),
|
||
TimeLocal: now.Format("2/Jan/2006:15:04:05 -0700"),
|
||
RequestId: "",
|
||
}
|
||
if record != nil {
|
||
pbAccessLog.NsRecordId = record.Id
|
||
if len(routeCode) > 0 {
|
||
pbAccessLog.NsRouteCodes = []string{routeCode}
|
||
}
|
||
pbAccessLog.RecordName = record.Name
|
||
pbAccessLog.RecordType = record.Type
|
||
pbAccessLog.RecordValue = record.Value
|
||
}
|
||
if err != nil {
|
||
pbAccessLog.Error = err.Error()
|
||
}
|
||
sharedNSAccessLogQueue.Push(pbAccessLog)
|
||
}
|
||
}
|
||
|
||
// 验证TSIG
|
||
func (this *Server) checkTSIG(msg *dns.Msg, domainId int64) error {
|
||
var tsig = msg.IsTsig()
|
||
if tsig == nil {
|
||
return errors.New("tsig: tsig required")
|
||
}
|
||
|
||
var keys = sharedKeyManager.FindKeysWithDomain(domainId)
|
||
if len(keys) == 0 {
|
||
return errors.New("tsig: no keys defined")
|
||
}
|
||
|
||
for _, key := range keys {
|
||
if key.Algo != tsig.Algorithm {
|
||
continue
|
||
}
|
||
|
||
// 需要重新Pack,每次Pack结果只能校验一次
|
||
msgData, err := msg.Pack()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if len(msgData) == 0 {
|
||
return nil
|
||
}
|
||
|
||
var base64Secret = key.Secret
|
||
if key.SecretType == dnsconfigs.NSKeySecretTypeClear {
|
||
base64Secret = base64.StdEncoding.EncodeToString([]byte(key.Secret))
|
||
}
|
||
err = dns.TsigVerify(msgData, base64Secret, "", false)
|
||
if err != nil {
|
||
continue
|
||
} else {
|
||
return nil
|
||
}
|
||
}
|
||
return dns.ErrSig
|
||
}
|
||
|
||
// 处理DNS请求
|
||
func (this *Server) handleDNSMessage(writer dns.ResponseWriter, req *dns.Msg) {
|
||
if len(req.Question) == 0 {
|
||
return
|
||
}
|
||
|
||
if sharedDomainManager == nil {
|
||
return
|
||
}
|
||
|
||
var networking = ""
|
||
if this.config != nil {
|
||
networking = this.config.Protocol.String()
|
||
}
|
||
|
||
var resultDomainId int64
|
||
var resultRecordIds [][2]int64 // [] { domainId, recordId}
|
||
|
||
var resp = new(dns.Msg)
|
||
resp.RecursionDesired = true
|
||
resp.RecursionAvailable = true
|
||
resp.SetReply(req)
|
||
|
||
resp.Answer = []dns.RR{}
|
||
|
||
var tsigIsChecked = false
|
||
var remoteAddr = writer.RemoteAddr().String()
|
||
|
||
for _, question := range req.Question {
|
||
if len(question.Name) == 0 {
|
||
continue
|
||
}
|
||
|
||
// PING
|
||
if question.Name == PingDomain {
|
||
resp.Answer = append(resp.Answer, &dns.A{
|
||
Hdr: dns.RR_Header{
|
||
Name: question.Name,
|
||
Rrtype: dns.TypeA,
|
||
Class: question.Qclass,
|
||
Ttl: 60,
|
||
},
|
||
A: net.ParseIP("127.0.0.1"),
|
||
})
|
||
resp.Rcode = dns.RcodeSuccess
|
||
err := writer.WriteMsg(resp)
|
||
if err != nil {
|
||
return
|
||
}
|
||
return
|
||
}
|
||
|
||
// 查询选项
|
||
if question.Name[0] == '$' {
|
||
_, port, _ := net.SplitHostPort(remoteAddr)
|
||
|
||
questionName, err := this.parseAction(question.Name, &remoteAddr)
|
||
if err != nil {
|
||
remotelogs.Error("SERVER", "invalid query option '"+question.Name+"'")
|
||
continue
|
||
}
|
||
question.Name = questionName
|
||
if len(port) > 0 {
|
||
if strings.Contains(remoteAddr, ":") { // IPv6
|
||
remoteAddr = "[" + remoteAddr + "]:" + port
|
||
} else {
|
||
remoteAddr += ":" + port
|
||
}
|
||
}
|
||
}
|
||
|
||
var fullName = strings.TrimSuffix(question.Name, ".")
|
||
var recordName string
|
||
var recordType = dns.Type(question.Qtype).String()
|
||
var domain *models.NSDomain
|
||
domain, recordName = sharedDomainManager.SplitDomain(fullName)
|
||
if domain == nil {
|
||
// 检查递归DNS
|
||
var recursionConfig = dnsNodeConfig().RecursionConfig
|
||
if recursionConfig != nil && recursionConfig.IsOn {
|
||
// 提取客户端 IP 用于 ECS
|
||
var clientIP = remoteAddr
|
||
if clientHost, _, splitErr := net.SplitHostPort(clientIP); splitErr == nil && len(clientHost) > 0 {
|
||
clientIP = clientHost
|
||
}
|
||
err := this.lookupRecursionDNS(req, resp, clientIP)
|
||
if err != nil {
|
||
this.addLog(networking, question, 0, "", nil, true, writer, remoteAddr, err)
|
||
} else {
|
||
var recordValue = ""
|
||
if len(resp.Answer) > 0 {
|
||
pieces := regexp.MustCompile(`\s+`).Split(resp.Answer[0].String(), 6)
|
||
if len(pieces) >= 5 {
|
||
recordValue = pieces[4]
|
||
}
|
||
}
|
||
this.addLog(networking, question, 0, "", &models.NSRecord{
|
||
Id: 0,
|
||
Name: recordName,
|
||
Type: recordType,
|
||
Value: recordValue,
|
||
Ttl: 0,
|
||
}, true, writer, remoteAddr, nil)
|
||
}
|
||
|
||
err = writer.WriteMsg(resp)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
// 是否为NS记录,用于验证域名所有权
|
||
if question.Qtype == dns.TypeNS {
|
||
var hosts = dnsNodeConfig().Hosts
|
||
var l = len(hosts)
|
||
var record = &models.NSRecord{
|
||
Id: 0,
|
||
Type: dnsconfigs.RecordTypeNS,
|
||
Ttl: 600, // TODO 可以设置
|
||
}
|
||
if l > 0 {
|
||
l = 1 // 目前只返回一个
|
||
|
||
// 随机
|
||
var indexes = []int{}
|
||
for i := 0; i < l; i++ {
|
||
indexes = append(indexes, i)
|
||
}
|
||
|
||
rand.Shuffle(l, func(i, j int) {
|
||
indexes[i], indexes[j] = indexes[j], indexes[i]
|
||
})
|
||
|
||
record.Value = hosts[0] + "."
|
||
for _, index := range indexes {
|
||
resp.Answer = append(resp.Answer, &dns.NS{
|
||
Hdr: record.ToRRHeader(question.Name, dns.TypeNS, question.Qclass),
|
||
Ns: hosts[index] + ".",
|
||
})
|
||
}
|
||
|
||
this.addLog(networking, question, 0, "", record, false, writer, remoteAddr, nil)
|
||
continue
|
||
}
|
||
|
||
this.addLog(networking, question, 0, "", nil, false, writer, remoteAddr, nil)
|
||
continue
|
||
}
|
||
|
||
this.addLog(networking, question, 0, "", nil, false, writer, remoteAddr, nil)
|
||
continue
|
||
}
|
||
|
||
// 检查TSIG
|
||
if domain.TSIG != nil && domain.TSIG.IsOn && !tsigIsChecked {
|
||
err := this.checkTSIG(req, domain.Id)
|
||
if err != nil {
|
||
this.addLog(networking, question, domain.Id, "", nil, false, writer, remoteAddr, err)
|
||
continue
|
||
}
|
||
tsigIsChecked = true
|
||
}
|
||
|
||
resultDomainId = domain.Id
|
||
|
||
var clientIP = remoteAddr
|
||
clientHost, _, err := net.SplitHostPort(clientIP)
|
||
if err == nil && len(clientHost) > 0 {
|
||
clientIP = clientHost
|
||
}
|
||
|
||
// 解析Agent
|
||
if dnsNodeConfig().DetectAgents {
|
||
agents.SharedQueue.Push(clientIP)
|
||
}
|
||
|
||
var routeCodes = sharedRouteManager.FindRouteCodes(clientIP, domain.UserId)
|
||
|
||
var records []*models.NSRecord
|
||
var matchedRouteCode string
|
||
if question.Qtype == dns.TypeSOA { // SOA
|
||
if len(recordName) == 0 { // 只有顶级域名才有SOA记录
|
||
records = []*models.NSRecord{
|
||
{
|
||
Id: 0,
|
||
Type: dnsconfigs.RecordTypeSOA,
|
||
Ttl: 600, // TODO 可以设置
|
||
},
|
||
}
|
||
}
|
||
} else if question.Qtype == dns.TypeNS { // NS
|
||
if len(recordName) == 0 { // 只有顶级域名才有NS记录
|
||
records = []*models.NSRecord{
|
||
{
|
||
Id: 0,
|
||
Type: dnsconfigs.RecordTypeNS,
|
||
Ttl: 600, // TODO 可以设置
|
||
},
|
||
}
|
||
}
|
||
} else if question.Qtype != dns.TypeCNAME {
|
||
// 是否有直接的设置
|
||
records, matchedRouteCode = sharedRecordManager.FindRecords(domain.Id, routeCodes, recordName, recordType, true)
|
||
|
||
// 检查CNAME
|
||
if len(records) == 0 {
|
||
records, matchedRouteCode = sharedRecordManager.FindRecords(domain.Id, routeCodes, recordName, dnsconfigs.RecordTypeCNAME, false)
|
||
if len(records) > 0 {
|
||
question.Qtype = dns.TypeCNAME
|
||
}
|
||
}
|
||
|
||
// 再次尝试查找默认设置
|
||
if len(records) == 0 {
|
||
records, matchedRouteCode = sharedRecordManager.FindRecords(domain.Id, routeCodes, recordName, recordType, false)
|
||
}
|
||
}
|
||
|
||
if len(records) == 0 {
|
||
records, matchedRouteCode = sharedRecordManager.FindRecords(domain.Id, routeCodes, recordName, recordType, false)
|
||
}
|
||
|
||
// 对 NS.example.com NS|SOA 处理
|
||
if (question.Qtype == dns.TypeNS || (question.Qtype == dns.TypeSOA && len(records) == 0)) && lists.ContainsString(dnsNodeConfig().Hosts, fullName) {
|
||
var recordDNSType string
|
||
switch question.Qtype {
|
||
case dns.TypeNS:
|
||
recordDNSType = dnsconfigs.RecordTypeNS
|
||
case dns.TypeSOA:
|
||
recordDNSType = dnsconfigs.RecordTypeSOA
|
||
}
|
||
this.composeSOAAnswer(question, &models.NSRecord{
|
||
Type: recordDNSType,
|
||
Ttl: 600,
|
||
}, resp)
|
||
}
|
||
|
||
if len(records) > 0 {
|
||
var firstRecord = records[0]
|
||
|
||
for _, record := range records {
|
||
resultRecordIds = append(resultRecordIds, [2]int64{record.DomainId, record.Id})
|
||
|
||
switch record.Type {
|
||
case dnsconfigs.RecordTypeA:
|
||
var answer = record.ToRRAnswer(question.Name, question.Qclass)
|
||
if answer != nil {
|
||
resp.Answer = append(resp.Answer, answer)
|
||
}
|
||
case dnsconfigs.RecordTypeCNAME:
|
||
var value = record.Value
|
||
if !strings.HasSuffix(value, ".") {
|
||
value += "."
|
||
}
|
||
var lastRecordValue = value
|
||
resp.Answer = append(resp.Answer, &dns.CNAME{
|
||
Hdr: record.ToRRHeader(question.Name, dns.TypeCNAME, question.Qclass),
|
||
Target: value,
|
||
})
|
||
|
||
// 继续查询CNAME
|
||
var allCNAMEValues = []string{lastRecordValue}
|
||
for {
|
||
// 限制最深32层
|
||
if len(allCNAMEValues) > 32 {
|
||
break
|
||
}
|
||
|
||
cnameDomain, cnameRecordName := sharedDomainManager.SplitDomain(lastRecordValue)
|
||
if cnameDomain == nil {
|
||
break
|
||
}
|
||
cnameRecords, _ := sharedRecordManager.FindRecords(cnameDomain.Id, sharedRouteManager.FindRouteCodes(clientIP, cnameDomain.UserId), cnameRecordName, dnsconfigs.RecordTypeCNAME, false)
|
||
if len(cnameRecords) == 0 {
|
||
break
|
||
}
|
||
var cnameRecord = cnameRecords[0]
|
||
if !lists.ContainsString(allCNAMEValues, cnameRecord.Value) {
|
||
resultRecordIds = append(resultRecordIds, [2]int64{cnameRecord.DomainId, cnameRecord.Id}) // 统计
|
||
var answer = cnameRecord.ToRRAnswer(lastRecordValue, question.Qclass)
|
||
if answer == nil {
|
||
break
|
||
}
|
||
resp.Answer = append(resp.Answer, answer)
|
||
lastRecordValue = cnameRecord.Value
|
||
allCNAMEValues = append(allCNAMEValues, lastRecordValue)
|
||
} else {
|
||
break
|
||
}
|
||
}
|
||
|
||
// 再次查询原始问题
|
||
if len(req.Question) > 0 {
|
||
var firstQuestion = req.Question[0]
|
||
if firstQuestion.Qtype != dns.TypeCNAME {
|
||
finalDomain, finalRecordName := sharedDomainManager.SplitDomain(lastRecordValue)
|
||
if finalDomain != nil {
|
||
var realRecords, _ = sharedRecordManager.FindRecords(finalDomain.Id, sharedRouteManager.FindRouteCodes(clientIP, finalDomain.UserId), finalRecordName, dns.Type(firstQuestion.Qtype).String(), false)
|
||
if len(realRecords) > 0 {
|
||
for _, realRecord := range realRecords {
|
||
resultRecordIds = append(resultRecordIds, [2]int64{realRecord.DomainId, realRecord.Id}) // 统计
|
||
var answer = realRecord.ToRRAnswer(lastRecordValue, question.Qclass)
|
||
if answer != nil {
|
||
resp.Answer = append(resp.Answer, answer)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
case dnsconfigs.RecordTypeAAAA:
|
||
var answer = record.ToRRAnswer(question.Name, question.Qclass)
|
||
if answer != nil {
|
||
resp.Answer = append(resp.Answer, answer)
|
||
}
|
||
case dnsconfigs.RecordTypeNS:
|
||
if record.Id == 0 {
|
||
var hosts = dnsNodeConfig().Hosts
|
||
var l = len(hosts)
|
||
if l > 0 {
|
||
// 随机
|
||
var indexes = []int{}
|
||
for i := 0; i < l; i++ {
|
||
indexes = append(indexes, i)
|
||
}
|
||
|
||
rand.Shuffle(l, func(i, j int) {
|
||
indexes[i], indexes[j] = indexes[j], indexes[i]
|
||
})
|
||
|
||
record.Value = hosts[0] + "."
|
||
for _, index := range indexes {
|
||
resp.Answer = append(resp.Answer, &dns.NS{
|
||
Hdr: record.ToRRHeader(question.Name, dns.TypeNS, question.Qclass),
|
||
Ns: hosts[index] + ".",
|
||
})
|
||
}
|
||
}
|
||
} else {
|
||
var value = record.Value
|
||
if !strings.HasSuffix(value, ".") {
|
||
value += "."
|
||
}
|
||
resp.Answer = append(resp.Answer, &dns.NS{
|
||
Hdr: record.ToRRHeader(question.Name, dns.TypeNS, question.Qclass),
|
||
Ns: value,
|
||
})
|
||
}
|
||
case dnsconfigs.RecordTypeMX:
|
||
var answer = record.ToRRAnswer(question.Name, question.Qclass)
|
||
if answer != nil {
|
||
resp.Answer = append(resp.Answer, answer)
|
||
}
|
||
case dnsconfigs.RecordTypeSRV:
|
||
var answer = record.ToRRAnswer(question.Name, question.Qclass)
|
||
if answer != nil {
|
||
resp.Answer = append(resp.Answer, answer)
|
||
}
|
||
case dnsconfigs.RecordTypeTXT:
|
||
var answer = record.ToRRAnswer(question.Name, question.Qclass)
|
||
if answer != nil {
|
||
resp.Answer = append(resp.Answer, answer)
|
||
}
|
||
case dnsconfigs.RecordTypeCAA:
|
||
var answer = record.ToRRAnswer(question.Name, question.Qclass)
|
||
if answer != nil {
|
||
resp.Answer = append(resp.Answer, answer)
|
||
}
|
||
case dnsconfigs.RecordTypeSOA:
|
||
this.composeSOAAnswer(question, record, resp)
|
||
}
|
||
}
|
||
|
||
// 访问日志
|
||
this.addLog(networking, question, resultDomainId, matchedRouteCode, firstRecord, false, writer, remoteAddr, nil)
|
||
} else {
|
||
this.addLog(networking, question, resultDomainId, "", nil, false, writer, remoteAddr, nil)
|
||
}
|
||
}
|
||
|
||
resp.Rcode = dns.RcodeSuccess
|
||
err := writer.WriteMsg(resp)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
// 统计
|
||
for _, resultRecordId := range resultRecordIds {
|
||
stats.SharedManager.Add(resultRecordId[0], resultRecordId[1], int64(resp.Len()))
|
||
}
|
||
}
|
||
|
||
// 处理HTTP请求
|
||
// 参考:https://datatracker.ietf.org/doc/html/rfc8484
|
||
// 参考:https://developers.google.com/speed/public-dns/docs/doh
|
||
func (this *Server) handleHTTP(writer http.ResponseWriter, req *http.Request) {
|
||
if req.URL.Path == "/dns-query" {
|
||
this.handleHTTPDNSMessage(writer, req)
|
||
return
|
||
}
|
||
|
||
if req.URL.Path == "/resolve" {
|
||
this.handleHTTPJSONAPI(writer, req)
|
||
return
|
||
}
|
||
|
||
writer.WriteHeader(http.StatusNotFound)
|
||
}
|
||
|
||
func (this *Server) handleHTTPDNSMessage(writer http.ResponseWriter, req *http.Request) {
|
||
const maxMessageSize = 512
|
||
|
||
writer.Header().Set("Accept", "application/dns-message")
|
||
|
||
if req.Method != http.MethodGet && req.Method != http.MethodPost {
|
||
writer.WriteHeader(http.StatusNotImplemented)
|
||
return
|
||
}
|
||
|
||
if req.ContentLength > maxMessageSize {
|
||
writer.WriteHeader(http.StatusRequestEntityTooLarge)
|
||
return
|
||
}
|
||
|
||
var messageData []byte
|
||
|
||
switch req.Method {
|
||
case http.MethodGet:
|
||
if len(req.URL.RawQuery) > maxMessageSize {
|
||
writer.WriteHeader(http.StatusRequestURITooLong)
|
||
return
|
||
}
|
||
var encodedMessage = req.URL.Query().Get("dns")
|
||
var err error
|
||
messageData, err = base64.StdEncoding.DecodeString(encodedMessage)
|
||
if err != nil {
|
||
writer.WriteHeader(http.StatusBadRequest)
|
||
return
|
||
}
|
||
case http.MethodPost:
|
||
var contentType = req.Header.Get("Content-Type")
|
||
if contentType != "application/dns-message" {
|
||
writer.WriteHeader(http.StatusUnsupportedMediaType)
|
||
return
|
||
}
|
||
data, err := io.ReadAll(io.LimitReader(req.Body, maxMessageSize))
|
||
if err != nil {
|
||
writer.WriteHeader(http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
messageData = data
|
||
}
|
||
|
||
if len(messageData) == 0 {
|
||
writer.WriteHeader(http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
var msg = &dns.Msg{}
|
||
err := msg.Unpack(messageData)
|
||
if err != nil {
|
||
writer.WriteHeader(http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
var connValue = req.Context().Value(HTTPConnContextKey)
|
||
if connValue == nil {
|
||
writer.WriteHeader(http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
conn, ok := connValue.(net.Conn)
|
||
if !ok {
|
||
writer.WriteHeader(http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
this.handleDNSMessage(NewHTTPWriter(writer, conn, "application/dns-message"), msg)
|
||
}
|
||
|
||
func (this *Server) handleHTTPJSONAPI(writer http.ResponseWriter, req *http.Request) {
|
||
if req.Method != http.MethodGet {
|
||
writer.WriteHeader(http.StatusNotImplemented)
|
||
return
|
||
}
|
||
|
||
var query = req.URL.Query()
|
||
var name = strings.TrimSpace(query.Get("name"))
|
||
var recordTypeString = strings.ToUpper(strings.TrimSpace(query.Get("type")))
|
||
|
||
if len(name) == 0 {
|
||
writer.WriteHeader(http.StatusBadRequest)
|
||
_, _ = writer.Write([]byte("invalid 'name' parameter"))
|
||
return
|
||
}
|
||
|
||
// add '.' to name
|
||
if !strings.HasSuffix(name, ".") {
|
||
name += "."
|
||
}
|
||
|
||
if len(recordTypeString) == 0 {
|
||
writer.WriteHeader(http.StatusBadRequest)
|
||
_, _ = writer.Write([]byte("invalid 'type' parameter"))
|
||
return
|
||
}
|
||
|
||
var recordType uint16
|
||
if regexp.MustCompile(`^\d{1,4}$`).MatchString(recordTypeString) {
|
||
recordType = types.Uint16(recordTypeString)
|
||
} else {
|
||
recordType = dns.StringToType[recordTypeString]
|
||
}
|
||
if recordType == 0 {
|
||
writer.WriteHeader(http.StatusBadRequest)
|
||
_, _ = writer.Write([]byte("invalid 'type' parameter"))
|
||
return
|
||
}
|
||
|
||
_, ok := dns.TypeToString[recordType]
|
||
if !ok {
|
||
writer.WriteHeader(http.StatusBadRequest)
|
||
_, _ = writer.Write([]byte("invalid 'type' parameter"))
|
||
return
|
||
}
|
||
|
||
var msg = &dns.Msg{}
|
||
msg.Question = []dns.Question{
|
||
{
|
||
Name: name,
|
||
Qtype: recordType,
|
||
Qclass: dns.ClassINET,
|
||
},
|
||
}
|
||
|
||
var connValue = req.Context().Value(HTTPConnContextKey)
|
||
if connValue == nil {
|
||
writer.WriteHeader(http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
// conn
|
||
conn, ok := connValue.(net.Conn)
|
||
if !ok {
|
||
writer.WriteHeader(http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
this.handleDNSMessage(NewHTTPWriter(writer, conn, "application/x-javascript"), msg)
|
||
}
|
||
|
||
// 组合SOA回复信息
|
||
func (this *Server) composeSOAAnswer(question dns.Question, record *models.NSRecord, resp *dns.Msg) {
|
||
var nodeCfg = dnsNodeConfig()
|
||
var config = nodeCfg.SOA
|
||
var serial = nodeCfg.SOASerial
|
||
|
||
if config == nil {
|
||
config = dnsconfigs.DefaultNSSOAConfig()
|
||
}
|
||
|
||
var mName = config.MName
|
||
if len(mName) == 0 {
|
||
var hosts = nodeCfg.Hosts
|
||
var l = len(hosts)
|
||
if l > 0 {
|
||
var index = rands.Int(0, l-1)
|
||
mName = hosts[index]
|
||
}
|
||
}
|
||
|
||
var rName = config.RName
|
||
if len(rName) == 0 {
|
||
rName = nodeCfg.Email
|
||
}
|
||
rName = strings.ReplaceAll(rName, "@", ".")
|
||
|
||
if len(mName) > 0 && len(rName) > 0 {
|
||
// 设置记录值
|
||
record.Value = mName + "."
|
||
|
||
resp.Answer = append(resp.Answer, &dns.SOA{
|
||
Hdr: record.ToRRHeader(question.Name, dns.TypeSOA, question.Qclass),
|
||
Ns: mName + ".",
|
||
Mbox: rName + ".",
|
||
Serial: serial,
|
||
Refresh: config.RefreshSeconds,
|
||
Retry: config.RetrySeconds,
|
||
Expire: config.ExpireSeconds,
|
||
Minttl: config.MinimumTTL,
|
||
})
|
||
}
|
||
}
|