Files
2026-02-04 20:27:13 +08:00

217 lines
5.3 KiB
Go

// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build plus
package oss
import (
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ossconfigs"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
"github.com/TeaOSLab/EdgeNode/internal/utils/goman"
"net/http"
"sync"
"time"
)
var SharedManager = NewManager()
func init() {
if !teaconst.IsMain {
return
}
var ticker = time.NewTicker(24 * time.Hour)
goman.New(func() {
for range ticker.C {
SharedManager.GC()
}
})
}
type Manager struct {
providerMap map[string]*Provider // unique id => *Provider
locker sync.RWMutex
}
func NewManager() *Manager {
return &Manager{
providerMap: map[string]*Provider{},
}
}
func (this *Manager) FindProviderWithConfig(req *http.Request, host string, ossConfig *ossconfigs.OSSConfig) (provider *Provider, objectBucketName string, objectKey string, err error) {
if ossConfig == nil {
return nil, "", "", errors.New("provider 'config' should not be nil")
}
var originOptions = ossConfig.Options
if originOptions == nil {
return nil, "", "", errors.New("provider 'options' should not be nil")
}
options, ok := originOptions.(ossconfigs.OSSOptions)
if !ok {
return nil, "", "", errors.New("provider 'options' should implement 'OSSOptions' interface")
}
bucketName, key, uniqueId := ossConfig.ParseRequest(req, host)
if len(bucketName) == 0 || len(key) == 0 {
err = errNotFound
return
}
objectBucketName = bucketName
objectKey = key
// 查询已有
this.locker.RLock()
provider, ok = this.providerMap[uniqueId]
if ok {
this.locker.RUnlock()
return provider, bucketName, objectKey, nil
}
this.locker.RUnlock()
this.locker.Lock()
defer this.locker.Unlock()
// 再次查询
provider, ok = this.providerMap[uniqueId]
if ok {
return provider, bucketName, objectKey, nil
}
var rawProvider ProviderInterface
switch ossConfig.Type {
case ossconfigs.OSSTypeTencentCOS:
rawProvider = NewTencentCOSProvider()
case ossconfigs.OSSTypeAliyunOSS:
rawProvider = NewAliyunOSSProvider()
case ossconfigs.OSSTypeHuaweiOBS:
rawProvider = NewHuaweiOBSProvider()
case ossconfigs.OSSTypeBaiduBOS:
rawProvider = NewBaiduBOSProvider()
case ossconfigs.OSSTypeQiniuKodo:
rawProvider = NewQiniuKodoProvider()
case ossconfigs.OSSTypeAmazonS3:
rawProvider = NewAmazonS3Provider()
case ossconfigs.OSSTypeB2:
rawProvider = NewB2Provider()
default:
return nil, "", "", errors.New("invalid provider '" + ossConfig.Type + "'")
}
if rawProvider == nil {
return nil, "", "", errors.New("invalid provider '" + ossConfig.Type + "'")
}
// 包装
provider = NewProvider(rawProvider)
provider.SetUniqueId(uniqueId)
// 初始化
err = provider.Init(options, bucketName)
if err != nil {
return nil, "", "", err
}
// 放入缓存
this.providerMap[uniqueId] = provider
return
}
func (this *Manager) Head(req *http.Request, host string, ossConfig *ossconfigs.OSSConfig) (resp *http.Response, nativeErrCode string, nativeBucketName string, err error) {
if ossConfig == nil {
return nil, "", "", errors.New("provider config should not be nil")
}
provider, bucketName, key, err := this.FindProviderWithConfig(req, host, ossConfig)
if err != nil {
return nil, "", "", err
}
nativeBucketName = bucketName
resp, nativeErrCode, err = provider.Head(key)
if err == errNoBucket {
this.locker.Lock()
delete(this.providerMap, provider.UniqueId())
this.locker.Unlock()
}
return
}
func (this *Manager) Get(req *http.Request, host string, ossConfig *ossconfigs.OSSConfig) (resp *http.Response, nativeErrCode string, nativeBucketName string, err error) {
if ossConfig == nil {
return nil, "", "", errors.New("provider config should not be nil")
}
provider, bucketName, key, err := this.FindProviderWithConfig(req, host, ossConfig)
if err != nil {
return nil, "", "", err
}
nativeBucketName = bucketName
resp, nativeErrCode, err = provider.Get(key)
if err == errNoBucket {
this.locker.Lock()
delete(this.providerMap, provider.UniqueId())
this.locker.Unlock()
}
return
}
func (this *Manager) GetRange(req *http.Request, host string, bytesRange string, ossConfig *ossconfigs.OSSConfig) (resp *http.Response, nativeErrCode string, nativeBucketName string, err error) {
if ossConfig == nil {
return nil, "", "", errors.New("provider config should not be nil")
}
provider, bucketName, key, err := this.FindProviderWithConfig(req, host, ossConfig)
if err != nil {
return nil, "", "", err
}
nativeBucketName = bucketName
resp, nativeErrCode, err = provider.GetRange(key, bytesRange)
if err == errNoBucket {
this.locker.Lock()
delete(this.providerMap, provider.UniqueId())
this.locker.Unlock()
}
return
}
func (this *Manager) GC() {
// 查询
this.locker.RLock()
if len(this.providerMap) < 1024 { // 如果数量很小,则不需要清理
this.locker.RUnlock()
return
}
// 查询"过期"的Provider实例
var expiredKeys = []string{}
var currentTime = fasttime.Now().Unix()
for key, provider := range this.providerMap {
if provider.UpdatedAt < currentTime-86400 {
expiredKeys = append(expiredKeys, key)
}
}
this.locker.RUnlock()
// 删除
if len(expiredKeys) > 0 {
this.locker.Lock()
for _, key := range expiredKeys {
delete(this.providerMap, key)
}
this.locker.Unlock()
}
}