1.4.5.2
This commit is contained in:
744
EdgeAPI/internal/dnsclients/provider_azure_dns_plus.go
Normal file
744
EdgeAPI/internal/dnsclients/provider_azure_dns_plus.go
Normal file
@@ -0,0 +1,744 @@
|
||||
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
//go:build plus
|
||||
|
||||
package dnsclients
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/dns/armdns"
|
||||
"github.com/TeaOSLab/EdgeAPI/internal/dnsclients/dnstypes"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// AzureDNSProvider 微软Azure DNS zone
|
||||
type AzureDNSProvider struct {
|
||||
BaseProvider
|
||||
|
||||
ProviderId int64
|
||||
|
||||
zonesClient *armdns.ZonesClient
|
||||
recordSetsClient *armdns.RecordSetsClient
|
||||
|
||||
resourceGroupName string
|
||||
}
|
||||
|
||||
func NewAzureDNSProvider() *AzureDNSProvider {
|
||||
return &AzureDNSProvider{}
|
||||
}
|
||||
|
||||
// Auth 认证
|
||||
func (this *AzureDNSProvider) Auth(params maps.Map) error {
|
||||
var subscriptionId = params.GetString("subscriptionId")
|
||||
var tenantId = params.GetString("tenantId")
|
||||
var clientId = params.GetString("clientId")
|
||||
var clientSecret = params.GetString("clientSecret")
|
||||
var resourceGroupName = params.GetString("resourceGroupName")
|
||||
|
||||
if len(subscriptionId) == 0 {
|
||||
return errors.New("'subscriptionId' required")
|
||||
}
|
||||
if len(tenantId) == 0 {
|
||||
return errors.New("'tenantId' required")
|
||||
}
|
||||
if len(clientId) == 0 {
|
||||
return errors.New("'clientId' required")
|
||||
}
|
||||
if len(clientSecret) == 0 {
|
||||
return errors.New("'clientSecret' required")
|
||||
}
|
||||
if len(resourceGroupName) == 0 {
|
||||
return errors.New("'resourceGroupName' required")
|
||||
}
|
||||
this.resourceGroupName = resourceGroupName
|
||||
|
||||
cred, err := azidentity.NewClientSecretCredential(tenantId, clientId, clientSecret, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("NewClientSecretCredential: %w", err)
|
||||
}
|
||||
|
||||
clientFactory, err := armdns.NewClientFactory(subscriptionId, cred, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("NewClientFactory: %w", err)
|
||||
}
|
||||
this.zonesClient = clientFactory.NewZonesClient()
|
||||
this.recordSetsClient = clientFactory.NewRecordSetsClient()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MaskParams 对参数进行掩码
|
||||
func (this *AzureDNSProvider) MaskParams(params maps.Map) {
|
||||
if params == nil {
|
||||
return
|
||||
}
|
||||
params["clientSecret"] = MaskString(params.GetString("clientSecret"))
|
||||
}
|
||||
|
||||
// GetDomains 获取所有域名列表
|
||||
func (this *AzureDNSProvider) GetDomains() (domains []string, err error) {
|
||||
var pager = this.zonesClient.NewListPager(nil)
|
||||
for pager.More() {
|
||||
page, err := pager.NextPage(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, zone := range page.Value {
|
||||
if zone.Name != nil && !lists.ContainsString(domains, *zone.Name) {
|
||||
domains = append(domains, *zone.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// GetRecords 获取域名列表
|
||||
func (this *AzureDNSProvider) GetRecords(domain string) (records []*dnstypes.Record, err error) {
|
||||
var pager = this.recordSetsClient.NewListByDNSZonePager(this.resourceGroupName, domain, nil)
|
||||
for pager.More() {
|
||||
page, err := pager.NextPage(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, recordSet := range page.Value {
|
||||
if recordSet.Name == nil || recordSet.Properties == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
records = append(records, this.recordSetToRecords(recordSet)...)
|
||||
}
|
||||
}
|
||||
|
||||
// 写入缓存
|
||||
if this.ProviderId > 0 {
|
||||
sharedDomainRecordsCache.WriteDomainRecords(this.ProviderId, domain, records)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// GetRoutes 读取线路数据
|
||||
func (this *AzureDNSProvider) GetRoutes(domain string) (routes []*dnstypes.Route, err error) {
|
||||
routes = []*dnstypes.Route{
|
||||
{
|
||||
Name: "默认",
|
||||
Code: "default",
|
||||
},
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// QueryRecord 查询单个记录
|
||||
func (this *AzureDNSProvider) QueryRecord(domain string, name string, recordType dnstypes.RecordType) (*dnstypes.Record, error) {
|
||||
// 从缓存中读取
|
||||
if this.ProviderId > 0 {
|
||||
record, hasRecords, _ := sharedDomainRecordsCache.QueryDomainRecord(this.ProviderId, domain, name, recordType)
|
||||
if hasRecords { // 有效的搜索
|
||||
return record, nil
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := this.recordSetsClient.Get(context.Background(), this.resourceGroupName, domain, name, armdns.RecordType(recordType), nil)
|
||||
if err != nil {
|
||||
if this.isNotFound(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var records = this.recordSetToRecords(&resp.RecordSet)
|
||||
if len(records) > 0 {
|
||||
return records[0], nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// QueryRecords 查询多个记录
|
||||
func (this *AzureDNSProvider) QueryRecords(domain string, name string, recordType dnstypes.RecordType) ([]*dnstypes.Record, error) {
|
||||
// 从缓存中读取
|
||||
if this.ProviderId > 0 {
|
||||
records, hasRecords, _ := sharedDomainRecordsCache.QueryDomainRecords(this.ProviderId, domain, name, recordType)
|
||||
if hasRecords { // 有效的搜索
|
||||
return records, nil
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := this.recordSetsClient.Get(context.Background(), this.resourceGroupName, domain, name, armdns.RecordType(recordType), nil)
|
||||
if err != nil {
|
||||
if this.isNotFound(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return this.recordSetToRecords(&resp.RecordSet), nil
|
||||
}
|
||||
|
||||
// AddRecord 设置记录
|
||||
func (this *AzureDNSProvider) AddRecord(domain string, newRecord *dnstypes.Record) error {
|
||||
if newRecord == nil {
|
||||
return errors.New("invalid new record")
|
||||
}
|
||||
|
||||
// 在CHANGE记录后面加入点
|
||||
if newRecord.Type == dnstypes.RecordTypeCNAME && !strings.HasSuffix(newRecord.Value, ".") {
|
||||
newRecord.Value += "."
|
||||
}
|
||||
|
||||
var ttl = newRecord.TTL
|
||||
if ttl <= 0 {
|
||||
ttl = 600
|
||||
}
|
||||
|
||||
var recordSet = armdns.RecordSet{
|
||||
Etag: nil,
|
||||
Properties: nil,
|
||||
ID: nil,
|
||||
Name: this.stringVal(newRecord.Name),
|
||||
Type: this.stringVal(newRecord.Type),
|
||||
}
|
||||
|
||||
resp, err := this.recordSetsClient.Get(context.Background(), this.resourceGroupName, domain, newRecord.Name, armdns.RecordType(newRecord.Type), nil)
|
||||
if err != nil {
|
||||
if this.isNotFound(err) {
|
||||
err = nil
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var ttlInt64 = int64(ttl)
|
||||
if resp.RecordSet.Properties != nil {
|
||||
recordSet.Properties = resp.RecordSet.Properties
|
||||
recordSet.Properties.TTL = &ttlInt64
|
||||
recordSet.ID = resp.RecordSet.ID
|
||||
} else {
|
||||
recordSet.Properties = &armdns.RecordSetProperties{
|
||||
TTL: &ttlInt64,
|
||||
}
|
||||
}
|
||||
exists, err := this.recordSetAddRecordValue(&recordSet, newRecord.Value)
|
||||
if exists || err != nil {
|
||||
if exists {
|
||||
newRecord.Id = this.ComposeRecordId(newRecord.Name, newRecord.Type, newRecord.Value)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = this.recordSetsClient.CreateOrUpdate(context.Background(), this.resourceGroupName, domain, newRecord.Name, armdns.RecordType(newRecord.Type), recordSet, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newRecord.Id = this.ComposeRecordId(newRecord.Name, newRecord.Type, newRecord.Value)
|
||||
|
||||
// 加入缓存
|
||||
if this.ProviderId > 0 {
|
||||
sharedDomainRecordsCache.AddDomainRecord(this.ProviderId, domain, newRecord)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateRecord 修改记录
|
||||
func (this *AzureDNSProvider) UpdateRecord(domain string, record *dnstypes.Record, newRecord *dnstypes.Record) error {
|
||||
if record == nil {
|
||||
return errors.New("invalid record")
|
||||
}
|
||||
if newRecord == nil {
|
||||
return errors.New("invalid new record")
|
||||
}
|
||||
|
||||
// 在CHANGE记录后面加入点
|
||||
if newRecord.Type == dnstypes.RecordTypeCNAME && !strings.HasSuffix(newRecord.Value, ".") {
|
||||
newRecord.Value += "."
|
||||
}
|
||||
|
||||
var newRoute = newRecord.Route
|
||||
if len(newRoute) == 0 {
|
||||
newRoute = this.DefaultRoute()
|
||||
}
|
||||
|
||||
_, oldRecordName, oldRecordType, oldRecordValue, err := this.decodeRecordId(record.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if oldRecordName == newRecord.Name && oldRecordType == newRecord.Type && oldRecordValue == newRecord.Value {
|
||||
return nil
|
||||
}
|
||||
|
||||
resp, err := this.recordSetsClient.Get(context.Background(), this.resourceGroupName, domain, newRecord.Name, armdns.RecordType(newRecord.Type), nil)
|
||||
if err != nil {
|
||||
if this.isNotFound(err) {
|
||||
return nil
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var recordSet = resp.RecordSet
|
||||
if recordSet.Properties == nil {
|
||||
return nil
|
||||
}
|
||||
found, err := this.recordSetUpdateRecordValue(&recordSet, newRecord.Type, oldRecordValue, newRecord.Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !found {
|
||||
return nil
|
||||
}
|
||||
|
||||
var ttlInt64 = int64(newRecord.TTL)
|
||||
if ttlInt64 <= 0 {
|
||||
ttlInt64 = 600
|
||||
}
|
||||
recordSet.Properties.TTL = &ttlInt64
|
||||
|
||||
_, err = this.recordSetsClient.Update(context.Background(), this.resourceGroupName, domain, newRecord.Name, armdns.RecordType(newRecord.Type), recordSet, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newRecord.Id = this.ComposeRecordId(newRecord.Name, newRecord.Type, newRecord.Value)
|
||||
|
||||
// 修改缓存
|
||||
if this.ProviderId > 0 {
|
||||
sharedDomainRecordsCache.UpdateDomainRecord(this.ProviderId, domain, newRecord)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteRecord 删除记录
|
||||
func (this *AzureDNSProvider) DeleteRecord(domain string, record *dnstypes.Record) error {
|
||||
if record == nil {
|
||||
return errors.New("invalid record to delete")
|
||||
}
|
||||
|
||||
_, recordName, recordType, recordValue, err := this.decodeRecordId(record.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := this.recordSetsClient.Get(context.Background(), this.resourceGroupName, domain, recordName, armdns.RecordType(recordType), nil)
|
||||
if err != nil {
|
||||
if this.isNotFound(err) {
|
||||
return nil
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var recordSet = resp.RecordSet
|
||||
if recordSet.Properties == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
shouldUpdate, shouldDelete, err := this.recordSetDeleteRecordValue(&recordSet, recordType, recordValue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if shouldDelete {
|
||||
_, err = this.recordSetsClient.Delete(context.Background(), this.resourceGroupName, domain, recordName, armdns.RecordType(recordType), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if shouldUpdate {
|
||||
_, err = this.recordSetsClient.Update(context.Background(), this.resourceGroupName, domain, recordName, armdns.RecordType(recordType), recordSet, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 删除缓存
|
||||
if this.ProviderId > 0 {
|
||||
sharedDomainRecordsCache.DeleteDomainRecord(this.ProviderId, domain, record.Id)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DefaultRoute 默认线路
|
||||
func (this *AzureDNSProvider) DefaultRoute() string {
|
||||
return "default"
|
||||
}
|
||||
|
||||
func (this *AzureDNSProvider) ComposeRecordId(recordName string, recordType string, recordValue string) string {
|
||||
return base64.StdEncoding.EncodeToString([]byte("$" + recordName + "$" + recordType + "$" + recordValue))
|
||||
}
|
||||
|
||||
func (this *AzureDNSProvider) decodeRecordId(encodedRecordId string) (recordSetId string, recordName string, recordType string, recordValue string, err error) {
|
||||
data, err := base64.StdEncoding.DecodeString(encodedRecordId)
|
||||
if err != nil {
|
||||
return "", "", "", "", err
|
||||
}
|
||||
if len(data) == 0 {
|
||||
err = errors.New("invalid record id")
|
||||
return
|
||||
}
|
||||
var pieces = strings.SplitN(string(data), "$", 4)
|
||||
if len(pieces) != 4 {
|
||||
err = errors.New("invalid record id")
|
||||
return
|
||||
}
|
||||
recordSetId = pieces[0]
|
||||
recordName = pieces[1]
|
||||
recordType = pieces[2]
|
||||
recordValue = pieces[3]
|
||||
return
|
||||
}
|
||||
|
||||
func (this *AzureDNSProvider) fixCNAME(recordType string, recordValue string) string {
|
||||
// 修正Record
|
||||
if strings.ToUpper(recordType) == dnstypes.RecordTypeCNAME && !strings.HasSuffix(recordValue, ".") {
|
||||
recordValue += "."
|
||||
}
|
||||
return recordValue
|
||||
}
|
||||
|
||||
func (this *AzureDNSProvider) isNotFound(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var respErr *azcore.ResponseError
|
||||
if errors.As(err, &respErr) && respErr.ErrorCode == "NotFound" {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *AzureDNSProvider) recordSetToRecords(recordSet *armdns.RecordSet) (records []*dnstypes.Record) {
|
||||
if recordSet == nil || recordSet.Properties == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// A
|
||||
for _, record := range recordSet.Properties.ARecords {
|
||||
var recordType = "A"
|
||||
records = append(records, &dnstypes.Record{
|
||||
Id: this.ComposeRecordId(*recordSet.Name, recordType, *record.IPv4Address),
|
||||
Name: *recordSet.Name,
|
||||
Type: recordType,
|
||||
Value: *record.IPv4Address,
|
||||
Route: this.DefaultRoute(),
|
||||
TTL: types.Int32(*recordSet.Properties.TTL),
|
||||
})
|
||||
}
|
||||
|
||||
// AAAA
|
||||
for _, record := range recordSet.Properties.AaaaRecords {
|
||||
var recordType = "AAAA"
|
||||
records = append(records, &dnstypes.Record{
|
||||
Id: this.ComposeRecordId(*recordSet.Name, recordType, *record.IPv6Address),
|
||||
Name: *recordSet.Name,
|
||||
Type: recordType,
|
||||
Value: *record.IPv6Address,
|
||||
Route: this.DefaultRoute(),
|
||||
TTL: types.Int32(*recordSet.Properties.TTL),
|
||||
})
|
||||
}
|
||||
|
||||
// CNAME
|
||||
if recordSet.Properties.CnameRecord != nil {
|
||||
var recordType = "CNAME"
|
||||
var record = recordSet.Properties.CnameRecord
|
||||
records = append(records, &dnstypes.Record{
|
||||
Id: this.ComposeRecordId(*recordSet.Name, recordType, *record.Cname),
|
||||
Name: *recordSet.Name,
|
||||
Type: recordType,
|
||||
Value: *record.Cname,
|
||||
Route: this.DefaultRoute(),
|
||||
TTL: types.Int32(*recordSet.Properties.TTL),
|
||||
})
|
||||
}
|
||||
|
||||
// TXT
|
||||
if recordSet.Properties.TxtRecords != nil {
|
||||
var recordType = "TXT"
|
||||
for _, record := range recordSet.Properties.TxtRecords {
|
||||
if len(record.Value) == 0 {
|
||||
continue
|
||||
}
|
||||
for _, value := range record.Value {
|
||||
records = append(records, &dnstypes.Record{
|
||||
Id: this.ComposeRecordId(*recordSet.Name, recordType, *value),
|
||||
Name: *recordSet.Name,
|
||||
Type: recordType,
|
||||
Value: *value,
|
||||
Route: this.DefaultRoute(),
|
||||
TTL: types.Int32(*recordSet.Properties.TTL),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NS
|
||||
for _, record := range recordSet.Properties.NsRecords {
|
||||
var recordType = "NS"
|
||||
records = append(records, &dnstypes.Record{
|
||||
Id: this.ComposeRecordId(*recordSet.Name, recordType, *record.Nsdname),
|
||||
Name: *recordSet.Name,
|
||||
Type: recordType,
|
||||
Value: *record.Nsdname,
|
||||
Route: this.DefaultRoute(),
|
||||
TTL: types.Int32(*recordSet.Properties.TTL),
|
||||
})
|
||||
}
|
||||
|
||||
// SOA
|
||||
if recordSet.Properties.SoaRecord != nil {
|
||||
var recordType = "SOA"
|
||||
var record = recordSet.Properties.SoaRecord
|
||||
records = append(records, &dnstypes.Record{
|
||||
Id: this.ComposeRecordId(*recordSet.Name, recordType, *record.Host),
|
||||
Name: *recordSet.Name,
|
||||
Type: recordType,
|
||||
Value: fmt.Sprintf("%s %s %d %d %d %d %d", *record.Host, *record.Email, *record.SerialNumber, *record.RefreshTime, *record.RetryTime, *record.ExpireTime, *record.MinimumTTL),
|
||||
Route: this.DefaultRoute(),
|
||||
TTL: types.Int32(*recordSet.Properties.TTL),
|
||||
})
|
||||
}
|
||||
|
||||
// we don't support other record type yet
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (this *AzureDNSProvider) recordSetAddRecordValue(recordSet *armdns.RecordSet, value string) (exists bool, err error) {
|
||||
if recordSet.Properties == nil {
|
||||
recordSet.Properties = &armdns.RecordSetProperties{}
|
||||
}
|
||||
|
||||
var newProperties = this.recordValueToProperties(*recordSet.Type, value)
|
||||
if newProperties == nil {
|
||||
return false, errors.New("recordSetMergeRecordValue: invalid properties")
|
||||
}
|
||||
|
||||
switch *recordSet.Type {
|
||||
case "A":
|
||||
for _, record := range recordSet.Properties.ARecords {
|
||||
if *record.IPv4Address == value {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
recordSet.Properties.ARecords = append(recordSet.Properties.ARecords, newProperties.ARecords...)
|
||||
case "AAAA":
|
||||
for _, record := range recordSet.Properties.AaaaRecords {
|
||||
if *record.IPv6Address == value {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
recordSet.Properties.AaaaRecords = append(recordSet.Properties.AaaaRecords, newProperties.AaaaRecords...)
|
||||
case "CNAME":
|
||||
var record = recordSet.Properties.CnameRecord
|
||||
if record != nil && record.Cname != nil {
|
||||
if *record.Cname == value {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
recordSet.Properties.CnameRecord = newProperties.CnameRecord
|
||||
case "TXT":
|
||||
for _, record := range recordSet.Properties.TxtRecords {
|
||||
for _, txtValue := range record.Value {
|
||||
if *txtValue == value {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
recordSet.Properties.TxtRecords = append(recordSet.Properties.TxtRecords, newProperties.TxtRecords...)
|
||||
default:
|
||||
// ignore
|
||||
return false, errors.New("not supported record type '" + (*recordSet.Type) + "'")
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (this *AzureDNSProvider) recordSetUpdateRecordValue(recordSet *armdns.RecordSet, recordType string, oldRecordValue string, newRecordValue string) (foundRecord bool, err error) {
|
||||
if recordSet.Properties == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
switch recordType {
|
||||
case "A":
|
||||
var newRecords = []*armdns.ARecord{}
|
||||
for _, record := range recordSet.Properties.ARecords {
|
||||
if *record.IPv4Address == oldRecordValue {
|
||||
foundRecord = true
|
||||
continue
|
||||
}
|
||||
newRecords = append(newRecords, record)
|
||||
}
|
||||
newRecords = append(newRecords, &armdns.ARecord{IPv4Address: this.stringVal(newRecordValue)})
|
||||
recordSet.Properties.ARecords = newRecords
|
||||
case "AAAA":
|
||||
var newRecords = []*armdns.AaaaRecord{}
|
||||
for _, record := range recordSet.Properties.AaaaRecords {
|
||||
if *record.IPv6Address == oldRecordValue {
|
||||
foundRecord = true
|
||||
continue
|
||||
}
|
||||
newRecords = append(newRecords, record)
|
||||
}
|
||||
newRecords = append(newRecords, &armdns.AaaaRecord{IPv6Address: this.stringVal(newRecordValue)})
|
||||
recordSet.Properties.AaaaRecords = newRecords
|
||||
case "CNAME":
|
||||
recordSet.Properties.CnameRecord = &armdns.CnameRecord{Cname: this.stringVal(newRecordValue)}
|
||||
foundRecord = true
|
||||
case "TXT":
|
||||
var newRecords = []*armdns.TxtRecord{}
|
||||
var shouldAdd = true
|
||||
for _, record := range recordSet.Properties.TxtRecords {
|
||||
var oldValues = []*string{}
|
||||
var found = false
|
||||
for _, oldValue := range record.Value {
|
||||
if *oldValue == oldRecordValue {
|
||||
found = true
|
||||
shouldAdd = false
|
||||
foundRecord = true
|
||||
continue
|
||||
}
|
||||
oldValues = append(oldValues, oldValue)
|
||||
}
|
||||
if found {
|
||||
oldValues = append(oldValues, this.stringVal(newRecordValue))
|
||||
record.Value = oldValues // overwrite
|
||||
}
|
||||
if len(oldValues) == 0 {
|
||||
continue
|
||||
}
|
||||
newRecords = append(newRecords, record)
|
||||
}
|
||||
if shouldAdd {
|
||||
newRecords = append(newRecords, &armdns.TxtRecord{Value: []*string{this.stringVal(newRecordValue)}})
|
||||
}
|
||||
recordSet.Properties.TxtRecords = newRecords
|
||||
default:
|
||||
// ignore
|
||||
return false, errors.New("not supported record type '" + (*recordSet.Type) + "'")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (this *AzureDNSProvider) recordSetDeleteRecordValue(recordSet *armdns.RecordSet, recordType string, recordValue string) (shouldUpdate bool, shouldDelete bool, err error) {
|
||||
if recordSet.Properties == nil {
|
||||
shouldDelete = true
|
||||
return
|
||||
}
|
||||
|
||||
switch recordType {
|
||||
case "A":
|
||||
var newRecords = []*armdns.ARecord{}
|
||||
for _, record := range recordSet.Properties.ARecords {
|
||||
if *record.IPv4Address == recordValue {
|
||||
shouldUpdate = true
|
||||
continue
|
||||
}
|
||||
newRecords = append(newRecords, record)
|
||||
}
|
||||
recordSet.Properties.ARecords = newRecords
|
||||
if !shouldUpdate {
|
||||
return
|
||||
}
|
||||
if len(newRecords) == 0 {
|
||||
shouldDelete = true
|
||||
}
|
||||
case "AAAA":
|
||||
var newRecords = []*armdns.AaaaRecord{}
|
||||
for _, record := range recordSet.Properties.AaaaRecords {
|
||||
if *record.IPv6Address == recordValue {
|
||||
shouldUpdate = true
|
||||
continue
|
||||
}
|
||||
newRecords = append(newRecords, record)
|
||||
}
|
||||
if !shouldUpdate {
|
||||
return
|
||||
}
|
||||
if len(newRecords) == 0 {
|
||||
shouldDelete = true
|
||||
}
|
||||
recordSet.Properties.AaaaRecords = newRecords
|
||||
case "CNAME":
|
||||
shouldDelete = true
|
||||
case "TXT":
|
||||
var newRecords = []*armdns.TxtRecord{}
|
||||
for _, record := range recordSet.Properties.TxtRecords {
|
||||
var oldValues = []*string{}
|
||||
for _, oldValue := range record.Value {
|
||||
if *oldValue == recordValue {
|
||||
shouldUpdate = true
|
||||
continue
|
||||
}
|
||||
oldValues = append(oldValues, oldValue)
|
||||
}
|
||||
if len(oldValues) == 0 {
|
||||
continue
|
||||
}
|
||||
record.Value = oldValues
|
||||
newRecords = append(newRecords, record)
|
||||
}
|
||||
recordSet.Properties.TxtRecords = newRecords
|
||||
if !shouldUpdate {
|
||||
return
|
||||
}
|
||||
if len(newRecords) == 0 {
|
||||
shouldDelete = true
|
||||
}
|
||||
default:
|
||||
// ignore
|
||||
return false, false, errors.New("not supported record type '" + (*recordSet.Type) + "'")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (this *AzureDNSProvider) recordValueToProperties(recordType string, recordValue string) *armdns.RecordSetProperties {
|
||||
var properties = &armdns.RecordSetProperties{}
|
||||
|
||||
switch recordType {
|
||||
case "A":
|
||||
properties.ARecords = []*armdns.ARecord{
|
||||
{
|
||||
IPv4Address: this.stringVal(recordValue),
|
||||
},
|
||||
}
|
||||
case "AAAA":
|
||||
properties.AaaaRecords = []*armdns.AaaaRecord{
|
||||
{
|
||||
IPv6Address: this.stringVal(recordValue),
|
||||
},
|
||||
}
|
||||
case "CNAME":
|
||||
properties.CnameRecord = &armdns.CnameRecord{Cname: this.stringVal(recordValue)}
|
||||
case "TXT":
|
||||
properties.TxtRecords = []*armdns.TxtRecord{
|
||||
{
|
||||
Value: []*string{this.stringVal(recordValue)},
|
||||
},
|
||||
}
|
||||
default:
|
||||
// ignore
|
||||
return nil
|
||||
}
|
||||
|
||||
return properties
|
||||
}
|
||||
|
||||
func (this *AzureDNSProvider) stringVal(s string) *string {
|
||||
return &s
|
||||
}
|
||||
Reference in New Issue
Block a user