Files
waf-platform/EdgeDNS/internal/nodes/server.go
2026-03-02 20:07:53 +08:00

942 lines
25 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes
import (
"context"
"crypto/tls"
"encoding/base64"
"encoding/json"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeDNS/internal/agents"
"github.com/TeaOSLab/EdgeDNS/internal/models"
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
"github.com/TeaOSLab/EdgeDNS/internal/stats"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types"
"github.com/miekg/dns"
"io"
"math/rand"
"net"
"net/http"
"regexp"
"strings"
"time"
)
var sharedRecursionDNSClient = &dns.Client{}
type httpContextKey struct {
key string
}
var HTTPConnContextKey = &httpContextKey{key: "http-conn"}
const PingDomain = "ping."
// Server 服务
type Server struct {
config *ServerConfig
rawServer *dns.Server
httpsServer *http.Server
}
// NewServer 构造新服务
func NewServer(config *ServerConfig) (*Server, error) {
var server = &Server{
config: config,
}
err := server.init()
if err != nil {
return nil, err
}
return server, nil
}
// ListenAndServe 监听
func (this *Server) ListenAndServe() error {
if this.rawServer != nil {
return this.rawServer.ListenAndServe()
}
if this.httpsServer != nil {
listener, err := net.Listen("tcp", this.httpsServer.Addr)
if err != nil {
return err
}
err = this.httpsServer.ServeTLS(listener, "", "")
if err == http.ErrServerClosed {
err = nil
}
return err
}
return errors.New("the server is not initialized")
}
// Shutdown 关闭
func (this *Server) Shutdown() error {
if this.rawServer != nil {
return this.rawServer.Shutdown()
}
if this.httpsServer != nil {
return this.httpsServer.Shutdown(context.Background())
}
return errors.New("the server is not initialized")
}
// Reload 重载配置
func (this *Server) Reload(config *ServerConfig) {
this.config = config
}
// 初始化
func (this *Server) init() error {
var rawServer = &dns.Server{
ReadTimeout: 3 * time.Second,
WriteTimeout: 3 * time.Second,
}
rawServer.Handler = dns.HandlerFunc(this.handleDNSMessage)
var addr = ""
if len(this.config.Host) > 0 {
addr += configutils.QuoteIP(this.config.Host)
}
addr += ":" + types.String(this.config.Port)
rawServer.Addr = addr
switch this.config.Protocol {
case serverconfigs.ProtocolTCP:
rawServer.Net = "tcp"
case serverconfigs.ProtocolTLS:
rawServer.Net = "tcp-tls"
rawServer.TLSConfig = &tls.Config{
Certificates: nil,
GetConfigForClient: func(clientInfo *tls.ClientHelloInfo) (config *tls.Config, e error) {
return this.config.SSLPolicy.TLSConfig(), nil
},
GetCertificate: func(clientInfo *tls.ClientHelloInfo) (certificate *tls.Certificate, e error) {
return this.config.SSLPolicy.FirstCert(), nil
},
}
case serverconfigs.ProtocolHTTPS: // DoH
rawServer = nil
this.httpsServer = &http.Server{
Addr: addr,
Handler: http.HandlerFunc(this.handleHTTP),
TLSConfig: &tls.Config{
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
if this.config == nil {
return nil, errors.New("invalid 'ServerConfig.config'")
}
if this.config.SSLPolicy == nil {
return nil, errors.New("invalid 'ServerConfig.config.SSLPolicy'")
}
return this.config.SSLPolicy.TLSConfig(), nil
},
GetCertificate: func(clientInfo *tls.ClientHelloInfo) (certificate *tls.Certificate, e error) {
if this.config == nil {
return nil, errors.New("invalid 'ServerConfig.config'")
}
if this.config.SSLPolicy == nil {
return nil, errors.New("invalid 'ServerConfig.config.SSLPolicy'")
}
return this.config.SSLPolicy.FirstCert(), nil
},
},
ConnContext: func(ctx context.Context, c net.Conn) context.Context {
return context.WithValue(ctx, HTTPConnContextKey, c)
},
ReadTimeout: 5 * time.Second,
ReadHeaderTimeout: 3 * time.Second,
WriteTimeout: 3 * time.Second,
IdleTimeout: 75 * time.Second,
MaxHeaderBytes: 4096,
}
case serverconfigs.ProtocolUDP:
rawServer.Net = "udp"
}
this.rawServer = rawServer
return nil
}
// 查询递归DNS
func (this *Server) lookupRecursionDNS(req *dns.Msg, resp *dns.Msg) error {
var config = sharedNodeConfig.RecursionConfig
if config == nil {
return nil
}
if !config.IsOn {
return nil
}
// 是否允许
var domain = strings.TrimSuffix(req.Question[0].Name, ".")
if len(config.DenyDomains) > 0 && configutils.MatchDomains(config.DenyDomains, domain) {
return nil
}
if len(config.AllowDomains) > 0 && !configutils.MatchDomains(config.AllowDomains, domain) {
return nil
}
if config.UseLocalHosts {
// TODO 需要缓存文件内容
resolveConfig, err := dns.ClientConfigFromFile("/etc/resolv.conf")
if err != nil {
return err
}
if len(resolveConfig.Servers) == 0 {
return errors.New("no dns servers found in config file")
}
if len(resolveConfig.Port) == 0 {
resolveConfig.Port = "53"
}
r, _, err := sharedRecursionDNSClient.Exchange(req, configutils.QuoteIP(resolveConfig.Servers[rands.Int(0, len(resolveConfig.Servers)-1)])+":"+resolveConfig.Port)
if err != nil {
return err
}
resp.Answer = r.Answer
} else if len(config.Hosts) > 0 {
var host = config.Hosts[rands.Int(0, len(config.Hosts)-1)]
if host.Port <= 0 {
host.Port = 53
}
r, _, err := sharedRecursionDNSClient.Exchange(req, configutils.QuoteIP(host.Host)+":"+types.String(host.Port))
if err != nil {
return err
}
resp.Answer = r.Answer
}
return nil
}
// 分析查询中的动作
func (this *Server) parseAction(questionName string, remoteAddr *string) (string, error) {
rpcClient, err := rpc.SharedRPC()
if err != nil {
return "", err
}
// TODO 需要防止恶意攻击
var optionIndex = strings.Index(questionName, "-")
if optionIndex > 0 {
optionId := types.Int64(questionName[1:optionIndex])
optionResp, err := rpcClient.NSQuestionOptionRPC.FindNSQuestionOption(rpcClient.Context(), &pb.FindNSQuestionOptionRequest{NsQuestionOptionId: optionId})
if err != nil {
return "", errors.New("query question option failed: " + err.Error())
} else {
var option = optionResp.NsQuestionOption
if option != nil {
switch option.Name {
case "setRemoteAddr":
var m = maps.Map{}
err = json.Unmarshal(option.ValuesJSON, &m)
if err != nil {
return "", errors.New("decode question option failed: " + err.Error())
} else {
var ip = m.GetString("ip")
// 验证 IP 地址合法性,防止 IP 欺骗
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return "", errors.New("invalid IP address in setRemoteAddr: " + ip)
}
// 拒绝回环地址和未指定地址
if parsedIP.IsLoopback() || parsedIP.IsUnspecified() {
return "", errors.New("disallowed IP address in setRemoteAddr: " + ip)
}
*remoteAddr = ip
}
}
}
}
questionName = questionName[optionIndex+1:]
}
return questionName, nil
}
// 记录日志
func (this *Server) addLog(networking string, question dns.Question, domainId int64, routeCode string, record *models.NSRecord, isRecursive bool, writer dns.ResponseWriter, remoteAddr string, err error) {
// 访问日志
var accessLogRef = sharedNodeConfig.AccessLogRef
if accessLogRef != nil && accessLogRef.IsOn {
if domainId == 0 && !accessLogRef.LogMissingDomains {
return
}
if accessLogRef.MissingRecordsOnly && record != nil && len(record.Value) > 0 {
return
}
var now = time.Now()
var pbAccessLog = &pb.NSAccessLog{
NsNodeId: sharedNodeConfig.Id,
RemoteAddr: remoteAddr,
NsDomainId: domainId,
QuestionName: question.Name,
QuestionType: dns.Type(question.Qtype).String(),
IsRecursive: isRecursive,
Networking: networking,
ServerAddr: writer.LocalAddr().String(),
Timestamp: now.Unix(),
TimeLocal: now.Format("2/Jan/2006:15:04:05 -0700"),
RequestId: "",
}
if record != nil {
pbAccessLog.NsRecordId = record.Id
if len(routeCode) > 0 {
pbAccessLog.NsRouteCodes = []string{routeCode}
}
pbAccessLog.RecordName = record.Name
pbAccessLog.RecordType = record.Type
pbAccessLog.RecordValue = record.Value
}
if err != nil {
pbAccessLog.Error = err.Error()
}
sharedNSAccessLogQueue.Push(pbAccessLog)
}
}
// 验证TSIG
func (this *Server) checkTSIG(msg *dns.Msg, domainId int64) error {
var tsig = msg.IsTsig()
if tsig == nil {
return errors.New("tsig: tsig required")
}
var keys = sharedKeyManager.FindKeysWithDomain(domainId)
if len(keys) == 0 {
return errors.New("tsig: no keys defined")
}
for _, key := range keys {
if key.Algo != tsig.Algorithm {
continue
}
// 需要重新Pack每次Pack结果只能校验一次
msgData, err := msg.Pack()
if err != nil {
return err
}
if len(msgData) == 0 {
return nil
}
var base64Secret = key.Secret
if key.SecretType == dnsconfigs.NSKeySecretTypeClear {
base64Secret = base64.StdEncoding.EncodeToString([]byte(key.Secret))
}
err = dns.TsigVerify(msgData, base64Secret, "", false)
if err != nil {
continue
} else {
return nil
}
}
return dns.ErrSig
}
// 处理DNS请求
func (this *Server) handleDNSMessage(writer dns.ResponseWriter, req *dns.Msg) {
if len(req.Question) == 0 {
return
}
if sharedDomainManager == nil {
return
}
var networking = ""
if this.config != nil {
networking = this.config.Protocol.String()
}
var resultDomainId int64
var resultRecordIds [][2]int64 // [] { domainId, recordId}
var resp = new(dns.Msg)
resp.RecursionDesired = true
resp.RecursionAvailable = true
resp.SetReply(req)
resp.Answer = []dns.RR{}
var tsigIsChecked = false
var remoteAddr = writer.RemoteAddr().String()
for _, question := range req.Question {
if len(question.Name) == 0 {
continue
}
// PING
if question.Name == PingDomain {
resp.Answer = append(resp.Answer, &dns.A{
Hdr: dns.RR_Header{
Name: question.Name,
Rrtype: dns.TypeA,
Class: question.Qclass,
Ttl: 60,
},
A: net.ParseIP("127.0.0.1"),
})
resp.Rcode = dns.RcodeSuccess
err := writer.WriteMsg(resp)
if err != nil {
return
}
return
}
// 查询选项
if question.Name[0] == '$' {
_, port, _ := net.SplitHostPort(remoteAddr)
questionName, err := this.parseAction(question.Name, &remoteAddr)
if err != nil {
remotelogs.Error("SERVER", "invalid query option '"+question.Name+"'")
continue
}
question.Name = questionName
if len(port) > 0 {
if strings.Contains(remoteAddr, ":") { // IPv6
remoteAddr = "[" + remoteAddr + "]:" + port
} else {
remoteAddr += ":" + port
}
}
}
var fullName = strings.TrimSuffix(question.Name, ".")
var recordName string
var recordType = dns.Type(question.Qtype).String()
var domain *models.NSDomain
domain, recordName = sharedDomainManager.SplitDomain(fullName)
if domain == nil {
// 检查递归DNS
if sharedNodeConfig.RecursionConfig != nil && sharedNodeConfig.RecursionConfig.IsOn {
err := this.lookupRecursionDNS(req, resp)
if err != nil {
this.addLog(networking, question, 0, "", nil, true, writer, remoteAddr, err)
} else {
var recordValue = ""
if len(resp.Answer) > 0 {
pieces := regexp.MustCompile(`\s+`).Split(resp.Answer[0].String(), 6)
if len(pieces) >= 5 {
recordValue = pieces[4]
}
}
this.addLog(networking, question, 0, "", &models.NSRecord{
Id: 0,
Name: recordName,
Type: recordType,
Value: recordValue,
Ttl: 0,
}, true, writer, remoteAddr, nil)
}
err = writer.WriteMsg(resp)
if err != nil {
return
}
return
}
// 是否为NS记录用于验证域名所有权
if question.Qtype == dns.TypeNS {
var hosts = sharedNodeConfig.Hosts
var l = len(hosts)
var record = &models.NSRecord{
Id: 0,
Type: dnsconfigs.RecordTypeNS,
Ttl: 600, // TODO 可以设置
}
if l > 0 {
l = 1 // 目前只返回一个
// 随机
var indexes = []int{}
for i := 0; i < l; i++ {
indexes = append(indexes, i)
}
rand.Shuffle(l, func(i, j int) {
indexes[i], indexes[j] = indexes[j], indexes[i]
})
record.Value = hosts[0] + "."
for _, index := range indexes {
resp.Answer = append(resp.Answer, &dns.NS{
Hdr: record.ToRRHeader(question.Name, dns.TypeNS, question.Qclass),
Ns: hosts[index] + ".",
})
}
this.addLog(networking, question, 0, "", record, false, writer, remoteAddr, nil)
continue
}
this.addLog(networking, question, 0, "", nil, false, writer, remoteAddr, nil)
continue
}
this.addLog(networking, question, 0, "", nil, false, writer, remoteAddr, nil)
continue
}
// 检查TSIG
if domain.TSIG != nil && domain.TSIG.IsOn && !tsigIsChecked {
err := this.checkTSIG(req, domain.Id)
if err != nil {
this.addLog(networking, question, domain.Id, "", nil, false, writer, remoteAddr, err)
continue
}
tsigIsChecked = true
}
resultDomainId = domain.Id
var clientIP = remoteAddr
clientHost, _, err := net.SplitHostPort(clientIP)
if err == nil && len(clientHost) > 0 {
clientIP = clientHost
}
// 解析Agent
if sharedNodeConfig.DetectAgents {
agents.SharedQueue.Push(clientIP)
}
var routeCodes = sharedRouteManager.FindRouteCodes(clientIP, domain.UserId)
var records []*models.NSRecord
var matchedRouteCode string
if question.Qtype == dns.TypeSOA { // SOA
if len(recordName) == 0 { // 只有顶级域名才有SOA记录
records = []*models.NSRecord{
{
Id: 0,
Type: dnsconfigs.RecordTypeSOA,
Ttl: 600, // TODO 可以设置
},
}
}
} else if question.Qtype == dns.TypeNS { // NS
if len(recordName) == 0 { // 只有顶级域名才有NS记录
records = []*models.NSRecord{
{
Id: 0,
Type: dnsconfigs.RecordTypeNS,
Ttl: 600, // TODO 可以设置
},
}
}
} else if question.Qtype != dns.TypeCNAME {
// 是否有直接的设置
records, matchedRouteCode = sharedRecordManager.FindRecords(domain.Id, routeCodes, recordName, recordType, true)
// 检查CNAME
if len(records) == 0 {
records, matchedRouteCode = sharedRecordManager.FindRecords(domain.Id, routeCodes, recordName, dnsconfigs.RecordTypeCNAME, false)
if len(records) > 0 {
question.Qtype = dns.TypeCNAME
}
}
// 再次尝试查找默认设置
if len(records) == 0 {
records, matchedRouteCode = sharedRecordManager.FindRecords(domain.Id, routeCodes, recordName, recordType, false)
}
}
if len(records) == 0 {
records, matchedRouteCode = sharedRecordManager.FindRecords(domain.Id, routeCodes, recordName, recordType, false)
}
// 对 NS.example.com NS|SOA 处理
if (question.Qtype == dns.TypeNS || (question.Qtype == dns.TypeSOA && len(records) == 0)) && lists.ContainsString(sharedNodeConfig.Hosts, fullName) {
var recordDNSType string
switch question.Qtype {
case dns.TypeNS:
recordDNSType = dnsconfigs.RecordTypeNS
case dns.TypeSOA:
recordDNSType = dnsconfigs.RecordTypeSOA
}
this.composeSOAAnswer(question, &models.NSRecord{
Type: recordDNSType,
Ttl: 600,
}, resp)
}
if len(records) > 0 {
var firstRecord = records[0]
for _, record := range records {
resultRecordIds = append(resultRecordIds, [2]int64{record.DomainId, record.Id})
switch record.Type {
case dnsconfigs.RecordTypeA:
var answer = record.ToRRAnswer(question.Name, question.Qclass)
if answer != nil {
resp.Answer = append(resp.Answer, answer)
}
case dnsconfigs.RecordTypeCNAME:
var value = record.Value
if !strings.HasSuffix(value, ".") {
value += "."
}
var lastRecordValue = value
resp.Answer = append(resp.Answer, &dns.CNAME{
Hdr: record.ToRRHeader(question.Name, dns.TypeCNAME, question.Qclass),
Target: value,
})
// 继续查询CNAME
var allCNAMEValues = []string{lastRecordValue}
for {
// 限制最深32层
if len(allCNAMEValues) > 32 {
break
}
cnameDomain, cnameRecordName := sharedDomainManager.SplitDomain(lastRecordValue)
if cnameDomain == nil {
break
}
cnameRecords, _ := sharedRecordManager.FindRecords(cnameDomain.Id, sharedRouteManager.FindRouteCodes(clientIP, cnameDomain.UserId), cnameRecordName, dnsconfigs.RecordTypeCNAME, false)
if len(cnameRecords) == 0 {
break
}
var cnameRecord = cnameRecords[0]
if !lists.ContainsString(allCNAMEValues, cnameRecord.Value) {
resultRecordIds = append(resultRecordIds, [2]int64{cnameRecord.DomainId, cnameRecord.Id}) // 统计
var answer = cnameRecord.ToRRAnswer(lastRecordValue, question.Qclass)
if answer == nil {
break
}
resp.Answer = append(resp.Answer, answer)
lastRecordValue = cnameRecord.Value
allCNAMEValues = append(allCNAMEValues, lastRecordValue)
} else {
break
}
}
// 再次查询原始问题
if len(req.Question) > 0 {
var firstQuestion = req.Question[0]
if firstQuestion.Qtype != dns.TypeCNAME {
finalDomain, finalRecordName := sharedDomainManager.SplitDomain(lastRecordValue)
if finalDomain != nil {
var realRecords, _ = sharedRecordManager.FindRecords(finalDomain.Id, sharedRouteManager.FindRouteCodes(clientIP, finalDomain.UserId), finalRecordName, dns.Type(firstQuestion.Qtype).String(), false)
if len(realRecords) > 0 {
for _, realRecord := range realRecords {
resultRecordIds = append(resultRecordIds, [2]int64{realRecord.DomainId, realRecord.Id}) // 统计
var answer = realRecord.ToRRAnswer(lastRecordValue, question.Qclass)
if answer != nil {
resp.Answer = append(resp.Answer, answer)
}
}
}
}
}
}
case dnsconfigs.RecordTypeAAAA:
var answer = record.ToRRAnswer(question.Name, question.Qclass)
if answer != nil {
resp.Answer = append(resp.Answer, answer)
}
case dnsconfigs.RecordTypeNS:
if record.Id == 0 {
var hosts = sharedNodeConfig.Hosts
var l = len(hosts)
if l > 0 {
// 随机
var indexes = []int{}
for i := 0; i < l; i++ {
indexes = append(indexes, i)
}
rand.Shuffle(l, func(i, j int) {
indexes[i], indexes[j] = indexes[j], indexes[i]
})
record.Value = hosts[0] + "."
for _, index := range indexes {
resp.Answer = append(resp.Answer, &dns.NS{
Hdr: record.ToRRHeader(question.Name, dns.TypeNS, question.Qclass),
Ns: hosts[index] + ".",
})
}
}
} else {
var value = record.Value
if !strings.HasSuffix(value, ".") {
value += "."
}
resp.Answer = append(resp.Answer, &dns.NS{
Hdr: record.ToRRHeader(question.Name, dns.TypeNS, question.Qclass),
Ns: value,
})
}
case dnsconfigs.RecordTypeMX:
var answer = record.ToRRAnswer(question.Name, question.Qclass)
if answer != nil {
resp.Answer = append(resp.Answer, answer)
}
case dnsconfigs.RecordTypeSRV:
var answer = record.ToRRAnswer(question.Name, question.Qclass)
if answer != nil {
resp.Answer = append(resp.Answer, answer)
}
case dnsconfigs.RecordTypeTXT:
var answer = record.ToRRAnswer(question.Name, question.Qclass)
if answer != nil {
resp.Answer = append(resp.Answer, answer)
}
case dnsconfigs.RecordTypeCAA:
var answer = record.ToRRAnswer(question.Name, question.Qclass)
if answer != nil {
resp.Answer = append(resp.Answer, answer)
}
case dnsconfigs.RecordTypeSOA:
this.composeSOAAnswer(question, record, resp)
}
}
// 访问日志
this.addLog(networking, question, resultDomainId, matchedRouteCode, firstRecord, false, writer, remoteAddr, nil)
} else {
this.addLog(networking, question, resultDomainId, "", nil, false, writer, remoteAddr, nil)
}
}
resp.Rcode = dns.RcodeSuccess
err := writer.WriteMsg(resp)
if err != nil {
return
}
// 统计
for _, resultRecordId := range resultRecordIds {
stats.SharedManager.Add(resultRecordId[0], resultRecordId[1], int64(resp.Len()))
}
}
// 处理HTTP请求
// 参考https://datatracker.ietf.org/doc/html/rfc8484
// 参考https://developers.google.com/speed/public-dns/docs/doh
func (this *Server) handleHTTP(writer http.ResponseWriter, req *http.Request) {
if req.URL.Path == "/dns-query" {
this.handleHTTPDNSMessage(writer, req)
return
}
if req.URL.Path == "/resolve" {
this.handleHTTPJSONAPI(writer, req)
return
}
writer.WriteHeader(http.StatusNotFound)
}
func (this *Server) handleHTTPDNSMessage(writer http.ResponseWriter, req *http.Request) {
const maxMessageSize = 512
writer.Header().Set("Accept", "application/dns-message")
if req.Method != http.MethodGet && req.Method != http.MethodPost {
writer.WriteHeader(http.StatusNotImplemented)
return
}
if req.ContentLength > maxMessageSize {
writer.WriteHeader(http.StatusRequestEntityTooLarge)
return
}
var messageData []byte
switch req.Method {
case http.MethodGet:
if len(req.URL.RawQuery) > maxMessageSize {
writer.WriteHeader(http.StatusRequestURITooLong)
return
}
var encodedMessage = req.URL.Query().Get("dns")
var err error
messageData, err = base64.StdEncoding.DecodeString(encodedMessage)
if err != nil {
writer.WriteHeader(http.StatusBadRequest)
return
}
case http.MethodPost:
var contentType = req.Header.Get("Content-Type")
if contentType != "application/dns-message" {
writer.WriteHeader(http.StatusUnsupportedMediaType)
return
}
data, err := io.ReadAll(io.LimitReader(req.Body, maxMessageSize))
if err != nil {
writer.WriteHeader(http.StatusBadRequest)
return
}
messageData = data
}
if len(messageData) == 0 {
writer.WriteHeader(http.StatusBadRequest)
return
}
var msg = &dns.Msg{}
err := msg.Unpack(messageData)
if err != nil {
writer.WriteHeader(http.StatusBadRequest)
return
}
var connValue = req.Context().Value(HTTPConnContextKey)
if connValue == nil {
writer.WriteHeader(http.StatusInternalServerError)
return
}
conn, ok := connValue.(net.Conn)
if !ok {
writer.WriteHeader(http.StatusInternalServerError)
return
}
this.handleDNSMessage(NewHTTPWriter(writer, conn, "application/dns-message"), msg)
}
func (this *Server) handleHTTPJSONAPI(writer http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodGet {
writer.WriteHeader(http.StatusNotImplemented)
return
}
var query = req.URL.Query()
var name = strings.TrimSpace(query.Get("name"))
var recordTypeString = strings.ToUpper(strings.TrimSpace(query.Get("type")))
if len(name) == 0 {
writer.WriteHeader(http.StatusBadRequest)
_, _ = writer.Write([]byte("invalid 'name' parameter"))
return
}
// add '.' to name
if !strings.HasSuffix(name, ".") {
name += "."
}
if len(recordTypeString) == 0 {
writer.WriteHeader(http.StatusBadRequest)
_, _ = writer.Write([]byte("invalid 'type' parameter"))
return
}
var recordType uint16
if regexp.MustCompile(`^\d{1,4}$`).MatchString(recordTypeString) {
recordType = types.Uint16(recordTypeString)
} else {
recordType = dns.StringToType[recordTypeString]
}
if recordType == 0 {
writer.WriteHeader(http.StatusBadRequest)
_, _ = writer.Write([]byte("invalid 'type' parameter"))
return
}
_, ok := dns.TypeToString[recordType]
if !ok {
writer.WriteHeader(http.StatusBadRequest)
_, _ = writer.Write([]byte("invalid 'type' parameter"))
return
}
var msg = &dns.Msg{}
msg.Question = []dns.Question{
{
Name: name,
Qtype: recordType,
Qclass: dns.ClassINET,
},
}
var connValue = req.Context().Value(HTTPConnContextKey)
if connValue == nil {
writer.WriteHeader(http.StatusInternalServerError)
return
}
// conn
conn, ok := connValue.(net.Conn)
if !ok {
writer.WriteHeader(http.StatusInternalServerError)
return
}
this.handleDNSMessage(NewHTTPWriter(writer, conn, "application/x-javascript"), msg)
}
// 组合SOA回复信息
func (this *Server) composeSOAAnswer(question dns.Question, record *models.NSRecord, resp *dns.Msg) {
var config = sharedNodeConfig.SOA
var serial = sharedNodeConfig.SOASerial
if config == nil {
config = dnsconfigs.DefaultNSSOAConfig()
}
var mName = config.MName
if len(mName) == 0 {
var hosts = sharedNodeConfig.Hosts
var l = len(hosts)
if l > 0 {
var index = rands.Int(0, l-1)
mName = hosts[index]
}
}
var rName = config.RName
if len(rName) == 0 {
rName = sharedNodeConfig.Email
}
rName = strings.ReplaceAll(rName, "@", ".")
if len(mName) > 0 && len(rName) > 0 {
// 设置记录值
record.Value = mName + "."
resp.Answer = append(resp.Answer, &dns.SOA{
Hdr: record.ToRRHeader(question.Name, dns.TypeSOA, question.Qclass),
Ns: mName + ".",
Mbox: rName + ".",
Serial: serial,
Refresh: config.RefreshSeconds,
Retry: config.RetrySeconds,
Expire: config.ExpireSeconds,
Minttl: config.MinimumTTL,
})
}
}