// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn . //go:build plus package dnsclients import ( "encoding/base64" "errors" "fmt" "github.com/TeaOSLab/EdgeAPI/internal/dnsclients/dnstypes" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/route53" "github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/types" stringutil "github.com/iwind/TeaGo/utils/string" "strings" ) // AmazonRoute53Provider Amazon Route 53 // 其中我们总是判断 recordSet.SetIdentifier != nil 是为了只关注我们自己创建的记录集 type AmazonRoute53Provider struct { BaseProvider ProviderId int64 client *route53.Route53 } func NewAmazonRoute53Provider() *AmazonRoute53Provider { return &AmazonRoute53Provider{} } // Auth 认证 func (this *AmazonRoute53Provider) Auth(params maps.Map) error { var accessKeyId = params.GetString("accessKeyId") var accessKeySecret = params.GetString("accessKeySecret") var region = params.GetString("region") if len(accessKeyId) == 0 { return errors.New("'accessKeyId' required") } if len(accessKeySecret) == 0 { return errors.New("'accessKeySecret' required") } var regionPtr *string if len(region) > 0 { regionPtr = aws.String(region) } sess, err := session.NewSession(&aws.Config{ Credentials: credentials.NewCredentials(NewAmazonCredentialProvider(accessKeyId, accessKeySecret)), Region: regionPtr, }) if err != nil { return err } this.client = route53.New(sess) return nil } // MaskParams 对参数进行掩码 func (this *AmazonRoute53Provider) MaskParams(params maps.Map) { if params == nil { return } params["accessKeySecret"] = MaskString(params.GetString("accessKeySecret")) } // GetDomains 获取所有域名列表 func (this *AmazonRoute53Provider) GetDomains() (domains []string, err error) { var nextMarker *string for { var input = &route53.ListHostedZonesInput{ Marker: nextMarker, } output, err := this.client.ListHostedZones(input) if err != nil { return nil, err } for _, zone := range output.HostedZones { domains = append(domains, strings.TrimSuffix(*zone.Name, ".")) } if *output.IsTruncated { nextMarker = output.NextMarker } else { break } } return } // GetRecords 获取域名列表 func (this *AmazonRoute53Provider) GetRecords(domain string) (records []*dnstypes.Record, err error) { zoneId, err := this.getZoneId(domain) if err != nil || zoneId == nil { return nil, err } var nextRecordIdentifier *string var nextRecordName *string var nextRecordType *string for { var input = &route53.ListResourceRecordSetsInput{ HostedZoneId: zoneId, StartRecordIdentifier: nextRecordIdentifier, StartRecordName: nextRecordName, StartRecordType: nextRecordType, } output, err := this.client.ListResourceRecordSets(input) if err != nil { return nil, err } for _, recordSet := range output.ResourceRecordSets { // 检查返回值是否正常 if recordSet.SetIdentifier == nil || recordSet.Name == nil || recordSet.Type == nil || recordSet.TTL == nil { continue } for _, rawRecord := range recordSet.ResourceRecords { if rawRecord.Value == nil { continue } var recordName = strings.TrimSuffix(strings.TrimSuffix(*recordSet.Name, domain+"."), ".") records = append(records, &dnstypes.Record{ Id: this.composeRecordId(*recordSet.SetIdentifier, recordName, *recordSet.Type, *rawRecord.Value), Name: recordName, Type: *recordSet.Type, Value: *rawRecord.Value, Route: this.composeGeoLocationCode(recordSet.GeoLocation), TTL: types.Int32(*recordSet.TTL), }) } } if *output.IsTruncated { nextRecordIdentifier = output.NextRecordIdentifier nextRecordName = output.NextRecordName nextRecordType = output.NextRecordType } else { break } } // 写入缓存 if this.ProviderId > 0 { sharedDomainRecordsCache.WriteDomainRecords(this.ProviderId, domain, records) } return } // GetRoutes 读取线路数据 func (this *AmazonRoute53Provider) GetRoutes(domain string) (routes []*dnstypes.Route, err error) { var nextContinentCode *string var nextCountryCode *string var nextSubdivisionCode *string for { var input = &route53.ListGeoLocationsInput{ MaxItems: nil, StartContinentCode: nextContinentCode, StartCountryCode: nextCountryCode, StartSubdivisionCode: nextSubdivisionCode, } output, err := this.client.ListGeoLocations(input) if err != nil { return nil, err } for _, location := range output.GeoLocationDetailsList { locationName, locationCode := this.composeGeoLocationDetail(location) routes = append(routes, &dnstypes.Route{ Name: locationName, Code: locationCode, }) } if *output.IsTruncated { nextContinentCode = output.NextContinentCode nextCountryCode = output.NextCountryCode nextSubdivisionCode = output.NextSubdivisionCode } else { break } } return } // QueryRecord 查询单个记录 func (this *AmazonRoute53Provider) QueryRecord(domain string, name string, recordType dnstypes.RecordType) (*dnstypes.Record, error) { // 从缓存中读取 if this.ProviderId > 0 { record, hasRecords, _ := sharedDomainRecordsCache.QueryDomainRecord(this.ProviderId, domain, name, recordType) if hasRecords { // 有效的搜索 return record, nil } } zoneId, err := this.getZoneId(domain) if err != nil || zoneId == nil { return nil, err } var fullname = name + "." + domain + "." var input = &route53.ListResourceRecordSetsInput{ HostedZoneId: zoneId, StartRecordName: aws.String(fullname), StartRecordType: aws.String(recordType), } output, err := this.client.ListResourceRecordSets(input) if err != nil { return nil, err } for _, recordSet := range output.ResourceRecordSets { if recordSet.SetIdentifier == nil { continue } if *recordSet.Name == fullname && *recordSet.Type == recordType { for _, rawRecord := range recordSet.ResourceRecords { return &dnstypes.Record{ Id: this.composeRecordId(*recordSet.SetIdentifier, name, *recordSet.Type, *rawRecord.Value), Name: name, Type: *recordSet.Type, Value: *rawRecord.Value, Route: this.composeGeoLocationCode(recordSet.GeoLocation), TTL: types.Int32(*recordSet.TTL), }, nil } } } return nil, nil } // QueryRecords 查询多个记录 func (this *AmazonRoute53Provider) QueryRecords(domain string, name string, recordType dnstypes.RecordType) ([]*dnstypes.Record, error) { // 从缓存中读取 if this.ProviderId > 0 { records, hasRecords, _ := sharedDomainRecordsCache.QueryDomainRecords(this.ProviderId, domain, name, recordType) if hasRecords { // 有效的搜索 return records, nil } } zoneId, err := this.getZoneId(domain) if err != nil || zoneId == nil { return nil, err } var fullname = name + "." + domain + "." var input = &route53.ListResourceRecordSetsInput{ HostedZoneId: zoneId, StartRecordName: aws.String(fullname), StartRecordType: aws.String(recordType), } output, err := this.client.ListResourceRecordSets(input) if err != nil { return nil, err } var result []*dnstypes.Record for _, recordSet := range output.ResourceRecordSets { if recordSet.SetIdentifier == nil { continue } if *recordSet.Name == fullname && *recordSet.Type == recordType { for _, rawRecord := range recordSet.ResourceRecords { result = append(result, &dnstypes.Record{ Id: this.composeRecordId(*recordSet.SetIdentifier, name, *recordSet.Type, *rawRecord.Value), Name: name, Type: *recordSet.Type, Value: *rawRecord.Value, Route: this.composeGeoLocationCode(recordSet.GeoLocation), TTL: types.Int32(*recordSet.TTL), }) } } } return result, nil } // AddRecord 设置记录 func (this *AmazonRoute53Provider) AddRecord(domain string, newRecord *dnstypes.Record) error { if newRecord == nil { return errors.New("invalid new record") } // 在CHANGE记录后面加入点 if newRecord.Type == dnstypes.RecordTypeCNAME && !strings.HasSuffix(newRecord.Value, ".") { newRecord.Value += "." } zoneId, err := this.getZoneId(domain) if err != nil || zoneId == nil { return err } var recordSetId = newRecord.Name + stringutil.Md5(newRecord.Name+"@"+newRecord.Type+"@"+newRecord.Route) // 检查是否已存在 var recordValues []*route53.ResourceRecord var existRecordValue = false { var input = &route53.ListResourceRecordSetsInput{ HostedZoneId: zoneId, StartRecordName: aws.String(newRecord.Name + "." + domain + "."), StartRecordType: aws.String(newRecord.Type), } output, err := this.client.ListResourceRecordSets(input) if err != nil { return fmt.Errorf("ListResourceRecordSets failed: %w", err) } if output.ResourceRecordSets != nil { for _, recordSet := range output.ResourceRecordSets { if recordSet.SetIdentifier == nil { continue } if *recordSet.SetIdentifier == recordSetId { for _, rawRecord := range recordSet.ResourceRecords { recordValues = append(recordValues, &route53.ResourceRecord{ Value: aws.String(*rawRecord.Value), }) if *rawRecord.Value == newRecord.Value { existRecordValue = true } } break } } } } if existRecordValue { return nil } recordValues = append(recordValues, &route53.ResourceRecord{ Value: aws.String(newRecord.Value), }) var geoLocation *route53.GeoLocation var recordRoute = newRecord.Route if len(recordRoute) == 0 { recordRoute = this.DefaultRoute() } if recordRoute != this.DefaultRoute() { for _, piece := range strings.Split(recordRoute, "@") { if strings.Contains(piece, ":") { var pieces2 = strings.SplitN(piece, ":", 2) if len(pieces2) == 2 && pieces2[1] != "*" { if geoLocation == nil { geoLocation = &route53.GeoLocation{} } switch pieces2[0] { case "CONTINENT": geoLocation.ContinentCode = aws.String(pieces2[1]) case "COUNTRY": geoLocation.CountryCode = aws.String(pieces2[1]) case "SUBDIVISION": geoLocation.SubdivisionCode = aws.String(pieces2[1]) } } } } } else { geoLocation = &route53.GeoLocation{ CountryCode: aws.String(recordRoute), } } var ttl = newRecord.TTL if ttl <= 0 { ttl = 600 } var input = &route53.ChangeResourceRecordSetsInput{ ChangeBatch: &route53.ChangeBatch{ Changes: []*route53.Change{ { Action: aws.String("UPSERT"), ResourceRecordSet: &route53.ResourceRecordSet{ AliasTarget: nil, Failover: nil, GeoLocation: geoLocation, HealthCheckId: nil, MultiValueAnswer: nil, Name: aws.String(newRecord.Name + "." + domain + "."), Region: nil, ResourceRecords: recordValues, SetIdentifier: aws.String(recordSetId), TTL: aws.Int64(int64(ttl)), TrafficPolicyInstanceId: nil, Type: aws.String(newRecord.Type), Weight: nil, }, }, }, Comment: nil, }, HostedZoneId: zoneId, } _, err = this.client.ChangeResourceRecordSets(input) if err != nil { return err } newRecord.Id = this.composeRecordId(recordSetId, newRecord.Name, newRecord.Type, newRecord.Value) // 加入缓存 if this.ProviderId > 0 { sharedDomainRecordsCache.AddDomainRecord(this.ProviderId, domain, newRecord) } return nil } // UpdateRecord 修改记录 func (this *AmazonRoute53Provider) UpdateRecord(domain string, record *dnstypes.Record, newRecord *dnstypes.Record) error { if record == nil { return errors.New("invalid record") } if newRecord == nil { return errors.New("invalid new record") } // 在CHANGE记录后面加入点 if newRecord.Type == dnstypes.RecordTypeCNAME && !strings.HasSuffix(newRecord.Value, ".") { newRecord.Value += "." } var newRoute = newRecord.Route if len(newRoute) == 0 { newRoute = this.DefaultRoute() } recordSetId, recordName, recordType, _, err := this.decodeRecordId(record.Id) if err != nil { return err } zoneId, err := this.getZoneId(domain) if err != nil || zoneId == nil { return err } // 检查是否已存在 var recordValues []*route53.ResourceRecord var shouldChanged = false { var input = &route53.ListResourceRecordSetsInput{ HostedZoneId: zoneId, StartRecordName: aws.String(newRecord.Name + "." + domain + "."), StartRecordType: aws.String(newRecord.Type), } output, err := this.client.ListResourceRecordSets(input) if err != nil { return fmt.Errorf("ListResourceRecordSets failed: %w", err) } if output.ResourceRecordSets != nil { for _, recordSet := range output.ResourceRecordSets { if recordSet.SetIdentifier == nil { continue } if *recordSet.SetIdentifier == recordSetId { shouldChanged = true for _, rawRecord := range recordSet.ResourceRecords { // skip old if record.Id == this.composeRecordId(recordSetId, recordName, recordType, *rawRecord.Value) || *rawRecord.Value == newRecord.Value { continue } recordValues = append(recordValues, &route53.ResourceRecord{ Value: aws.String(*rawRecord.Value), }) } break } } } } if !shouldChanged { return nil } recordValues = append(recordValues, &route53.ResourceRecord{ Value: aws.String(newRecord.Value), }) var geoLocation *route53.GeoLocation var recordRoute = newRecord.Route if len(recordRoute) == 0 { recordRoute = this.DefaultRoute() } if recordRoute != this.DefaultRoute() { for _, piece := range strings.Split(recordRoute, "@") { if strings.Contains(piece, ":") { var pieces2 = strings.SplitN(piece, ":", 2) if len(pieces2) == 2 && pieces2[1] != "*" { if geoLocation == nil { geoLocation = &route53.GeoLocation{} } switch pieces2[0] { case "CONTINENT": geoLocation.ContinentCode = aws.String(pieces2[1]) case "COUNTRY": geoLocation.CountryCode = aws.String(pieces2[1]) case "SUBDIVISION": geoLocation.SubdivisionCode = aws.String(pieces2[1]) } } } } } else { geoLocation = &route53.GeoLocation{ CountryCode: aws.String(recordRoute), } } var ttl = newRecord.TTL if ttl <= 0 { ttl = 600 } var input = &route53.ChangeResourceRecordSetsInput{ ChangeBatch: &route53.ChangeBatch{ Changes: []*route53.Change{ { Action: aws.String("UPSERT"), ResourceRecordSet: &route53.ResourceRecordSet{ AliasTarget: nil, Failover: nil, GeoLocation: geoLocation, HealthCheckId: nil, MultiValueAnswer: nil, Name: aws.String(newRecord.Name + "." + domain + "."), Region: nil, ResourceRecords: recordValues, SetIdentifier: aws.String(recordSetId), TTL: aws.Int64(int64(ttl)), TrafficPolicyInstanceId: nil, Type: aws.String(newRecord.Type), Weight: nil, }, }, }, Comment: nil, }, HostedZoneId: zoneId, } _, err = this.client.ChangeResourceRecordSets(input) if err != nil { return err } newRecord.Id = this.composeRecordId(recordSetId, newRecord.Name, newRecord.Type, newRecord.Value) // 修改缓存 if this.ProviderId > 0 { sharedDomainRecordsCache.UpdateDomainRecord(this.ProviderId, domain, newRecord) } return nil } // DeleteRecord 删除记录 func (this *AmazonRoute53Provider) DeleteRecord(domain string, record *dnstypes.Record) error { if record == nil { return errors.New("invalid record to delete") } recordSetId, recordName, recordType, recordValue, err := this.decodeRecordId(record.Id) if err != nil { return fmt.Errorf("decode record id failed: %w", err) } zoneId, err := this.getZoneId(domain) if err != nil || zoneId == nil { return err } var newRecordValues []*route53.ResourceRecord var foundRecordSet *route53.ResourceRecordSet { var input = &route53.ListResourceRecordSetsInput{ HostedZoneId: zoneId, StartRecordIdentifier: aws.String(recordSetId), StartRecordName: aws.String(recordName + "." + domain + "."), StartRecordType: aws.String(recordType), } output, err := this.client.ListResourceRecordSets(input) if err != nil { return fmt.Errorf("ListResourceRecordSets failed: %w", err) } if output.ResourceRecordSets != nil { for _, recordSet := range output.ResourceRecordSets { if recordSet.SetIdentifier == nil { continue } if *recordSet.SetIdentifier == recordSetId { foundRecordSet = recordSet var foundRecord = false for _, rawRecord := range recordSet.ResourceRecords { if *rawRecord.Value == recordValue { foundRecord = true } else { newRecordValues = append(newRecordValues, &route53.ResourceRecord{ Value: aws.String(*rawRecord.Value), }) } } if !foundRecord { return nil } break } } } } if foundRecordSet == nil { return nil } var action = "UPSERT" if len(newRecordValues) == 0 { action = "DELETE" } else { foundRecordSet.ResourceRecords = newRecordValues } var input = &route53.ChangeResourceRecordSetsInput{ ChangeBatch: &route53.ChangeBatch{ Changes: []*route53.Change{ { Action: aws.String(action), ResourceRecordSet: foundRecordSet, }, }, Comment: nil, }, HostedZoneId: zoneId, } _, err = this.client.ChangeResourceRecordSets(input) if err != nil { return err } // 删除缓存 if this.ProviderId > 0 { sharedDomainRecordsCache.DeleteDomainRecord(this.ProviderId, domain, record.Id) } return nil } // DefaultRoute 默认线路 func (this *AmazonRoute53Provider) DefaultRoute() string { return "*" } func (this *AmazonRoute53Provider) getZoneId(domain string) (*string, error) { var input = &route53.ListHostedZonesByNameInput{ DNSName: aws.String(domain + "."), HostedZoneId: nil, MaxItems: nil, } output, err := this.client.ListHostedZonesByName(input) if err != nil { return nil, err } for _, zone := range output.HostedZones { if strings.TrimSuffix(*zone.Name, ".") == domain { return zone.Id, nil } } return nil, nil } func (this *AmazonRoute53Provider) composeGeoLocationCode(location *route53.GeoLocation) (locationCode string) { if location == nil { return } var codes []string if location.ContinentCode != nil && len(*location.ContinentCode) > 0 { codes = append(codes, "CONTINENT:"+(*location.ContinentCode)) } if location.CountryCode != nil && len(*location.CountryCode) > 0 { if *location.CountryCode == "*" { codes = append(codes, "*") } else { codes = append(codes, "COUNTRY:"+(*location.CountryCode)) } } if location.SubdivisionCode != nil && len(*location.SubdivisionCode) > 0 { codes = append(codes, "SUBDIVISION:"+(*location.SubdivisionCode)) } return strings.Join(codes, "@") } func (this *AmazonRoute53Provider) composeGeoLocationDetail(location *route53.GeoLocationDetails) (locationName string, locationCode string) { if location == nil { return } var names []string var codes []string if location.ContinentName != nil && len(*location.ContinentName) > 0 { names = append(names, "CONTINENT:"+(*location.ContinentName)) } if location.ContinentCode != nil && len(*location.ContinentCode) > 0 { codes = append(codes, "CONTINENT:"+(*location.ContinentCode)) } if location.CountryName != nil && len(*location.CountryName) > 0 { if *location.CountryName == "Default" { names = append(names, "Default") } else { names = append(names, "COUNTRY/REGION:"+(*location.CountryName)) } } if location.CountryCode != nil && len(*location.CountryCode) > 0 { if *location.CountryCode == "*" { codes = append(codes, "*") } else { codes = append(codes, "COUNTRY:"+(*location.CountryCode)) } } if location.SubdivisionName != nil && len(*location.SubdivisionName) > 0 { names = append(names, "SUBDIVISION:"+(*location.SubdivisionName)) } if location.SubdivisionCode != nil && len(*location.SubdivisionCode) > 0 { codes = append(codes, "SUBDIVISION:"+(*location.SubdivisionCode)) } return strings.Join(names, " "), strings.Join(codes, "@") } func (this *AmazonRoute53Provider) composeRecordId(recordSetId string, recordName string, recordType string, recordValue string) string { return base64.StdEncoding.EncodeToString([]byte(recordSetId + "$" + recordName + "$" + recordType + "$" + recordValue)) } func (this *AmazonRoute53Provider) decodeRecordId(encodedRecordId string) (recordSetId string, recordName string, recordType string, recordValue string, err error) { data, err := base64.StdEncoding.DecodeString(encodedRecordId) if err != nil { return "", "", "", "", err } if len(data) == 0 { err = errors.New("invalid record id") return } var pieces = strings.SplitN(string(data), "$", 4) if len(pieces) != 4 { err = errors.New("invalid record id") return } recordSetId = pieces[0] recordName = pieces[1] recordType = pieces[2] recordValue = pieces[3] return } func (this *AmazonRoute53Provider) fixCNAME(recordType string, recordValue string) string { // 修正Record if strings.ToUpper(recordType) == dnstypes.RecordTypeCNAME && !strings.HasSuffix(recordValue, ".") { recordValue += "." } return recordValue } // AmazonCredentialProvider Amazon认证服务 type AmazonCredentialProvider struct { accessKeyId string secretAccessKey string } func NewAmazonCredentialProvider(accessKeyId string, secretAccessKey string) *AmazonCredentialProvider { return &AmazonCredentialProvider{ accessKeyId: accessKeyId, secretAccessKey: secretAccessKey, } } func (this *AmazonCredentialProvider) Retrieve() (credentials.Value, error) { return credentials.Value{ AccessKeyID: this.accessKeyId, SecretAccessKey: this.secretAccessKey, SessionToken: "", ProviderName: "", }, nil } func (this *AmazonCredentialProvider) IsExpired() bool { return false }