// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn . //go:build plus package dnsclients import ( "context" "encoding/base64" "errors" "fmt" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/dns/armdns" "github.com/TeaOSLab/EdgeAPI/internal/dnsclients/dnstypes" "github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/types" "strings" ) // AzureDNSProvider 微软Azure DNS zone type AzureDNSProvider struct { BaseProvider ProviderId int64 zonesClient *armdns.ZonesClient recordSetsClient *armdns.RecordSetsClient resourceGroupName string } func NewAzureDNSProvider() *AzureDNSProvider { return &AzureDNSProvider{} } // Auth 认证 func (this *AzureDNSProvider) Auth(params maps.Map) error { var subscriptionId = params.GetString("subscriptionId") var tenantId = params.GetString("tenantId") var clientId = params.GetString("clientId") var clientSecret = params.GetString("clientSecret") var resourceGroupName = params.GetString("resourceGroupName") if len(subscriptionId) == 0 { return errors.New("'subscriptionId' required") } if len(tenantId) == 0 { return errors.New("'tenantId' required") } if len(clientId) == 0 { return errors.New("'clientId' required") } if len(clientSecret) == 0 { return errors.New("'clientSecret' required") } if len(resourceGroupName) == 0 { return errors.New("'resourceGroupName' required") } this.resourceGroupName = resourceGroupName cred, err := azidentity.NewClientSecretCredential(tenantId, clientId, clientSecret, nil) if err != nil { return fmt.Errorf("NewClientSecretCredential: %w", err) } clientFactory, err := armdns.NewClientFactory(subscriptionId, cred, nil) if err != nil { return fmt.Errorf("NewClientFactory: %w", err) } this.zonesClient = clientFactory.NewZonesClient() this.recordSetsClient = clientFactory.NewRecordSetsClient() return nil } // MaskParams 对参数进行掩码 func (this *AzureDNSProvider) MaskParams(params maps.Map) { if params == nil { return } params["clientSecret"] = MaskString(params.GetString("clientSecret")) } // GetDomains 获取所有域名列表 func (this *AzureDNSProvider) GetDomains() (domains []string, err error) { var pager = this.zonesClient.NewListPager(nil) for pager.More() { page, err := pager.NextPage(context.Background()) if err != nil { return nil, err } for _, zone := range page.Value { if zone.Name != nil && !lists.ContainsString(domains, *zone.Name) { domains = append(domains, *zone.Name) } } } return } // GetRecords 获取域名列表 func (this *AzureDNSProvider) GetRecords(domain string) (records []*dnstypes.Record, err error) { var pager = this.recordSetsClient.NewListByDNSZonePager(this.resourceGroupName, domain, nil) for pager.More() { page, err := pager.NextPage(context.Background()) if err != nil { return nil, err } for _, recordSet := range page.Value { if recordSet.Name == nil || recordSet.Properties == nil { continue } records = append(records, this.recordSetToRecords(recordSet)...) } } // 写入缓存 if this.ProviderId > 0 { sharedDomainRecordsCache.WriteDomainRecords(this.ProviderId, domain, records) } return } // GetRoutes 读取线路数据 func (this *AzureDNSProvider) GetRoutes(domain string) (routes []*dnstypes.Route, err error) { routes = []*dnstypes.Route{ { Name: "默认", Code: "default", }, } return } // QueryRecord 查询单个记录 func (this *AzureDNSProvider) QueryRecord(domain string, name string, recordType dnstypes.RecordType) (*dnstypes.Record, error) { // 从缓存中读取 if this.ProviderId > 0 { record, hasRecords, _ := sharedDomainRecordsCache.QueryDomainRecord(this.ProviderId, domain, name, recordType) if hasRecords { // 有效的搜索 return record, nil } } resp, err := this.recordSetsClient.Get(context.Background(), this.resourceGroupName, domain, name, armdns.RecordType(recordType), nil) if err != nil { if this.isNotFound(err) { return nil, nil } return nil, err } var records = this.recordSetToRecords(&resp.RecordSet) if len(records) > 0 { return records[0], nil } return nil, nil } // QueryRecords 查询多个记录 func (this *AzureDNSProvider) QueryRecords(domain string, name string, recordType dnstypes.RecordType) ([]*dnstypes.Record, error) { // 从缓存中读取 if this.ProviderId > 0 { records, hasRecords, _ := sharedDomainRecordsCache.QueryDomainRecords(this.ProviderId, domain, name, recordType) if hasRecords { // 有效的搜索 return records, nil } } resp, err := this.recordSetsClient.Get(context.Background(), this.resourceGroupName, domain, name, armdns.RecordType(recordType), nil) if err != nil { if this.isNotFound(err) { return nil, nil } return nil, err } return this.recordSetToRecords(&resp.RecordSet), nil } // AddRecord 设置记录 func (this *AzureDNSProvider) AddRecord(domain string, newRecord *dnstypes.Record) error { if newRecord == nil { return errors.New("invalid new record") } // 在CHANGE记录后面加入点 if newRecord.Type == dnstypes.RecordTypeCNAME && !strings.HasSuffix(newRecord.Value, ".") { newRecord.Value += "." } var ttl = newRecord.TTL if ttl <= 0 { ttl = 600 } var recordSet = armdns.RecordSet{ Etag: nil, Properties: nil, ID: nil, Name: this.stringVal(newRecord.Name), Type: this.stringVal(newRecord.Type), } resp, err := this.recordSetsClient.Get(context.Background(), this.resourceGroupName, domain, newRecord.Name, armdns.RecordType(newRecord.Type), nil) if err != nil { if this.isNotFound(err) { err = nil } else { return err } } var ttlInt64 = int64(ttl) if resp.RecordSet.Properties != nil { recordSet.Properties = resp.RecordSet.Properties recordSet.Properties.TTL = &ttlInt64 recordSet.ID = resp.RecordSet.ID } else { recordSet.Properties = &armdns.RecordSetProperties{ TTL: &ttlInt64, } } exists, err := this.recordSetAddRecordValue(&recordSet, newRecord.Value) if exists || err != nil { if exists { newRecord.Id = this.ComposeRecordId(newRecord.Name, newRecord.Type, newRecord.Value) } return err } _, err = this.recordSetsClient.CreateOrUpdate(context.Background(), this.resourceGroupName, domain, newRecord.Name, armdns.RecordType(newRecord.Type), recordSet, nil) if err != nil { return err } newRecord.Id = this.ComposeRecordId(newRecord.Name, newRecord.Type, newRecord.Value) // 加入缓存 if this.ProviderId > 0 { sharedDomainRecordsCache.AddDomainRecord(this.ProviderId, domain, newRecord) } return nil } // UpdateRecord 修改记录 func (this *AzureDNSProvider) UpdateRecord(domain string, record *dnstypes.Record, newRecord *dnstypes.Record) error { if record == nil { return errors.New("invalid record") } if newRecord == nil { return errors.New("invalid new record") } // 在CHANGE记录后面加入点 if newRecord.Type == dnstypes.RecordTypeCNAME && !strings.HasSuffix(newRecord.Value, ".") { newRecord.Value += "." } var newRoute = newRecord.Route if len(newRoute) == 0 { newRoute = this.DefaultRoute() } _, oldRecordName, oldRecordType, oldRecordValue, err := this.decodeRecordId(record.Id) if err != nil { return err } if oldRecordName == newRecord.Name && oldRecordType == newRecord.Type && oldRecordValue == newRecord.Value { return nil } resp, err := this.recordSetsClient.Get(context.Background(), this.resourceGroupName, domain, newRecord.Name, armdns.RecordType(newRecord.Type), nil) if err != nil { if this.isNotFound(err) { return nil } else { return err } } var recordSet = resp.RecordSet if recordSet.Properties == nil { return nil } found, err := this.recordSetUpdateRecordValue(&recordSet, newRecord.Type, oldRecordValue, newRecord.Value) if err != nil { return err } if !found { return nil } var ttlInt64 = int64(newRecord.TTL) if ttlInt64 <= 0 { ttlInt64 = 600 } recordSet.Properties.TTL = &ttlInt64 _, err = this.recordSetsClient.Update(context.Background(), this.resourceGroupName, domain, newRecord.Name, armdns.RecordType(newRecord.Type), recordSet, nil) if err != nil { return err } newRecord.Id = this.ComposeRecordId(newRecord.Name, newRecord.Type, newRecord.Value) // 修改缓存 if this.ProviderId > 0 { sharedDomainRecordsCache.UpdateDomainRecord(this.ProviderId, domain, newRecord) } return nil } // DeleteRecord 删除记录 func (this *AzureDNSProvider) DeleteRecord(domain string, record *dnstypes.Record) error { if record == nil { return errors.New("invalid record to delete") } _, recordName, recordType, recordValue, err := this.decodeRecordId(record.Id) if err != nil { return err } resp, err := this.recordSetsClient.Get(context.Background(), this.resourceGroupName, domain, recordName, armdns.RecordType(recordType), nil) if err != nil { if this.isNotFound(err) { return nil } else { return err } } var recordSet = resp.RecordSet if recordSet.Properties == nil { return nil } shouldUpdate, shouldDelete, err := this.recordSetDeleteRecordValue(&recordSet, recordType, recordValue) if err != nil { return err } if shouldDelete { _, err = this.recordSetsClient.Delete(context.Background(), this.resourceGroupName, domain, recordName, armdns.RecordType(recordType), nil) if err != nil { return err } } else if shouldUpdate { _, err = this.recordSetsClient.Update(context.Background(), this.resourceGroupName, domain, recordName, armdns.RecordType(recordType), recordSet, nil) if err != nil { return err } } else { return nil } // 删除缓存 if this.ProviderId > 0 { sharedDomainRecordsCache.DeleteDomainRecord(this.ProviderId, domain, record.Id) } return nil } // DefaultRoute 默认线路 func (this *AzureDNSProvider) DefaultRoute() string { return "default" } func (this *AzureDNSProvider) ComposeRecordId(recordName string, recordType string, recordValue string) string { return base64.StdEncoding.EncodeToString([]byte("$" + recordName + "$" + recordType + "$" + recordValue)) } func (this *AzureDNSProvider) decodeRecordId(encodedRecordId string) (recordSetId string, recordName string, recordType string, recordValue string, err error) { data, err := base64.StdEncoding.DecodeString(encodedRecordId) if err != nil { return "", "", "", "", err } if len(data) == 0 { err = errors.New("invalid record id") return } var pieces = strings.SplitN(string(data), "$", 4) if len(pieces) != 4 { err = errors.New("invalid record id") return } recordSetId = pieces[0] recordName = pieces[1] recordType = pieces[2] recordValue = pieces[3] return } func (this *AzureDNSProvider) fixCNAME(recordType string, recordValue string) string { // 修正Record if strings.ToUpper(recordType) == dnstypes.RecordTypeCNAME && !strings.HasSuffix(recordValue, ".") { recordValue += "." } return recordValue } func (this *AzureDNSProvider) isNotFound(err error) bool { if err == nil { return false } var respErr *azcore.ResponseError if errors.As(err, &respErr) && respErr.ErrorCode == "NotFound" { return true } return false } func (this *AzureDNSProvider) recordSetToRecords(recordSet *armdns.RecordSet) (records []*dnstypes.Record) { if recordSet == nil || recordSet.Properties == nil { return } // A for _, record := range recordSet.Properties.ARecords { var recordType = "A" records = append(records, &dnstypes.Record{ Id: this.ComposeRecordId(*recordSet.Name, recordType, *record.IPv4Address), Name: *recordSet.Name, Type: recordType, Value: *record.IPv4Address, Route: this.DefaultRoute(), TTL: types.Int32(*recordSet.Properties.TTL), }) } // AAAA for _, record := range recordSet.Properties.AaaaRecords { var recordType = "AAAA" records = append(records, &dnstypes.Record{ Id: this.ComposeRecordId(*recordSet.Name, recordType, *record.IPv6Address), Name: *recordSet.Name, Type: recordType, Value: *record.IPv6Address, Route: this.DefaultRoute(), TTL: types.Int32(*recordSet.Properties.TTL), }) } // CNAME if recordSet.Properties.CnameRecord != nil { var recordType = "CNAME" var record = recordSet.Properties.CnameRecord records = append(records, &dnstypes.Record{ Id: this.ComposeRecordId(*recordSet.Name, recordType, *record.Cname), Name: *recordSet.Name, Type: recordType, Value: *record.Cname, Route: this.DefaultRoute(), TTL: types.Int32(*recordSet.Properties.TTL), }) } // TXT if recordSet.Properties.TxtRecords != nil { var recordType = "TXT" for _, record := range recordSet.Properties.TxtRecords { if len(record.Value) == 0 { continue } for _, value := range record.Value { records = append(records, &dnstypes.Record{ Id: this.ComposeRecordId(*recordSet.Name, recordType, *value), Name: *recordSet.Name, Type: recordType, Value: *value, Route: this.DefaultRoute(), TTL: types.Int32(*recordSet.Properties.TTL), }) } } } // NS for _, record := range recordSet.Properties.NsRecords { var recordType = "NS" records = append(records, &dnstypes.Record{ Id: this.ComposeRecordId(*recordSet.Name, recordType, *record.Nsdname), Name: *recordSet.Name, Type: recordType, Value: *record.Nsdname, Route: this.DefaultRoute(), TTL: types.Int32(*recordSet.Properties.TTL), }) } // SOA if recordSet.Properties.SoaRecord != nil { var recordType = "SOA" var record = recordSet.Properties.SoaRecord records = append(records, &dnstypes.Record{ Id: this.ComposeRecordId(*recordSet.Name, recordType, *record.Host), Name: *recordSet.Name, Type: recordType, Value: fmt.Sprintf("%s %s %d %d %d %d %d", *record.Host, *record.Email, *record.SerialNumber, *record.RefreshTime, *record.RetryTime, *record.ExpireTime, *record.MinimumTTL), Route: this.DefaultRoute(), TTL: types.Int32(*recordSet.Properties.TTL), }) } // we don't support other record type yet return } func (this *AzureDNSProvider) recordSetAddRecordValue(recordSet *armdns.RecordSet, value string) (exists bool, err error) { if recordSet.Properties == nil { recordSet.Properties = &armdns.RecordSetProperties{} } var newProperties = this.recordValueToProperties(*recordSet.Type, value) if newProperties == nil { return false, errors.New("recordSetMergeRecordValue: invalid properties") } switch *recordSet.Type { case "A": for _, record := range recordSet.Properties.ARecords { if *record.IPv4Address == value { return true, nil } } recordSet.Properties.ARecords = append(recordSet.Properties.ARecords, newProperties.ARecords...) case "AAAA": for _, record := range recordSet.Properties.AaaaRecords { if *record.IPv6Address == value { return true, nil } } recordSet.Properties.AaaaRecords = append(recordSet.Properties.AaaaRecords, newProperties.AaaaRecords...) case "CNAME": var record = recordSet.Properties.CnameRecord if record != nil && record.Cname != nil { if *record.Cname == value { return true, nil } } recordSet.Properties.CnameRecord = newProperties.CnameRecord case "TXT": for _, record := range recordSet.Properties.TxtRecords { for _, txtValue := range record.Value { if *txtValue == value { return true, nil } } } recordSet.Properties.TxtRecords = append(recordSet.Properties.TxtRecords, newProperties.TxtRecords...) default: // ignore return false, errors.New("not supported record type '" + (*recordSet.Type) + "'") } return false, nil } func (this *AzureDNSProvider) recordSetUpdateRecordValue(recordSet *armdns.RecordSet, recordType string, oldRecordValue string, newRecordValue string) (foundRecord bool, err error) { if recordSet.Properties == nil { return false, nil } switch recordType { case "A": var newRecords = []*armdns.ARecord{} for _, record := range recordSet.Properties.ARecords { if *record.IPv4Address == oldRecordValue { foundRecord = true continue } newRecords = append(newRecords, record) } newRecords = append(newRecords, &armdns.ARecord{IPv4Address: this.stringVal(newRecordValue)}) recordSet.Properties.ARecords = newRecords case "AAAA": var newRecords = []*armdns.AaaaRecord{} for _, record := range recordSet.Properties.AaaaRecords { if *record.IPv6Address == oldRecordValue { foundRecord = true continue } newRecords = append(newRecords, record) } newRecords = append(newRecords, &armdns.AaaaRecord{IPv6Address: this.stringVal(newRecordValue)}) recordSet.Properties.AaaaRecords = newRecords case "CNAME": recordSet.Properties.CnameRecord = &armdns.CnameRecord{Cname: this.stringVal(newRecordValue)} foundRecord = true case "TXT": var newRecords = []*armdns.TxtRecord{} var shouldAdd = true for _, record := range recordSet.Properties.TxtRecords { var oldValues = []*string{} var found = false for _, oldValue := range record.Value { if *oldValue == oldRecordValue { found = true shouldAdd = false foundRecord = true continue } oldValues = append(oldValues, oldValue) } if found { oldValues = append(oldValues, this.stringVal(newRecordValue)) record.Value = oldValues // overwrite } if len(oldValues) == 0 { continue } newRecords = append(newRecords, record) } if shouldAdd { newRecords = append(newRecords, &armdns.TxtRecord{Value: []*string{this.stringVal(newRecordValue)}}) } recordSet.Properties.TxtRecords = newRecords default: // ignore return false, errors.New("not supported record type '" + (*recordSet.Type) + "'") } return } func (this *AzureDNSProvider) recordSetDeleteRecordValue(recordSet *armdns.RecordSet, recordType string, recordValue string) (shouldUpdate bool, shouldDelete bool, err error) { if recordSet.Properties == nil { shouldDelete = true return } switch recordType { case "A": var newRecords = []*armdns.ARecord{} for _, record := range recordSet.Properties.ARecords { if *record.IPv4Address == recordValue { shouldUpdate = true continue } newRecords = append(newRecords, record) } recordSet.Properties.ARecords = newRecords if !shouldUpdate { return } if len(newRecords) == 0 { shouldDelete = true } case "AAAA": var newRecords = []*armdns.AaaaRecord{} for _, record := range recordSet.Properties.AaaaRecords { if *record.IPv6Address == recordValue { shouldUpdate = true continue } newRecords = append(newRecords, record) } if !shouldUpdate { return } if len(newRecords) == 0 { shouldDelete = true } recordSet.Properties.AaaaRecords = newRecords case "CNAME": shouldDelete = true case "TXT": var newRecords = []*armdns.TxtRecord{} for _, record := range recordSet.Properties.TxtRecords { var oldValues = []*string{} for _, oldValue := range record.Value { if *oldValue == recordValue { shouldUpdate = true continue } oldValues = append(oldValues, oldValue) } if len(oldValues) == 0 { continue } record.Value = oldValues newRecords = append(newRecords, record) } recordSet.Properties.TxtRecords = newRecords if !shouldUpdate { return } if len(newRecords) == 0 { shouldDelete = true } default: // ignore return false, false, errors.New("not supported record type '" + (*recordSet.Type) + "'") } return } func (this *AzureDNSProvider) recordValueToProperties(recordType string, recordValue string) *armdns.RecordSetProperties { var properties = &armdns.RecordSetProperties{} switch recordType { case "A": properties.ARecords = []*armdns.ARecord{ { IPv4Address: this.stringVal(recordValue), }, } case "AAAA": properties.AaaaRecords = []*armdns.AaaaRecord{ { IPv6Address: this.stringVal(recordValue), }, } case "CNAME": properties.CnameRecord = &armdns.CnameRecord{Cname: this.stringVal(recordValue)} case "TXT": properties.TxtRecords = []*armdns.TxtRecord{ { Value: []*string{this.stringVal(recordValue)}, }, } default: // ignore return nil } return properties } func (this *AzureDNSProvider) stringVal(s string) *string { return &s }