Files
waf-platform/EdgeAPI/internal/dnsclients/provider_amazon_route53_plus.go
2026-02-04 20:27:13 +08:00

808 lines
22 KiB
Go

// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package dnsclients
import (
"encoding/base64"
"errors"
"fmt"
"github.com/TeaOSLab/EdgeAPI/internal/dnsclients/dnstypes"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/route53"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
stringutil "github.com/iwind/TeaGo/utils/string"
"strings"
)
// AmazonRoute53Provider Amazon Route 53
// 其中我们总是判断 recordSet.SetIdentifier != nil 是为了只关注我们自己创建的记录集
type AmazonRoute53Provider struct {
BaseProvider
ProviderId int64
client *route53.Route53
}
func NewAmazonRoute53Provider() *AmazonRoute53Provider {
return &AmazonRoute53Provider{}
}
// Auth 认证
func (this *AmazonRoute53Provider) Auth(params maps.Map) error {
var accessKeyId = params.GetString("accessKeyId")
var accessKeySecret = params.GetString("accessKeySecret")
var region = params.GetString("region")
if len(accessKeyId) == 0 {
return errors.New("'accessKeyId' required")
}
if len(accessKeySecret) == 0 {
return errors.New("'accessKeySecret' required")
}
var regionPtr *string
if len(region) > 0 {
regionPtr = aws.String(region)
}
sess, err := session.NewSession(&aws.Config{
Credentials: credentials.NewCredentials(NewAmazonCredentialProvider(accessKeyId, accessKeySecret)),
Region: regionPtr,
})
if err != nil {
return err
}
this.client = route53.New(sess)
return nil
}
// MaskParams 对参数进行掩码
func (this *AmazonRoute53Provider) MaskParams(params maps.Map) {
if params == nil {
return
}
params["accessKeySecret"] = MaskString(params.GetString("accessKeySecret"))
}
// GetDomains 获取所有域名列表
func (this *AmazonRoute53Provider) GetDomains() (domains []string, err error) {
var nextMarker *string
for {
var input = &route53.ListHostedZonesInput{
Marker: nextMarker,
}
output, err := this.client.ListHostedZones(input)
if err != nil {
return nil, err
}
for _, zone := range output.HostedZones {
domains = append(domains, strings.TrimSuffix(*zone.Name, "."))
}
if *output.IsTruncated {
nextMarker = output.NextMarker
} else {
break
}
}
return
}
// GetRecords 获取域名列表
func (this *AmazonRoute53Provider) GetRecords(domain string) (records []*dnstypes.Record, err error) {
zoneId, err := this.getZoneId(domain)
if err != nil || zoneId == nil {
return nil, err
}
var nextRecordIdentifier *string
var nextRecordName *string
var nextRecordType *string
for {
var input = &route53.ListResourceRecordSetsInput{
HostedZoneId: zoneId,
StartRecordIdentifier: nextRecordIdentifier,
StartRecordName: nextRecordName,
StartRecordType: nextRecordType,
}
output, err := this.client.ListResourceRecordSets(input)
if err != nil {
return nil, err
}
for _, recordSet := range output.ResourceRecordSets {
// 检查返回值是否正常
if recordSet.SetIdentifier == nil || recordSet.Name == nil || recordSet.Type == nil || recordSet.TTL == nil {
continue
}
for _, rawRecord := range recordSet.ResourceRecords {
if rawRecord.Value == nil {
continue
}
var recordName = strings.TrimSuffix(strings.TrimSuffix(*recordSet.Name, domain+"."), ".")
records = append(records, &dnstypes.Record{
Id: this.composeRecordId(*recordSet.SetIdentifier, recordName, *recordSet.Type, *rawRecord.Value),
Name: recordName,
Type: *recordSet.Type,
Value: *rawRecord.Value,
Route: this.composeGeoLocationCode(recordSet.GeoLocation),
TTL: types.Int32(*recordSet.TTL),
})
}
}
if *output.IsTruncated {
nextRecordIdentifier = output.NextRecordIdentifier
nextRecordName = output.NextRecordName
nextRecordType = output.NextRecordType
} else {
break
}
}
// 写入缓存
if this.ProviderId > 0 {
sharedDomainRecordsCache.WriteDomainRecords(this.ProviderId, domain, records)
}
return
}
// GetRoutes 读取线路数据
func (this *AmazonRoute53Provider) GetRoutes(domain string) (routes []*dnstypes.Route, err error) {
var nextContinentCode *string
var nextCountryCode *string
var nextSubdivisionCode *string
for {
var input = &route53.ListGeoLocationsInput{
MaxItems: nil,
StartContinentCode: nextContinentCode,
StartCountryCode: nextCountryCode,
StartSubdivisionCode: nextSubdivisionCode,
}
output, err := this.client.ListGeoLocations(input)
if err != nil {
return nil, err
}
for _, location := range output.GeoLocationDetailsList {
locationName, locationCode := this.composeGeoLocationDetail(location)
routes = append(routes, &dnstypes.Route{
Name: locationName,
Code: locationCode,
})
}
if *output.IsTruncated {
nextContinentCode = output.NextContinentCode
nextCountryCode = output.NextCountryCode
nextSubdivisionCode = output.NextSubdivisionCode
} else {
break
}
}
return
}
// QueryRecord 查询单个记录
func (this *AmazonRoute53Provider) QueryRecord(domain string, name string, recordType dnstypes.RecordType) (*dnstypes.Record, error) {
// 从缓存中读取
if this.ProviderId > 0 {
record, hasRecords, _ := sharedDomainRecordsCache.QueryDomainRecord(this.ProviderId, domain, name, recordType)
if hasRecords { // 有效的搜索
return record, nil
}
}
zoneId, err := this.getZoneId(domain)
if err != nil || zoneId == nil {
return nil, err
}
var fullname = name + "." + domain + "."
var input = &route53.ListResourceRecordSetsInput{
HostedZoneId: zoneId,
StartRecordName: aws.String(fullname),
StartRecordType: aws.String(recordType),
}
output, err := this.client.ListResourceRecordSets(input)
if err != nil {
return nil, err
}
for _, recordSet := range output.ResourceRecordSets {
if recordSet.SetIdentifier == nil {
continue
}
if *recordSet.Name == fullname && *recordSet.Type == recordType {
for _, rawRecord := range recordSet.ResourceRecords {
return &dnstypes.Record{
Id: this.composeRecordId(*recordSet.SetIdentifier, name, *recordSet.Type, *rawRecord.Value),
Name: name,
Type: *recordSet.Type,
Value: *rawRecord.Value,
Route: this.composeGeoLocationCode(recordSet.GeoLocation),
TTL: types.Int32(*recordSet.TTL),
}, nil
}
}
}
return nil, nil
}
// QueryRecords 查询多个记录
func (this *AmazonRoute53Provider) QueryRecords(domain string, name string, recordType dnstypes.RecordType) ([]*dnstypes.Record, error) {
// 从缓存中读取
if this.ProviderId > 0 {
records, hasRecords, _ := sharedDomainRecordsCache.QueryDomainRecords(this.ProviderId, domain, name, recordType)
if hasRecords { // 有效的搜索
return records, nil
}
}
zoneId, err := this.getZoneId(domain)
if err != nil || zoneId == nil {
return nil, err
}
var fullname = name + "." + domain + "."
var input = &route53.ListResourceRecordSetsInput{
HostedZoneId: zoneId,
StartRecordName: aws.String(fullname),
StartRecordType: aws.String(recordType),
}
output, err := this.client.ListResourceRecordSets(input)
if err != nil {
return nil, err
}
var result []*dnstypes.Record
for _, recordSet := range output.ResourceRecordSets {
if recordSet.SetIdentifier == nil {
continue
}
if *recordSet.Name == fullname && *recordSet.Type == recordType {
for _, rawRecord := range recordSet.ResourceRecords {
result = append(result, &dnstypes.Record{
Id: this.composeRecordId(*recordSet.SetIdentifier, name, *recordSet.Type, *rawRecord.Value),
Name: name,
Type: *recordSet.Type,
Value: *rawRecord.Value,
Route: this.composeGeoLocationCode(recordSet.GeoLocation),
TTL: types.Int32(*recordSet.TTL),
})
}
}
}
return result, nil
}
// AddRecord 设置记录
func (this *AmazonRoute53Provider) AddRecord(domain string, newRecord *dnstypes.Record) error {
if newRecord == nil {
return errors.New("invalid new record")
}
// 在CHANGE记录后面加入点
if newRecord.Type == dnstypes.RecordTypeCNAME && !strings.HasSuffix(newRecord.Value, ".") {
newRecord.Value += "."
}
zoneId, err := this.getZoneId(domain)
if err != nil || zoneId == nil {
return err
}
var recordSetId = newRecord.Name + stringutil.Md5(newRecord.Name+"@"+newRecord.Type+"@"+newRecord.Route)
// 检查是否已存在
var recordValues []*route53.ResourceRecord
var existRecordValue = false
{
var input = &route53.ListResourceRecordSetsInput{
HostedZoneId: zoneId,
StartRecordName: aws.String(newRecord.Name + "." + domain + "."),
StartRecordType: aws.String(newRecord.Type),
}
output, err := this.client.ListResourceRecordSets(input)
if err != nil {
return fmt.Errorf("ListResourceRecordSets failed: %w", err)
}
if output.ResourceRecordSets != nil {
for _, recordSet := range output.ResourceRecordSets {
if recordSet.SetIdentifier == nil {
continue
}
if *recordSet.SetIdentifier == recordSetId {
for _, rawRecord := range recordSet.ResourceRecords {
recordValues = append(recordValues, &route53.ResourceRecord{
Value: aws.String(*rawRecord.Value),
})
if *rawRecord.Value == newRecord.Value {
existRecordValue = true
}
}
break
}
}
}
}
if existRecordValue {
return nil
}
recordValues = append(recordValues, &route53.ResourceRecord{
Value: aws.String(newRecord.Value),
})
var geoLocation *route53.GeoLocation
var recordRoute = newRecord.Route
if len(recordRoute) == 0 {
recordRoute = this.DefaultRoute()
}
if recordRoute != this.DefaultRoute() {
for _, piece := range strings.Split(recordRoute, "@") {
if strings.Contains(piece, ":") {
var pieces2 = strings.SplitN(piece, ":", 2)
if len(pieces2) == 2 && pieces2[1] != "*" {
if geoLocation == nil {
geoLocation = &route53.GeoLocation{}
}
switch pieces2[0] {
case "CONTINENT":
geoLocation.ContinentCode = aws.String(pieces2[1])
case "COUNTRY":
geoLocation.CountryCode = aws.String(pieces2[1])
case "SUBDIVISION":
geoLocation.SubdivisionCode = aws.String(pieces2[1])
}
}
}
}
} else {
geoLocation = &route53.GeoLocation{
CountryCode: aws.String(recordRoute),
}
}
var ttl = newRecord.TTL
if ttl <= 0 {
ttl = 600
}
var input = &route53.ChangeResourceRecordSetsInput{
ChangeBatch: &route53.ChangeBatch{
Changes: []*route53.Change{
{
Action: aws.String("UPSERT"),
ResourceRecordSet: &route53.ResourceRecordSet{
AliasTarget: nil,
Failover: nil,
GeoLocation: geoLocation,
HealthCheckId: nil,
MultiValueAnswer: nil,
Name: aws.String(newRecord.Name + "." + domain + "."),
Region: nil,
ResourceRecords: recordValues,
SetIdentifier: aws.String(recordSetId),
TTL: aws.Int64(int64(ttl)),
TrafficPolicyInstanceId: nil,
Type: aws.String(newRecord.Type),
Weight: nil,
},
},
},
Comment: nil,
},
HostedZoneId: zoneId,
}
_, err = this.client.ChangeResourceRecordSets(input)
if err != nil {
return err
}
newRecord.Id = this.composeRecordId(recordSetId, newRecord.Name, newRecord.Type, newRecord.Value)
// 加入缓存
if this.ProviderId > 0 {
sharedDomainRecordsCache.AddDomainRecord(this.ProviderId, domain, newRecord)
}
return nil
}
// UpdateRecord 修改记录
func (this *AmazonRoute53Provider) UpdateRecord(domain string, record *dnstypes.Record, newRecord *dnstypes.Record) error {
if record == nil {
return errors.New("invalid record")
}
if newRecord == nil {
return errors.New("invalid new record")
}
// 在CHANGE记录后面加入点
if newRecord.Type == dnstypes.RecordTypeCNAME && !strings.HasSuffix(newRecord.Value, ".") {
newRecord.Value += "."
}
var newRoute = newRecord.Route
if len(newRoute) == 0 {
newRoute = this.DefaultRoute()
}
recordSetId, recordName, recordType, _, err := this.decodeRecordId(record.Id)
if err != nil {
return err
}
zoneId, err := this.getZoneId(domain)
if err != nil || zoneId == nil {
return err
}
// 检查是否已存在
var recordValues []*route53.ResourceRecord
var shouldChanged = false
{
var input = &route53.ListResourceRecordSetsInput{
HostedZoneId: zoneId,
StartRecordName: aws.String(newRecord.Name + "." + domain + "."),
StartRecordType: aws.String(newRecord.Type),
}
output, err := this.client.ListResourceRecordSets(input)
if err != nil {
return fmt.Errorf("ListResourceRecordSets failed: %w", err)
}
if output.ResourceRecordSets != nil {
for _, recordSet := range output.ResourceRecordSets {
if recordSet.SetIdentifier == nil {
continue
}
if *recordSet.SetIdentifier == recordSetId {
shouldChanged = true
for _, rawRecord := range recordSet.ResourceRecords {
// skip old
if record.Id == this.composeRecordId(recordSetId, recordName, recordType, *rawRecord.Value) || *rawRecord.Value == newRecord.Value {
continue
}
recordValues = append(recordValues, &route53.ResourceRecord{
Value: aws.String(*rawRecord.Value),
})
}
break
}
}
}
}
if !shouldChanged {
return nil
}
recordValues = append(recordValues, &route53.ResourceRecord{
Value: aws.String(newRecord.Value),
})
var geoLocation *route53.GeoLocation
var recordRoute = newRecord.Route
if len(recordRoute) == 0 {
recordRoute = this.DefaultRoute()
}
if recordRoute != this.DefaultRoute() {
for _, piece := range strings.Split(recordRoute, "@") {
if strings.Contains(piece, ":") {
var pieces2 = strings.SplitN(piece, ":", 2)
if len(pieces2) == 2 && pieces2[1] != "*" {
if geoLocation == nil {
geoLocation = &route53.GeoLocation{}
}
switch pieces2[0] {
case "CONTINENT":
geoLocation.ContinentCode = aws.String(pieces2[1])
case "COUNTRY":
geoLocation.CountryCode = aws.String(pieces2[1])
case "SUBDIVISION":
geoLocation.SubdivisionCode = aws.String(pieces2[1])
}
}
}
}
} else {
geoLocation = &route53.GeoLocation{
CountryCode: aws.String(recordRoute),
}
}
var ttl = newRecord.TTL
if ttl <= 0 {
ttl = 600
}
var input = &route53.ChangeResourceRecordSetsInput{
ChangeBatch: &route53.ChangeBatch{
Changes: []*route53.Change{
{
Action: aws.String("UPSERT"),
ResourceRecordSet: &route53.ResourceRecordSet{
AliasTarget: nil,
Failover: nil,
GeoLocation: geoLocation,
HealthCheckId: nil,
MultiValueAnswer: nil,
Name: aws.String(newRecord.Name + "." + domain + "."),
Region: nil,
ResourceRecords: recordValues,
SetIdentifier: aws.String(recordSetId),
TTL: aws.Int64(int64(ttl)),
TrafficPolicyInstanceId: nil,
Type: aws.String(newRecord.Type),
Weight: nil,
},
},
},
Comment: nil,
},
HostedZoneId: zoneId,
}
_, err = this.client.ChangeResourceRecordSets(input)
if err != nil {
return err
}
newRecord.Id = this.composeRecordId(recordSetId, newRecord.Name, newRecord.Type, newRecord.Value)
// 修改缓存
if this.ProviderId > 0 {
sharedDomainRecordsCache.UpdateDomainRecord(this.ProviderId, domain, newRecord)
}
return nil
}
// DeleteRecord 删除记录
func (this *AmazonRoute53Provider) DeleteRecord(domain string, record *dnstypes.Record) error {
if record == nil {
return errors.New("invalid record to delete")
}
recordSetId, recordName, recordType, recordValue, err := this.decodeRecordId(record.Id)
if err != nil {
return fmt.Errorf("decode record id failed: %w", err)
}
zoneId, err := this.getZoneId(domain)
if err != nil || zoneId == nil {
return err
}
var newRecordValues []*route53.ResourceRecord
var foundRecordSet *route53.ResourceRecordSet
{
var input = &route53.ListResourceRecordSetsInput{
HostedZoneId: zoneId,
StartRecordIdentifier: aws.String(recordSetId),
StartRecordName: aws.String(recordName + "." + domain + "."),
StartRecordType: aws.String(recordType),
}
output, err := this.client.ListResourceRecordSets(input)
if err != nil {
return fmt.Errorf("ListResourceRecordSets failed: %w", err)
}
if output.ResourceRecordSets != nil {
for _, recordSet := range output.ResourceRecordSets {
if recordSet.SetIdentifier == nil {
continue
}
if *recordSet.SetIdentifier == recordSetId {
foundRecordSet = recordSet
var foundRecord = false
for _, rawRecord := range recordSet.ResourceRecords {
if *rawRecord.Value == recordValue {
foundRecord = true
} else {
newRecordValues = append(newRecordValues, &route53.ResourceRecord{
Value: aws.String(*rawRecord.Value),
})
}
}
if !foundRecord {
return nil
}
break
}
}
}
}
if foundRecordSet == nil {
return nil
}
var action = "UPSERT"
if len(newRecordValues) == 0 {
action = "DELETE"
} else {
foundRecordSet.ResourceRecords = newRecordValues
}
var input = &route53.ChangeResourceRecordSetsInput{
ChangeBatch: &route53.ChangeBatch{
Changes: []*route53.Change{
{
Action: aws.String(action),
ResourceRecordSet: foundRecordSet,
},
},
Comment: nil,
},
HostedZoneId: zoneId,
}
_, err = this.client.ChangeResourceRecordSets(input)
if err != nil {
return err
}
// 删除缓存
if this.ProviderId > 0 {
sharedDomainRecordsCache.DeleteDomainRecord(this.ProviderId, domain, record.Id)
}
return nil
}
// DefaultRoute 默认线路
func (this *AmazonRoute53Provider) DefaultRoute() string {
return "*"
}
func (this *AmazonRoute53Provider) getZoneId(domain string) (*string, error) {
var input = &route53.ListHostedZonesByNameInput{
DNSName: aws.String(domain + "."),
HostedZoneId: nil,
MaxItems: nil,
}
output, err := this.client.ListHostedZonesByName(input)
if err != nil {
return nil, err
}
for _, zone := range output.HostedZones {
if strings.TrimSuffix(*zone.Name, ".") == domain {
return zone.Id, nil
}
}
return nil, nil
}
func (this *AmazonRoute53Provider) composeGeoLocationCode(location *route53.GeoLocation) (locationCode string) {
if location == nil {
return
}
var codes []string
if location.ContinentCode != nil && len(*location.ContinentCode) > 0 {
codes = append(codes, "CONTINENT:"+(*location.ContinentCode))
}
if location.CountryCode != nil && len(*location.CountryCode) > 0 {
if *location.CountryCode == "*" {
codes = append(codes, "*")
} else {
codes = append(codes, "COUNTRY:"+(*location.CountryCode))
}
}
if location.SubdivisionCode != nil && len(*location.SubdivisionCode) > 0 {
codes = append(codes, "SUBDIVISION:"+(*location.SubdivisionCode))
}
return strings.Join(codes, "@")
}
func (this *AmazonRoute53Provider) composeGeoLocationDetail(location *route53.GeoLocationDetails) (locationName string, locationCode string) {
if location == nil {
return
}
var names []string
var codes []string
if location.ContinentName != nil && len(*location.ContinentName) > 0 {
names = append(names, "CONTINENT:"+(*location.ContinentName))
}
if location.ContinentCode != nil && len(*location.ContinentCode) > 0 {
codes = append(codes, "CONTINENT:"+(*location.ContinentCode))
}
if location.CountryName != nil && len(*location.CountryName) > 0 {
if *location.CountryName == "Default" {
names = append(names, "Default")
} else {
names = append(names, "COUNTRY/REGION:"+(*location.CountryName))
}
}
if location.CountryCode != nil && len(*location.CountryCode) > 0 {
if *location.CountryCode == "*" {
codes = append(codes, "*")
} else {
codes = append(codes, "COUNTRY:"+(*location.CountryCode))
}
}
if location.SubdivisionName != nil && len(*location.SubdivisionName) > 0 {
names = append(names, "SUBDIVISION:"+(*location.SubdivisionName))
}
if location.SubdivisionCode != nil && len(*location.SubdivisionCode) > 0 {
codes = append(codes, "SUBDIVISION:"+(*location.SubdivisionCode))
}
return strings.Join(names, " "), strings.Join(codes, "@")
}
func (this *AmazonRoute53Provider) composeRecordId(recordSetId string, recordName string, recordType string, recordValue string) string {
return base64.StdEncoding.EncodeToString([]byte(recordSetId + "$" + recordName + "$" + recordType + "$" + recordValue))
}
func (this *AmazonRoute53Provider) decodeRecordId(encodedRecordId string) (recordSetId string, recordName string, recordType string, recordValue string, err error) {
data, err := base64.StdEncoding.DecodeString(encodedRecordId)
if err != nil {
return "", "", "", "", err
}
if len(data) == 0 {
err = errors.New("invalid record id")
return
}
var pieces = strings.SplitN(string(data), "$", 4)
if len(pieces) != 4 {
err = errors.New("invalid record id")
return
}
recordSetId = pieces[0]
recordName = pieces[1]
recordType = pieces[2]
recordValue = pieces[3]
return
}
func (this *AmazonRoute53Provider) fixCNAME(recordType string, recordValue string) string {
// 修正Record
if strings.ToUpper(recordType) == dnstypes.RecordTypeCNAME && !strings.HasSuffix(recordValue, ".") {
recordValue += "."
}
return recordValue
}
// AmazonCredentialProvider Amazon认证服务
type AmazonCredentialProvider struct {
accessKeyId string
secretAccessKey string
}
func NewAmazonCredentialProvider(accessKeyId string, secretAccessKey string) *AmazonCredentialProvider {
return &AmazonCredentialProvider{
accessKeyId: accessKeyId,
secretAccessKey: secretAccessKey,
}
}
func (this *AmazonCredentialProvider) Retrieve() (credentials.Value, error) {
return credentials.Value{
AccessKeyID: this.accessKeyId,
SecretAccessKey: this.secretAccessKey,
SessionToken: "",
ProviderName: "",
}, nil
}
func (this *AmazonCredentialProvider) IsExpired() bool {
return false
}