// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn . //go:build plus package dnsclients import ( "context" "errors" "github.com/TeaOSLab/EdgeAPI/internal/dnsclients/dnstypes" "github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/types" volcdns "github.com/volcengine/volc-sdk-golang/service/dns" "strconv" "strings" ) // VolcEngineProvider 火山引擎 type VolcEngineProvider struct { BaseProvider ProviderId int64 client *volcdns.Client } // Auth 认证 func (this *VolcEngineProvider) Auth(params maps.Map) error { var accessKeyId = params.GetString("accessKeyId") var accessKeySecret = params.GetString("accessKeySecret") if len(accessKeyId) == 0 { return errors.New("'accessKeyId' required") } if len(accessKeySecret) == 0 { return errors.New("'accessKeySecret' required") } var caller = volcdns.NewVolcCaller() if caller.Volc == nil { return errors.New("system error: caller.Volc == nil") } caller.Volc.SetAccessKey(accessKeyId) caller.Volc.SetSecretKey(accessKeySecret) this.client = volcdns.NewClient(caller) return nil } // MaskParams 对参数进行掩码 func (this *VolcEngineProvider) MaskParams(params maps.Map) { if params == nil { return } params["accessKeySecret"] = MaskString(params.GetString("accessKeySecret")) } // GetDomains 获取所有域名列表 func (this *VolcEngineProvider) GetDomains() (domains []string, err error) { var pageSizeInt = 500 var pageSize = strconv.Itoa(pageSizeInt) for i := 1; i < 10_000; i++ { var pageNumber = strconv.Itoa(i) zonesResp, err := this.client.ListZones(context.Background(), &volcdns.ListZonesRequest{ PageNumber: &pageNumber, PageSize: &pageSize, }) if err != nil { return nil, err } if zonesResp == nil { break } var zones = zonesResp.Zones for _, zone := range zones { domains = append(domains, *zone.ZoneName) } if len(zones) < pageSizeInt { break } } return } // GetRecords 获取域名列表 func (this *VolcEngineProvider) GetRecords(domain string) (records []*dnstypes.Record, err error) { zoneId, err := this.getZoneId(domain) if err != nil || zoneId <= 0 { return nil, err } var pageSizeInt = 500 var pageSize = strconv.Itoa(pageSizeInt) for i := 1; i < 10_000; i++ { recordsResp, err := this.client.ListRecords(context.Background(), &volcdns.ListRecordsRequest{ ZID: this.stringVal(types.String(zoneId)), PageSize: &pageSize, PageNumber: this.stringVal(strconv.Itoa(i)), }) if err != nil { return nil, err } if recordsResp == nil { break } for _, record := range recordsResp.Records { records = append(records, &dnstypes.Record{ Id: *record.RecordID, Name: *record.Host, Type: strings.ToUpper(*record.Type), Value: this.fixCNAME(record.Type, record.Value), Route: *record.Line, TTL: types.Int32(*record.TTL), }) } if len(recordsResp.Records) < pageSizeInt { break } } // 写入缓存 if this.ProviderId > 0 { sharedDomainRecordsCache.WriteDomainRecords(this.ProviderId, domain, records) } return } // GetRoutes 读取线路数据 func (this *VolcEngineProvider) GetRoutes(domain string) (routes []*dnstypes.Route, err error) { zoneId, err := this.getZoneId(domain) if err != nil || zoneId <= 0 { return nil, err } // 公共线路 var defaultRoute *dnstypes.Route { var pageSize = 2000 var pageSizeVal = this.int64Val(int64(pageSize)) for i := 1; i < 100; i++ { linesResp, err := this.client.ListLines(context.Background(), &volcdns.ListLinesRequest{ ZID: this.int64Val(zoneId), Hierarchy: this.stringVal("false"), PageSize: pageSizeVal, PageNumber: this.intVal(i), }) if err != nil { return nil, err } if linesResp == nil { break } for _, line := range linesResp.Lines { if *line.Value == "default" { defaultRoute = &dnstypes.Route{ Name: *line.Name, Code: *line.Value, } continue } routes = append(routes, &dnstypes.Route{ Name: *line.Name, Code: *line.Value, }) } if len(linesResp.Lines) < pageSize { break } } } // 自定义线路 { var pageSize = 500 var pageSizeVal = this.int64Val(int64(pageSize)) for i := 1; i < 100; i++ { linesResp, err := this.client.ListCustomLines(context.Background(), &volcdns.ListCustomLinesRequest{ PageSize: pageSizeVal, PageNumber: this.intVal(i), }) if err != nil { return nil, err } if linesResp == nil { break } for _, line := range linesResp.CustomerLines { routes = append(routes, &dnstypes.Route{ Name: "自定义:" + (*line.NameCN), Code: *line.Line, }) } if len(linesResp.CustomerLines) < pageSize { break } } } // 将default放在最前面 if defaultRoute != nil { routes = append([]*dnstypes.Route{defaultRoute}, routes...) } return } // QueryRecord 查询单个记录 func (this *VolcEngineProvider) QueryRecord(domain string, name string, recordType dnstypes.RecordType) (*dnstypes.Record, error) { // 从缓存中读取 if this.ProviderId > 0 { record, hasRecords, _ := sharedDomainRecordsCache.QueryDomainRecord(this.ProviderId, domain, name, recordType) if hasRecords { // 有效的搜索 return record, nil } } zoneId, err := this.getZoneId(domain) if err != nil || zoneId <= 0 { return nil, err } recordsResp, err := this.client.ListRecords(context.Background(), &volcdns.ListRecordsRequest{ ZID: this.int64Val(zoneId), Host: this.stringVal(name), Type: this.stringVal(recordType), SearchMode: this.stringVal("exact"), PageNumber: this.stringVal("1"), PageSize: this.stringVal(strconv.Itoa(500)), }) if err != nil || recordsResp == nil { return nil, err } for _, record := range recordsResp.Records { if *record.Host == name && *record.Type == recordType { return &dnstypes.Record{ Id: *record.RecordID, Name: name, Type: recordType, Value: this.fixCNAME(record.Type, record.Value), Route: *record.Line, TTL: types.Int32(*record.TTL), }, nil } } return nil, nil } // QueryRecords 查询多个记录 func (this *VolcEngineProvider) QueryRecords(domain string, name string, recordType dnstypes.RecordType) ([]*dnstypes.Record, error) { // 从缓存中读取 if this.ProviderId > 0 { records, hasRecords, _ := sharedDomainRecordsCache.QueryDomainRecords(this.ProviderId, domain, name, recordType) if hasRecords { // 有效的搜索 return records, nil } } zoneId, err := this.getZoneId(domain) if err != nil || zoneId <= 0 { return nil, err } var pageSizeInt = 500 var pageSizeVal = this.intVal(pageSizeInt) var result []*dnstypes.Record for i := 1; i < 1_000; i++ { recordsResp, err := this.client.ListRecords(context.Background(), &volcdns.ListRecordsRequest{ ZID: this.int64Val(zoneId), Host: this.stringVal(name), Type: this.stringVal(recordType), SearchMode: this.stringVal("exact"), PageNumber: this.stringVal(strconv.Itoa(i)), PageSize: pageSizeVal, }) if err != nil || recordsResp == nil { return nil, err } for _, record := range recordsResp.Records { if *record.Host == name && *record.Type == recordType { result = append(result, &dnstypes.Record{ Id: *record.RecordID, Name: name, Type: recordType, Value: this.fixCNAME(record.Type, record.Value), Route: *record.Line, TTL: types.Int32(*record.TTL), }) } } if len(recordsResp.Records) < pageSizeInt { break } } return result, nil } // AddRecord 设置记录 func (this *VolcEngineProvider) AddRecord(domain string, newRecord *dnstypes.Record) error { if newRecord == nil { return errors.New("invalid new record") } zoneId, err := this.getZoneId(domain) if err != nil || zoneId <= 0 { return err } // 在CHANGE记录后面加入点 if newRecord.Type == dnstypes.RecordTypeCNAME && !strings.HasSuffix(newRecord.Value, ".") { newRecord.Value += "." } var ttl = int64(newRecord.TTL) if ttl <= 0 { ttl = 600 } createResp, err := this.client.CreateRecord(context.Background(), &volcdns.CreateRecordRequest{ Host: this.stringVal(newRecord.Name), Line: this.stringVal(newRecord.Route), Remark: nil, TTL: &ttl, Type: this.stringVal(newRecord.Type), Value: this.stringVal(newRecord.Value), Weight: nil, ZID: &zoneId, }) if err != nil { return err } newRecord.Id = *createResp.RecordID // 加入缓存 if this.ProviderId > 0 { sharedDomainRecordsCache.AddDomainRecord(this.ProviderId, domain, newRecord) } return nil } // UpdateRecord 修改记录 func (this *VolcEngineProvider) UpdateRecord(domain string, record *dnstypes.Record, newRecord *dnstypes.Record) error { if record == nil { return errors.New("invalid record") } if newRecord == nil { return errors.New("invalid new record") } // 在CHANGE记录后面加入点 if newRecord.Type == dnstypes.RecordTypeCNAME && !strings.HasSuffix(newRecord.Value, ".") { newRecord.Value += "." } var ttl = int64(newRecord.TTL) if ttl <= 0 { ttl = 600 } _, err := this.client.UpdateRecord(context.Background(), &volcdns.UpdateRecordRequest{ Host: newRecord.Name, Line: newRecord.Route, RecordID: record.Id, TTL: &ttl, Type: this.stringVal(newRecord.Type), Value: this.stringVal(newRecord.Value), Weight: nil, }) if err != nil { return err } // 修改缓存 if this.ProviderId > 0 { sharedDomainRecordsCache.UpdateDomainRecord(this.ProviderId, domain, newRecord) } return nil } // DeleteRecord 删除记录 func (this *VolcEngineProvider) DeleteRecord(domain string, record *dnstypes.Record) error { if record == nil { return errors.New("invalid record to delete") } err := this.client.DeleteRecord(context.Background(), &volcdns.DeleteRecordRequest{ RecordID: this.stringVal(record.Id), }) if err != nil { // ignore not found error var topErr *volcdns.TOPError if errors.As(err, &topErr) && topErr != nil && topErr.Code == "ErrDBNotFound" { return nil } return err } // 删除缓存 if this.ProviderId > 0 { sharedDomainRecordsCache.DeleteDomainRecord(this.ProviderId, domain, record.Id) } return nil } // DefaultRoute 默认线路 func (this *VolcEngineProvider) DefaultRoute() string { return "default" } func (this *VolcEngineProvider) getZoneId(domain string) (int64, error) { zonesResp, err := this.client.ListZones(context.Background(), &volcdns.ListZonesRequest{ Key: &domain, SearchMode: this.stringVal("exact"), }) if err != nil || zonesResp == nil || len(zonesResp.Zones) == 0 { return 0, err } var zoneId int64 for _, zone := range zonesResp.Zones { if *zone.ZoneName == domain { zoneId = *zone.ZID break } } return zoneId, nil } func (this *VolcEngineProvider) stringVal(s string) *string { return &s } func (this *VolcEngineProvider) int64Val(i int64) *string { return this.stringVal(types.String(i)) } func (this *VolcEngineProvider) intVal(i int) *string { return this.stringVal(strconv.Itoa(i)) } func (this *VolcEngineProvider) fixCNAME(recordType *string, recordValue *string) string { // 修正Record if strings.ToUpper(*recordType) == dnstypes.RecordTypeCNAME && !strings.HasSuffix(*recordValue, ".") { recordValue = this.stringVal(*recordValue + ".") } return *recordValue }