Files
waf-platform/EdgeAPI/internal/dnsclients/provider_volc_engine_plus.go

453 lines
11 KiB
Go

// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package dnsclients
import (
"context"
"errors"
"github.com/TeaOSLab/EdgeAPI/internal/dnsclients/dnstypes"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
volcdns "github.com/volcengine/volc-sdk-golang/service/dns"
"strconv"
"strings"
)
// VolcEngineProvider 火山引擎
type VolcEngineProvider struct {
BaseProvider
ProviderId int64
client *volcdns.Client
}
// Auth 认证
func (this *VolcEngineProvider) Auth(params maps.Map) error {
var accessKeyId = params.GetString("accessKeyId")
var accessKeySecret = params.GetString("accessKeySecret")
if len(accessKeyId) == 0 {
return errors.New("'accessKeyId' required")
}
if len(accessKeySecret) == 0 {
return errors.New("'accessKeySecret' required")
}
var caller = volcdns.NewVolcCaller()
if caller.Volc == nil {
return errors.New("system error: caller.Volc == nil")
}
caller.Volc.SetAccessKey(accessKeyId)
caller.Volc.SetSecretKey(accessKeySecret)
this.client = volcdns.NewClient(caller)
return nil
}
// MaskParams 对参数进行掩码
func (this *VolcEngineProvider) MaskParams(params maps.Map) {
if params == nil {
return
}
params["accessKeySecret"] = MaskString(params.GetString("accessKeySecret"))
}
// GetDomains 获取所有域名列表
func (this *VolcEngineProvider) GetDomains() (domains []string, err error) {
var pageSizeInt = 500
var pageSize = strconv.Itoa(pageSizeInt)
for i := 1; i < 10_000; i++ {
var pageNumber = strconv.Itoa(i)
zonesResp, err := this.client.ListZones(context.Background(), &volcdns.ListZonesRequest{
PageNumber: &pageNumber,
PageSize: &pageSize,
})
if err != nil {
return nil, err
}
if zonesResp == nil {
break
}
var zones = zonesResp.Zones
for _, zone := range zones {
domains = append(domains, *zone.ZoneName)
}
if len(zones) < pageSizeInt {
break
}
}
return
}
// GetRecords 获取域名列表
func (this *VolcEngineProvider) GetRecords(domain string) (records []*dnstypes.Record, err error) {
zoneId, err := this.getZoneId(domain)
if err != nil || zoneId <= 0 {
return nil, err
}
var pageSizeInt = 500
var pageSize = strconv.Itoa(pageSizeInt)
for i := 1; i < 10_000; i++ {
recordsResp, err := this.client.ListRecords(context.Background(), &volcdns.ListRecordsRequest{
ZID: this.stringVal(types.String(zoneId)),
PageSize: &pageSize,
PageNumber: this.stringVal(strconv.Itoa(i)),
})
if err != nil {
return nil, err
}
if recordsResp == nil {
break
}
for _, record := range recordsResp.Records {
records = append(records, &dnstypes.Record{
Id: *record.RecordID,
Name: *record.Host,
Type: strings.ToUpper(*record.Type),
Value: this.fixCNAME(record.Type, record.Value),
Route: *record.Line,
TTL: types.Int32(*record.TTL),
})
}
if len(recordsResp.Records) < pageSizeInt {
break
}
}
// 写入缓存
if this.ProviderId > 0 {
sharedDomainRecordsCache.WriteDomainRecords(this.ProviderId, domain, records)
}
return
}
// GetRoutes 读取线路数据
func (this *VolcEngineProvider) GetRoutes(domain string) (routes []*dnstypes.Route, err error) {
zoneId, err := this.getZoneId(domain)
if err != nil || zoneId <= 0 {
return nil, err
}
// 公共线路
var defaultRoute *dnstypes.Route
{
var pageSize = 2000
var pageSizeVal = this.int64Val(int64(pageSize))
for i := 1; i < 100; i++ {
linesResp, err := this.client.ListLines(context.Background(), &volcdns.ListLinesRequest{
ZID: this.int64Val(zoneId),
Hierarchy: this.stringVal("false"),
PageSize: pageSizeVal,
PageNumber: this.intVal(i),
})
if err != nil {
return nil, err
}
if linesResp == nil {
break
}
for _, line := range linesResp.Lines {
if *line.Value == "default" {
defaultRoute = &dnstypes.Route{
Name: *line.Name,
Code: *line.Value,
}
continue
}
routes = append(routes, &dnstypes.Route{
Name: *line.Name,
Code: *line.Value,
})
}
if len(linesResp.Lines) < pageSize {
break
}
}
}
// 自定义线路
{
var pageSize = 500
var pageSizeVal = this.int64Val(int64(pageSize))
for i := 1; i < 100; i++ {
linesResp, err := this.client.ListCustomLines(context.Background(), &volcdns.ListCustomLinesRequest{
PageSize: pageSizeVal,
PageNumber: this.intVal(i),
})
if err != nil {
return nil, err
}
if linesResp == nil {
break
}
for _, line := range linesResp.CustomerLines {
routes = append(routes, &dnstypes.Route{
Name: "自定义:" + (*line.NameCN),
Code: *line.Line,
})
}
if len(linesResp.CustomerLines) < pageSize {
break
}
}
}
// 将default放在最前面
if defaultRoute != nil {
routes = append([]*dnstypes.Route{defaultRoute}, routes...)
}
return
}
// QueryRecord 查询单个记录
func (this *VolcEngineProvider) QueryRecord(domain string, name string, recordType dnstypes.RecordType) (*dnstypes.Record, error) {
// 从缓存中读取
if this.ProviderId > 0 {
record, hasRecords, _ := sharedDomainRecordsCache.QueryDomainRecord(this.ProviderId, domain, name, recordType)
if hasRecords { // 有效的搜索
return record, nil
}
}
zoneId, err := this.getZoneId(domain)
if err != nil || zoneId <= 0 {
return nil, err
}
recordsResp, err := this.client.ListRecords(context.Background(), &volcdns.ListRecordsRequest{
ZID: this.int64Val(zoneId),
Host: this.stringVal(name),
Type: this.stringVal(recordType),
SearchMode: this.stringVal("exact"),
PageNumber: this.stringVal("1"),
PageSize: this.stringVal(strconv.Itoa(500)),
})
if err != nil || recordsResp == nil {
return nil, err
}
for _, record := range recordsResp.Records {
if *record.Host == name && *record.Type == recordType {
return &dnstypes.Record{
Id: *record.RecordID,
Name: name,
Type: recordType,
Value: this.fixCNAME(record.Type, record.Value),
Route: *record.Line,
TTL: types.Int32(*record.TTL),
}, nil
}
}
return nil, nil
}
// QueryRecords 查询多个记录
func (this *VolcEngineProvider) QueryRecords(domain string, name string, recordType dnstypes.RecordType) ([]*dnstypes.Record, error) {
// 从缓存中读取
if this.ProviderId > 0 {
records, hasRecords, _ := sharedDomainRecordsCache.QueryDomainRecords(this.ProviderId, domain, name, recordType)
if hasRecords { // 有效的搜索
return records, nil
}
}
zoneId, err := this.getZoneId(domain)
if err != nil || zoneId <= 0 {
return nil, err
}
var pageSizeInt = 500
var pageSizeVal = this.intVal(pageSizeInt)
var result []*dnstypes.Record
for i := 1; i < 1_000; i++ {
recordsResp, err := this.client.ListRecords(context.Background(), &volcdns.ListRecordsRequest{
ZID: this.int64Val(zoneId),
Host: this.stringVal(name),
Type: this.stringVal(recordType),
SearchMode: this.stringVal("exact"),
PageNumber: this.stringVal(strconv.Itoa(i)),
PageSize: pageSizeVal,
})
if err != nil || recordsResp == nil {
return nil, err
}
for _, record := range recordsResp.Records {
if *record.Host == name && *record.Type == recordType {
result = append(result, &dnstypes.Record{
Id: *record.RecordID,
Name: name,
Type: recordType,
Value: this.fixCNAME(record.Type, record.Value),
Route: *record.Line,
TTL: types.Int32(*record.TTL),
})
}
}
if len(recordsResp.Records) < pageSizeInt {
break
}
}
return result, nil
}
// AddRecord 设置记录
func (this *VolcEngineProvider) AddRecord(domain string, newRecord *dnstypes.Record) error {
if newRecord == nil {
return errors.New("invalid new record")
}
zoneId, err := this.getZoneId(domain)
if err != nil || zoneId <= 0 {
return err
}
// 在CHANGE记录后面加入点
if newRecord.Type == dnstypes.RecordTypeCNAME && !strings.HasSuffix(newRecord.Value, ".") {
newRecord.Value += "."
}
var ttl = int64(newRecord.TTL)
if ttl <= 0 {
ttl = 600
}
createResp, err := this.client.CreateRecord(context.Background(), &volcdns.CreateRecordRequest{
Host: this.stringVal(newRecord.Name),
Line: this.stringVal(newRecord.Route),
Remark: nil,
TTL: &ttl,
Type: this.stringVal(newRecord.Type),
Value: this.stringVal(newRecord.Value),
Weight: nil,
ZID: &zoneId,
})
if err != nil {
return err
}
newRecord.Id = *createResp.RecordID
// 加入缓存
if this.ProviderId > 0 {
sharedDomainRecordsCache.AddDomainRecord(this.ProviderId, domain, newRecord)
}
return nil
}
// UpdateRecord 修改记录
func (this *VolcEngineProvider) UpdateRecord(domain string, record *dnstypes.Record, newRecord *dnstypes.Record) error {
if record == nil {
return errors.New("invalid record")
}
if newRecord == nil {
return errors.New("invalid new record")
}
// 在CHANGE记录后面加入点
if newRecord.Type == dnstypes.RecordTypeCNAME && !strings.HasSuffix(newRecord.Value, ".") {
newRecord.Value += "."
}
var ttl = int64(newRecord.TTL)
if ttl <= 0 {
ttl = 600
}
_, err := this.client.UpdateRecord(context.Background(), &volcdns.UpdateRecordRequest{
Host: newRecord.Name,
Line: newRecord.Route,
RecordID: record.Id,
TTL: &ttl,
Type: this.stringVal(newRecord.Type),
Value: this.stringVal(newRecord.Value),
Weight: nil,
})
if err != nil {
return err
}
// 修改缓存
if this.ProviderId > 0 {
sharedDomainRecordsCache.UpdateDomainRecord(this.ProviderId, domain, newRecord)
}
return nil
}
// DeleteRecord 删除记录
func (this *VolcEngineProvider) DeleteRecord(domain string, record *dnstypes.Record) error {
if record == nil {
return errors.New("invalid record to delete")
}
err := this.client.DeleteRecord(context.Background(), &volcdns.DeleteRecordRequest{
RecordID: this.stringVal(record.Id),
})
if err != nil {
// ignore not found error
var topErr *volcdns.TOPError
if errors.As(err, &topErr) && topErr != nil && topErr.Code == "ErrDBNotFound" {
return nil
}
return err
}
// 删除缓存
if this.ProviderId > 0 {
sharedDomainRecordsCache.DeleteDomainRecord(this.ProviderId, domain, record.Id)
}
return nil
}
// DefaultRoute 默认线路
func (this *VolcEngineProvider) DefaultRoute() string {
return "default"
}
func (this *VolcEngineProvider) getZoneId(domain string) (int64, error) {
zonesResp, err := this.client.ListZones(context.Background(), &volcdns.ListZonesRequest{
Key: &domain,
SearchMode: this.stringVal("exact"),
})
if err != nil || zonesResp == nil || len(zonesResp.Zones) == 0 {
return 0, err
}
var zoneId int64
for _, zone := range zonesResp.Zones {
if *zone.ZoneName == domain {
zoneId = *zone.ZID
break
}
}
return zoneId, nil
}
func (this *VolcEngineProvider) stringVal(s string) *string {
return &s
}
func (this *VolcEngineProvider) int64Val(i int64) *string {
return this.stringVal(types.String(i))
}
func (this *VolcEngineProvider) intVal(i int) *string {
return this.stringVal(strconv.Itoa(i))
}
func (this *VolcEngineProvider) fixCNAME(recordType *string, recordValue *string) string {
// 修正Record
if strings.ToUpper(*recordType) == dnstypes.RecordTypeCNAME && !strings.HasSuffix(*recordValue, ".") {
recordValue = this.stringVal(*recordValue + ".")
}
return *recordValue
}