Initial commit (code only without large binaries)

This commit is contained in:
robin
2026-02-15 18:58:44 +08:00
commit 35df75498f
9442 changed files with 1495866 additions and 0 deletions

View File

@@ -0,0 +1,932 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes
import (
"context"
"crypto/tls"
"encoding/base64"
"encoding/json"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeDNS/internal/agents"
"github.com/TeaOSLab/EdgeDNS/internal/models"
"github.com/TeaOSLab/EdgeDNS/internal/remotelogs"
"github.com/TeaOSLab/EdgeDNS/internal/rpc"
"github.com/TeaOSLab/EdgeDNS/internal/stats"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types"
"github.com/miekg/dns"
"io"
"math/rand"
"net"
"net/http"
"regexp"
"strings"
"time"
)
var sharedRecursionDNSClient = &dns.Client{}
type httpContextKey struct {
key string
}
var HTTPConnContextKey = &httpContextKey{key: "http-conn"}
const PingDomain = "ping."
// Server 服务
type Server struct {
config *ServerConfig
rawServer *dns.Server
httpsServer *http.Server
}
// NewServer 构造新服务
func NewServer(config *ServerConfig) (*Server, error) {
var server = &Server{
config: config,
}
err := server.init()
if err != nil {
return nil, err
}
return server, nil
}
// ListenAndServe 监听
func (this *Server) ListenAndServe() error {
if this.rawServer != nil {
return this.rawServer.ListenAndServe()
}
if this.httpsServer != nil {
listener, err := net.Listen("tcp", this.httpsServer.Addr)
if err != nil {
return err
}
err = this.httpsServer.ServeTLS(listener, "", "")
if err == http.ErrServerClosed {
err = nil
}
return err
}
return errors.New("the server is not initialized")
}
// Shutdown 关闭
func (this *Server) Shutdown() error {
if this.rawServer != nil {
return this.rawServer.Shutdown()
}
if this.httpsServer != nil {
return this.httpsServer.Shutdown(context.Background())
}
return errors.New("the server is not initialized")
}
// Reload 重载配置
func (this *Server) Reload(config *ServerConfig) {
this.config = config
}
// 初始化
func (this *Server) init() error {
var rawServer = &dns.Server{
ReadTimeout: 3 * time.Second,
WriteTimeout: 3 * time.Second,
}
rawServer.Handler = dns.HandlerFunc(this.handleDNSMessage)
var addr = ""
if len(this.config.Host) > 0 {
addr += configutils.QuoteIP(this.config.Host)
}
addr += ":" + types.String(this.config.Port)
rawServer.Addr = addr
switch this.config.Protocol {
case serverconfigs.ProtocolTCP:
rawServer.Net = "tcp"
case serverconfigs.ProtocolTLS:
rawServer.Net = "tcp-tls"
rawServer.TLSConfig = &tls.Config{
Certificates: nil,
GetConfigForClient: func(clientInfo *tls.ClientHelloInfo) (config *tls.Config, e error) {
return this.config.SSLPolicy.TLSConfig(), nil
},
GetCertificate: func(clientInfo *tls.ClientHelloInfo) (certificate *tls.Certificate, e error) {
return this.config.SSLPolicy.FirstCert(), nil
},
}
case serverconfigs.ProtocolHTTPS: // DoH
rawServer = nil
this.httpsServer = &http.Server{
Addr: addr,
Handler: http.HandlerFunc(this.handleHTTP),
TLSConfig: &tls.Config{
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
if this.config == nil {
return nil, errors.New("invalid 'ServerConfig.config'")
}
if this.config.SSLPolicy == nil {
return nil, errors.New("invalid 'ServerConfig.config.SSLPolicy'")
}
return this.config.SSLPolicy.TLSConfig(), nil
},
GetCertificate: func(clientInfo *tls.ClientHelloInfo) (certificate *tls.Certificate, e error) {
if this.config == nil {
return nil, errors.New("invalid 'ServerConfig.config'")
}
if this.config.SSLPolicy == nil {
return nil, errors.New("invalid 'ServerConfig.config.SSLPolicy'")
}
return this.config.SSLPolicy.FirstCert(), nil
},
},
ConnContext: func(ctx context.Context, c net.Conn) context.Context {
return context.WithValue(ctx, HTTPConnContextKey, c)
},
ReadTimeout: 5 * time.Second,
ReadHeaderTimeout: 3 * time.Second,
WriteTimeout: 3 * time.Second,
IdleTimeout: 75 * time.Second,
MaxHeaderBytes: 4096,
}
case serverconfigs.ProtocolUDP:
rawServer.Net = "udp"
}
this.rawServer = rawServer
return nil
}
// 查询递归DNS
func (this *Server) lookupRecursionDNS(req *dns.Msg, resp *dns.Msg) error {
var config = sharedNodeConfig.RecursionConfig
if config == nil {
return nil
}
if !config.IsOn {
return nil
}
// 是否允许
var domain = strings.TrimSuffix(req.Question[0].Name, ".")
if len(config.DenyDomains) > 0 && configutils.MatchDomains(config.DenyDomains, domain) {
return nil
}
if len(config.AllowDomains) > 0 && !configutils.MatchDomains(config.AllowDomains, domain) {
return nil
}
if config.UseLocalHosts {
// TODO 需要缓存文件内容
resolveConfig, err := dns.ClientConfigFromFile("/etc/resolv.conf")
if err != nil {
return err
}
if len(resolveConfig.Servers) == 0 {
return errors.New("no dns servers found in config file")
}
if len(resolveConfig.Port) == 0 {
resolveConfig.Port = "53"
}
r, _, err := sharedRecursionDNSClient.Exchange(req, configutils.QuoteIP(resolveConfig.Servers[rands.Int(0, len(resolveConfig.Servers)-1)])+":"+resolveConfig.Port)
if err != nil {
return err
}
resp.Answer = r.Answer
} else if len(config.Hosts) > 0 {
var host = config.Hosts[rands.Int(0, len(config.Hosts)-1)]
if host.Port <= 0 {
host.Port = 53
}
r, _, err := sharedRecursionDNSClient.Exchange(req, configutils.QuoteIP(host.Host)+":"+types.String(host.Port))
if err != nil {
return err
}
resp.Answer = r.Answer
}
return nil
}
// 分析查询中的动作
func (this *Server) parseAction(questionName string, remoteAddr *string) (string, error) {
rpcClient, err := rpc.SharedRPC()
if err != nil {
return "", err
}
// TODO 需要防止恶意攻击
var optionIndex = strings.Index(questionName, "-")
if optionIndex > 0 {
optionId := types.Int64(questionName[1:optionIndex])
optionResp, err := rpcClient.NSQuestionOptionRPC.FindNSQuestionOption(rpcClient.Context(), &pb.FindNSQuestionOptionRequest{NsQuestionOptionId: optionId})
if err != nil {
return "", errors.New("query question option failed: " + err.Error())
} else {
var option = optionResp.NsQuestionOption
if option != nil {
switch option.Name {
case "setRemoteAddr":
var m = maps.Map{}
err = json.Unmarshal(option.ValuesJSON, &m)
if err != nil {
return "", errors.New("decode question option failed: " + err.Error())
} else {
var ip = m.GetString("ip")
*remoteAddr = ip
}
}
}
}
questionName = questionName[optionIndex+1:]
}
return questionName, nil
}
// 记录日志
func (this *Server) addLog(networking string, question dns.Question, domainId int64, routeCode string, record *models.NSRecord, isRecursive bool, writer dns.ResponseWriter, remoteAddr string, err error) {
// 访问日志
var accessLogRef = sharedNodeConfig.AccessLogRef
if accessLogRef != nil && accessLogRef.IsOn {
if domainId == 0 && !accessLogRef.LogMissingDomains {
return
}
if accessLogRef.MissingRecordsOnly && record != nil && len(record.Value) > 0 {
return
}
var now = time.Now()
var pbAccessLog = &pb.NSAccessLog{
NsNodeId: sharedNodeConfig.Id,
RemoteAddr: remoteAddr,
NsDomainId: domainId,
QuestionName: question.Name,
QuestionType: dns.Type(question.Qtype).String(),
IsRecursive: isRecursive,
Networking: networking,
ServerAddr: writer.LocalAddr().String(),
Timestamp: now.Unix(),
TimeLocal: now.Format("2/Jan/2006:15:04:05 -0700"),
RequestId: "",
}
if record != nil {
pbAccessLog.NsRecordId = record.Id
if len(routeCode) > 0 {
pbAccessLog.NsRouteCodes = []string{routeCode}
}
pbAccessLog.RecordName = record.Name
pbAccessLog.RecordType = record.Type
pbAccessLog.RecordValue = record.Value
}
if err != nil {
pbAccessLog.Error = err.Error()
}
sharedNSAccessLogQueue.Push(pbAccessLog)
}
}
// 验证TSIG
func (this *Server) checkTSIG(msg *dns.Msg, domainId int64) error {
var tsig = msg.IsTsig()
if tsig == nil {
return errors.New("tsig: tsig required")
}
var keys = sharedKeyManager.FindKeysWithDomain(domainId)
if len(keys) == 0 {
return errors.New("tsig: no keys defined")
}
for _, key := range keys {
if key.Algo != tsig.Algorithm {
continue
}
// 需要重新Pack每次Pack结果只能校验一次
msgData, err := msg.Pack()
if err != nil {
return err
}
if len(msgData) == 0 {
return nil
}
var base64Secret = key.Secret
if key.SecretType == dnsconfigs.NSKeySecretTypeClear {
base64Secret = base64.StdEncoding.EncodeToString([]byte(key.Secret))
}
err = dns.TsigVerify(msgData, base64Secret, "", false)
if err != nil {
continue
} else {
return nil
}
}
return dns.ErrSig
}
// 处理DNS请求
func (this *Server) handleDNSMessage(writer dns.ResponseWriter, req *dns.Msg) {
if len(req.Question) == 0 {
return
}
if sharedDomainManager == nil {
return
}
var networking = ""
if this.config != nil {
networking = this.config.Protocol.String()
}
var resultDomainId int64
var resultRecordIds [][2]int64 // [] { domainId, recordId}
var resp = new(dns.Msg)
resp.RecursionDesired = true
resp.RecursionAvailable = true
resp.SetReply(req)
resp.Answer = []dns.RR{}
var tsigIsChecked = false
var remoteAddr = writer.RemoteAddr().String()
for _, question := range req.Question {
if len(question.Name) == 0 {
continue
}
// PING
if question.Name == PingDomain {
resp.Answer = append(resp.Answer, &dns.A{
Hdr: dns.RR_Header{
Name: question.Name,
Rrtype: dns.TypeA,
Class: question.Qclass,
Ttl: 60,
},
A: net.ParseIP("127.0.0.1"),
})
resp.Rcode = dns.RcodeSuccess
err := writer.WriteMsg(resp)
if err != nil {
return
}
return
}
// 查询选项
if question.Name[0] == '$' {
_, port, _ := net.SplitHostPort(remoteAddr)
questionName, err := this.parseAction(question.Name, &remoteAddr)
if err != nil {
remotelogs.Error("SERVER", "invalid query option '"+question.Name+"'")
continue
}
question.Name = questionName
if len(port) > 0 {
if strings.Contains(remoteAddr, ":") { // IPv6
remoteAddr = "[" + remoteAddr + "]:" + port
} else {
remoteAddr += ":" + port
}
}
}
var fullName = strings.TrimSuffix(question.Name, ".")
var recordName string
var recordType = dns.Type(question.Qtype).String()
var domain *models.NSDomain
domain, recordName = sharedDomainManager.SplitDomain(fullName)
if domain == nil {
// 检查递归DNS
if sharedNodeConfig.RecursionConfig != nil && sharedNodeConfig.RecursionConfig.IsOn {
err := this.lookupRecursionDNS(req, resp)
if err != nil {
this.addLog(networking, question, 0, "", nil, true, writer, remoteAddr, err)
} else {
var recordValue = ""
if len(resp.Answer) > 0 {
pieces := regexp.MustCompile(`\s+`).Split(resp.Answer[0].String(), 6)
if len(pieces) >= 5 {
recordValue = pieces[4]
}
}
this.addLog(networking, question, 0, "", &models.NSRecord{
Id: 0,
Name: recordName,
Type: recordType,
Value: recordValue,
Ttl: 0,
}, true, writer, remoteAddr, nil)
}
err = writer.WriteMsg(resp)
if err != nil {
return
}
return
}
// 是否为NS记录用于验证域名所有权
if question.Qtype == dns.TypeNS {
var hosts = sharedNodeConfig.Hosts
var l = len(hosts)
var record = &models.NSRecord{
Id: 0,
Type: dnsconfigs.RecordTypeNS,
Ttl: 600, // TODO 可以设置
}
if l > 0 {
l = 1 // 目前只返回一个
// 随机
var indexes = []int{}
for i := 0; i < l; i++ {
indexes = append(indexes, i)
}
rand.Shuffle(l, func(i, j int) {
indexes[i], indexes[j] = indexes[j], indexes[i]
})
record.Value = hosts[0] + "."
for _, index := range indexes {
resp.Answer = append(resp.Answer, &dns.NS{
Hdr: record.ToRRHeader(question.Name, dns.TypeNS, question.Qclass),
Ns: hosts[index] + ".",
})
}
this.addLog(networking, question, 0, "", record, false, writer, remoteAddr, nil)
continue
}
this.addLog(networking, question, 0, "", nil, false, writer, remoteAddr, nil)
continue
}
this.addLog(networking, question, 0, "", nil, false, writer, remoteAddr, nil)
continue
}
// 检查TSIG
if domain.TSIG != nil && domain.TSIG.IsOn && !tsigIsChecked {
err := this.checkTSIG(req, domain.Id)
if err != nil {
this.addLog(networking, question, domain.Id, "", nil, false, writer, remoteAddr, err)
continue
}
tsigIsChecked = true
}
resultDomainId = domain.Id
var clientIP = remoteAddr
clientHost, _, err := net.SplitHostPort(clientIP)
if err == nil && len(clientHost) > 0 {
clientIP = clientHost
}
// 解析Agent
if sharedNodeConfig.DetectAgents {
agents.SharedQueue.Push(clientIP)
}
var routeCodes = sharedRouteManager.FindRouteCodes(clientIP, domain.UserId)
var records []*models.NSRecord
var matchedRouteCode string
if question.Qtype == dns.TypeSOA { // SOA
if len(recordName) == 0 { // 只有顶级域名才有SOA记录
records = []*models.NSRecord{
{
Id: 0,
Type: dnsconfigs.RecordTypeSOA,
Ttl: 600, // TODO 可以设置
},
}
}
} else if question.Qtype == dns.TypeNS { // NS
if len(recordName) == 0 { // 只有顶级域名才有NS记录
records = []*models.NSRecord{
{
Id: 0,
Type: dnsconfigs.RecordTypeNS,
Ttl: 600, // TODO 可以设置
},
}
}
} else if question.Qtype != dns.TypeCNAME {
// 是否有直接的设置
records, matchedRouteCode = sharedRecordManager.FindRecords(domain.Id, routeCodes, recordName, recordType, true)
// 检查CNAME
if len(records) == 0 {
records, matchedRouteCode = sharedRecordManager.FindRecords(domain.Id, routeCodes, recordName, dnsconfigs.RecordTypeCNAME, false)
if len(records) > 0 {
question.Qtype = dns.TypeCNAME
}
}
// 再次尝试查找默认设置
if len(records) == 0 {
records, matchedRouteCode = sharedRecordManager.FindRecords(domain.Id, routeCodes, recordName, recordType, false)
}
}
if len(records) == 0 {
records, matchedRouteCode = sharedRecordManager.FindRecords(domain.Id, routeCodes, recordName, recordType, false)
}
// 对 NS.example.com NS|SOA 处理
if (question.Qtype == dns.TypeNS || (question.Qtype == dns.TypeSOA && len(records) == 0)) && lists.ContainsString(sharedNodeConfig.Hosts, fullName) {
var recordDNSType string
switch question.Qtype {
case dns.TypeNS:
recordDNSType = dnsconfigs.RecordTypeNS
case dns.TypeSOA:
recordDNSType = dnsconfigs.RecordTypeSOA
}
this.composeSOAAnswer(question, &models.NSRecord{
Type: recordDNSType,
Ttl: 600,
}, resp)
}
if len(records) > 0 {
var firstRecord = records[0]
for _, record := range records {
resultRecordIds = append(resultRecordIds, [2]int64{record.DomainId, record.Id})
switch record.Type {
case dnsconfigs.RecordTypeA:
var answer = record.ToRRAnswer(question.Name, question.Qclass)
if answer != nil {
resp.Answer = append(resp.Answer, answer)
}
case dnsconfigs.RecordTypeCNAME:
var value = record.Value
if !strings.HasSuffix(value, ".") {
value += "."
}
var lastRecordValue = value
resp.Answer = append(resp.Answer, &dns.CNAME{
Hdr: record.ToRRHeader(question.Name, dns.TypeCNAME, question.Qclass),
Target: value,
})
// 继续查询CNAME
var allCNAMEValues = []string{lastRecordValue}
for {
// 限制最深32层
if len(allCNAMEValues) > 32 {
break
}
cnameDomain, cnameRecordName := sharedDomainManager.SplitDomain(lastRecordValue)
if cnameDomain == nil {
break
}
cnameRecords, _ := sharedRecordManager.FindRecords(cnameDomain.Id, sharedRouteManager.FindRouteCodes(clientIP, cnameDomain.UserId), cnameRecordName, dnsconfigs.RecordTypeCNAME, false)
if len(cnameRecords) == 0 {
break
}
var cnameRecord = cnameRecords[0]
if !lists.ContainsString(allCNAMEValues, cnameRecord.Value) {
resultRecordIds = append(resultRecordIds, [2]int64{cnameRecord.DomainId, cnameRecord.Id}) // 统计
var answer = cnameRecord.ToRRAnswer(lastRecordValue, question.Qclass)
if answer == nil {
break
}
resp.Answer = append(resp.Answer, answer)
lastRecordValue = cnameRecord.Value
allCNAMEValues = append(allCNAMEValues, lastRecordValue)
} else {
break
}
}
// 再次查询原始问题
if len(req.Question) > 0 {
var firstQuestion = req.Question[0]
if firstQuestion.Qtype != dns.TypeCNAME {
finalDomain, finalRecordName := sharedDomainManager.SplitDomain(lastRecordValue)
if finalDomain != nil {
var realRecords, _ = sharedRecordManager.FindRecords(finalDomain.Id, sharedRouteManager.FindRouteCodes(clientIP, finalDomain.UserId), finalRecordName, dns.Type(firstQuestion.Qtype).String(), false)
if len(realRecords) > 0 {
for _, realRecord := range realRecords {
resultRecordIds = append(resultRecordIds, [2]int64{realRecord.DomainId, realRecord.Id}) // 统计
var answer = realRecord.ToRRAnswer(lastRecordValue, question.Qclass)
if answer != nil {
resp.Answer = append(resp.Answer, answer)
}
}
}
}
}
}
case dnsconfigs.RecordTypeAAAA:
var answer = record.ToRRAnswer(question.Name, question.Qclass)
if answer != nil {
resp.Answer = append(resp.Answer, answer)
}
case dnsconfigs.RecordTypeNS:
if record.Id == 0 {
var hosts = sharedNodeConfig.Hosts
var l = len(hosts)
if l > 0 {
// 随机
var indexes = []int{}
for i := 0; i < l; i++ {
indexes = append(indexes, i)
}
rand.Shuffle(l, func(i, j int) {
indexes[i], indexes[j] = indexes[j], indexes[i]
})
record.Value = hosts[0] + "."
for _, index := range indexes {
resp.Answer = append(resp.Answer, &dns.NS{
Hdr: record.ToRRHeader(question.Name, dns.TypeNS, question.Qclass),
Ns: hosts[index] + ".",
})
}
}
} else {
var value = record.Value
if !strings.HasSuffix(value, ".") {
value += "."
}
resp.Answer = append(resp.Answer, &dns.NS{
Hdr: record.ToRRHeader(question.Name, dns.TypeNS, question.Qclass),
Ns: value,
})
}
case dnsconfigs.RecordTypeMX:
var answer = record.ToRRAnswer(question.Name, question.Qclass)
if answer != nil {
resp.Answer = append(resp.Answer, answer)
}
case dnsconfigs.RecordTypeSRV:
var answer = record.ToRRAnswer(question.Name, question.Qclass)
if answer != nil {
resp.Answer = append(resp.Answer, answer)
}
case dnsconfigs.RecordTypeTXT:
var answer = record.ToRRAnswer(question.Name, question.Qclass)
if answer != nil {
resp.Answer = append(resp.Answer, answer)
}
case dnsconfigs.RecordTypeCAA:
var answer = record.ToRRAnswer(question.Name, question.Qclass)
if answer != nil {
resp.Answer = append(resp.Answer, answer)
}
case dnsconfigs.RecordTypeSOA:
this.composeSOAAnswer(question, record, resp)
}
}
// 访问日志
this.addLog(networking, question, resultDomainId, matchedRouteCode, firstRecord, false, writer, remoteAddr, nil)
} else {
this.addLog(networking, question, resultDomainId, "", nil, false, writer, remoteAddr, nil)
}
}
resp.Rcode = dns.RcodeSuccess
err := writer.WriteMsg(resp)
if err != nil {
return
}
// 统计
for _, resultRecordId := range resultRecordIds {
stats.SharedManager.Add(resultRecordId[0], resultRecordId[1], int64(resp.Len()))
}
}
// 处理HTTP请求
// 参考https://datatracker.ietf.org/doc/html/rfc8484
// 参考https://developers.google.com/speed/public-dns/docs/doh
func (this *Server) handleHTTP(writer http.ResponseWriter, req *http.Request) {
if req.URL.Path == "/dns-query" {
this.handleHTTPDNSMessage(writer, req)
return
}
if req.URL.Path == "/resolve" {
this.handleHTTPJSONAPI(writer, req)
return
}
writer.WriteHeader(http.StatusNotFound)
}
func (this *Server) handleHTTPDNSMessage(writer http.ResponseWriter, req *http.Request) {
const maxMessageSize = 512
writer.Header().Set("Accept", "application/dns-message")
if req.Method != http.MethodGet && req.Method != http.MethodPost {
writer.WriteHeader(http.StatusNotImplemented)
return
}
if req.ContentLength > maxMessageSize {
writer.WriteHeader(http.StatusRequestEntityTooLarge)
return
}
var messageData []byte
switch req.Method {
case http.MethodGet:
if len(req.URL.RawQuery) > maxMessageSize {
writer.WriteHeader(http.StatusRequestURITooLong)
return
}
var encodedMessage = req.URL.Query().Get("dns")
var err error
messageData, err = base64.StdEncoding.DecodeString(encodedMessage)
if err != nil {
writer.WriteHeader(http.StatusBadRequest)
return
}
case http.MethodPost:
var contentType = req.Header.Get("Content-Type")
if contentType != "application/dns-message" {
writer.WriteHeader(http.StatusUnsupportedMediaType)
return
}
data, err := io.ReadAll(io.LimitReader(req.Body, maxMessageSize))
if err != nil {
writer.WriteHeader(http.StatusBadRequest)
return
}
messageData = data
}
if len(messageData) == 0 {
writer.WriteHeader(http.StatusBadRequest)
return
}
var msg = &dns.Msg{}
err := msg.Unpack(messageData)
if err != nil {
writer.WriteHeader(http.StatusBadRequest)
return
}
var connValue = req.Context().Value(HTTPConnContextKey)
if connValue == nil {
writer.WriteHeader(http.StatusInternalServerError)
return
}
conn, ok := connValue.(net.Conn)
if !ok {
writer.WriteHeader(http.StatusInternalServerError)
return
}
this.handleDNSMessage(NewHTTPWriter(writer, conn, "application/dns-message"), msg)
}
func (this *Server) handleHTTPJSONAPI(writer http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodGet {
writer.WriteHeader(http.StatusNotImplemented)
return
}
var query = req.URL.Query()
var name = strings.TrimSpace(query.Get("name"))
var recordTypeString = strings.ToUpper(strings.TrimSpace(query.Get("type")))
if len(name) == 0 {
writer.WriteHeader(http.StatusBadRequest)
_, _ = writer.Write([]byte("invalid 'name' parameter"))
return
}
// add '.' to name
if !strings.HasSuffix(name, ".") {
name += "."
}
if len(recordTypeString) == 0 {
writer.WriteHeader(http.StatusBadRequest)
_, _ = writer.Write([]byte("invalid 'type' parameter"))
return
}
var recordType uint16
if regexp.MustCompile(`^\d{1,4}$`).MatchString(recordTypeString) {
recordType = types.Uint16(recordTypeString)
} else {
recordType = dns.StringToType[recordTypeString]
}
if recordType == 0 {
writer.WriteHeader(http.StatusBadRequest)
_, _ = writer.Write([]byte("invalid 'type' parameter"))
return
}
_, ok := dns.TypeToString[recordType]
if !ok {
writer.WriteHeader(http.StatusBadRequest)
_, _ = writer.Write([]byte("invalid 'type' parameter"))
return
}
var msg = &dns.Msg{}
msg.Question = []dns.Question{
{
Name: name,
Qtype: recordType,
Qclass: dns.ClassINET,
},
}
var connValue = req.Context().Value(HTTPConnContextKey)
if connValue == nil {
writer.WriteHeader(http.StatusInternalServerError)
return
}
// conn
conn, ok := connValue.(net.Conn)
if !ok {
writer.WriteHeader(http.StatusInternalServerError)
return
}
this.handleDNSMessage(NewHTTPWriter(writer, conn, "application/x-javascript"), msg)
}
// 组合SOA回复信息
func (this *Server) composeSOAAnswer(question dns.Question, record *models.NSRecord, resp *dns.Msg) {
var config = sharedNodeConfig.SOA
var serial = sharedNodeConfig.SOASerial
if config == nil {
config = dnsconfigs.DefaultNSSOAConfig()
}
var mName = config.MName
if len(mName) == 0 {
var hosts = sharedNodeConfig.Hosts
var l = len(hosts)
if l > 0 {
var index = rands.Int(0, l-1)
mName = hosts[index]
}
}
var rName = config.RName
if len(rName) == 0 {
rName = sharedNodeConfig.Email
}
rName = strings.ReplaceAll(rName, "@", ".")
if len(mName) > 0 && len(rName) > 0 {
// 设置记录值
record.Value = mName + "."
resp.Answer = append(resp.Answer, &dns.SOA{
Hdr: record.ToRRHeader(question.Name, dns.TypeSOA, question.Qclass),
Ns: mName + ".",
Mbox: rName + ".",
Serial: serial,
Refresh: config.RefreshSeconds,
Retry: config.RetrySeconds,
Expire: config.ExpireSeconds,
Minttl: config.MinimumTTL,
})
}
}