// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn . //go:build plus package oss import ( "errors" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ossconfigs" teaconst "github.com/TeaOSLab/EdgeNode/internal/const" "github.com/TeaOSLab/EdgeNode/internal/utils/fasttime" "github.com/TeaOSLab/EdgeNode/internal/utils/goman" "net/http" "sync" "time" ) var SharedManager = NewManager() func init() { if !teaconst.IsMain { return } var ticker = time.NewTicker(24 * time.Hour) goman.New(func() { for range ticker.C { SharedManager.GC() } }) } type Manager struct { providerMap map[string]*Provider // unique id => *Provider locker sync.RWMutex } func NewManager() *Manager { return &Manager{ providerMap: map[string]*Provider{}, } } func (this *Manager) FindProviderWithConfig(req *http.Request, host string, ossConfig *ossconfigs.OSSConfig) (provider *Provider, objectBucketName string, objectKey string, err error) { if ossConfig == nil { return nil, "", "", errors.New("provider 'config' should not be nil") } var originOptions = ossConfig.Options if originOptions == nil { return nil, "", "", errors.New("provider 'options' should not be nil") } options, ok := originOptions.(ossconfigs.OSSOptions) if !ok { return nil, "", "", errors.New("provider 'options' should implement 'OSSOptions' interface") } bucketName, key, uniqueId := ossConfig.ParseRequest(req, host) if len(bucketName) == 0 || len(key) == 0 { err = errNotFound return } objectBucketName = bucketName objectKey = key // 查询已有 this.locker.RLock() provider, ok = this.providerMap[uniqueId] if ok { this.locker.RUnlock() return provider, bucketName, objectKey, nil } this.locker.RUnlock() this.locker.Lock() defer this.locker.Unlock() // 再次查询 provider, ok = this.providerMap[uniqueId] if ok { return provider, bucketName, objectKey, nil } var rawProvider ProviderInterface switch ossConfig.Type { case ossconfigs.OSSTypeTencentCOS: rawProvider = NewTencentCOSProvider() case ossconfigs.OSSTypeAliyunOSS: rawProvider = NewAliyunOSSProvider() case ossconfigs.OSSTypeHuaweiOBS: rawProvider = NewHuaweiOBSProvider() case ossconfigs.OSSTypeBaiduBOS: rawProvider = NewBaiduBOSProvider() case ossconfigs.OSSTypeQiniuKodo: rawProvider = NewQiniuKodoProvider() case ossconfigs.OSSTypeAmazonS3: rawProvider = NewAmazonS3Provider() case ossconfigs.OSSTypeB2: rawProvider = NewB2Provider() default: return nil, "", "", errors.New("invalid provider '" + ossConfig.Type + "'") } if rawProvider == nil { return nil, "", "", errors.New("invalid provider '" + ossConfig.Type + "'") } // 包装 provider = NewProvider(rawProvider) provider.SetUniqueId(uniqueId) // 初始化 err = provider.Init(options, bucketName) if err != nil { return nil, "", "", err } // 放入缓存 this.providerMap[uniqueId] = provider return } func (this *Manager) Head(req *http.Request, host string, ossConfig *ossconfigs.OSSConfig) (resp *http.Response, nativeErrCode string, nativeBucketName string, err error) { if ossConfig == nil { return nil, "", "", errors.New("provider config should not be nil") } provider, bucketName, key, err := this.FindProviderWithConfig(req, host, ossConfig) if err != nil { return nil, "", "", err } nativeBucketName = bucketName resp, nativeErrCode, err = provider.Head(key) if err == errNoBucket { this.locker.Lock() delete(this.providerMap, provider.UniqueId()) this.locker.Unlock() } return } func (this *Manager) Get(req *http.Request, host string, ossConfig *ossconfigs.OSSConfig) (resp *http.Response, nativeErrCode string, nativeBucketName string, err error) { if ossConfig == nil { return nil, "", "", errors.New("provider config should not be nil") } provider, bucketName, key, err := this.FindProviderWithConfig(req, host, ossConfig) if err != nil { return nil, "", "", err } nativeBucketName = bucketName resp, nativeErrCode, err = provider.Get(key) if err == errNoBucket { this.locker.Lock() delete(this.providerMap, provider.UniqueId()) this.locker.Unlock() } return } func (this *Manager) GetRange(req *http.Request, host string, bytesRange string, ossConfig *ossconfigs.OSSConfig) (resp *http.Response, nativeErrCode string, nativeBucketName string, err error) { if ossConfig == nil { return nil, "", "", errors.New("provider config should not be nil") } provider, bucketName, key, err := this.FindProviderWithConfig(req, host, ossConfig) if err != nil { return nil, "", "", err } nativeBucketName = bucketName resp, nativeErrCode, err = provider.GetRange(key, bytesRange) if err == errNoBucket { this.locker.Lock() delete(this.providerMap, provider.UniqueId()) this.locker.Unlock() } return } func (this *Manager) GC() { // 查询 this.locker.RLock() if len(this.providerMap) < 1024 { // 如果数量很小,则不需要清理 this.locker.RUnlock() return } // 查询"过期"的Provider实例 var expiredKeys = []string{} var currentTime = fasttime.Now().Unix() for key, provider := range this.providerMap { if provider.UpdatedAt < currentTime-86400 { expiredKeys = append(expiredKeys, key) } } this.locker.RUnlock() // 删除 if len(expiredKeys) > 0 { this.locker.Lock() for _, key := range expiredKeys { delete(this.providerMap, key) } this.locker.Unlock() } }