Files
waf-platform/EdgeAPI/internal/dnsclients/provider_azure_dns_plus.go
2026-02-04 20:27:13 +08:00

745 lines
20 KiB
Go

// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package dnsclients
import (
"context"
"encoding/base64"
"errors"
"fmt"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/dns/armdns"
"github.com/TeaOSLab/EdgeAPI/internal/dnsclients/dnstypes"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
"strings"
)
// AzureDNSProvider 微软Azure DNS zone
type AzureDNSProvider struct {
BaseProvider
ProviderId int64
zonesClient *armdns.ZonesClient
recordSetsClient *armdns.RecordSetsClient
resourceGroupName string
}
func NewAzureDNSProvider() *AzureDNSProvider {
return &AzureDNSProvider{}
}
// Auth 认证
func (this *AzureDNSProvider) Auth(params maps.Map) error {
var subscriptionId = params.GetString("subscriptionId")
var tenantId = params.GetString("tenantId")
var clientId = params.GetString("clientId")
var clientSecret = params.GetString("clientSecret")
var resourceGroupName = params.GetString("resourceGroupName")
if len(subscriptionId) == 0 {
return errors.New("'subscriptionId' required")
}
if len(tenantId) == 0 {
return errors.New("'tenantId' required")
}
if len(clientId) == 0 {
return errors.New("'clientId' required")
}
if len(clientSecret) == 0 {
return errors.New("'clientSecret' required")
}
if len(resourceGroupName) == 0 {
return errors.New("'resourceGroupName' required")
}
this.resourceGroupName = resourceGroupName
cred, err := azidentity.NewClientSecretCredential(tenantId, clientId, clientSecret, nil)
if err != nil {
return fmt.Errorf("NewClientSecretCredential: %w", err)
}
clientFactory, err := armdns.NewClientFactory(subscriptionId, cred, nil)
if err != nil {
return fmt.Errorf("NewClientFactory: %w", err)
}
this.zonesClient = clientFactory.NewZonesClient()
this.recordSetsClient = clientFactory.NewRecordSetsClient()
return nil
}
// MaskParams 对参数进行掩码
func (this *AzureDNSProvider) MaskParams(params maps.Map) {
if params == nil {
return
}
params["clientSecret"] = MaskString(params.GetString("clientSecret"))
}
// GetDomains 获取所有域名列表
func (this *AzureDNSProvider) GetDomains() (domains []string, err error) {
var pager = this.zonesClient.NewListPager(nil)
for pager.More() {
page, err := pager.NextPage(context.Background())
if err != nil {
return nil, err
}
for _, zone := range page.Value {
if zone.Name != nil && !lists.ContainsString(domains, *zone.Name) {
domains = append(domains, *zone.Name)
}
}
}
return
}
// GetRecords 获取域名列表
func (this *AzureDNSProvider) GetRecords(domain string) (records []*dnstypes.Record, err error) {
var pager = this.recordSetsClient.NewListByDNSZonePager(this.resourceGroupName, domain, nil)
for pager.More() {
page, err := pager.NextPage(context.Background())
if err != nil {
return nil, err
}
for _, recordSet := range page.Value {
if recordSet.Name == nil || recordSet.Properties == nil {
continue
}
records = append(records, this.recordSetToRecords(recordSet)...)
}
}
// 写入缓存
if this.ProviderId > 0 {
sharedDomainRecordsCache.WriteDomainRecords(this.ProviderId, domain, records)
}
return
}
// GetRoutes 读取线路数据
func (this *AzureDNSProvider) GetRoutes(domain string) (routes []*dnstypes.Route, err error) {
routes = []*dnstypes.Route{
{
Name: "默认",
Code: "default",
},
}
return
}
// QueryRecord 查询单个记录
func (this *AzureDNSProvider) QueryRecord(domain string, name string, recordType dnstypes.RecordType) (*dnstypes.Record, error) {
// 从缓存中读取
if this.ProviderId > 0 {
record, hasRecords, _ := sharedDomainRecordsCache.QueryDomainRecord(this.ProviderId, domain, name, recordType)
if hasRecords { // 有效的搜索
return record, nil
}
}
resp, err := this.recordSetsClient.Get(context.Background(), this.resourceGroupName, domain, name, armdns.RecordType(recordType), nil)
if err != nil {
if this.isNotFound(err) {
return nil, nil
}
return nil, err
}
var records = this.recordSetToRecords(&resp.RecordSet)
if len(records) > 0 {
return records[0], nil
}
return nil, nil
}
// QueryRecords 查询多个记录
func (this *AzureDNSProvider) QueryRecords(domain string, name string, recordType dnstypes.RecordType) ([]*dnstypes.Record, error) {
// 从缓存中读取
if this.ProviderId > 0 {
records, hasRecords, _ := sharedDomainRecordsCache.QueryDomainRecords(this.ProviderId, domain, name, recordType)
if hasRecords { // 有效的搜索
return records, nil
}
}
resp, err := this.recordSetsClient.Get(context.Background(), this.resourceGroupName, domain, name, armdns.RecordType(recordType), nil)
if err != nil {
if this.isNotFound(err) {
return nil, nil
}
return nil, err
}
return this.recordSetToRecords(&resp.RecordSet), nil
}
// AddRecord 设置记录
func (this *AzureDNSProvider) AddRecord(domain string, newRecord *dnstypes.Record) error {
if newRecord == nil {
return errors.New("invalid new record")
}
// 在CHANGE记录后面加入点
if newRecord.Type == dnstypes.RecordTypeCNAME && !strings.HasSuffix(newRecord.Value, ".") {
newRecord.Value += "."
}
var ttl = newRecord.TTL
if ttl <= 0 {
ttl = 600
}
var recordSet = armdns.RecordSet{
Etag: nil,
Properties: nil,
ID: nil,
Name: this.stringVal(newRecord.Name),
Type: this.stringVal(newRecord.Type),
}
resp, err := this.recordSetsClient.Get(context.Background(), this.resourceGroupName, domain, newRecord.Name, armdns.RecordType(newRecord.Type), nil)
if err != nil {
if this.isNotFound(err) {
err = nil
} else {
return err
}
}
var ttlInt64 = int64(ttl)
if resp.RecordSet.Properties != nil {
recordSet.Properties = resp.RecordSet.Properties
recordSet.Properties.TTL = &ttlInt64
recordSet.ID = resp.RecordSet.ID
} else {
recordSet.Properties = &armdns.RecordSetProperties{
TTL: &ttlInt64,
}
}
exists, err := this.recordSetAddRecordValue(&recordSet, newRecord.Value)
if exists || err != nil {
if exists {
newRecord.Id = this.ComposeRecordId(newRecord.Name, newRecord.Type, newRecord.Value)
}
return err
}
_, err = this.recordSetsClient.CreateOrUpdate(context.Background(), this.resourceGroupName, domain, newRecord.Name, armdns.RecordType(newRecord.Type), recordSet, nil)
if err != nil {
return err
}
newRecord.Id = this.ComposeRecordId(newRecord.Name, newRecord.Type, newRecord.Value)
// 加入缓存
if this.ProviderId > 0 {
sharedDomainRecordsCache.AddDomainRecord(this.ProviderId, domain, newRecord)
}
return nil
}
// UpdateRecord 修改记录
func (this *AzureDNSProvider) UpdateRecord(domain string, record *dnstypes.Record, newRecord *dnstypes.Record) error {
if record == nil {
return errors.New("invalid record")
}
if newRecord == nil {
return errors.New("invalid new record")
}
// 在CHANGE记录后面加入点
if newRecord.Type == dnstypes.RecordTypeCNAME && !strings.HasSuffix(newRecord.Value, ".") {
newRecord.Value += "."
}
var newRoute = newRecord.Route
if len(newRoute) == 0 {
newRoute = this.DefaultRoute()
}
_, oldRecordName, oldRecordType, oldRecordValue, err := this.decodeRecordId(record.Id)
if err != nil {
return err
}
if oldRecordName == newRecord.Name && oldRecordType == newRecord.Type && oldRecordValue == newRecord.Value {
return nil
}
resp, err := this.recordSetsClient.Get(context.Background(), this.resourceGroupName, domain, newRecord.Name, armdns.RecordType(newRecord.Type), nil)
if err != nil {
if this.isNotFound(err) {
return nil
} else {
return err
}
}
var recordSet = resp.RecordSet
if recordSet.Properties == nil {
return nil
}
found, err := this.recordSetUpdateRecordValue(&recordSet, newRecord.Type, oldRecordValue, newRecord.Value)
if err != nil {
return err
}
if !found {
return nil
}
var ttlInt64 = int64(newRecord.TTL)
if ttlInt64 <= 0 {
ttlInt64 = 600
}
recordSet.Properties.TTL = &ttlInt64
_, err = this.recordSetsClient.Update(context.Background(), this.resourceGroupName, domain, newRecord.Name, armdns.RecordType(newRecord.Type), recordSet, nil)
if err != nil {
return err
}
newRecord.Id = this.ComposeRecordId(newRecord.Name, newRecord.Type, newRecord.Value)
// 修改缓存
if this.ProviderId > 0 {
sharedDomainRecordsCache.UpdateDomainRecord(this.ProviderId, domain, newRecord)
}
return nil
}
// DeleteRecord 删除记录
func (this *AzureDNSProvider) DeleteRecord(domain string, record *dnstypes.Record) error {
if record == nil {
return errors.New("invalid record to delete")
}
_, recordName, recordType, recordValue, err := this.decodeRecordId(record.Id)
if err != nil {
return err
}
resp, err := this.recordSetsClient.Get(context.Background(), this.resourceGroupName, domain, recordName, armdns.RecordType(recordType), nil)
if err != nil {
if this.isNotFound(err) {
return nil
} else {
return err
}
}
var recordSet = resp.RecordSet
if recordSet.Properties == nil {
return nil
}
shouldUpdate, shouldDelete, err := this.recordSetDeleteRecordValue(&recordSet, recordType, recordValue)
if err != nil {
return err
}
if shouldDelete {
_, err = this.recordSetsClient.Delete(context.Background(), this.resourceGroupName, domain, recordName, armdns.RecordType(recordType), nil)
if err != nil {
return err
}
} else if shouldUpdate {
_, err = this.recordSetsClient.Update(context.Background(), this.resourceGroupName, domain, recordName, armdns.RecordType(recordType), recordSet, nil)
if err != nil {
return err
}
} else {
return nil
}
// 删除缓存
if this.ProviderId > 0 {
sharedDomainRecordsCache.DeleteDomainRecord(this.ProviderId, domain, record.Id)
}
return nil
}
// DefaultRoute 默认线路
func (this *AzureDNSProvider) DefaultRoute() string {
return "default"
}
func (this *AzureDNSProvider) ComposeRecordId(recordName string, recordType string, recordValue string) string {
return base64.StdEncoding.EncodeToString([]byte("$" + recordName + "$" + recordType + "$" + recordValue))
}
func (this *AzureDNSProvider) decodeRecordId(encodedRecordId string) (recordSetId string, recordName string, recordType string, recordValue string, err error) {
data, err := base64.StdEncoding.DecodeString(encodedRecordId)
if err != nil {
return "", "", "", "", err
}
if len(data) == 0 {
err = errors.New("invalid record id")
return
}
var pieces = strings.SplitN(string(data), "$", 4)
if len(pieces) != 4 {
err = errors.New("invalid record id")
return
}
recordSetId = pieces[0]
recordName = pieces[1]
recordType = pieces[2]
recordValue = pieces[3]
return
}
func (this *AzureDNSProvider) fixCNAME(recordType string, recordValue string) string {
// 修正Record
if strings.ToUpper(recordType) == dnstypes.RecordTypeCNAME && !strings.HasSuffix(recordValue, ".") {
recordValue += "."
}
return recordValue
}
func (this *AzureDNSProvider) isNotFound(err error) bool {
if err == nil {
return false
}
var respErr *azcore.ResponseError
if errors.As(err, &respErr) && respErr.ErrorCode == "NotFound" {
return true
}
return false
}
func (this *AzureDNSProvider) recordSetToRecords(recordSet *armdns.RecordSet) (records []*dnstypes.Record) {
if recordSet == nil || recordSet.Properties == nil {
return
}
// A
for _, record := range recordSet.Properties.ARecords {
var recordType = "A"
records = append(records, &dnstypes.Record{
Id: this.ComposeRecordId(*recordSet.Name, recordType, *record.IPv4Address),
Name: *recordSet.Name,
Type: recordType,
Value: *record.IPv4Address,
Route: this.DefaultRoute(),
TTL: types.Int32(*recordSet.Properties.TTL),
})
}
// AAAA
for _, record := range recordSet.Properties.AaaaRecords {
var recordType = "AAAA"
records = append(records, &dnstypes.Record{
Id: this.ComposeRecordId(*recordSet.Name, recordType, *record.IPv6Address),
Name: *recordSet.Name,
Type: recordType,
Value: *record.IPv6Address,
Route: this.DefaultRoute(),
TTL: types.Int32(*recordSet.Properties.TTL),
})
}
// CNAME
if recordSet.Properties.CnameRecord != nil {
var recordType = "CNAME"
var record = recordSet.Properties.CnameRecord
records = append(records, &dnstypes.Record{
Id: this.ComposeRecordId(*recordSet.Name, recordType, *record.Cname),
Name: *recordSet.Name,
Type: recordType,
Value: *record.Cname,
Route: this.DefaultRoute(),
TTL: types.Int32(*recordSet.Properties.TTL),
})
}
// TXT
if recordSet.Properties.TxtRecords != nil {
var recordType = "TXT"
for _, record := range recordSet.Properties.TxtRecords {
if len(record.Value) == 0 {
continue
}
for _, value := range record.Value {
records = append(records, &dnstypes.Record{
Id: this.ComposeRecordId(*recordSet.Name, recordType, *value),
Name: *recordSet.Name,
Type: recordType,
Value: *value,
Route: this.DefaultRoute(),
TTL: types.Int32(*recordSet.Properties.TTL),
})
}
}
}
// NS
for _, record := range recordSet.Properties.NsRecords {
var recordType = "NS"
records = append(records, &dnstypes.Record{
Id: this.ComposeRecordId(*recordSet.Name, recordType, *record.Nsdname),
Name: *recordSet.Name,
Type: recordType,
Value: *record.Nsdname,
Route: this.DefaultRoute(),
TTL: types.Int32(*recordSet.Properties.TTL),
})
}
// SOA
if recordSet.Properties.SoaRecord != nil {
var recordType = "SOA"
var record = recordSet.Properties.SoaRecord
records = append(records, &dnstypes.Record{
Id: this.ComposeRecordId(*recordSet.Name, recordType, *record.Host),
Name: *recordSet.Name,
Type: recordType,
Value: fmt.Sprintf("%s %s %d %d %d %d %d", *record.Host, *record.Email, *record.SerialNumber, *record.RefreshTime, *record.RetryTime, *record.ExpireTime, *record.MinimumTTL),
Route: this.DefaultRoute(),
TTL: types.Int32(*recordSet.Properties.TTL),
})
}
// we don't support other record type yet
return
}
func (this *AzureDNSProvider) recordSetAddRecordValue(recordSet *armdns.RecordSet, value string) (exists bool, err error) {
if recordSet.Properties == nil {
recordSet.Properties = &armdns.RecordSetProperties{}
}
var newProperties = this.recordValueToProperties(*recordSet.Type, value)
if newProperties == nil {
return false, errors.New("recordSetMergeRecordValue: invalid properties")
}
switch *recordSet.Type {
case "A":
for _, record := range recordSet.Properties.ARecords {
if *record.IPv4Address == value {
return true, nil
}
}
recordSet.Properties.ARecords = append(recordSet.Properties.ARecords, newProperties.ARecords...)
case "AAAA":
for _, record := range recordSet.Properties.AaaaRecords {
if *record.IPv6Address == value {
return true, nil
}
}
recordSet.Properties.AaaaRecords = append(recordSet.Properties.AaaaRecords, newProperties.AaaaRecords...)
case "CNAME":
var record = recordSet.Properties.CnameRecord
if record != nil && record.Cname != nil {
if *record.Cname == value {
return true, nil
}
}
recordSet.Properties.CnameRecord = newProperties.CnameRecord
case "TXT":
for _, record := range recordSet.Properties.TxtRecords {
for _, txtValue := range record.Value {
if *txtValue == value {
return true, nil
}
}
}
recordSet.Properties.TxtRecords = append(recordSet.Properties.TxtRecords, newProperties.TxtRecords...)
default:
// ignore
return false, errors.New("not supported record type '" + (*recordSet.Type) + "'")
}
return false, nil
}
func (this *AzureDNSProvider) recordSetUpdateRecordValue(recordSet *armdns.RecordSet, recordType string, oldRecordValue string, newRecordValue string) (foundRecord bool, err error) {
if recordSet.Properties == nil {
return false, nil
}
switch recordType {
case "A":
var newRecords = []*armdns.ARecord{}
for _, record := range recordSet.Properties.ARecords {
if *record.IPv4Address == oldRecordValue {
foundRecord = true
continue
}
newRecords = append(newRecords, record)
}
newRecords = append(newRecords, &armdns.ARecord{IPv4Address: this.stringVal(newRecordValue)})
recordSet.Properties.ARecords = newRecords
case "AAAA":
var newRecords = []*armdns.AaaaRecord{}
for _, record := range recordSet.Properties.AaaaRecords {
if *record.IPv6Address == oldRecordValue {
foundRecord = true
continue
}
newRecords = append(newRecords, record)
}
newRecords = append(newRecords, &armdns.AaaaRecord{IPv6Address: this.stringVal(newRecordValue)})
recordSet.Properties.AaaaRecords = newRecords
case "CNAME":
recordSet.Properties.CnameRecord = &armdns.CnameRecord{Cname: this.stringVal(newRecordValue)}
foundRecord = true
case "TXT":
var newRecords = []*armdns.TxtRecord{}
var shouldAdd = true
for _, record := range recordSet.Properties.TxtRecords {
var oldValues = []*string{}
var found = false
for _, oldValue := range record.Value {
if *oldValue == oldRecordValue {
found = true
shouldAdd = false
foundRecord = true
continue
}
oldValues = append(oldValues, oldValue)
}
if found {
oldValues = append(oldValues, this.stringVal(newRecordValue))
record.Value = oldValues // overwrite
}
if len(oldValues) == 0 {
continue
}
newRecords = append(newRecords, record)
}
if shouldAdd {
newRecords = append(newRecords, &armdns.TxtRecord{Value: []*string{this.stringVal(newRecordValue)}})
}
recordSet.Properties.TxtRecords = newRecords
default:
// ignore
return false, errors.New("not supported record type '" + (*recordSet.Type) + "'")
}
return
}
func (this *AzureDNSProvider) recordSetDeleteRecordValue(recordSet *armdns.RecordSet, recordType string, recordValue string) (shouldUpdate bool, shouldDelete bool, err error) {
if recordSet.Properties == nil {
shouldDelete = true
return
}
switch recordType {
case "A":
var newRecords = []*armdns.ARecord{}
for _, record := range recordSet.Properties.ARecords {
if *record.IPv4Address == recordValue {
shouldUpdate = true
continue
}
newRecords = append(newRecords, record)
}
recordSet.Properties.ARecords = newRecords
if !shouldUpdate {
return
}
if len(newRecords) == 0 {
shouldDelete = true
}
case "AAAA":
var newRecords = []*armdns.AaaaRecord{}
for _, record := range recordSet.Properties.AaaaRecords {
if *record.IPv6Address == recordValue {
shouldUpdate = true
continue
}
newRecords = append(newRecords, record)
}
if !shouldUpdate {
return
}
if len(newRecords) == 0 {
shouldDelete = true
}
recordSet.Properties.AaaaRecords = newRecords
case "CNAME":
shouldDelete = true
case "TXT":
var newRecords = []*armdns.TxtRecord{}
for _, record := range recordSet.Properties.TxtRecords {
var oldValues = []*string{}
for _, oldValue := range record.Value {
if *oldValue == recordValue {
shouldUpdate = true
continue
}
oldValues = append(oldValues, oldValue)
}
if len(oldValues) == 0 {
continue
}
record.Value = oldValues
newRecords = append(newRecords, record)
}
recordSet.Properties.TxtRecords = newRecords
if !shouldUpdate {
return
}
if len(newRecords) == 0 {
shouldDelete = true
}
default:
// ignore
return false, false, errors.New("not supported record type '" + (*recordSet.Type) + "'")
}
return
}
func (this *AzureDNSProvider) recordValueToProperties(recordType string, recordValue string) *armdns.RecordSetProperties {
var properties = &armdns.RecordSetProperties{}
switch recordType {
case "A":
properties.ARecords = []*armdns.ARecord{
{
IPv4Address: this.stringVal(recordValue),
},
}
case "AAAA":
properties.AaaaRecords = []*armdns.AaaaRecord{
{
IPv6Address: this.stringVal(recordValue),
},
}
case "CNAME":
properties.CnameRecord = &armdns.CnameRecord{Cname: this.stringVal(recordValue)}
case "TXT":
properties.TxtRecords = []*armdns.TxtRecord{
{
Value: []*string{this.stringVal(recordValue)},
},
}
default:
// ignore
return nil
}
return properties
}
func (this *AzureDNSProvider) stringVal(s string) *string {
return &s
}