// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved. //go:build plus package test import ( "bytes" "crypto/tls" "github.com/TeaOSLab/EdgeAdmin/internal/web/actions/actionutils" "github.com/TeaOSLab/EdgeAdmin/internal/web/actions/default/dns/domains/domainutils" "github.com/TeaOSLab/EdgeCommon/pkg/configutils" "github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/iwind/TeaGo/actions" "github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/types" "github.com/miekg/dns" "io" "net" "net/http" "regexp" "strings" ) var dohClient = &http.Client{ Transport: &http.Transport{ TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, }, }, } type IndexAction struct { actionutils.ParentAction } func (this *IndexAction) Init() { this.Nav("", "", "") } func (this *IndexAction) RunGet(params struct{}) { // 集群列表 clustersResp, err := this.RPC().NSClusterRPC().FindAllNSClusters(this.AdminContext(), &pb.FindAllNSClustersRequest{}) if err != nil { this.ErrorPage(err) return } var clusterMaps = []maps.Map{} for _, cluster := range clustersResp.NsClusters { if !cluster.IsOn { continue } countNodesResp, err := this.RPC().NSNodeRPC().CountAllNSNodesMatch(this.AdminContext(), &pb.CountAllNSNodesMatchRequest{ NsClusterId: cluster.Id, InstallState: 0, ActiveState: 0, Keyword: "", }) if err != nil { this.ErrorPage(err) return } var countNodes = countNodesResp.Count if countNodes <= 0 { continue } clusterMaps = append(clusterMaps, maps.Map{ "id": cluster.Id, "name": cluster.Name, "countNodes": countNodes, }) } this.Data["clusters"] = clusterMaps // 记录类型 this.Data["recordTypes"] = dnsconfigs.FindAllRecordTypeDefinitions() this.Show() } func (this *IndexAction) RunPost(params struct { NodeId int64 Domain string Type string Ip string ClientIP string Port string Must *actions.Must }) { nodeResp, err := this.RPC().NSNodeRPC().FindNSNode(this.AdminContext(), &pb.FindNSNodeRequest{NsNodeId: params.NodeId}) if err != nil { this.ErrorPage(err) return } var node = nodeResp.NsNode if node == nil || node.NsCluster == nil { this.Fail("找不到要测试的节点") } var isOk = false var errMsg string var isNetError = false var result string defer func() { this.Data["isOk"] = isOk this.Data["err"] = errMsg this.Data["isNetErr"] = isNetError this.Data["result"] = result this.Success() }() if len(params.Domain) == 0 { errMsg = "请输入要解析的域名" return } if !domainutils.ValidateDomainFormat(params.Domain) { errMsg = "域名格式错误" return } recordType, ok := dns.StringToType[params.Type] if !ok { errMsg = "不支持此记录类型" return } if len(params.ClientIP) > 0 && net.ParseIP(params.ClientIP) == nil { errMsg = "客户端IP格式不正确" return } // 域名状态 var domainPieces = strings.Split(params.Domain, ".") var clusterId = node.NsCluster.Id var pbDomain *pb.NSDomain for index := range domainPieces { var rootDomainName = strings.Join(domainPieces[len(domainPieces)-1-index:], ".") domainResp, err := this.RPC().NSDomainRPC().FindVerifiedNSDomainOnCluster(this.AdminContext(), &pb.FindVerifiedNSDomainOnClusterRequest{ NsClusterId: clusterId, Name: rootDomainName, }) if err != nil { this.ErrorPage(err) return } if domainResp.NsDomain != nil { pbDomain = domainResp.NsDomain break } } if pbDomain == nil { this.Data["domain"] = nil return } this.Data["domain"] = maps.Map{ "id": pbDomain.Id, "name": pbDomain.Name, } var optionId int64 if len(params.ClientIP) > 0 { optionResp, err := this.RPC().NSQuestionOptionRPC().CreateNSQuestionOption(this.AdminContext(), &pb.CreateNSQuestionOptionRequest{ Name: "setRemoteAddr", ValuesJSON: maps.Map{"ip": params.ClientIP}.AsJSON(), }) if err != nil { this.ErrorPage(err) return } optionId = optionResp.NsQuestionOptionId defer func() { _, err = this.RPC().NSQuestionOptionRPC().DeleteNSQuestionOption(this.AdminContext(), &pb.DeleteNSQuestionOptionRequest{NsQuestionOptionId: optionId}) if err != nil { this.ErrorPage(err) } }() } var c = new(dns.Client) var m = new(dns.Msg) // 协议 var portString = "" var isDoH = false if len(params.Port) > 0 { var pieces = strings.Split(params.Port, "/") if len(pieces) == 2 { portString = pieces[0] var protocol = pieces[1] switch protocol { case "udp", "tcp": c.Net = protocol case "tls": c.Net = "tcp-tls" c.TLSConfig = &tls.Config{ InsecureSkipVerify: true, } case "doh": isDoH = true } } } if len(portString) == 0 { // 默认 c.Net = "udp" portString = "53" } var domain = params.Domain + "." if optionId > 0 { domain = "$" + types.String(optionId) + "-" + domain } m.SetQuestion(domain, recordType) // 处理DoH if isDoH { msgData, err := m.Pack() if err != nil { this.ErrorPage(err) return } req, err := http.NewRequest(http.MethodPost, "https://"+configutils.QuoteIP(params.Ip)+":"+portString+"/dns-query", bytes.NewReader(msgData)) if err != nil { this.ErrorPage(err) return } req.Header.Set("Content-Length", types.String(len(msgData))) req.Header.Set("Content-Type", "application/dns-message") req.Header.Set("Accept", "application/dns-message") resp, err := dohClient.Do(req) if err != nil { errMsg = "解析过程中出错:" + err.Error() // 是否为网络错误 if regexp.MustCompile(`timeout|connect`).MatchString(err.Error()) { isNetError = true } return } defer func() { _ = resp.Body.Close() }() replyData, err := io.ReadAll(resp.Body) if err != nil { errMsg = "读取响应时失败:" + err.Error() return } var replyMsg = &dns.Msg{} err = replyMsg.Unpack(replyData) if err != nil { errMsg = "解析响应时失败:" + err.Error() return } result = replyMsg.String() result = regexp.MustCompile(`\$\d+-`).ReplaceAllString(result, "") isOk = true return } // 处理一般的DNS查询 replyMsg, _, err := c.Exchange(m, configutils.QuoteIP(params.Ip)+":"+portString) if err != nil { errMsg = "解析过程中出错:" + err.Error() // 是否为网络错误 if regexp.MustCompile(`timeout|connect`).MatchString(err.Error()) { isNetError = true } return } result = replyMsg.String() result = regexp.MustCompile(`\$\d+-`).ReplaceAllString(result, "") isOk = true }