Initial commit (code only without large binaries)
This commit is contained in:
932
EdgeDNS/internal/nodes/server.go
Normal file
932
EdgeDNS/internal/nodes/server.go
Normal file
@@ -0,0 +1,932 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/agents"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/models"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeDNS/internal/stats"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/iwind/TeaGo/rands"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"github.com/miekg/dns"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var sharedRecursionDNSClient = &dns.Client{}
|
||||
|
||||
type httpContextKey struct {
|
||||
key string
|
||||
}
|
||||
|
||||
var HTTPConnContextKey = &httpContextKey{key: "http-conn"}
|
||||
|
||||
const PingDomain = "ping."
|
||||
|
||||
// Server 服务
|
||||
type Server struct {
|
||||
config *ServerConfig
|
||||
|
||||
rawServer *dns.Server
|
||||
httpsServer *http.Server
|
||||
}
|
||||
|
||||
// NewServer 构造新服务
|
||||
func NewServer(config *ServerConfig) (*Server, error) {
|
||||
var server = &Server{
|
||||
config: config,
|
||||
}
|
||||
err := server.init()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return server, nil
|
||||
}
|
||||
|
||||
// ListenAndServe 监听
|
||||
func (this *Server) ListenAndServe() error {
|
||||
if this.rawServer != nil {
|
||||
return this.rawServer.ListenAndServe()
|
||||
}
|
||||
if this.httpsServer != nil {
|
||||
listener, err := net.Listen("tcp", this.httpsServer.Addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = this.httpsServer.ServeTLS(listener, "", "")
|
||||
if err == http.ErrServerClosed {
|
||||
err = nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return errors.New("the server is not initialized")
|
||||
}
|
||||
|
||||
// Shutdown 关闭
|
||||
func (this *Server) Shutdown() error {
|
||||
if this.rawServer != nil {
|
||||
return this.rawServer.Shutdown()
|
||||
}
|
||||
if this.httpsServer != nil {
|
||||
return this.httpsServer.Shutdown(context.Background())
|
||||
}
|
||||
|
||||
return errors.New("the server is not initialized")
|
||||
}
|
||||
|
||||
// Reload 重载配置
|
||||
func (this *Server) Reload(config *ServerConfig) {
|
||||
this.config = config
|
||||
}
|
||||
|
||||
// 初始化
|
||||
func (this *Server) init() error {
|
||||
var rawServer = &dns.Server{
|
||||
ReadTimeout: 3 * time.Second,
|
||||
WriteTimeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
rawServer.Handler = dns.HandlerFunc(this.handleDNSMessage)
|
||||
var addr = ""
|
||||
if len(this.config.Host) > 0 {
|
||||
addr += configutils.QuoteIP(this.config.Host)
|
||||
}
|
||||
addr += ":" + types.String(this.config.Port)
|
||||
rawServer.Addr = addr
|
||||
|
||||
switch this.config.Protocol {
|
||||
case serverconfigs.ProtocolTCP:
|
||||
rawServer.Net = "tcp"
|
||||
case serverconfigs.ProtocolTLS:
|
||||
rawServer.Net = "tcp-tls"
|
||||
rawServer.TLSConfig = &tls.Config{
|
||||
Certificates: nil,
|
||||
GetConfigForClient: func(clientInfo *tls.ClientHelloInfo) (config *tls.Config, e error) {
|
||||
return this.config.SSLPolicy.TLSConfig(), nil
|
||||
},
|
||||
GetCertificate: func(clientInfo *tls.ClientHelloInfo) (certificate *tls.Certificate, e error) {
|
||||
return this.config.SSLPolicy.FirstCert(), nil
|
||||
},
|
||||
}
|
||||
case serverconfigs.ProtocolHTTPS: // DoH
|
||||
rawServer = nil
|
||||
this.httpsServer = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: http.HandlerFunc(this.handleHTTP),
|
||||
TLSConfig: &tls.Config{
|
||||
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
if this.config == nil {
|
||||
return nil, errors.New("invalid 'ServerConfig.config'")
|
||||
}
|
||||
if this.config.SSLPolicy == nil {
|
||||
return nil, errors.New("invalid 'ServerConfig.config.SSLPolicy'")
|
||||
}
|
||||
return this.config.SSLPolicy.TLSConfig(), nil
|
||||
},
|
||||
GetCertificate: func(clientInfo *tls.ClientHelloInfo) (certificate *tls.Certificate, e error) {
|
||||
if this.config == nil {
|
||||
return nil, errors.New("invalid 'ServerConfig.config'")
|
||||
}
|
||||
if this.config.SSLPolicy == nil {
|
||||
return nil, errors.New("invalid 'ServerConfig.config.SSLPolicy'")
|
||||
}
|
||||
return this.config.SSLPolicy.FirstCert(), nil
|
||||
},
|
||||
},
|
||||
ConnContext: func(ctx context.Context, c net.Conn) context.Context {
|
||||
return context.WithValue(ctx, HTTPConnContextKey, c)
|
||||
},
|
||||
ReadTimeout: 5 * time.Second,
|
||||
ReadHeaderTimeout: 3 * time.Second,
|
||||
WriteTimeout: 3 * time.Second,
|
||||
IdleTimeout: 75 * time.Second,
|
||||
MaxHeaderBytes: 4096,
|
||||
}
|
||||
case serverconfigs.ProtocolUDP:
|
||||
rawServer.Net = "udp"
|
||||
}
|
||||
|
||||
this.rawServer = rawServer
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 查询递归DNS
|
||||
func (this *Server) lookupRecursionDNS(req *dns.Msg, resp *dns.Msg) error {
|
||||
var config = sharedNodeConfig.RecursionConfig
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
if !config.IsOn {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 是否允许
|
||||
var domain = strings.TrimSuffix(req.Question[0].Name, ".")
|
||||
if len(config.DenyDomains) > 0 && configutils.MatchDomains(config.DenyDomains, domain) {
|
||||
return nil
|
||||
}
|
||||
if len(config.AllowDomains) > 0 && !configutils.MatchDomains(config.AllowDomains, domain) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if config.UseLocalHosts {
|
||||
// TODO 需要缓存文件内容
|
||||
resolveConfig, err := dns.ClientConfigFromFile("/etc/resolv.conf")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(resolveConfig.Servers) == 0 {
|
||||
return errors.New("no dns servers found in config file")
|
||||
}
|
||||
if len(resolveConfig.Port) == 0 {
|
||||
resolveConfig.Port = "53"
|
||||
}
|
||||
r, _, err := sharedRecursionDNSClient.Exchange(req, configutils.QuoteIP(resolveConfig.Servers[rands.Int(0, len(resolveConfig.Servers)-1)])+":"+resolveConfig.Port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp.Answer = r.Answer
|
||||
} else if len(config.Hosts) > 0 {
|
||||
var host = config.Hosts[rands.Int(0, len(config.Hosts)-1)]
|
||||
if host.Port <= 0 {
|
||||
host.Port = 53
|
||||
}
|
||||
r, _, err := sharedRecursionDNSClient.Exchange(req, configutils.QuoteIP(host.Host)+":"+types.String(host.Port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp.Answer = r.Answer
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 分析查询中的动作
|
||||
func (this *Server) parseAction(questionName string, remoteAddr *string) (string, error) {
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// TODO 需要防止恶意攻击
|
||||
var optionIndex = strings.Index(questionName, "-")
|
||||
if optionIndex > 0 {
|
||||
optionId := types.Int64(questionName[1:optionIndex])
|
||||
optionResp, err := rpcClient.NSQuestionOptionRPC.FindNSQuestionOption(rpcClient.Context(), &pb.FindNSQuestionOptionRequest{NsQuestionOptionId: optionId})
|
||||
if err != nil {
|
||||
return "", errors.New("query question option failed: " + err.Error())
|
||||
} else {
|
||||
var option = optionResp.NsQuestionOption
|
||||
if option != nil {
|
||||
switch option.Name {
|
||||
case "setRemoteAddr":
|
||||
var m = maps.Map{}
|
||||
err = json.Unmarshal(option.ValuesJSON, &m)
|
||||
if err != nil {
|
||||
return "", errors.New("decode question option failed: " + err.Error())
|
||||
} else {
|
||||
var ip = m.GetString("ip")
|
||||
*remoteAddr = ip
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
questionName = questionName[optionIndex+1:]
|
||||
}
|
||||
return questionName, nil
|
||||
}
|
||||
|
||||
// 记录日志
|
||||
func (this *Server) addLog(networking string, question dns.Question, domainId int64, routeCode string, record *models.NSRecord, isRecursive bool, writer dns.ResponseWriter, remoteAddr string, err error) {
|
||||
// 访问日志
|
||||
var accessLogRef = sharedNodeConfig.AccessLogRef
|
||||
if accessLogRef != nil && accessLogRef.IsOn {
|
||||
if domainId == 0 && !accessLogRef.LogMissingDomains {
|
||||
return
|
||||
}
|
||||
|
||||
if accessLogRef.MissingRecordsOnly && record != nil && len(record.Value) > 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var now = time.Now()
|
||||
var pbAccessLog = &pb.NSAccessLog{
|
||||
NsNodeId: sharedNodeConfig.Id,
|
||||
RemoteAddr: remoteAddr,
|
||||
NsDomainId: domainId,
|
||||
QuestionName: question.Name,
|
||||
QuestionType: dns.Type(question.Qtype).String(),
|
||||
IsRecursive: isRecursive,
|
||||
Networking: networking,
|
||||
ServerAddr: writer.LocalAddr().String(),
|
||||
Timestamp: now.Unix(),
|
||||
TimeLocal: now.Format("2/Jan/2006:15:04:05 -0700"),
|
||||
RequestId: "",
|
||||
}
|
||||
if record != nil {
|
||||
pbAccessLog.NsRecordId = record.Id
|
||||
if len(routeCode) > 0 {
|
||||
pbAccessLog.NsRouteCodes = []string{routeCode}
|
||||
}
|
||||
pbAccessLog.RecordName = record.Name
|
||||
pbAccessLog.RecordType = record.Type
|
||||
pbAccessLog.RecordValue = record.Value
|
||||
}
|
||||
if err != nil {
|
||||
pbAccessLog.Error = err.Error()
|
||||
}
|
||||
sharedNSAccessLogQueue.Push(pbAccessLog)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证TSIG
|
||||
func (this *Server) checkTSIG(msg *dns.Msg, domainId int64) error {
|
||||
var tsig = msg.IsTsig()
|
||||
if tsig == nil {
|
||||
return errors.New("tsig: tsig required")
|
||||
}
|
||||
|
||||
var keys = sharedKeyManager.FindKeysWithDomain(domainId)
|
||||
if len(keys) == 0 {
|
||||
return errors.New("tsig: no keys defined")
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
if key.Algo != tsig.Algorithm {
|
||||
continue
|
||||
}
|
||||
|
||||
// 需要重新Pack,每次Pack结果只能校验一次
|
||||
msgData, err := msg.Pack()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(msgData) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var base64Secret = key.Secret
|
||||
if key.SecretType == dnsconfigs.NSKeySecretTypeClear {
|
||||
base64Secret = base64.StdEncoding.EncodeToString([]byte(key.Secret))
|
||||
}
|
||||
err = dns.TsigVerify(msgData, base64Secret, "", false)
|
||||
if err != nil {
|
||||
continue
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return dns.ErrSig
|
||||
}
|
||||
|
||||
// 处理DNS请求
|
||||
func (this *Server) handleDNSMessage(writer dns.ResponseWriter, req *dns.Msg) {
|
||||
if len(req.Question) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if sharedDomainManager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var networking = ""
|
||||
if this.config != nil {
|
||||
networking = this.config.Protocol.String()
|
||||
}
|
||||
|
||||
var resultDomainId int64
|
||||
var resultRecordIds [][2]int64 // [] { domainId, recordId}
|
||||
|
||||
var resp = new(dns.Msg)
|
||||
resp.RecursionDesired = true
|
||||
resp.RecursionAvailable = true
|
||||
resp.SetReply(req)
|
||||
|
||||
resp.Answer = []dns.RR{}
|
||||
|
||||
var tsigIsChecked = false
|
||||
var remoteAddr = writer.RemoteAddr().String()
|
||||
|
||||
for _, question := range req.Question {
|
||||
if len(question.Name) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// PING
|
||||
if question.Name == PingDomain {
|
||||
resp.Answer = append(resp.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: question.Name,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: question.Qclass,
|
||||
Ttl: 60,
|
||||
},
|
||||
A: net.ParseIP("127.0.0.1"),
|
||||
})
|
||||
resp.Rcode = dns.RcodeSuccess
|
||||
err := writer.WriteMsg(resp)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 查询选项
|
||||
if question.Name[0] == '$' {
|
||||
_, port, _ := net.SplitHostPort(remoteAddr)
|
||||
|
||||
questionName, err := this.parseAction(question.Name, &remoteAddr)
|
||||
if err != nil {
|
||||
remotelogs.Error("SERVER", "invalid query option '"+question.Name+"'")
|
||||
continue
|
||||
}
|
||||
question.Name = questionName
|
||||
if len(port) > 0 {
|
||||
if strings.Contains(remoteAddr, ":") { // IPv6
|
||||
remoteAddr = "[" + remoteAddr + "]:" + port
|
||||
} else {
|
||||
remoteAddr += ":" + port
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var fullName = strings.TrimSuffix(question.Name, ".")
|
||||
var recordName string
|
||||
var recordType = dns.Type(question.Qtype).String()
|
||||
var domain *models.NSDomain
|
||||
domain, recordName = sharedDomainManager.SplitDomain(fullName)
|
||||
if domain == nil {
|
||||
// 检查递归DNS
|
||||
if sharedNodeConfig.RecursionConfig != nil && sharedNodeConfig.RecursionConfig.IsOn {
|
||||
err := this.lookupRecursionDNS(req, resp)
|
||||
if err != nil {
|
||||
this.addLog(networking, question, 0, "", nil, true, writer, remoteAddr, err)
|
||||
} else {
|
||||
var recordValue = ""
|
||||
if len(resp.Answer) > 0 {
|
||||
pieces := regexp.MustCompile(`\s+`).Split(resp.Answer[0].String(), 6)
|
||||
if len(pieces) >= 5 {
|
||||
recordValue = pieces[4]
|
||||
}
|
||||
}
|
||||
this.addLog(networking, question, 0, "", &models.NSRecord{
|
||||
Id: 0,
|
||||
Name: recordName,
|
||||
Type: recordType,
|
||||
Value: recordValue,
|
||||
Ttl: 0,
|
||||
}, true, writer, remoteAddr, nil)
|
||||
}
|
||||
|
||||
err = writer.WriteMsg(resp)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// 是否为NS记录,用于验证域名所有权
|
||||
if question.Qtype == dns.TypeNS {
|
||||
var hosts = sharedNodeConfig.Hosts
|
||||
var l = len(hosts)
|
||||
var record = &models.NSRecord{
|
||||
Id: 0,
|
||||
Type: dnsconfigs.RecordTypeNS,
|
||||
Ttl: 600, // TODO 可以设置
|
||||
}
|
||||
if l > 0 {
|
||||
l = 1 // 目前只返回一个
|
||||
|
||||
// 随机
|
||||
var indexes = []int{}
|
||||
for i := 0; i < l; i++ {
|
||||
indexes = append(indexes, i)
|
||||
}
|
||||
|
||||
rand.Shuffle(l, func(i, j int) {
|
||||
indexes[i], indexes[j] = indexes[j], indexes[i]
|
||||
})
|
||||
|
||||
record.Value = hosts[0] + "."
|
||||
for _, index := range indexes {
|
||||
resp.Answer = append(resp.Answer, &dns.NS{
|
||||
Hdr: record.ToRRHeader(question.Name, dns.TypeNS, question.Qclass),
|
||||
Ns: hosts[index] + ".",
|
||||
})
|
||||
}
|
||||
|
||||
this.addLog(networking, question, 0, "", record, false, writer, remoteAddr, nil)
|
||||
continue
|
||||
}
|
||||
|
||||
this.addLog(networking, question, 0, "", nil, false, writer, remoteAddr, nil)
|
||||
continue
|
||||
}
|
||||
|
||||
this.addLog(networking, question, 0, "", nil, false, writer, remoteAddr, nil)
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查TSIG
|
||||
if domain.TSIG != nil && domain.TSIG.IsOn && !tsigIsChecked {
|
||||
err := this.checkTSIG(req, domain.Id)
|
||||
if err != nil {
|
||||
this.addLog(networking, question, domain.Id, "", nil, false, writer, remoteAddr, err)
|
||||
continue
|
||||
}
|
||||
tsigIsChecked = true
|
||||
}
|
||||
|
||||
resultDomainId = domain.Id
|
||||
|
||||
var clientIP = remoteAddr
|
||||
clientHost, _, err := net.SplitHostPort(clientIP)
|
||||
if err == nil && len(clientHost) > 0 {
|
||||
clientIP = clientHost
|
||||
}
|
||||
|
||||
// 解析Agent
|
||||
if sharedNodeConfig.DetectAgents {
|
||||
agents.SharedQueue.Push(clientIP)
|
||||
}
|
||||
|
||||
var routeCodes = sharedRouteManager.FindRouteCodes(clientIP, domain.UserId)
|
||||
|
||||
var records []*models.NSRecord
|
||||
var matchedRouteCode string
|
||||
if question.Qtype == dns.TypeSOA { // SOA
|
||||
if len(recordName) == 0 { // 只有顶级域名才有SOA记录
|
||||
records = []*models.NSRecord{
|
||||
{
|
||||
Id: 0,
|
||||
Type: dnsconfigs.RecordTypeSOA,
|
||||
Ttl: 600, // TODO 可以设置
|
||||
},
|
||||
}
|
||||
}
|
||||
} else if question.Qtype == dns.TypeNS { // NS
|
||||
if len(recordName) == 0 { // 只有顶级域名才有NS记录
|
||||
records = []*models.NSRecord{
|
||||
{
|
||||
Id: 0,
|
||||
Type: dnsconfigs.RecordTypeNS,
|
||||
Ttl: 600, // TODO 可以设置
|
||||
},
|
||||
}
|
||||
}
|
||||
} else if question.Qtype != dns.TypeCNAME {
|
||||
// 是否有直接的设置
|
||||
records, matchedRouteCode = sharedRecordManager.FindRecords(domain.Id, routeCodes, recordName, recordType, true)
|
||||
|
||||
// 检查CNAME
|
||||
if len(records) == 0 {
|
||||
records, matchedRouteCode = sharedRecordManager.FindRecords(domain.Id, routeCodes, recordName, dnsconfigs.RecordTypeCNAME, false)
|
||||
if len(records) > 0 {
|
||||
question.Qtype = dns.TypeCNAME
|
||||
}
|
||||
}
|
||||
|
||||
// 再次尝试查找默认设置
|
||||
if len(records) == 0 {
|
||||
records, matchedRouteCode = sharedRecordManager.FindRecords(domain.Id, routeCodes, recordName, recordType, false)
|
||||
}
|
||||
}
|
||||
|
||||
if len(records) == 0 {
|
||||
records, matchedRouteCode = sharedRecordManager.FindRecords(domain.Id, routeCodes, recordName, recordType, false)
|
||||
}
|
||||
|
||||
// 对 NS.example.com NS|SOA 处理
|
||||
if (question.Qtype == dns.TypeNS || (question.Qtype == dns.TypeSOA && len(records) == 0)) && lists.ContainsString(sharedNodeConfig.Hosts, fullName) {
|
||||
var recordDNSType string
|
||||
switch question.Qtype {
|
||||
case dns.TypeNS:
|
||||
recordDNSType = dnsconfigs.RecordTypeNS
|
||||
case dns.TypeSOA:
|
||||
recordDNSType = dnsconfigs.RecordTypeSOA
|
||||
}
|
||||
this.composeSOAAnswer(question, &models.NSRecord{
|
||||
Type: recordDNSType,
|
||||
Ttl: 600,
|
||||
}, resp)
|
||||
}
|
||||
|
||||
if len(records) > 0 {
|
||||
var firstRecord = records[0]
|
||||
|
||||
for _, record := range records {
|
||||
resultRecordIds = append(resultRecordIds, [2]int64{record.DomainId, record.Id})
|
||||
|
||||
switch record.Type {
|
||||
case dnsconfigs.RecordTypeA:
|
||||
var answer = record.ToRRAnswer(question.Name, question.Qclass)
|
||||
if answer != nil {
|
||||
resp.Answer = append(resp.Answer, answer)
|
||||
}
|
||||
case dnsconfigs.RecordTypeCNAME:
|
||||
var value = record.Value
|
||||
if !strings.HasSuffix(value, ".") {
|
||||
value += "."
|
||||
}
|
||||
var lastRecordValue = value
|
||||
resp.Answer = append(resp.Answer, &dns.CNAME{
|
||||
Hdr: record.ToRRHeader(question.Name, dns.TypeCNAME, question.Qclass),
|
||||
Target: value,
|
||||
})
|
||||
|
||||
// 继续查询CNAME
|
||||
var allCNAMEValues = []string{lastRecordValue}
|
||||
for {
|
||||
// 限制最深32层
|
||||
if len(allCNAMEValues) > 32 {
|
||||
break
|
||||
}
|
||||
|
||||
cnameDomain, cnameRecordName := sharedDomainManager.SplitDomain(lastRecordValue)
|
||||
if cnameDomain == nil {
|
||||
break
|
||||
}
|
||||
cnameRecords, _ := sharedRecordManager.FindRecords(cnameDomain.Id, sharedRouteManager.FindRouteCodes(clientIP, cnameDomain.UserId), cnameRecordName, dnsconfigs.RecordTypeCNAME, false)
|
||||
if len(cnameRecords) == 0 {
|
||||
break
|
||||
}
|
||||
var cnameRecord = cnameRecords[0]
|
||||
if !lists.ContainsString(allCNAMEValues, cnameRecord.Value) {
|
||||
resultRecordIds = append(resultRecordIds, [2]int64{cnameRecord.DomainId, cnameRecord.Id}) // 统计
|
||||
var answer = cnameRecord.ToRRAnswer(lastRecordValue, question.Qclass)
|
||||
if answer == nil {
|
||||
break
|
||||
}
|
||||
resp.Answer = append(resp.Answer, answer)
|
||||
lastRecordValue = cnameRecord.Value
|
||||
allCNAMEValues = append(allCNAMEValues, lastRecordValue)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 再次查询原始问题
|
||||
if len(req.Question) > 0 {
|
||||
var firstQuestion = req.Question[0]
|
||||
if firstQuestion.Qtype != dns.TypeCNAME {
|
||||
finalDomain, finalRecordName := sharedDomainManager.SplitDomain(lastRecordValue)
|
||||
if finalDomain != nil {
|
||||
var realRecords, _ = sharedRecordManager.FindRecords(finalDomain.Id, sharedRouteManager.FindRouteCodes(clientIP, finalDomain.UserId), finalRecordName, dns.Type(firstQuestion.Qtype).String(), false)
|
||||
if len(realRecords) > 0 {
|
||||
for _, realRecord := range realRecords {
|
||||
resultRecordIds = append(resultRecordIds, [2]int64{realRecord.DomainId, realRecord.Id}) // 统计
|
||||
var answer = realRecord.ToRRAnswer(lastRecordValue, question.Qclass)
|
||||
if answer != nil {
|
||||
resp.Answer = append(resp.Answer, answer)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case dnsconfigs.RecordTypeAAAA:
|
||||
var answer = record.ToRRAnswer(question.Name, question.Qclass)
|
||||
if answer != nil {
|
||||
resp.Answer = append(resp.Answer, answer)
|
||||
}
|
||||
case dnsconfigs.RecordTypeNS:
|
||||
if record.Id == 0 {
|
||||
var hosts = sharedNodeConfig.Hosts
|
||||
var l = len(hosts)
|
||||
if l > 0 {
|
||||
// 随机
|
||||
var indexes = []int{}
|
||||
for i := 0; i < l; i++ {
|
||||
indexes = append(indexes, i)
|
||||
}
|
||||
|
||||
rand.Shuffle(l, func(i, j int) {
|
||||
indexes[i], indexes[j] = indexes[j], indexes[i]
|
||||
})
|
||||
|
||||
record.Value = hosts[0] + "."
|
||||
for _, index := range indexes {
|
||||
resp.Answer = append(resp.Answer, &dns.NS{
|
||||
Hdr: record.ToRRHeader(question.Name, dns.TypeNS, question.Qclass),
|
||||
Ns: hosts[index] + ".",
|
||||
})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
var value = record.Value
|
||||
if !strings.HasSuffix(value, ".") {
|
||||
value += "."
|
||||
}
|
||||
resp.Answer = append(resp.Answer, &dns.NS{
|
||||
Hdr: record.ToRRHeader(question.Name, dns.TypeNS, question.Qclass),
|
||||
Ns: value,
|
||||
})
|
||||
}
|
||||
case dnsconfigs.RecordTypeMX:
|
||||
var answer = record.ToRRAnswer(question.Name, question.Qclass)
|
||||
if answer != nil {
|
||||
resp.Answer = append(resp.Answer, answer)
|
||||
}
|
||||
case dnsconfigs.RecordTypeSRV:
|
||||
var answer = record.ToRRAnswer(question.Name, question.Qclass)
|
||||
if answer != nil {
|
||||
resp.Answer = append(resp.Answer, answer)
|
||||
}
|
||||
case dnsconfigs.RecordTypeTXT:
|
||||
var answer = record.ToRRAnswer(question.Name, question.Qclass)
|
||||
if answer != nil {
|
||||
resp.Answer = append(resp.Answer, answer)
|
||||
}
|
||||
case dnsconfigs.RecordTypeCAA:
|
||||
var answer = record.ToRRAnswer(question.Name, question.Qclass)
|
||||
if answer != nil {
|
||||
resp.Answer = append(resp.Answer, answer)
|
||||
}
|
||||
case dnsconfigs.RecordTypeSOA:
|
||||
this.composeSOAAnswer(question, record, resp)
|
||||
}
|
||||
}
|
||||
|
||||
// 访问日志
|
||||
this.addLog(networking, question, resultDomainId, matchedRouteCode, firstRecord, false, writer, remoteAddr, nil)
|
||||
} else {
|
||||
this.addLog(networking, question, resultDomainId, "", nil, false, writer, remoteAddr, nil)
|
||||
}
|
||||
}
|
||||
|
||||
resp.Rcode = dns.RcodeSuccess
|
||||
err := writer.WriteMsg(resp)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 统计
|
||||
for _, resultRecordId := range resultRecordIds {
|
||||
stats.SharedManager.Add(resultRecordId[0], resultRecordId[1], int64(resp.Len()))
|
||||
}
|
||||
}
|
||||
|
||||
// 处理HTTP请求
|
||||
// 参考:https://datatracker.ietf.org/doc/html/rfc8484
|
||||
// 参考:https://developers.google.com/speed/public-dns/docs/doh
|
||||
func (this *Server) handleHTTP(writer http.ResponseWriter, req *http.Request) {
|
||||
if req.URL.Path == "/dns-query" {
|
||||
this.handleHTTPDNSMessage(writer, req)
|
||||
return
|
||||
}
|
||||
|
||||
if req.URL.Path == "/resolve" {
|
||||
this.handleHTTPJSONAPI(writer, req)
|
||||
return
|
||||
}
|
||||
|
||||
writer.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
|
||||
func (this *Server) handleHTTPDNSMessage(writer http.ResponseWriter, req *http.Request) {
|
||||
const maxMessageSize = 512
|
||||
|
||||
writer.Header().Set("Accept", "application/dns-message")
|
||||
|
||||
if req.Method != http.MethodGet && req.Method != http.MethodPost {
|
||||
writer.WriteHeader(http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
|
||||
if req.ContentLength > maxMessageSize {
|
||||
writer.WriteHeader(http.StatusRequestEntityTooLarge)
|
||||
return
|
||||
}
|
||||
|
||||
var messageData []byte
|
||||
|
||||
switch req.Method {
|
||||
case http.MethodGet:
|
||||
if len(req.URL.RawQuery) > maxMessageSize {
|
||||
writer.WriteHeader(http.StatusRequestURITooLong)
|
||||
return
|
||||
}
|
||||
var encodedMessage = req.URL.Query().Get("dns")
|
||||
var err error
|
||||
messageData, err = base64.StdEncoding.DecodeString(encodedMessage)
|
||||
if err != nil {
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
case http.MethodPost:
|
||||
var contentType = req.Header.Get("Content-Type")
|
||||
if contentType != "application/dns-message" {
|
||||
writer.WriteHeader(http.StatusUnsupportedMediaType)
|
||||
return
|
||||
}
|
||||
data, err := io.ReadAll(io.LimitReader(req.Body, maxMessageSize))
|
||||
if err != nil {
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
messageData = data
|
||||
}
|
||||
|
||||
if len(messageData) == 0 {
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var msg = &dns.Msg{}
|
||||
err := msg.Unpack(messageData)
|
||||
if err != nil {
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var connValue = req.Context().Value(HTTPConnContextKey)
|
||||
if connValue == nil {
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
conn, ok := connValue.(net.Conn)
|
||||
if !ok {
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
this.handleDNSMessage(NewHTTPWriter(writer, conn, "application/dns-message"), msg)
|
||||
}
|
||||
|
||||
func (this *Server) handleHTTPJSONAPI(writer http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != http.MethodGet {
|
||||
writer.WriteHeader(http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
|
||||
var query = req.URL.Query()
|
||||
var name = strings.TrimSpace(query.Get("name"))
|
||||
var recordTypeString = strings.ToUpper(strings.TrimSpace(query.Get("type")))
|
||||
|
||||
if len(name) == 0 {
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = writer.Write([]byte("invalid 'name' parameter"))
|
||||
return
|
||||
}
|
||||
|
||||
// add '.' to name
|
||||
if !strings.HasSuffix(name, ".") {
|
||||
name += "."
|
||||
}
|
||||
|
||||
if len(recordTypeString) == 0 {
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = writer.Write([]byte("invalid 'type' parameter"))
|
||||
return
|
||||
}
|
||||
|
||||
var recordType uint16
|
||||
if regexp.MustCompile(`^\d{1,4}$`).MatchString(recordTypeString) {
|
||||
recordType = types.Uint16(recordTypeString)
|
||||
} else {
|
||||
recordType = dns.StringToType[recordTypeString]
|
||||
}
|
||||
if recordType == 0 {
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = writer.Write([]byte("invalid 'type' parameter"))
|
||||
return
|
||||
}
|
||||
|
||||
_, ok := dns.TypeToString[recordType]
|
||||
if !ok {
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = writer.Write([]byte("invalid 'type' parameter"))
|
||||
return
|
||||
}
|
||||
|
||||
var msg = &dns.Msg{}
|
||||
msg.Question = []dns.Question{
|
||||
{
|
||||
Name: name,
|
||||
Qtype: recordType,
|
||||
Qclass: dns.ClassINET,
|
||||
},
|
||||
}
|
||||
|
||||
var connValue = req.Context().Value(HTTPConnContextKey)
|
||||
if connValue == nil {
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// conn
|
||||
conn, ok := connValue.(net.Conn)
|
||||
if !ok {
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
this.handleDNSMessage(NewHTTPWriter(writer, conn, "application/x-javascript"), msg)
|
||||
}
|
||||
|
||||
// 组合SOA回复信息
|
||||
func (this *Server) composeSOAAnswer(question dns.Question, record *models.NSRecord, resp *dns.Msg) {
|
||||
var config = sharedNodeConfig.SOA
|
||||
var serial = sharedNodeConfig.SOASerial
|
||||
|
||||
if config == nil {
|
||||
config = dnsconfigs.DefaultNSSOAConfig()
|
||||
}
|
||||
|
||||
var mName = config.MName
|
||||
if len(mName) == 0 {
|
||||
var hosts = sharedNodeConfig.Hosts
|
||||
var l = len(hosts)
|
||||
if l > 0 {
|
||||
var index = rands.Int(0, l-1)
|
||||
mName = hosts[index]
|
||||
}
|
||||
}
|
||||
|
||||
var rName = config.RName
|
||||
if len(rName) == 0 {
|
||||
rName = sharedNodeConfig.Email
|
||||
}
|
||||
rName = strings.ReplaceAll(rName, "@", ".")
|
||||
|
||||
if len(mName) > 0 && len(rName) > 0 {
|
||||
// 设置记录值
|
||||
record.Value = mName + "."
|
||||
|
||||
resp.Answer = append(resp.Answer, &dns.SOA{
|
||||
Hdr: record.ToRRHeader(question.Name, dns.TypeSOA, question.Qclass),
|
||||
Ns: mName + ".",
|
||||
Mbox: rName + ".",
|
||||
Serial: serial,
|
||||
Refresh: config.RefreshSeconds,
|
||||
Retry: config.RetrySeconds,
|
||||
Expire: config.ExpireSeconds,
|
||||
Minttl: config.MinimumTTL,
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user